AtCoder Beginner Contest 342 D

https://atcoder.jp/contests/abc342/tasks/abc342_d

基本的には、素因数分解して、指数が奇数の素数だけ残して、同じものを数えればよいです。

// Square Pair
#![allow(non_snake_case)]


//////////////////// library ////////////////////

fn read<T: std::str::FromStr>() -> T {
    let mut line = String::new();
    std::io::stdin().read_line(&mut line).ok();
    line.trim().parse().ok().unwrap()
}

fn read_vec<T: std::str::FromStr>() -> Vec<T> {
    read::<String>().split_whitespace()
            .map(|e| e.parse().ok().unwrap()).collect()
}

fn div_pow(n: u32, d: u32) -> (u32, u32) {
    let mut e: u32 = 0;
    let mut m = n;
    while m % d == 0 {
        e += 1;
        m /= d
    }
    (e, m)
}


//////////////////// Factors ////////////////////

struct Factors {
    fs: Vec<(u32, u32)>
}

impl Factors {
    fn make_min_prime_table(N: usize) -> Vec<u32> {
        let mut a: Vec<u32> = (0..N+1).map(|n| n as u32).collect();
        let mut b: Vec<u32> = vec![1_u32; N+1];
        for p in (2..).take_while(|&p| p*p <= N) {
            if a[p] == 1 {
                continue
            }
            for n in (p..N+1).step_by(p) {
                let (_, m) = div_pow(a[n], p as u32);
                a[n] = m;
                if b[n] == 1 {
                    b[n] = p as u32
                }
            }
        }
        
        for n in 2..N+1 {
            if a[n] > 1 {
                b[n] = a[n]
            }
        }
        b
    }
    
    fn factorize(n: u32, table: &Vec<u32>) -> Factors {
        if n == 1 {
            return Factors { fs: vec![] }
        }
        
        let p = table[n as usize];
        let (e, m) = div_pow(n, p);
        let fs1 = Factors { fs: vec![(p, e)] };
        let fs2 = Factors::factorize(m, table);
        fs1 * fs2
    }
}

use std::ops::Mul;

impl Mul for Factors {
    type Output = Self;
    
    fn mul(self, other: Self) -> Self {
        let mut fs = self.fs.to_vec();
        for &a in other.fs.iter() {
            fs.push(a)
        }
        Self { fs }
    }
}


//////////////////// process ////////////////////

fn read_input() -> Vec<u32> {
    let _N: usize = read();
    let A: Vec<u32> = read_vec();
    A
}

fn not_square(n: u32, table: &Vec<u32>) -> u32 {
    if n <= 3 {
        return n
    }
    
    let fs = Factors::factorize(n, table);
    let ps: Vec<u32> = fs.fs.into_iter().filter(|&(_, e)| e % 2 == 1).
                                            map(|(p, _)| p).collect();
    ps.into_iter().fold(1, |x, y| x*y)
}

use std::collections::HashMap;

fn count(B: Vec<u32>) -> HashMap<u32, usize> {
    let mut c: HashMap<u32, usize> = HashMap::new();
    for b in B.into_iter() {
        let e = c.entry(b).or_insert(0);
        (*e) += 1
    }
    c
}

fn F(A: Vec<u32>) -> usize {
    let N = A.len();
    let max_value = A.iter().map(|&n| n).max().unwrap();
    let table = Factors::make_min_prime_table(max_value as usize);
    let B: Vec<u32> = A.iter().map(|&n| not_square(n, &table)).collect();
    let c = count(B);
    let mut s = 0;
    for (m, f) in c.into_iter() {
        if m == 0 {
            s += f*N - f*(f+1)/2
        }
        else {
            s += f*(f-1)/2
        }
    }
    s
}

fn main() {
    let A = read_input();
    println!("{}", F(A))
}