AtCoder Beginner Contest 347 F

https://atcoder.jp/contests/abc347/tasks/abc347_f

3つの正方形は、6通りの配置があって、2つの座標を定めればある長方形の範囲最大の和になります。
Rustなのに時間的にかなり際どいです。木を辿らずにいきなり値を取れるところは取るようにして、やっと通りました。

// Set Add Query
#![allow(non_snake_case)]

use std::cmp::max;


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}


//////////////////// Segment Tree ////////////////////

type Val = i64;
type Range = (usize, usize);

struct SegmentTree {
    n: usize,
    m: usize,
    v: Vec<Val>,
}

impl SegmentTree {
    fn max(&self, rng: Range) -> Val {
        if rng == (0, self.m) {
            self.v[0]
        }
        else if rng.0 == rng.1 - 1 {
            self.v[rng.0+self.n-1]
        }
        else {
            self.max_core(rng, 0, self.n, 0)
        }
    }
    
    fn max_core(&self, rng: Range, first: usize, last: usize, i: usize) -> Val {
        if rng.0 <= first && last <= rng.1 {
            self.v[i]
        }
        else {
            let mid = (first + last) / 2;
            if rng.1 <= mid {
                self.max_core(rng, first, mid, i*2+1)
            }
            else if rng.0 >= mid {
                self.max_core(rng, mid, last, i*2+2)
            }
            else {
                max(self.max_core(rng, first, mid, i*2+1),
                    self.max_core(rng, mid, last, i*2+2))
            }
        }
    }
    
    fn ceil_two_pow(n: usize) -> usize {
        if n == 1 { 1 } else { SegmentTree::ceil_two_pow((n+1)/2) * 2 }
    }
    
    fn create(a: &Vec<Val>) -> SegmentTree {
        let m = a.len();
        let n = SegmentTree::ceil_two_pow(m);
        let mut v: Vec<Val> = vec![0; n*2-1];
        for i in n-1..n+m-1 {
            v[i] = a[i+1-n]
        }
        for i in (0..n-1).rev() {
            v[i] = max(v[i*2+1], v[i*2+2])
        }
        SegmentTree { n, m, v }
    }
    
    fn max_segment(s1: &SegmentTree, s2: &SegmentTree) -> SegmentTree {
        let v: Vec<Val> = s1.v.iter().zip(s2.v.iter()).
                                map(|(&a, &b)| max(a, b)).collect();
        SegmentTree { n: s1.n, m: s1.m, v }
    }
}


//////////////////// Segment Tree 2D ////////////////////

struct SegmentTree2D {
    n: usize,
    m: usize,
    v: Vec<SegmentTree>,
}

impl SegmentTree2D {
    fn max(&self, rng1: Range, rng2: Range) -> Val {
        if rng1 == (0, self.m) {
            self.v[0].max(rng2)
        }
        else if rng1.0 == rng1.1 - 1 {
            self.v[rng1.0+self.n-1].max(rng2)
        }
        else {
            self.max_core(rng1, rng2, 0, self.n, 0)
        }
    }
    
    fn max_core(&self, rng1: Range, rng2: Range,
                        first: usize, last: usize, i: usize) -> Val {
        if rng1.0 <= first && last <= rng1.1 {
            self.v[i].max(rng2)
        }
        else {
            let mid = (first + last) / 2;
            if rng1.1 <= mid {
                self.max_core(rng1, rng2, first, mid, i*2+1)
            }
            else if rng1.0 >= mid {
                self.max_core(rng1, rng2, mid, last, i*2+2)
            }
            else {
                max(self.max_core(rng1, rng2, first, mid, i*2+1),
                    self.max_core(rng1, rng2, mid, last, i*2+2))
            }
        }
    }
    
    fn row_max(&self, i: usize) -> i64 {
        self.v[i+self.n-1].v[0]
    }
    
    fn column_max(&self, j: usize) -> i64 {
        let tree = &self.v[0];
        tree.v[j+tree.n-1]
    }
    
    fn create(A: &Vec<Vec<Val>>) -> SegmentTree2D {
        let m = A.len();
        let n = SegmentTree::ceil_two_pow(m);
        let mut v: Vec<SegmentTree> = vec![];
        for _ in 0..n*2-1 {
            // 仮のSegmentTreeを登録
            let w: Vec<Val> = vec![0];
            v.push(SegmentTree { n: 1, m, v: w })
        }
        for i in n-1..n+m-1 {
            v[i] = SegmentTree::create(&A[i+1-n])
        }
        for i in n+m-1..n*2-1 {
            v[i] = SegmentTree::create(&vec![0; n*2-1])
        }
        for i in (0..n-1).rev() {
            v[i] = SegmentTree::max_segment(&v[i*2+1], &v[i*2+2])
        }
        SegmentTree2D { n, m, v }
    }
}


//////////////////// Matrix ////////////////////

type Matrix = Vec<Vec<i64>>;


//////////////////// Table ////////////////////

struct Table {
    N: usize,
    M: usize,
    tree: SegmentTree2D
}

impl Table {
    fn max(&self) -> i64 {
        let v: Vec<i64> = vec![self.max1(), self.max2(), self.max3(),
                               self.max4(), self.max5(), self.max6()];
        v.into_iter().fold(0, |x, y| max(x, y))
    }
    
