https://atcoder.jp/contests/abc222/tasks/abc222_e
最短経路が辺を通る回数をとすると、母関数を考えて、
このの係数を求めればよいです。と置いて、
として、と置けば、
なので、を計算しての係数を求めればよいです。
同じの重複度をと書くと、
これを最初分割統治法的に求めたら、すごく遅いんですね。ランダムな経路で計算したら14秒でした。各多項式が「疎な」多項式なので順に積を求める方が速いんですね。そのようにしたら1.1秒になりました。
さらに、あとに掛けるほうが項数が少ない方が速くなるはずなので、でソートして掛けると、0.8秒になりました。
以下は、PyPy2のコードです。
# coding: utf-8 # Red and Blue Tree from collections import defaultdict, Counter from itertools import count #################### library #################### def read_int(): return int(raw_input()) def read_tuple(): return tuple(map(int, raw_input().split())) def read_list(): return list(map(int, raw_input().split())) #################### polynomial #################### def poly_mul(f, g, d, D): L1, L2 = len(f), len(g) L3 = L1+(L2-1)*d h = [0] * L3 for k in range(L2): for l in range(L1): h[l+k*d] += f[l]*g[k] return [ int(c % D) for c in h ] def poly_pow(f, e, D): if e == 1: return f elif e%2 == 1: return poly_mul(f, poly_pow(f, e-1, D), D) else: g = poly_pow(f, e/2, D) return poly_mul(g, g, D) def binomial(e): v = [1] c = 1 for k in range(1, e+1): c = c * (e - k + 1) / k v.append(c) return v #################### naive #################### #################### process #################### def read_input(): N, M, K = read_tuple() As = read_list() edges = [ read_tuple() for _ in range(N-1) ] return (As, edges, K) def make_tree(edges): graph = defaultdict(list) for U, V in edges: graph[U].append(V) graph[V].append(U) return graph def shortest_path(v, w, graph): if v == w: return [v] a = [v] paths = { v: [v] } for d in count(1): b = [] for v1 in a: for v2 in graph[v1]: if v2 in paths: continue paths[v2] = paths[v1] + [v2] if v2 == w: return paths[w] b.append(v2) a = b def search_all_path(seq, graph): all_path = [] for v, w in zip(seq, seq[1:]): path = shortest_path(v, w, graph) if all_path: all_path.extend(path[1:]) else: all_path.extend(path) return all_path def F(seq, graph, K): path = search_all_path(seq, graph) if len(path) % 2 == K % 2: return 0 counter = Counter() for v1, v2 in zip(path, path[1:]): counter[(min(v1, v2), max(v1, v2))] += 1 counter2 = Counter() for freq in counter.values(): counter2[freq] += 1 ens = counter2.items() s = sum(n for e, n in ens) f = [1] for d, e in sorted(ens, key=lambda (d, e): e, reverse=True): b = binomial(e) f = poly_mul(f, b, d, D) k = (len(path)-1+K)/2 if 0 <= k < len(f): return f[k] * pow(2, len(graph) - s - 1, D) % D else: return 0 #################### main #################### D = 998244353 seq, edges, K = read_input() graph = make_tree(edges) print F(seq, graph, K)