AtCoder Beginner Contest 222 E

https://atcoder.jp/contests/abc222/tasks/abc222_e

最短経路が辺 iを通る回数を d_iとすると、母関数 Gを考えて、
 \displaystyle G(x) = \prod_i{(x^{-d_i}+x^{d_i})}
この x^Kの係数を求めればよいです。 s = \sum_i{d_i}と置いて、
 \displaystyle G(x) = x^{-s}\prod_i{(1+x^{2d_i})}
 X = x^2として、 \displaystyle G_1(X) = \prod_i{(1+X^{d_i})}と置けば、
 G(x) = x^{-s}G_1(X)
なので、 G_1(X)を計算して X^{\frac{K+s}{2}}の係数を求めればよいです。

同じ dの重複度を dup(d)と書くと、
 \displaystyle G_1(X) = \prod_d{(1+X^d)^{dup(d)}}

これを最初分割統治法的に求めたら、すごく遅いんですね。ランダムな経路で計算したら14秒でした。各多項式が「疎な」多項式なので順に積を求める方が速いんですね。そのようにしたら1.1秒になりました。
さらに、あとに掛けるほうが項数が少ない方が速くなるはずなので、 dup(d)でソートして掛けると、0.8秒になりました。

以下は、PyPy2のコードです。

# coding: utf-8
# Red and Blue Tree

from collections import defaultdict, Counter
from itertools import count


#################### library ####################

def read_int():
    return int(raw_input())

def read_tuple():
    return tuple(map(int, raw_input().split()))

def read_list():
    return list(map(int, raw_input().split()))


#################### polynomial ####################

def poly_mul(f, g, d, D):
    L1, L2 = len(f), len(g)
    L3 = L1+(L2-1)*d
    h = [0] * L3
    for k in range(L2):
        for l in range(L1):
            h[l+k*d] += f[l]*g[k]
    return [ int(c % D) for c in h ]

def poly_pow(f, e, D):
    if e == 1:
        return f
    elif e%2 == 1:
        return poly_mul(f, poly_pow(f, e-1, D), D)
    else:
        g = poly_pow(f, e/2, D)
        return poly_mul(g, g, D)

def binomial(e):
    v = [1]
    c = 1
    for k in range(1, e+1):
        c = c * (e - k + 1) / k
        v.append(c)
    return v


#################### naive ####################


#################### process ####################

def read_input():
    N, M, K = read_tuple()
    As = read_list()
    edges = [ read_tuple() for _ in range(N-1) ]
    return (As, edges, K)

def make_tree(edges):
    graph = defaultdict(list)
    for U, V in edges:
        graph[U].append(V)
        graph[V].append(U)
    return graph

def shortest_path(v, w, graph):
    if v == w:
        return [v]
    
    a = [v]
    paths = { v: [v] }
    for d in count(1):
        b = []
        for v1 in a:
            for v2 in graph[v1]:
                if v2 in paths:
                    continue
                paths[v2] = paths[v1] + [v2]
                if v2 == w:
                    return paths[w]
                b.append(v2)
        a = b

def search_all_path(seq, graph):
    all_path = []
    for v, w in zip(seq, seq[1:]):
        path = shortest_path(v, w, graph)
        if all_path:
            all_path.extend(path[1:])
        else:
            all_path.extend(path)
    return all_path

def F(seq, graph, K):
    path = search_all_path(seq, graph)
    if len(path) % 2 == K % 2:
        return 0
    
    counter = Counter()
    for v1, v2 in zip(path, path[1:]):
        counter[(min(v1, v2), max(v1, v2))] += 1
    
    counter2 = Counter()
    for freq in counter.values():
        counter2[freq] += 1
    
    ens = counter2.items()
    s = sum(n for e, n in ens)
    f = [1]
    for d, e in sorted(ens, key=lambda (d, e): e, reverse=True):
        b = binomial(e)
        f = poly_mul(f, b, d, D)
    k = (len(path)-1+K)/2
    if 0 <= k < len(f):
        return f[k] * pow(2, len(graph) - s - 1, D) % D
    else:
        return 0


#################### main ####################

D = 998244353
seq, edges, K = read_input()
graph = make_tree(edges)
print F(seq, graph, K)