https://atcoder.jp/contests/math-and-algorithm/tasks/typical90_am
グラフを作って、1をルートにして木を辿っていくと計算できます。木のルートを端点とする距離とそれ以外の距離とを分けるとうまくいきます。
イジワルな入力で再帰が深くなるようなものだとうまくいかないかなと思ったのですが、そんなことはなかったですね。
// Tree Distance #![allow(non_snake_case)] use std::collections::HashMap; //////////////////// library //////////////////// fn read<T: std::str::FromStr>() -> T { let mut line = String::new(); std::io::stdin().read_line(&mut line).ok(); line.trim().parse().ok().unwrap() } fn read_vec<T: std::str::FromStr>() -> Vec<T> { read::<String>().split_whitespace() .map(|e| e.parse().ok().unwrap()).collect() } //////////////////// Graph //////////////////// type Node = i32; type Edge = (Node, Node); type Graph = HashMap<Node, Vec<Node>>; fn read_edge() -> Edge { let v: Vec<Node> = read_vec(); (v[0], v[1]) } fn make_graph(edges: Vec<Edge>) -> Graph { let mut graph: Graph = Graph::new(); for (u, v) in edges.into_iter() { let e1 = graph.entry(u).or_insert(vec![]); (*e1).push(v); let e2 = graph.entry(v).or_insert(vec![]); (*e2).push(u); } graph } //////////////////// process //////////////////// fn read_input() -> Vec<Edge> { let N: usize = read(); let edges: Vec<Edge> = (0..N-1).map(|_| read_edge()).collect(); edges } fn sum_dist(v: Node, parent: Node, graph: &Graph) -> (usize, usize, usize) { let children: Vec<Node> = graph.get(&v).unwrap().iter(). map(|&w| w).filter(|&w| w != parent). collect(); if children.is_empty() { (1, 0, 0) } else { let a: Vec<(usize, usize, usize)> = children.iter().map(|&c| sum_dist(c, v, graph)).collect(); let sn = a.iter().map(|&(n, _, _)| n).sum::<usize>(); let sr = a.iter().map(|&(_, r, _)| r).sum::<usize>(); let si = a.iter().map(|&(_, _, i)| i).sum::<usize>(); let snn = a.iter().map(|&(n, _, _)| n*n).sum::<usize>(); let snr = a.iter().map(|&(n, r, _)| n*r).sum::<usize>(); let n = sn + 1; let r = sn + sr; let i = sr + si + sn*sn - snn + sn*sr - snr; (n, r, i) } } fn f(edges: Vec<Edge>) -> usize { let graph = make_graph(edges); let (_, r, i) = sum_dist(1, 1, &graph); r + i } fn main() { let edges = read_input(); println!("{}", f(edges)) }