https://atcoder.jp/contests/abc347/tasks/abc347_f
3つの正方形は、6通りの配置があって、2つの座標を定めればある長方形の範囲最大の和になります。
Rustなのに時間的にかなり際どいです。木を辿らずにいきなり値を取れるところは取るようにして、やっと通りました。
// Set Add Query #![allow(non_snake_case)] use std::cmp::max; //////////////////// library //////////////////// fn read<T: std::str::FromStr>() -> T { let mut line = String::new(); std::io::stdin().read_line(&mut line).ok(); line.trim().parse().ok().unwrap() } fn read_vec<T: std::str::FromStr>() -> Vec<T> { read::<String>().split_whitespace() .map(|e| e.parse().ok().unwrap()).collect() } //////////////////// Segment Tree //////////////////// type Val = i64; type Range = (usize, usize); struct SegmentTree { n: usize, m: usize, v: Vec<Val>, } impl SegmentTree { fn max(&self, rng: Range) -> Val { if rng == (0, self.m) { self.v[0] } else if rng.0 == rng.1 - 1 { self.v[rng.0+self.n-1] } else { self.max_core(rng, 0, self.n, 0) } } fn max_core(&self, rng: Range, first: usize, last: usize, i: usize) -> Val { if rng.0 <= first && last <= rng.1 { self.v[i] } else { let mid = (first + last) / 2; if rng.1 <= mid { self.max_core(rng, first, mid, i*2+1) } else if rng.0 >= mid { self.max_core(rng, mid, last, i*2+2) } else { max(self.max_core(rng, first, mid, i*2+1), self.max_core(rng, mid, last, i*2+2)) } } } fn ceil_two_pow(n: usize) -> usize { if n == 1 { 1 } else { SegmentTree::ceil_two_pow((n+1)/2) * 2 } } fn create(a: &Vec<Val>) -> SegmentTree { let m = a.len(); let n = SegmentTree::ceil_two_pow(m); let mut v: Vec<Val> = vec![0; n*2-1]; for i in n-1..n+m-1 { v[i] = a[i+1-n] } for i in (0..n-1).rev() { v[i] = max(v[i*2+1], v[i*2+2]) } SegmentTree { n, m, v } } fn max_segment(s1: &SegmentTree, s2: &SegmentTree) -> SegmentTree { let v: Vec<Val> = s1.v.iter().zip(s2.v.iter()). map(|(&a, &b)| max(a, b)).collect(); SegmentTree { n: s1.n, m: s1.m, v } } } //////////////////// Segment Tree 2D //////////////////// struct SegmentTree2D { n: usize, m: usize, v: Vec<SegmentTree>, } impl SegmentTree2D { fn max(&self, rng1: Range, rng2: Range) -> Val { if rng1 == (0, self.m) { self.v[0].max(rng2) } else if rng1.0 == rng1.1 - 1 { self.v[rng1.0+self.n-1].max(rng2) } else { self.max_core(rng1, rng2, 0, self.n, 0) } } fn max_core(&self, rng1: Range, rng2: Range, first: usize, last: usize, i: usize) -> Val { if rng1.0 <= first && last <= rng1.1 { self.v[i].max(rng2) } else { let mid = (first + last) / 2; if rng1.1 <= mid { self.max_core(rng1, rng2, first, mid, i*2+1) } else if rng1.0 >= mid { self.max_core(rng1, rng2, mid, last, i*2+2) } else { max(self.max_core(rng1, rng2, first, mid, i*2+1), self.max_core(rng1, rng2, mid, last, i*2+2)) } } } fn row_max(&self, i: usize) -> i64 { self.v[i+self.n-1].v[0] } fn column_max(&self, j: usize) -> i64 { let tree = &self.v[0]; tree.v[j+tree.n-1] } fn create(A: &Vec<Vec<Val>>) -> SegmentTree2D { let m = A.len(); let n = SegmentTree::ceil_two_pow(m); let mut v: Vec<SegmentTree> = vec![]; for _ in 0..n*2-1 { // 仮のSegmentTreeを登録 let w: Vec<Val> = vec![0]; v.push(SegmentTree { n: 1, m, v: w }) } for i in n-1..n+m-1 { v[i] = SegmentTree::create(&A[i+1-n]) } for i in n+m-1..n*2-1 { v[i] = SegmentTree::create(&vec![0; n*2-1]) } for i in (0..n-1).rev() { v[i] = SegmentTree::max_segment(&v[i*2+1], &v[i*2+2]) } SegmentTree2D { n, m, v } } } //////////////////// Matrix //////////////////// type Matrix = Vec<Vec<i64>>; //////////////////// Table //////////////////// struct Table { N: usize, M: usize, tree: SegmentTree2D } impl Table { fn max(&self) -> i64 { let v: Vec<i64> = vec![self.max1(), self.max2(), self.max3(), self.max4(), self.max5(), self.max6()]; v.into_iter().fold(0, |x, y| max(x, y)) } // 1 2 // 3 fn max1(&self) -> i64 { let mut max_value: i64 = 0; for i in self.M..self.N-self.M+1 { // 3の縦座標 for j in 0..self.N-self.M*2+1 { // 1の横座標 let max1 = self.tree.max(self.above(i), (j, j+1)); let max2 = self.tree.max(self.above(i), self.right(j)); let max3 = self.tree.row_max(i); max_value = max(max_value, max1 + max2 + max3); } } max_value } // 3 // 1 2 fn max2(&self) -> i64 { let mut max_value: i64 = 0; for i in 0..self.N-self.M*2+1 { // 3の縦座標 for j in 0..self.N-self.M*2+1 { // 1の横座標 let max1 = self.tree.max(self.below(i), (j, j+1)); let max2 = self.tree.max(self.below(i), self.right(j)); let max3 = self.tree.row_max(i); max_value = max(max_value, max1 + max2 + max3); } } max_value } // 1 // 2 // 3 fn max3(&self) -> i64 { let mut max_value: i64 = 0; for i in 0..self.N-self.M*2+1 { // 1の縦座標 for j in self.M..self.N-self.M+1 { // 2の横座標 let max1 = self.tree.max((i, i+1), self.left(j)); let max2 = self.tree.column_max(j); let max3 = self.tree.max(self.below(i), self.left(j)); max_value = max(max_value, max1 + max2 + max3); } } max_value } // 1 // 2 // 3 fn max4(&self) -> i64 { let mut max_value: i64 = 0; for i in 0..self.N-self.M*2+1 { // 1の縦座標 for j in 0..self.N-self.M*2+1 { // 2の横座標 let max1 = self.tree.max((i, i+1), self.right(j)); let max2 = self.tree.column_max(j); let max3 = self.tree.max(self.below(i), self.right(j)); max_value = max(max_value, max1 + max2 + max3); } } max_value } // 1 2 3 fn max5(&self) -> i64 { let mut max_value: i64 = 0; for j in self.M..self.N-self.M*2+1 { // 真ん中の横座標 let max1 = self.tree.max((0, self.N-self.M+1), self.left(j)); let max2 = self.tree.max((0, self.N-self.M+1), (j, j+1)); let max3 = self.tree.max((0, self.N-self.M+1), self.right(j)); max_value = max(max_value, max1 + max2 + max3); } max_value } // 1 // 2 // 3 fn max6(&self) -> i64 { let mut max_value: i64 = 0; for i in self.M..self.N-self.M*2+1 { // 真ん中の縦座標 let max1 = self.tree.max(self.above(i), (0, self.N-self.M+1)); let max2 = self.tree.row_max(i); let max3 = self.tree.max(self.below(i), (0, self.N-self.M+1)); max_value = max(max_value, max1 + max2 + max3); } max_value } fn above(&self, i: usize) -> Range { (0, i-self.M+1) } fn below(&self, i: usize) -> Range { (i+self.M, self.N-self.M+1) } fn left(&self, j: usize) -> Range { (0, j-self.M+1) } fn right(&self, j: usize) -> Range { (j+self.M, self.N-self.M+1) } fn accumulate(A: &Matrix) -> Matrix { let N = A.len(); let mut B: Matrix = vec![vec![0; N+1]; N+1]; for i in 0..N { for j in 0..N { B[i+1][j+1] = B[i+1][j] + B[i][j+1] - B[i][j] + A[i][j] } } B } fn make_sum_matrix(M: usize, B: &Matrix) -> Matrix { let N = B.len() - 1; let L = N - M + 1; let mut C: Matrix = vec![vec![0; L]; L]; for i in 0..L { for j in 0..L { C[i][j] = B[i+M][j+M] - B[i][j+M] - B[i+M][j] + B[i][j] } } C } } //////////////////// process //////////////////// fn read_input() -> (usize, Matrix) { let v = read_vec(); let (N, M) = (v[0], v[1]); let A: Matrix = (0..N).map(|_| read_vec::<i64>()).collect(); (M, A) } fn make_table(M: usize, A: Matrix) -> Table { let N = A.len(); let B = Table::accumulate(&A); let C = Table::make_sum_matrix(M, &B); let tree = SegmentTree2D::create(&C); Table { N, M, tree } } fn F(M: usize, A: Matrix) -> i64 { let table = make_table(M, A); table.max() } fn main() { let (M, A) = read_input(); println!("{}", F(M, A)) }