    // 1 2
    //  3
    fn max1(&self) -> i64 {
        let mut max_value: i64 = 0;
        for i in self.M..self.N-self.M+1 {      // 3の縦座標
            for j in 0..self.N-self.M*2+1 {     // 1の横座標
                let max1 = self.tree.max(self.above(i), (j, j+1));
                let max2 = self.tree.max(self.above(i), self.right(j));
                let max3 = self.tree.row_max(i);
                max_value = max(max_value, max1 + max2 + max3);
            }
        }
        max_value
    }
    
    //  3
    // 1 2
    fn max2(&self) -> i64 {
        let mut max_value: i64 = 0;
        for i in 0..self.N-self.M*2+1 {         // 3の縦座標
            for j in 0..self.N-self.M*2+1 {     // 1の横座標
                let max1 = self.tree.max(self.below(i), (j, j+1));
                let max2 = self.tree.max(self.below(i), self.right(j));
                let max3 = self.tree.row_max(i);
                max_value = max(max_value, max1 + max2 + max3);
            }
        }
        max_value
    }
    
    // 1
    //   2
    // 3
    fn max3(&self) -> i64 {
        let mut max_value: i64 = 0;
        for i in 0..self.N-self.M*2+1 {         // 1の縦座標
            for j in self.M..self.N-self.M+1 {  // 2の横座標
                let max1 = self.tree.max((i, i+1), self.left(j));
                let max2 = self.tree.column_max(j);
                let max3 = self.tree.max(self.below(i), self.left(j));
                max_value = max(max_value, max1 + max2 + max3);
            }
        }
        max_value
    }
    
    //   1
    // 2
    //   3
    fn max4(&self) -> i64 {
        let mut max_value: i64 = 0;
        for i in 0..self.N-self.M*2+1 {         // 1の縦座標
            for j in 0..self.N-self.M*2+1 {     // 2の横座標
                let max1 = self.tree.max((i, i+1), self.right(j));
                let max2 = self.tree.column_max(j);
                let max3 = self.tree.max(self.below(i), self.right(j));
                max_value = max(max_value, max1 + max2 + max3);
            }
        }
        max_value
    }
    
    // 1 2 3
    fn max5(&self) -> i64 {
        let mut max_value: i64 = 0;
        for j in self.M..self.N-self.M*2+1 {    // 真ん中の横座標
            let max1 = self.tree.max((0, self.N-self.M+1), self.left(j));
            let max2 = self.tree.max((0, self.N-self.M+1), (j, j+1));
            let max3 = self.tree.max((0, self.N-self.M+1), self.right(j));
            max_value = max(max_value, max1 + max2 + max3);
        }
        max_value
    }
    
    // 1
    // 2
    // 3
    fn max6(&self) -> i64 {
        let mut max_value: i64 = 0;
        for i in self.M..self.N-self.M*2+1 {    // 真ん中の縦座標
            let max1 = self.tree.max(self.above(i), (0, self.N-self.M+1));
            let max2 = self.tree.row_max(i);
            let max3 = self.tree.max(self.below(i), (0, self.N-self.M+1));
            max_value = max(max_value, max1 + max2 + max3);
        }
        max_value
    }
    
    fn above(&self, i: usize) -> Range {
        (0, i-self.M+1)
    }
    
    fn below(&self, i: usize) -> Range {
        (i+self.M, self.N-self.M+1)
    }
    
    fn left(&self, j: usize) -> Range {
        (0, j-self.M+1)
    }
    
    fn right(&self, j: usize) -> Range {
        (j+self.M, self.N-self.M+1)
    }
    
    fn accumulate(A: &Matrix) -> Matrix {
        let N = A.len();
        let mut B: Matrix = vec![vec![0; N+1]; N+1];
        for i in 0..N {
            for j in 0..N {
                B[i+1][j+1] = B[i+1][j] + B[i][j+1] - B[i][j] + A[i][j]
            }
        }
        B
    }
    
    fn make_sum_matrix(M: usize, B: &Matrix) -> Matrix {
        let N = B.len() - 1;
        let L = N - M + 1;
        let mut C: Matrix = vec![vec![0; L]; L];
        for i in 0..L {
            for j in 0..L {
                C[i][j] = B[i+M][j+M] - B[i][j+M] - B[i+M][j] + B[i][j]
            }
        }
        C
    }
}


//////////////////// process ////////////////////

fn read_input() -> (usize, Matrix) {
    let v = read_vec();
    let (N, M) = (v[0], v[1]);
    let A: Matrix = (0..N).map(|_| read_vec::<i64>()).collect();
    (M, A)
}

fn make_table(M: usize, A: Matrix) -> Table {
    let N = A.len();
    let B = Table::accumulate(&A);
    let C = Table::make_sum_matrix(M, &B);
    let tree = SegmentTree2D::create(&C);
    Table { N, M, tree }
}

fn F(M: usize, A: Matrix) -> i64 {
    let table = make_table(M, A);
    table.max()
}

fn main() {
    let (M, A) = read_input();
    println!("{}", F(M, A))
}