AtCoder Beginner Contest 328 E

https://atcoder.jp/contests/abc328/tasks/abc328_e

全域木を全て調べればいいのですが、すでに繋がっているのかどうかの判定にUnion-Findを使います。しかし、エッジを取り除くのが難しいので、エッジを追加するときに追加する前の状態を最小限に保存しておきます。ただ、Union-Findを全コピーでも間に合いそうですね。

// Modulo MST
#![allow(non_snake_case)]

use std::cmp::{min, max};


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}


//////////////////// Edge ////////////////////

type Node = usize;
type Weight = i64;
type Edge = (Node, Node, Weight);

fn read_edge() -> Edge {
    let v1: Vec<usize> = read_vec();
    let u = v1[0] - 1;
    let v = v1[1] - 1;
    let w: Weight = v1[2] as Weight;
    (u, v, w)
}


//////////////////// UnionFind ////////////////////

struct UnionFind {
    parents: Vec<Node>,
    heights: Vec<i32>,
    num_edges: usize,
    sum_weights: Weight
}

impl UnionFind {
    fn new(N: usize) -> UnionFind {
        let parents: Vec<Node> = (0..N).collect();
        let heights: Vec<i32> = vec![1; N];
        UnionFind { parents, heights, num_edges: 0, sum_weights: 0 }
    }
    
    fn join(&mut self, edge: &Edge) -> (Node, Node, Node, i32) {
        let r1 = self.root(edge.0);
        let r2 = self.root(edge.1);
        self.num_edges += 1;
        self.sum_weights += edge.2;
        if r1 == r2 {
            return (0, 0, 0, 0) // ここには来ないはず
        }
        
        let h1 = self.heights[r1];
        let h2 = self.heights[r2];
        if h1 <= h2 {   // r2にr1がぶら下がる
            let ret = (r1, self.parents[r1], r2, self.heights[r2]);
            self.parents[r1] = r2;
            self.heights[r2] = max(self.heights[r2], self.heights[r1]+1);
            ret
        }
        else {
            let ret = (r2, self.parents[r2], r1, self.heights[r1]);
            self.parents[r2] = r1;
            self.heights[r1] = max(self.heights[r1], self.heights[r2]+1);
            ret
        }
    }
    
    fn remove(&mut self, r1: Node, p1: Node, r2: Node, h2: i32, w: Weight) {
        self.parents[r1] = p1;
        self.heights[r2] = h2;
        self.num_edges -= 1;
        self.sum_weights -= w
    }
    
    fn root(&self, v0: Node) -> Node {
        let mut v = v0;
        while self.parents[v] != v {
            v = self.parents[v]
        }
        v
    }
    
    fn is_connected(&self, v1: Node, v2: Node) -> bool {
        self.root(v1) == self.root(v2)
    }
}


//////////////////// process ////////////////////

fn read_input() -> (usize, Weight, Vec<Edge>) {
    let v = read_vec();
    let N = v[0];
    let M = v[1];
    let K: Weight = v[2] as Weight;
    let edges: Vec<Edge> = (0..M).map(|_| read_edge()).collect();
    (N, K, edges)
}

fn min_weight_tree(uf: &mut UnionFind, K: Weight, edges: &[Edge]) -> Weight {
    let N = uf.parents.len();
    if uf.num_edges == N - 1 {
        uf.sum_weights % K
    }
    else if edges.is_empty() {
        K
    }
    else {
        let w1 = min_weight_tree(uf, K, &edges[1..]);
        let edge = edges[0];
        if uf.is_connected(edge.0, edge.1) {
            return w1
        }
        
        let (r1, p1, r2, h2) = uf.join(&edge);
        let w2 = min_weight_tree(uf, K, &edges[1..]);
        uf.remove(r1, p1, r2, h2, edge.2);
        min(w1, w2)
    }
}

fn F(N: usize, K: Weight, edges: Vec<Edge>) -> Weight {
    let mut uf = UnionFind::new(N);
    min_weight_tree(&mut uf, K, &edges[..])
}

fn main() {
    let (N, K, edges) = read_input();
    println!("{}", F(N, K, edges))
}