https://atcoder.jp/contests/abc248/tasks/abc248_c
でなく、とした方が書きやすいので、そうしましょう。そして、とします。
パット見、母関数ですが、C問題で母関数のはずないと思って考え直すと、単なるDPですね。
def update(dp, M, L): new_dp = [0] * (L + 1) for i, a in enumerate(dp): for j in range(M): k = i + j if k > L: break new_dp[k] += a return new_dp def F(N, M, K): L = K - N dp = [1] for _ in range(N): dp = update(dp, M, L) return sum(dp) % D
計算量は、です。
母関数はなので、
def F(N, M, K): L = K - N f = [1] * min(M, L + 1) # poly_powは最後に g = poly_pow(f, N, L, D) return sum(g) % D
計算量はふつうに書くとで、Karatsubaを使えば2乗のところが約1.6乗になります。
さて、
を使うと、
だから、
です。
分母(の逆数)はです。N=2なら、N=3ならとなります。
N=3のの係数はだと容易に推察されますよね。これは、2個の中から重複を許してk個選ぶ組合せになるからです。です。
一般にの係数はとなります。の係数まで必要なので、分母の計算量はです。
分子は単なる二項係数なので、の係数はです。分子の計算量はです。
これらを掛け算すればよいです。計算量はとなります。とすると、O(M)倍速くなりました。
N, M, K = 1000, 1000, 1000000としたところ、母関数でKaratsubaを使った方法で132秒、速くした方法で14秒でした。計算量はおよそとですが、Karatubaは定数倍が大きいので差が付くのではないかと思います。
こういう工夫は単なる多項式の計算にしたから思いつくのです。母関数の力が発揮されたと言えるでしょう。
# coding: utf-8 # Dice Sum from itertools import count #################### library #################### def read_tuple(): return tuple(map(int, raw_input().split())) # ax = by + c def linear_diophantine(a, b, c): def f(a, b, c): if a == 1: return (b + c, 1) elif a == -1: return (-b - c, 1) d = b / a r = b % a t = f(r, a, -c) return (t[1] + d * t[0], t[0]) return f(a, b, c) def inverse(n, p): x, y = linear_diophantine(n, p, 1) return x #################### polynomial #################### def poly_add(f, g): if len(f) < len(g): return poly_add(g, f) h = f[:] for k in xrange(len(g)): h[k] += g[k] return h def poly_sub(f, g): h = f[:] for k, a in enumerate(g): if k < len(h): h[k] -= g[k] else: h.append(-a) return h def poly_mul(f, g, L, D): L1, L2 = len(f), len(g) L3 = min(L1+L2-1, L+1) h = [0] * L3 if len(f) < 20 or len(g) < 20: for k in range(L1): for l in range(min(L2, L+1-k)): h[k+l] += f[k]*g[l] return [ int(c % D) for c in h ] else: # Karatsuba algorithm mid = min(len(f) / 2, len(g) / 2) f1 = f[:mid] f2 = f[mid:] g1 = g[:mid] g2 = g[mid:] h1 = poly_mul(f1, g1, L, D) h2 = poly_mul(f2, g2, L, D) h3 = poly_sub(poly_add(h2, h1), poly_mul(poly_sub(f2, f1), poly_sub(g2, g1), L, D)) for k, a in enumerate(h1): h[k] += a for k, a in enumerate(h3, mid): if k > L3: break h[k] += a for k, a in enumerate(h2, mid * 2): if k > L3: break elif k < len(h): h[k] += a else: h.append(a) return [ int(c % D) for c in h ] def poly_pow(f, e, L, D): if e == 1: return f elif e%2 == 1: return poly_mul(f, poly_pow(f, e-1, L, D), L, D) else: g = poly_pow(f, e/2, L, D) return poly_mul(g, g, L, D) #################### process #################### def read_input(): N, M, K = read_tuple() return (N, M, K) def update(dp, M, L): new_dp = [0] * (L + 1) for i, a in enumerate(dp): for j in range(M): k = i + j if k > L: break new_dp[k] += a return new_dp def F_gf(N, M, K): L = K - N f = [1] * min(M, L + 1) g = poly_pow(f, N, L, D) return sum(g) % D # 1/(1-x)^N def make_inv_poly(N, L): f = [1] * (L + 1) for k in range(1, L+1): f[k] = f[k-1] * (N+k-1) * inverse(k, D) % D return f # (1-x)^N def make_binomial(N, M, L): g = [1] * (L/M+1) for k in range(1, L/M+1): g[k] = -g[k-1] * (N-k+1) * inverse(k, D) % D return g def F(N, M, K): L = K - N f = make_inv_poly(N, L) g = make_binomial(N, M, L) s = 0 for i in range(L/M+1): for j in range(L - M*i + 1): s += g[i] * f[j] return s % D #################### main #################### import time D = 998244353 #N, M, K = read_input() N, M, K = 1000, 1000, 1000000 t0 = time.clock() print F_gf(N, M, K) print time.clock() - t0 t0 = time.clock() print F(N, M, K) print time.clock() - t0