AtCoder Beginner Contest 368 D

https://atcoder.jp/contests/abc368/tasks/abc368_d

エッジからグラフを作って、それを木にして、ルートから辿ります。Vにあるノードをルートにすると間違いが無いです。

// Minimum Steiner Tree
#![allow(non_snake_case)]

use std::collections::HashSet;


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}


//////////////////// Tree ////////////////////

type Node = usize;
type Edge = (Node, Node);

fn read_edge() -> Edge {
    let v: Vec<Node> = read_vec();
    (v[0]-1, v[1]-1)
}

struct Graph {
    g: Vec<Vec<Node>>
}

impl Graph {
    fn create(edges: Vec<Edge>, N: usize) -> Graph {
        let mut g : Vec<Vec<Node>> = vec![vec![]; N];
        for (v1, v2) in edges.into_iter() {
            g[v1].push(v2);
            g[v2].push(v1)
        }
        Graph { g }
    }
}

struct Tree {
    root: Node,
    children: Vec<Vec<Node>>
}

impl Tree {
    fn count_nodes(&self, v: Node, nodes: &HashSet<Node>) -> i32 {
        let n = self.children[v].iter().
                        map(|&c| self.count_nodes(c, nodes)).sum::<i32>();
        if nodes.contains(&v) || n > 0 {
            n + 1
        }
        else {
            0
        }
    }
    
    fn create(g: Graph, root: Node) -> Tree {
        let N = g.g.len();
        let mut children: Vec<Vec<Node>> = vec![vec![]; N];
        let mut stack: Vec<Node> = vec![root];
        while let Some(v) = stack.pop() {
            for &c in g.g[v].iter() {
                if children[c].is_empty() {
                    children[v].push(c);
                    stack.push(c)
                }
            }
        }
        Tree { root, children }
    }
}


//////////////////// process ////////////////////

fn read_input() -> (usize, Vec<Edge>, Vec<Node>) {
    let v: Vec<usize> = read_vec();
    let (N, _K) = (v[0], v[1]);
    let edges = (0..N-1).map(|_| read_edge()).collect();
    let V_: Vec<Node> = read_vec();
    let V: Vec<Node> = V_.into_iter().map(|v| v - 1).collect();
    (N, edges, V)
}

fn F(N: usize, edges: Vec<Edge>, V: Vec<Node>) -> i32 {
    let graph = Graph::create(edges, N);
    let root = V[0];
    let tree = Tree::create(graph, root);
    let nodes: HashSet<Node> = V.into_iter().collect();
    tree.count_nodes(tree.root, &nodes)
}

fn main() {
    let (N, edges, V) = read_input();
    println!("{}", F(N, edges, V))
}