AtCoder Beginner Contest 357 E

https://atcoder.jp/contests/abc357/tasks/abc357_e

木よりエッジが一つ多いので、ループができます。連結成分ごとに一つループができます。ループに含まれないノードはエッジを辿っていくとループに行き着きます。
なので、ループ上のノードはループのサイズだけ到達可能なノードの個数があり、それ以外のノードは逆に辿っていくごとに1だけ増えていきます。

// Reachability in Functional Graph
#![allow(non_snake_case)]

use std::collections::{HashSet, HashMap};
use std::hash::Hash;


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}


//////////////////// Union-Find ////////////////////

struct UnionFind<T: Copy + Eq + Hash> {
    parents: HashMap<T, T>,
    heights: HashMap<T, i32>
}

impl<T: Copy + Eq + Hash> UnionFind<T> {
    fn contains(&self, v: &T) -> bool {
        self.parents.contains_key(v)
    }
    
    // return (r1, r2, h1, h2)
    fn join(&mut self, v1: T, v2: T) -> (T, T, i32, i32) {
        let r1 = self.root(v1);
        let r2 = self.root(v2);
        let h1 = self.height(v1);
        let h2 = self.height(v2);
        let ret = (r1, r2, h1, h2);
        if h1 == h2 {
            self.parents.insert(r1, r2);
            self.heights.insert(r1, h1 + 1);
        }
        else if h1 < h2 {
            self.parents.insert(r1, r2);
        }
        else {
            self.parents.insert(r2, r1);
        }
        ret
    }
    
    fn restore(&mut self, ret: (T, T, i32, i32)) {
        let (r1, r2, h1, h2) = ret;
        self.parents.insert(r1, r1);
        self.parents.insert(r2, r2);
        self.heights.insert(r1, h1);
        self.heights.insert(r2, h2);
    }
    
    fn add(&mut self, v: T) {
        self.parents.insert(v, v);
        self.heights.insert(v, 1);
    }
    
    fn remove(&mut self, v: T) {
        self.parents.remove(&v);
        self.heights.remove(&v);
    }
    
    fn root(&self, mut v: T) -> T {
        loop {
            let &v1 = self.parents.get(&v).unwrap();
            if v1 == v {
                break
            }
            v = v1
        }
        v
    }
    
    fn height(&self, v: T) -> i32 {
        *self.heights.get(&v).unwrap()
    }
    
    fn num_connected(&self) -> usize {
        let mut s: HashSet<T> = HashSet::new();
        for v in self.parents.keys() {
            let r = self.root(*v);
            s.insert(r);
        }
        s.len()
    }
    
    fn new(nodes: &Vec<T>) -> UnionFind<T> {
        let parents: HashMap<T, T> = nodes.iter().map(|&v| (v, v)).collect();
        let heights: HashMap<T, i32> = nodes.iter().map(|&v| (v, 1)).collect();
        UnionFind { parents, heights }
    }
}


//////////////////// Graph ////////////////////

type Node = usize;
type Edge = (Node, Node);

fn read_edge() -> Edge {
    let v = read_vec::<Node>();
    (v[0]-1, v[1]-1)
}

// directed
struct Graph {
    g: HashMap<Node, Vec<Node>>
}

impl Graph {
    fn divide_into_connected(&self) -> Vec<Graph> {
        let N = self.g.len();
        let nodes: Vec<Node> = (1..N+1).collect();
        let mut uf = UnionFind::new(&nodes);
        for (&u, vs) in self.g.iter() {
            for &v in vs.iter() {
                uf.join(u, v);
            }
        }
        
        let mut m: HashMap<Node, Vec<Node>> = HashMap::new();
        for v in 1..N+1 {
            let r = uf.root(v);
            let e = m.entry(r).or_insert(vec![]);
            (*e).push(v)
        }
        
        let vss: Vec<Vec<Node>> = m.into_iter().map(|(_, vs)| vs.to_vec()).
                                                                    collect();
        let mut subgraphs: Vec<Graph> = vec![];
        for us in vss.into_iter() {
            let mut g: HashMap<Node, Vec<Node>> = HashMap::new();
            for u in us.into_iter() {
                let vs = self.g.get(&u).unwrap();
                g.insert(u, vs.to_vec());
            }
            subgraphs.push(Graph { g })
        }
        subgraphs
    }
    
    fn reverse(&self) -> Graph {
        let mut g: HashMap<Node, Vec<Node>> = HashMap::new();
        for &v in self.g.keys() {
            g.insert(v, vec![]);
        }
        for (&u, vs) in self.g.iter() {
            for &v in vs.iter() {
                g.get_mut(&v).unwrap().push(u)
            }
        }
        Graph { g }
    }
    
    fn create(A: Vec<Node>) -> Graph {
        let mut g: HashMap<Node, Vec<Node>> = HashMap::new();
        for (i, v) in A.into_iter().enumerate() {
            g.insert(i+1, vec![v]);
        }
        Graph { g }
    }
}


//////////////////// process ////////////////////

fn read_input() -> Vec<Node> {
    let _N: usize = read();
    let A: Vec<Node> = read_vec();
    A
}

fn find_loop(graph: &Graph) -> Vec<Node> {
    let mut v = *graph.g.keys().next().unwrap();
    let mut visited = HashSet::<Node>::new();
    while !visited.contains(&v) {
        visited.insert(v);
        let vs = graph.g.get(&v).unwrap();
        v = vs[0];
    }
    
    let mut node_loop = Vec::<Node>::new();
    let v0 = v;
    loop {
        node_loop.push(v);
        let vs = graph.g.get(&v).unwrap();
        v = vs[0];
        if v == v0 {
            break
        }
    }
    node_loop
}

fn F_each(graph: Graph) -> u64 {
    let node_loop = find_loop(&graph);
    let mut nums = HashMap::<Node, u64>::new();
    for &v in node_loop.iter() {
        nums.insert(v, node_loop.len() as u64);
    }
    
    let rev_g = graph.reverse();
    let mut stack = node_loop.to_vec();
    while let Some(u) = stack.pop() {
        let vs = rev_g.g.get(&u).unwrap();
        for &v in vs.iter() {
            if !nums.contains_key(&v) {
                nums.insert(v, *(nums.get(&u).unwrap()) + 1);
                stack.push(v)
            }
        }
    }
    
    nums.values().sum::<u64>()
}

fn F(A: Vec<Node>) -> u64 {
    let graph = Graph::create(A);
    let subgraphs = graph.divide_into_connected();
    subgraphs.into_iter().map(|g| F_each(g)).sum::<u64>()
}

fn main() {
    let A = read_input();
    println!("{}", F(A))
}