AtCoder Beginner Contest 231 G

https://atcoder.jp/contests/abc231/tasks/abc231_g

最初の例で、2×2×3 + 1×3×3 + 1×2×4の計算をしていますが、これは母関数を考えるとすぐに出てきます。
 (1+2x)(2+3x)(3+4x)のxの係数を見ればよいです。
 (1+2x)(2+3x)(3+4x) = 6+29x+46x^2+24x^3
だから、2×2×3 + 1×3×3 + 1×2×4 = 29です。

2番目の例を見てみましょう。同様に考えると、
 (1+2x+3x^2)(2+3x+4x^2) = 12+17x+16x^2+7x^3+2x^4
なので、 x^2の係数の16となりますが、22になるはずです。
これは、箱1→箱2と選ぶのと箱2→箱1と選ぶのは別カウントだからです。ボールをK個入れるとして箱が2つあって箱1にn個入れるなら、 _KC_n通りあります。

こういう時は、指数型母関数を使うとよいです。こういう形です。
 \displaystyle \sum_k{a_k\frac{x^k}{k!}}
例えば、箱1に3個、箱2に2個ボールを入れるとすると、
 a_3\frac{x^3}{3!} \times b_2\frac{x^2}{2!} = a_3b_2\frac{x^5}{3!2!} = a_3b_2\frac{5!}{3!2!}\frac{x^5}{5!}
となり、ちょうど _5C_3通りあることと一致します。

Kがいくら大きくてもいいように、箱iの項を \displaystyle \sum_{k=0}^{\infty}{(A_i+k)\frac{x^k}{k!}}とすると、これは、
 \displaystyle \sum_k{A_i\frac{x^k}{k!}} + \sum_k{\frac{x^k}{(k-1)!}} = \sum_k{A_i\frac{x^k}{k!}} + \sum_k{x\frac{x^{k-1}}{(k-1)!}} = (A_i+x)e^x
となるので、全ての箱の積は、
 \displaystyle e^{Nx}\prod_i{(A_i+x)}
となります。まず、第2項を計算すると、N次多項式なので、割り算を除くと O(N^2)で計算できます。第1項は無限級数ですが、K次の項だけ欲しいので、{第1項のK次の項}×{第2項の0次の項} + {第1項の(K-1)次の項}×{第2項の1次の項} + ...とすれば O(N)で済みます。

# coding: utf-8
# Balls in Boxes

from itertools import *
import sys


#################### library ####################

def read_tuple():
    return tuple(map(int, raw_input().split()))

def read_list():
    return list(map(int, raw_input().split()))

# ax = by + c
def linear_diophantine(a, b, c):
    def f(a, b, c):
        if a == 1:
            return (b + c, 1)
        elif a == -1:
            return (-b - c, 1)
        
        d = b / a
        r = b % a
        t = f(r, a, -c)
        return (t[1] + d * t[0], t[0])
    
    return f(a, b, c)

def inverse(n, p):
    x, y = linear_diophantine(n, p, 1)
    return x


#################### polynomial ####################

def poly_mul(f, g, D):
    L1, L2 = len(f), len(g)
    L3 = L1+L2-1
    h = [0] * L3
    for k in range(L1):
        for l in range(L2):
            h[k+l] += f[k]*g[l]
    return [ int(c % D) for c in h ]


#################### process ####################

def read_input():
    N, K = read_tuple()
    A = read_list()
    return (K, A)

def F(K, A):
    def mul(fs):
        if len(fs) == 1:
            return fs[0]
        
        mid = len(fs) / 2
        f1 = mul(fs[:mid])
        f2 = mul(fs[mid:])
        return poly_mul(f1, f2, D)
    
    N = len(A)
    f = mul([[a, 1] for a in A])
    
    c = pow(N, K, D)
    cs = [c]
    L = min(K, N)
    for k in range(1, L+1):
        c = cs[-1] * (K-k+1) * inverse(N, D) % D
        cs.append(c)
    
    s = 0
    for k in range(L+1):
        s += cs[k] * f[k]
    return s * inverse(pow(N, K, D), D) % D


#################### main ####################

D = 998244353
K, A = read_input()
print F(K, A)