https://atcoder.jp/contests/abc304/tasks/abc304_f
これは、包除原理ですね。
題意を満たすシフトの数をF(M)とすると、N = 12で考えると、M=6とM=4はそれぞれF(6)とF(4)がありますが、両方にF(2)個分共通のシフトが含まれることが分かるので、結局トータルでは
F(6) + F(4) - F(2)
となります。
一般に、
ここで、はメビウス関数です。
このように包除原理を使えば簡単なのですが、Pythonと同じように書くと負数の剰余が負になるというトラップが出てきますね。
// Shift Table #![allow(non_snake_case)] //////////////////// constants //////////////////// const D: i64 = 998244353; //////////////////// library //////////////////// fn read<T: std::str::FromStr>() -> T { let mut line = String::new(); std::io::stdin().read_line(&mut line).ok(); line.trim().parse().ok().unwrap() } //////////////////// Factors //////////////////// type Factors = Vec<(u32, u32)>; fn div_pow(n: u32, d: u32) -> (u32, u32) { if n % d == 0 { let (e, m) = div_pow(n / d, d); (e + 1, m) } else { (0, n) } } fn factorize(n: u32, p0: u32) -> Factors { if n == 1 { return vec![] } for p in p0.. { if p * p > n { break } if n % p == 0 { let (e, m) = div_pow(n, p); let factors = factorize(m, p + 1); let mut new_factors = vec![(p, e)]; new_factors.extend(&factors); return new_factors } } vec![(n, 1)] } fn divisors(fs: &Factors) -> Vec<(bool, u32)> { if fs.len() == 1 { let (p, _e) = fs[0]; vec![(false, 1), (true, p)] } else { let mid = fs.len() / 2; let fs1 = (0..mid).map(|i| fs[i]).collect::<Factors>(); let fs2 = (mid..fs.len()).map(|i| fs[i]).collect::<Factors>(); let ds1 = divisors(&fs1); let ds2 = divisors(&fs2); let mut ds: Vec<(bool, u32)> = Vec::new(); for (s1, d1) in ds1.iter() { for (s2, d2) in ds2.iter() { ds.push((s1 ^ s2, d1 * d2)) } } ds } } //////////////////// process //////////////////// fn read_input() -> (u32, String) { let N = read(); let S = read(); (N, S) } fn num_valid_shifts(M: u32, S: &String) -> i64 { let mut a: Vec<bool> = (0..M).map(|_| false).collect(); for (i, c) in S.chars().enumerate() { if c == '.' { a[i%M as usize] = true } } a.into_iter().fold(1i64, |x, y| { if y { x } else { x*2 % D } }) } fn f(N: u32, S: String) -> i64 { let fs = factorize(N, 2); let ds = divisors(&fs); let mut s: i64 = 0; for (sign, d) in ds.into_iter() { if d == 1 { continue } let M = N / d; if sign { s = (s + num_valid_shifts(M, &S)) % D } else { s = (s - num_valid_shifts(M, &S)) % D } } (s + D) % D } //////////////////// main //////////////////// fn main() { let (N, S) = read_input(); println!("{}", f(N, S)) }