https://atcoder.jp/contests/abc368/tasks/abc368_d
エッジからグラフを作って、それを木にして、ルートから辿ります。Vにあるノードをルートにすると間違いが無いです。
// Minimum Steiner Tree #![allow(non_snake_case)] use std::collections::HashSet; //////////////////// library //////////////////// fn read<T: std::str::FromStr>() -> T { let mut line = String::new(); std::io::stdin().read_line(&mut line).ok(); line.trim().parse().ok().unwrap() } fn read_vec<T: std::str::FromStr>() -> Vec<T> { read::<String>().split_whitespace() .map(|e| e.parse().ok().unwrap()).collect() } //////////////////// Tree //////////////////// type Node = usize; type Edge = (Node, Node); fn read_edge() -> Edge { let v: Vec<Node> = read_vec(); (v[0]-1, v[1]-1) } struct Graph { g: Vec<Vec<Node>> } impl Graph { fn create(edges: Vec<Edge>, N: usize) -> Graph { let mut g : Vec<Vec<Node>> = vec![vec![]; N]; for (v1, v2) in edges.into_iter() { g[v1].push(v2); g[v2].push(v1) } Graph { g } } } struct Tree { root: Node, children: Vec<Vec<Node>> } impl Tree { fn count_nodes(&self, v: Node, nodes: &HashSet<Node>) -> i32 { let n = self.children[v].iter(). map(|&c| self.count_nodes(c, nodes)).sum::<i32>(); if nodes.contains(&v) || n > 0 { n + 1 } else { 0 } } fn create(g: Graph, root: Node) -> Tree { let N = g.g.len(); let mut children: Vec<Vec<Node>> = vec![vec![]; N]; let mut stack: Vec<Node> = vec![root]; while let Some(v) = stack.pop() { for &c in g.g[v].iter() { if children[c].is_empty() { children[v].push(c); stack.push(c) } } } Tree { root, children } } } //////////////////// process //////////////////// fn read_input() -> (usize, Vec<Edge>, Vec<Node>) { let v: Vec<usize> = read_vec(); let (N, _K) = (v[0], v[1]); let edges = (0..N-1).map(|_| read_edge()).collect(); let V_: Vec<Node> = read_vec(); let V: Vec<Node> = V_.into_iter().map(|v| v - 1).collect(); (N, edges, V) } fn F(N: usize, edges: Vec<Edge>, V: Vec<Node>) -> i32 { let graph = Graph::create(edges, N); let root = V[0]; let tree = Tree::create(graph, root); let nodes: HashSet<Node> = V.into_iter().collect(); tree.count_nodes(tree.root, &nodes) } fn main() { let (N, edges, V) = read_input(); println!("{}", F(N, edges, V)) }