AtCoder Beginner Contest 234 G

https://atcoder.jp/contests/abc234/tasks/abc234_g

 O(N^2)でよければ、
 dp_0 = 1
 dp_n = \sum_{k=0}^{n-1}{dp_k(\max{(A_{k+1}, \cdots A_n)} - \min{(A_{k+1}, \cdots A_n)})}
とすればよいです。簡単ですね。ただ、この問題では N \le 3 \times 10^5なので、そうはいかないです。

解説を見ると、maxとminを分けて計算すると。なるほど。
ただ、解説文を読んでもよくわからなかったので、例えばこういう例で考えてみました。

 A = [4, 2, 5, 2, 3]
 dp_0 = 1
 dpmax_1 = 4dp_0 = 4
 dpmin_1 = 4dp_0 = 4
 dp_1 = dpmax_1 - dpmin_1 = 0
 dpmax_2 = 4dp_0 + 2dp_1 = dpmax_1 + 2dp_1 = 4
 dpmin_2 = 2(dp_0 + dp_1) = 2
 dp_2 = dpmax_2 - dpmin_2 = 2
 dpmax_3 = 5(dp_0 + dp_1 + dp_2) = 15
 dpmin_3 = 2(dp_0 + dp_1) + 5dp_2 = dpmin_2 + 5dp_2 = 12
 dp_3 = dpmax_3 - dpmin_3 = 3
 dpmax_4 = 5(dp_0 + dp_1 + dp_2) + 2dp_3 = dpmax_3 + 2dp_3 = 21
 dpmin_4 = 2(dp_0 + dp_1 + dp_2 + dp_3) = 12
 dp_4 = dpmax_4 - dpmin_4 = 9
 dpmax_5 = 5(dp_0 + dp_1 + dp_2) + 3dp_3 + 2dp_4 = dpmax_3 + 3(dp_3 + dp_4) = 51
 dpmin_5 = 2(dp_0 + dp_1 + dp_2 + dp_3) + 3dp_4 = dpmin_4 + 3dp_4 = 39
 dp_5 = dpmax_5 - dpmin_5 = 12

dpmaxだけ見てみると、Aの今の値より大きい場所を探して、無かったら、
 dpmax_n = A_n(dp_0 + \cdots + dp_{n-1})
あったらその位置をkとして、
 dpmax_n = dpmax_k + A_n(dp_k + \cdots + dp_{n-1})
となります。今の値より大きい場所を線形探索すると計算量が大きくなりそうですが、一度探索したらその値は用なしで消せばいいので、結局トータルとしては O(N)となります。

# coding: utf-8
# Divide a Sequence

from itertools import *
import sys


#################### library ####################

def read_int():
    return int(raw_input())

def read_list():
    return list(map(int, raw_input().split()))


#################### process ####################

def read_input():
    N = read_int()
    A = read_list()
    return A

def F(A):
    def find_nearest_max(m, i, a):
        for j in xrange(len(a)-1, -1, -1):
            value, index = a[j]
            if value > m:
                a.append((m, i))
                return index
            else:
                a.pop()
        else:
            a.append((m, i))
            return -1
    
    def find_nearest_min(m, i, a):
        for j in xrange(len(a)-1, -1, -1):
            value, index = a[j]
            if value < m:
                a.append((m, i))
                return index
            else:
                a.pop()
        else:
            a.append((m, i))
            return -1
    
    N = len(A)
    dp = [1] * (N + 1)
    sdp = [1] * (N + 1)
    dp_max = [0] * (N + 1)
    dp_min = [0] * (N + 1)
    a_max = []  # [(value, index)]
    a_min = []
    for n in range(1, N+1):
        i_max = find_nearest_max(A[n-1], n-1, a_max)
        if i_max == -1:
            dp_max[n] = A[n-1] * sdp[n-1] % D
        else:
            dp_max[n] = (dp_max[i_max+1] + A[n-1] * (sdp[n-1] - sdp[i_max])) % D
        i_min = find_nearest_min(A[n-1], n-1, a_min)
        if i_min == -1:
            dp_min[n] = A[n-1] * sdp[n-1] % D
        else:
            dp_min[n] = (dp_min[i_min+1] + A[n-1] * (sdp[n-1] - sdp[i_min])) % D
        dp[n] = dp_max[n] - dp_min[n]
        sdp[n] = (sdp[n-1] + dp[n]) % D
    return dp[N]

import random
def random_input(N):
    return [ random.randrange(1, 10**9) for _ in range(N) ]


#################### main ####################

D = 998244353
A = read_input()
print F(A) % D