AtCoder Beginner Contest 306 F

https://atcoder.jp/contests/abc306/tasks/abc306_f

ふつうに数えていったらとても間に合わないので、別の数え方を考えます。
入力例1を考えると、順位を0ベースとして、1と同じ列とその下に自分より小さい数が無いのでカウント無しで、3は1と2があるので、2カウントします。8は同じ列とその下に3つあるので、3カウントします。もう少し分かりやすくするために全てのAをソートして何列目かにすると、

1 2 1 3 3 2

となります。そのときに各列がその前にいくつあったかを調べると、

  1 2 1 3 3 2
0 1 1 2 2 2 2
0 0 1 1 1 1 2
0 0 0 0 1 2 2

となります。そして、1はおいておいて、その次の次の1はその前に1以上が2つあるので、2カウントです。最後の2は、2以上が1+2=3カウントです。
このように、上下に累積もすぐに分からないといけないので、BITでそれを実現しました。あと、同じ列は高さ分重複して数えるのと、0ベースにした分を考慮しました。

// Merge Sets
#![allow(non_snake_case)]


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}


//////////////////// BIT ////////////////////

struct BIT {
    n: i32,
    bit: Vec<i32>
}

impl BIT {
    fn new(m: i32) -> BIT {
        let n = m + 1;
        let bit: Vec<i32> = (0..n).map(|_| 0).collect();
        BIT { n, bit }
    }
    
    fn add(&mut self, i: i32, x: i32) {
        let mut idx = i;
        while idx < self.n {
            self.bit[idx as usize] += x;
            idx += idx & (-idx)
        }
    }
    
    fn sum(&self, i: i32) -> i32 {
        let mut s: i32 = 0;
        let mut idx = i;
        while idx > 0 {
            s += self.bit[idx as usize];
            idx -= idx & (-idx)
        }
        s
    }
}


//////////////////// process ////////////////////

fn read_input() -> Vec<Vec<i32>> {
    let v = read_vec();
    let N = v[0];
    let A: Vec<Vec<i32>> = (0..N).map(|_| read_vec::<i32>()).collect();
    A
}

fn floor_log2(n: usize) -> usize {
    for e in 0.. {
        if 1 << e >= n {
            return e
        }
    }
    0   // dummy
}

fn f(A: Vec<Vec<i32>>) -> i64 {
    let N = A.len();
    let M = A[0].len();
    let mut v1 = A.into_iter().enumerate().
                    map(|(i, w)| w.into_iter().map(move |e| (e, i))).
                    flatten().collect::<Vec<(i32, usize)>>();
    v1.sort();
    let v: Vec<usize> = v1.into_iter().map(|(_, i)| N - i).collect();
    let e = floor_log2(N);
    let n = 1 << e;
    let mut tree = BIT::new(n);
    let mut s: i64 = 0;
    for i in v.into_iter() {
        if i != 1 {
            s += tree.sum(i as i32) as i64;
        }
        tree.add(i as i32, 1)
    }
    (N*(N-1)*M/2 + (N-1)*(N-2)*M*(M-1)/4) as i64 + s
}


//////////////////// main ////////////////////

fn main() {
    let A = read_input();
    println!("{}", f(A))
}