AtCoder Beginner Contest 248 C

https://atcoder.jp/contests/abc248/tasks/abc248_c

 1 \le A_i \le Nでなく、 0 \le A_i \lt Nとした方が書きやすいので、そうしましょう。そして、 L \equiv K - Nとします。

パット見、母関数ですが、C問題で母関数のはずないと思って考え直すと、単なるDPですね。

def update(dp, M, L):
    new_dp = [0] * (L + 1)
    for i, a in enumerate(dp):
        for j in range(M):
            k = i + j
            if k > L:
                break
            new_dp[k] += a
    return new_dp

def F(N, M, K):
    L = K - N
    dp = [1]
    for _ in range(N):
        dp = update(dp, M, L)
    return sum(dp) % D

計算量は、 O(NML)です。
母関数は G(x) = (1 + x + \cdots + x^{M-1})^Nなので、

def F(N, M, K):
    L = K - N
    f = [1] * min(M, L + 1)
    # poly_powは最後に
    g = poly_pow(f, N, L, D)
    return sum(g) % D

計算量はふつうに書くと O(L^2(1+\log{\frac{NM}{L}}))で、Karatsubaを使えば2乗のところが約1.6乗になります。

さて、
 \displaystyle \frac{1}{1-x} = 1 + x + x^2 + \cdots
を使うと、
 \displaystyle 1 + x + \cdots + x^{M-1} = \frac{1-x^M}{1-x}
だから、
 G(x) = \frac{(1-x^M)^N}{(1-x)^N}
です。

分母(の逆数)は (1 + x + x^2 + \cdots)^Nです。N=2なら 1 + 2x + 3x^2 + \cdots、N=3なら 1 + 3x + 6x^2 + \cdotsとなります。
N=3の x^kの係数は _{k+1}C_2だと容易に推察されますよね。これは、2個の中から重複を許してk個選ぶ組合せになるからです。 _2H_k = _{k+1}C_2です。
一般に x^kの係数は _NH_k = _{k+N-1}C_Nとなります。 x^Lの係数まで必要なので、分母の計算量は O(L)です。

分子は単なる二項係数なので、 x^{Mk}の係数は (-1)^k_NC_kです。分子の計算量は O(L/M)です。
これらを掛け算すればよいです。計算量は O(L^2/M)となります。 L = (M - 1)Nとすると、O(M)倍速くなりました。

N, M, K = 1000, 1000, 1000000としたところ、母関数でKaratsubaを使った方法で132秒、速くした方法で14秒でした。計算量はおよそ O(N^{3.2}) O(N^3)ですが、Karatubaは定数倍が大きいので差が付くのではないかと思います。
こういう工夫は単なる多項式の計算にしたから思いつくのです。母関数の力が発揮されたと言えるでしょう。

# coding: utf-8
# Dice Sum

from itertools import count


#################### library ####################

def read_tuple():
    return tuple(map(int, raw_input().split()))

# ax = by + c
def linear_diophantine(a, b, c):
    def f(a, b, c):
        if a == 1:
            return (b + c, 1)
        elif a == -1:
            return (-b - c, 1)
        
        d = b / a
        r = b % a
        t = f(r, a, -c)
        return (t[1] + d * t[0], t[0])
    
    return f(a, b, c)

def inverse(n, p):
    x, y = linear_diophantine(n, p, 1)
    return x


#################### polynomial ####################

def poly_add(f, g):
    if len(f) < len(g):
        return poly_add(g, f)
    h = f[:]
    for k in xrange(len(g)):
        h[k] += g[k]
    return h

def poly_sub(f, g):
    h = f[:]
    for k, a in enumerate(g):
        if k < len(h):
            h[k] -= g[k]
        else:
            h.append(-a)
    return h

def poly_mul(f, g, L, D):
    L1, L2 = len(f), len(g)
    L3 = min(L1+L2-1, L+1)
    h = [0] * L3
    if len(f) < 20 or len(g) < 20:
        for k in range(L1):
            for l in range(min(L2, L+1-k)):
                h[k+l] += f[k]*g[l]
        return [ int(c % D) for c in h ]
    else:
        # Karatsuba algorithm
        mid = min(len(f) / 2, len(g) / 2)
        f1 = f[:mid]
        f2 = f[mid:]
        g1 = g[:mid]
        g2 = g[mid:]
        h1 = poly_mul(f1, g1, L, D)
        h2 = poly_mul(f2, g2, L, D)
        h3 = poly_sub(poly_add(h2, h1),
                        poly_mul(poly_sub(f2, f1), poly_sub(g2, g1), L, D))
        for k, a in enumerate(h1):
            h[k] += a
        for k, a in enumerate(h3, mid):
            if k > L3:
                break
            h[k] += a
        for k, a in enumerate(h2, mid * 2):
            if k > L3:
                break
            elif k < len(h):
                h[k] += a
            else:
                h.append(a)
        return [ int(c % D) for c in h ]

def poly_pow(f, e, L, D):
    if e == 1:
        return f
    elif e%2 == 1:
        return poly_mul(f, poly_pow(f, e-1, L, D), L, D)
    else:
        g = poly_pow(f, e/2, L, D)
        return poly_mul(g, g, L, D)


#################### process ####################

def read_input():
    N, M, K = read_tuple()
    return (N, M, K)

def update(dp, M, L):
    new_dp = [0] * (L + 1)
    for i, a in enumerate(dp):
        for j in range(M):
            k = i + j
            if k > L:
                break
            new_dp[k] += a
    return new_dp

def F_gf(N, M, K):
    L = K - N
    f = [1] * min(M, L + 1)
    g = poly_pow(f, N, L, D)
    return sum(g) % D

# 1/(1-x)^N
def make_inv_poly(N, L):
    f = [1] * (L + 1)
    for k in range(1, L+1):
        f[k] = f[k-1] * (N+k-1) * inverse(k, D) % D
    return f

# (1-x)^N
def make_binomial(N, M, L):
    g = [1] * (L/M+1)
    for k in range(1, L/M+1):
        g[k] = -g[k-1] * (N-k+1) * inverse(k, D) % D
    return g

def F(N, M, K):
    L = K - N
    f = make_inv_poly(N, L)
    g = make_binomial(N, M, L)
    
    s = 0
    for i in range(L/M+1):
        for j in range(L - M*i + 1):
            s += g[i] * f[j]
    return s % D


#################### main ####################

import time
D = 998244353
#N, M, K = read_input()
N, M, K = 1000, 1000, 1000000
t0 = time.clock()
print F_gf(N, M, K)
print time.clock() - t0
t0 = time.clock()
print F(N, M, K)
print time.clock() - t0