競プロ典型 005(2)

https://atcoder.jp/contests/typical90/tasks/typical90_e

多項式の掛け算に時間がかかっているので、Karatsuba法を実装してみました。
最長210msが66msになりました。

// Restricted Digits
#![allow(non_snake_case)]

use std::cmp::{min, max};


//////////////////// constances ////////////////////

const D: i64 = 10i64.pow(9) + 7;


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}

fn pow(n: usize, e: usize, d: usize) -> usize {
    if e == 1 {
        n
    }
    else if e % 2 == 1 {
        n * pow(n, e-1, d) % d
    }
    else {
        let p = pow(n, e/2, d);
        p * p % d
    }
}


//////////////////// Polynomial ////////////////////

struct Polynomial {
    cs: Vec<i64>
}

use std::ops::{Add, Sub, Mul};

impl Add for Polynomial {
    type Output = Self;
    
    fn add(self, other: Self) -> Self {
        let N = max(self.len(), other.len());
        let mut cs: Vec<i64> = vec![0; N];
        for i in 0..self.cs.len() {
            cs[i] += self.cs[i]
        }
        for i in 0..other.cs.len() {
            cs[i] += other.cs[i]
        }
        Self { cs }
    }
}

impl Sub for Polynomial {
    type Output = Self;
    
    fn sub(self, other: Self) -> Self {
        let N = max(self.len(), other.len());
        let mut cs: Vec<i64> = vec![0; N];
        for i in 0..self.cs.len() {
            cs[i] += self.cs[i]
        }
        for i in 0..other.cs.len() {
            cs[i] -= other.cs[i]
        }
        Self { cs }
    }
}

impl Mul for Polynomial {
    type Output = Self;
    
    fn mul(self, other: Self) -> Self {
        let cs = Polynomial::mul(&self.cs[..], &other.cs[..]);
        Self { cs }
    }
}

impl Polynomial {
    fn len(&self) -> usize {
        self.cs.len()
    }
    
    fn add(f: &[i64], g: &[i64]) -> Vec<i64> {
        let N = max(f.len(), g.len());
        let mut cs = vec![0; N];
        for (i, &c) in f.iter().enumerate() {
            cs[i] += c
        }
        for (i, &c) in g.iter().enumerate() {
            cs[i] = (cs[i] + c) % D
        }
        cs
    }
    
    fn sub(f: &[i64], g: &[i64]) -> Vec<i64> {
        let N = max(f.len(), g.len());
        let mut cs = vec![0; N];
        for (i, &c) in f.iter().enumerate() {
            cs[i] += c
        }
        for (i, &c) in g.iter().enumerate() {
            cs[i] = (cs[i] - c) % D
        }
        cs
    }
    
    fn mul(f: &[i64], g: &[i64]) -> Vec<i64> {
        let N = f.len() + g.len() - 1;
        let mut h: Vec<i64> = vec![0; N];
        if f.len() < 20 || g.len() < 20 {
            for (i, &c1) in f.iter().enumerate() {
                for (j, &c2) in g.iter().enumerate() {
                    h[i+j] = (h[i+j] + c1 * c2) % D
                }
            }
        }
        else {
            // Karatsuba algorithm
            let mid = min(f.len() / 2, g.len() / 2);
            let f1 = &f[..mid];
            let f2 = &f[mid..];
            let g1 = &g[..mid];
            let g2 = &g[mid..];
            let h1 = Self::mul(f1, g1);
            let h2 = Self::mul(f2, g2);
            let h3 = Self::sub(&Self::add(&h2, &h1),
                               &Self::mul(&Self::sub(&f2, &f1),
                                          &Self::sub(&g2, &g1)));
            for (k, c) in h1.into_iter().enumerate() {
                h[k] += c
            }
            for (k, c) in h3.into_iter().enumerate() {
                h[k+mid] += c
            }
            for (k, c) in h2.into_iter().enumerate() {
                h[k+mid*2] += c
            }
            for i in 0..N {
                h[i] %= D
            }
        }
        h
    }
}


//////////////////// process ////////////////////

fn read_input() -> (usize, usize, Vec<i64>) {
    let v = read_vec();
    let (N, B) = (v[0], v[1]);
    let cs = read_vec();
    (N, B, cs)
}

fn initialize_counter(cs: &Vec<i64>, B: usize) -> Polynomial {
    let mut counter: Vec<i64> = vec![0; B];
    for &c in cs.iter() {
        counter[(c as usize) % B] += 1
    }
    Polynomial { cs: counter }
}

fn mul2(f: &Polynomial, a: usize) -> Polynomial {
    let N = f.len();
    let mut cs: Vec<i64> = vec![0; N];
    for (r, &c) in f.cs.iter().enumerate() {
        cs[r*a%N] += c
    }
    normalize(Polynomial { cs }, N)
}

fn normalize(f: Polynomial, B: usize) -> Polynomial {
    let mut cs: Vec<i64> = vec![0; B];
    for (i, &c) in f.cs.iter().enumerate() {
        cs[i%B] = (cs[i%B] + c) % D
    }
    Polynomial { cs }
}

fn DC(n: usize, B: usize, cs: &Vec<i64>) -> Polynomial {
    if n == 1 {
        initialize_counter(cs, B)
    }
    else if n % 2 == 1 {
        let f = DC(n-1, B, cs);
        let g = initialize_counter(cs, B);
        let fB = mul2(&f, 10 % B);
        normalize(fB * g, B)
    }
    else {
        let f = DC(n/2, B, cs);
        let g = mul2(&f, pow(10, n/2, B));
        normalize(f * g, B)
    }
}

fn F(N: usize, B: usize, cs: Vec<i64>) -> i64 {
    let f = DC(N, B, &cs);
    if f.cs[0] >= 0 { f.cs[0] } else { f.cs[0] + D }
}

fn main() {
    let (N, B, cs) = read_input();
    println!("{}", F(N, B, cs))
}