AtCoder Beginner Contest 430 D

https://atcoder.jp/contests/abc430/tasks/abc430_d

BTreeSetを使って、前後の要素を拾ってこればよいです。

// Neighbor Distance
#![allow(non_snake_case)]


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}


//////////////////// process ////////////////////

fn read_input() -> Vec<i64> {
    let _N: usize = read();
    let X: Vec<i64> = read_vec();
    X
}

use std::cmp::min;
use std::collections::BTreeSet;

const INF: i64 = 2_000_000_000;

fn prev_xs(x: i64, tree: &mut BTreeSet<i64>) -> (i64, i64) {
    let mut it = tree.range(..x);
    if let Some(&x1) = it.next_back() {
        if let Some(&x2) = it.next_back() {
            (x1, x2)
        }
        else {
            (x1, INF)
        }
    }
    else {
        (INF, INF)
    }
}

fn next_xs(x: i64, tree: &mut BTreeSet<i64>) -> (i64, i64) {
    let mut it = tree.range(x+1..);
    if let Some(&x1) = it.next() {
        if let Some(&x2) = it.next() {
            (x1, x2)
        }
        else {
            (x1, INF)
        }
    }
    else {
        (INF, INF)
    }
}

fn min_dist(x1: i64, x2: i64, x3: i64) -> i64 {
    if x1 == INF {
        if x3 == INF {
            0
        }
        else {
            x3 - x2
        }
    }
    else {
        min(x2 - x1, x3 - x2)
    }
}

fn F(X: Vec<i64>) {
    let mut tree: BTreeSet<i64> = BTreeSet::new();
    tree.insert(0);
    let mut total_dists: i64 = 0;
    for x in X {
        tree.insert(x);
        let (prev1, prev2) = prev_xs(x, &mut tree);
        let (next1, next2) = next_xs(x, &mut tree);
        let xs = [prev2, prev1, x, next1, next2];
        let new_x = min_dist(xs[1], xs[2], xs[3]);
        let old_prev1 = min_dist(xs[0], xs[1], xs[3]);
        let new_prev1 = min_dist(xs[0], xs[1], xs[2]);
        let old_next1 = min_dist(xs[1], xs[3], xs[4]);
        let new_next1 = min_dist(xs[2], xs[3], xs[4]);
        total_dists += new_x + new_prev1 - old_prev1 + new_next1 - old_next1;
        println!("{}", total_dists)
    }
}

fn main() {
    let X = read_input();
    F(X)
}