AtCoder Beginner Contest 405 E

https://atcoder.jp/contests/abc405/tasks/abc405_e

まず、リンゴとブドウを並べます。そして、オレンジを配置します。オレンジは各リンゴの左とリンゴとブドウの間に配置できます。リンゴとブドウの間にオレンジをn個置くとします。そうすると、その他のオレンジは各リンゴの左に置けるので、場合の数は重複組み合わせで、
 \displaystyle _AH_{B-n}
となります。そのときバナナは各オレンジとブドウの左と一番右に置けるので、場合の数はこれも重複組み合わせで、
 \displaystyle _{n+D+1}H_C
となります。結局、
 \displaystyle \sum_{n=0}^B{_AH_{B-n\ n+D+1}H_C}
を計算すればよいです。

// Fruit Lineup
#![allow(non_snake_case)]


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}

// ax = by + 1 (a, b > 0)
fn linear_diophantine(a: i64, b: i64) -> Option<(i64, i64)> {
    if a == 1 {
        return Some((1, 0))
    }
    
    let q = b / a;
    let r = b % a;
    if r == 0 {
        return None
    }
    let (x1, y1) = linear_diophantine(r, a)?;
    Some((-q * x1 - y1, -x1))
}

fn inverse(a: i64, d: i64) -> i64 {
    let (x, _y) = linear_diophantine(a, d).unwrap();
    if x >= 0 {
        x % d
    }
    else {
        x % d + d
    }
}


//////////////////// process ////////////////////

const E: i64 = 998244353;

fn read_input() -> (i64, i64, i64, i64) {
    let v: Vec<i64> = read_vec();
    let (A, B, C, D) = (v[0], v[1], v[2], v[3]);
    (A, B, C, D)
}

fn fac_table(N: i64) -> Vec<i64> {
    let mut fac: Vec<i64> = vec![1; N as usize + 1];
    for i in 1..N+1 {
        fac[i as usize] = fac[i as usize - 1] * i % E
    }
    fac
}

use std::cmp::max;

fn F(A: i64, B: i64, C: i64, D:i64) -> i64 {
    let N: i64 = max(A+B-1, D+C+B);
    let fac = fac_table(N);
    let mut s: i64 = 0;
    for n in 0..B+1 {
        let num = fac[(A+B-n-1) as usize] * fac[(D+C+n) as usize] % E;
        let den = fac[(B-n) as usize] * fac[(A-1) as usize] % E *
                    fac[C as usize] % E * fac[(D+n) as usize] % E;
        s += num * inverse(den, E) % E
    }
    s % E
}

fn main() {
    let (A, B, C, D) = read_input();
    println!("{}", F(A, B, C, D))
}