AtCoder Beginner Contest 234 F

https://atcoder.jp/contests/abc234/tasks/abc234_f

例えば、S = 'aabbcc'を考えると、6文字ではaの場所を6つの中から2つ選んで、bの場所を残りの4つの中から2つ選ぶので、

 \displaystyle _6C_4 \times _4C_2 = \frac{6!}{2!4!}\frac{4!}{2!2!} = \frac{6!}{2!2!2!}

となります。指数型母関数をこうします。

 \displaystyle G(x) = g_2(x)g_2(x)g_2(x)-1
 \displaystyle g_n(x) = \sum_{i=0}^{n}{\frac{1}{i!}x^i}

長さ6の文字列の個数は \frac{x^6}{6!}の係数に対応して、 \frac{6!}{2!2!2!}となり直接場合の数を考えた場合と一致します。定数項は長さ0の文字列に対応しているので、その分は引いておきます。
同じ文字が出てくる回数nの文字の数をσ(n)とすると、母関数をこう書くと速くなります。

 \displaystyle G(x) = \prod_{n}{g_n(x)^{\sigma(n)}}-1

例えば、S = 'aabcc'なら、 G(x) = g_2(x)^2g_1(x)-1となります。

# coding: utf-8
# Reordering

from itertools import *
from collections import Counter
import sys


#################### library ####################

# ax = by + c
def linear_diophantine(a, b, c):
    def f(a, b, c):
        if a == 1:
            return (b + c, 1)
        elif a == -1:
            return (-b - c, 1)
        
        d = b / a
        r = b % a
        t = f(r, a, -c)
        return (t[1] + d * t[0], t[0])
    
    return f(a, b, c)

def inverse(n, p):
    x, y = linear_diophantine(n, p, 1)
    return x


#################### polynomial ####################

def poly_mul(f, g, D):
    L1, L2 = len(f), len(g)
    L3 = L1+L2-1
    h = [0] * L3
    for k in range(L1):
        for l in range(L2):
            h[k+l] += f[k]*g[l]
    return [ int(c % D) for c in h ]

def poly_pow(f, e, D):
    if e == 1:
        return f
    elif e%2 == 1:
        return poly_mul(f, poly_pow(f, e-1, D), D)
    else:
        g = poly_pow(f, e/2, D)
        return poly_mul(g, g, D)


#################### process ####################

def read_input():
    S = raw_input()
    return S

def F(S, D):
    def make_poly(n):
        f = [1]
        a = 1
        for k in range(1, n+1):
            a = a * inverse(k, D) % D
            f.append(a)
        return f
    
    c = Counter(S)
    c2 = Counter(c.values())
    f = reduce(lambda x, (n, e): poly_mul(x, poly_pow(make_poly(n), e, D), D),
                                                                c2.items(), [1])
    
    s = 0
    a = 1
    for i in range(1, len(f)):
        a = a * i % D
        s += f[i] * a
    return s % D


#################### main ####################

D = 998244353
S = read_input()
print F(S, D)