AtCoder Beginner Contest 406 E

https://atcoder.jp/contests/abc406/tasks/abc406_e

難しい問題ではないです。例の20は1足して21にすると考えやすいです。これを二進にすると10101です。
二進で10000は16ですが、[0, 16)の範囲はまとめて考えられます。1は2個なので、4C2=6通りあります。各ビット当たり何回1になるかというと、まずどこかのビットが1だとすると、残り3ビットで1個なので、3C1=3通りあります。全てのビットで3回1になります。なので和は15*3=45です。
次に[16, 20)の範囲をまとめて考えられます。残りの1は1個なので、2C1=2通りあって、各ビットは1回なので和は3*1=3です。しかし、16が2通りあるので、これが32です。このようにビットが立っている場所で計算します。

しかし、998244353の剰余群で逆数を求めるのを避けるためにメモ化をしようとすると、Pythonのようにmemoをグローバルに置くのが難しいです。散々試したのですが、うまくいきませんでした。
次に、998244353をGenericにすると割り算が掛け算になると思うのですが、これは下のコードのようにできました。
最後に、i64で計算するとどこかでエラーが出るらしく、i128にしたらあっさり通りました。

// Popcount Sum 3
#![allow(non_snake_case)]


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}

// ax = by + 1 (a, b > 0)
fn linear_diophantine(a: i128, b: i128) -> Option<(i128, i128)> {
    if a == 1 {
        return Some((1, 0))
    }
    
    let q = b / a;
    let r = b % a;
    if r == 0 {
        return None
    }
    let (x1, y1) = linear_diophantine(r, a)?;
    Some((-q * x1 - y1, -x1))
}

fn inverse<const D: i128>(a: i128) -> i128 {
    let (x, _y) = linear_diophantine(a, D).unwrap();
    if x >= 0 {
        x % D
    }
    else {
        x % D + D
    }
}

fn binary(mut n: i128) -> Vec<i128> {
    let mut bs: Vec<i128> = vec![];
    while n > 0 {
        bs.push(n & 1);
        n >>= 1
    }
    bs
}

fn C<const D: i128>(n: i128, m: i128) -> i128 {
    (0..m).fold(1, |x, y| x * (n - y) % D * inverse::<D>(y + 1) % D)
}


//////////////////// process ////////////////////

const D: i128 = 998244353;

type Test = (i128, i128);

fn read_test() -> Test {
    let v: Vec<i128> = read_vec();
    (v[0], v[1])
}

fn read_input() -> Vec<Test> {
    let T: usize = read();
    let tests: Vec<Test> = (0..T).map(|_| read_test()).collect();
    tests
}

// 桁数, 1の個数 -> (場合の数, 和)
fn g(n: i128, k: i128) -> (i128, i128) {
    if k == 0 {
        (1, 0)
    }
    else if n < k {
        (0, 0)
    }
    else {
        (C::<D>(n, k), C::<D>(n-1, k-1) * (((1i128 << n) - 1) % D) % D)
    }
}

fn F_each(N: i128, K: i128) -> i128 {
    let mut n = N + 1;
    let mut k = K;
    let bs = binary(n);
    let mut s: i128 = 0;
    for i in (0..bs.len()).rev() {
        if bs[i] == 0 {
            continue
        }
        let (num, sum) = g(i as i128, k);   // [n, n+2^i)
        s = (s + (N + 1 - n) * num + sum) % D;
        n -= 1i128 << i;
        k -= 1;
        if k < 0 {
            break
        }
    }
    s
}

fn F(tests: Vec<Test>) {
    for (N, K) in tests {
        println!("{}", F_each(N, K))
    }
}

fn main() {
    let tests = read_input();
    F(tests)
}