https://atcoder.jp/contests/abc357/tasks/abc357_e
木よりエッジが一つ多いので、ループができます。連結成分ごとに一つループができます。ループに含まれないノードはエッジを辿っていくとループに行き着きます。
なので、ループ上のノードはループのサイズだけ到達可能なノードの個数があり、それ以外のノードは逆に辿っていくごとに1だけ増えていきます。
// Reachability in Functional Graph #![allow(non_snake_case)] use std::collections::{HashSet, HashMap}; use std::hash::Hash; //////////////////// library //////////////////// fn read<T: std::str::FromStr>() -> T { let mut line = String::new(); std::io::stdin().read_line(&mut line).ok(); line.trim().parse().ok().unwrap() } fn read_vec<T: std::str::FromStr>() -> Vec<T> { read::<String>().split_whitespace() .map(|e| e.parse().ok().unwrap()).collect() } //////////////////// Union-Find //////////////////// struct UnionFind<T: Copy + Eq + Hash> { parents: HashMap<T, T>, heights: HashMap<T, i32> } impl<T: Copy + Eq + Hash> UnionFind<T> { fn contains(&self, v: &T) -> bool { self.parents.contains_key(v) } // return (r1, r2, h1, h2) fn join(&mut self, v1: T, v2: T) -> (T, T, i32, i32) { let r1 = self.root(v1); let r2 = self.root(v2); let h1 = self.height(v1); let h2 = self.height(v2); let ret = (r1, r2, h1, h2); if h1 == h2 { self.parents.insert(r1, r2); self.heights.insert(r1, h1 + 1); } else if h1 < h2 { self.parents.insert(r1, r2); } else { self.parents.insert(r2, r1); } ret } fn restore(&mut self, ret: (T, T, i32, i32)) { let (r1, r2, h1, h2) = ret; self.parents.insert(r1, r1); self.parents.insert(r2, r2); self.heights.insert(r1, h1); self.heights.insert(r2, h2); } fn add(&mut self, v: T) { self.parents.insert(v, v); self.heights.insert(v, 1); } fn remove(&mut self, v: T) { self.parents.remove(&v); self.heights.remove(&v); } fn root(&self, mut v: T) -> T { loop { let &v1 = self.parents.get(&v).unwrap(); if v1 == v { break } v = v1 } v } fn height(&self, v: T) -> i32 { *self.heights.get(&v).unwrap() } fn num_connected(&self) -> usize { let mut s: HashSet<T> = HashSet::new(); for v in self.parents.keys() { let r = self.root(*v); s.insert(r); } s.len() } fn new(nodes: &Vec<T>) -> UnionFind<T> { let parents: HashMap<T, T> = nodes.iter().map(|&v| (v, v)).collect(); let heights: HashMap<T, i32> = nodes.iter().map(|&v| (v, 1)).collect(); UnionFind { parents, heights } } } //////////////////// Graph //////////////////// type Node = usize; type Edge = (Node, Node); fn read_edge() -> Edge { let v = read_vec::<Node>(); (v[0]-1, v[1]-1) } // directed struct Graph { g: HashMap<Node, Vec<Node>> } impl Graph { fn divide_into_connected(&self) -> Vec<Graph> { let N = self.g.len(); let nodes: Vec<Node> = (1..N+1).collect(); let mut uf = UnionFind::new(&nodes); for (&u, vs) in self.g.iter() { for &v in vs.iter() { uf.join(u, v); } } let mut m: HashMap<Node, Vec<Node>> = HashMap::new(); for v in 1..N+1 { let r = uf.root(v); let e = m.entry(r).or_insert(vec![]); (*e).push(v) } let vss: Vec<Vec<Node>> = m.into_iter().map(|(_, vs)| vs.to_vec()). collect(); let mut subgraphs: Vec<Graph> = vec![]; for us in vss.into_iter() { let mut g: HashMap<Node, Vec<Node>> = HashMap::new(); for u in us.into_iter() { let vs = self.g.get(&u).unwrap(); g.insert(u, vs.to_vec()); } subgraphs.push(Graph { g }) } subgraphs } fn reverse(&self) -> Graph { let mut g: HashMap<Node, Vec<Node>> = HashMap::new(); for &v in self.g.keys() { g.insert(v, vec![]); } for (&u, vs) in self.g.iter() { for &v in vs.iter() { g.get_mut(&v).unwrap().push(u) } } Graph { g } } fn create(A: Vec<Node>) -> Graph { let mut g: HashMap<Node, Vec<Node>> = HashMap::new(); for (i, v) in A.into_iter().enumerate() { g.insert(i+1, vec![v]); } Graph { g } } } //////////////////// process //////////////////// fn read_input() -> Vec<Node> { let _N: usize = read(); let A: Vec<Node> = read_vec(); A } fn find_loop(graph: &Graph) -> Vec<Node> { let mut v = *graph.g.keys().next().unwrap(); let mut visited = HashSet::<Node>::new(); while !visited.contains(&v) { visited.insert(v); let vs = graph.g.get(&v).unwrap(); v = vs[0]; } let mut node_loop = Vec::<Node>::new(); let v0 = v; loop { node_loop.push(v); let vs = graph.g.get(&v).unwrap(); v = vs[0]; if v == v0 { break } } node_loop } fn F_each(graph: Graph) -> u64 { let node_loop = find_loop(&graph); let mut nums = HashMap::<Node, u64>::new(); for &v in node_loop.iter() { nums.insert(v, node_loop.len() as u64); } let rev_g = graph.reverse(); let mut stack = node_loop.to_vec(); while let Some(u) = stack.pop() { let vs = rev_g.g.get(&u).unwrap(); for &v in vs.iter() { if !nums.contains_key(&v) { nums.insert(v, *(nums.get(&u).unwrap()) + 1); stack.push(v) } } } nums.values().sum::<u64>() } fn F(A: Vec<Node>) -> u64 { let graph = Graph::create(A); let subgraphs = graph.divide_into_connected(); subgraphs.into_iter().map(|g| F_each(g)).sum::<u64>() } fn main() { let A = read_input(); println!("{}", F(A)) }