AtCoder Beginner Contest 360 E

https://atcoder.jp/contests/abc360/tasks/abc360_e

例えばN=3で考えると、1から2に移動する確率は2/9で、1から1に移動する確率は1-2/9*2=5/9です。一般には、ある位置から他に移動する確率は、 \frac{2}{N^2}で、同じところにとどまる確率は 1 - \frac{2}{N^2}(N-1) = \frac{N^2-2N+2}{N^2}となります。なので、1回での状態遷移行列は、

 \displaystyle T_1 = aA + bI
と書けます。ここで、Aは全ての要素が1の行列、Iは単位行列 a = \frac{2}{N^2} b = \frac{N-2}{N}です。

K回の状態遷移確率は T = T_1^Kです。また、Tが与えられたときの期待値Eは、

 \displaystyle E = \sum_i{\sum_j{iT_{ij}{\delta}_{j1}}} = \sum_{i=1}^N{iT_{i1}}

となります。
 T = (aA + bI)^K = \sum_{k=0}^K{_KC_ka^kA^kb^{K-k}I}ですが、 A^2 = NAだから、 A^k = N^{k-1}A (k \ge 1)を考慮すると、 b^KI + \frac{1}{N}\sum_{k=1}^K{_KC_k(aN)^kb^{K-k}A} = b^KI + ((aN+b)^K - b^K)A
ここで、aN+b=1だから、
 T = b^KI + (1 - b^K)A
また、 \sum_i{iI_{i1}} = 1 \sum_i{iA_{i1}} = \frac{N(N+1)}{2}だから、 E = b^K + \frac{N+1}{2}(1 - b^K)

結局、
 \displaystyle E = \frac{N+1}{2} - \frac{N-1}{2}b^K
となります。

// Random Swaps of Balls
#![allow(non_snake_case)]


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}

// ax = by + 1 (a, b > 0)
fn linear_diophantine(a: i64, b: i64) -> Option<(i64, i64)> {
    if a == 1 {
        return Some((1, 0))
    }
    
    let q = b / a;
    let r = b % a;
    if r == 0 {
        return None
    }
    let (x1, y1) = linear_diophantine(r, a)?;
    Some((-q * x1 - y1, -x1))
}

fn inverse(a: i64, d: i64) -> i64 {
    let (x, _y) = linear_diophantine(a, d).unwrap();
    if x >= 0 {
        x % d
    }
    else {
        x % d + d
    }
}

fn pow(n: i64, e: u32, d: i64) -> i64 {
    if e == 1 {
        n
    }
    else if e % 2 == 1 {
        n * pow(n, e-1, d) % d
    }
    else {
        let p = pow(n, e/2, d);
        p * p % d
    }
}


//////////////////// process ////////////////////

const D: i64 = 998244353;

fn read_input() -> (i64, u32) {
    let v: Vec<u32> = read_vec();
    let (N, K) = (v[0] as i64, v[1]);
    (N, K)
}

fn F(N: i64, K: u32) -> i64 {
    let N_inv = inverse(N, D);
    let b = (N - 2) * N_inv % D;
    let c1 = (N + 1) * inverse(2, D) % D;
    let c2 = (N - 1) * inverse(2, D) % D * pow(b, K, D) % D;
    (c1 - c2 + D) % D
}

fn main() {
    let (N, K) = read_input();
    println!("{}", F(N, K))
}