Project Euler 9

素直に書くととても遅い。


n = 1000
m = div n 2
is_right_angle (a, b, c) = a * a + b * b == c * c
product_t (a, b, c) = a * b * c
triplets = [ (a, b, c) | a <- [ 1..m ], b <- [ 1..m ], c <- [ 1..m ],
a + b + c == n && a < b && b < c ]
main = print(map product_t (filter is_right_angle triplets))

直積をタプルのリストで表現して、タプルの要素を取り出すのは、


product_t (a, b, c) = a * b * c


はじめから和が1000になるような組を出すと速くなる。1000を2つ和に分割して、さらに一方を分割する。


n = 1000
is_right_angle (a:b:c:[]) = a * a + b * b == c * c

-- nをm個のl以上の和に分割する
divide n m l | n < l = []
| m == 1 = [ [n] ]
divide n m l = foldr (++) [] [ f p n m | p <- [ l..(div n m) ] ]
f p n m = map (\x -> [p] ++ x) (divide (n - p) (m - 1) (p + 1))

main = print(map product (filter is_right_angle (divide n 3 1)))

タプルではなく、リストにすると再帰が使える。


本当は、ピタゴラス数の生成の公式を使うと速い。


n = 1000000

divisors n = [ d | d <- [1..n], mod n d == 0 ]
divide n 2 = [ [d, div n d] | d <- divisors n ]
divide n m = foldl (++) []
(map (\(p:ps) -> [ [d, div p d] ++ ps | d <- divisors p ])
(divide n (m - 1)))

is_valid (l:m:mn:[]) = mod mn 2 == 1
&& m < mn && mn < m * 2 && (gcd m mn) == 1
mult (l:m:mn:[]) = 2 * l^3 * (m^4 - (mn - m)^4) * m * (mn - m)
main = print(map mult (filter is_valid (divide (div n 2) 3)))

コードは長くなるが、素因数分解を使うとさらに速い。


n = 1000000
m = div n 2
primes = 2:[3,5..]

div_pow n p | mod n p /= 0 = (n, 0)
| otherwise = (\(n, e) -> (n, e + 1)) (div_pow (div n p) p)

factorize n (p:ps) | n == 1 = []
| n < p * p = [ (n, 1) ]
| mod n p == 0 = (\(p:ps) (n, e) -> (p, e) :
(factorize n ps)) (p:ps) (div_pow n p)
| n > 0 = factorize n ps

div_f f [] = f
div_f ((p1,e1):f1) ((p2,e2):f2)
| p1 < p2 = (p1,e1):(div_f f1 ( (p2,e2):f2))
| p1 == p2 = (p1, e1 - e2):(div_f f1 f2)

divisors [] = [[]]
divisors (f:fs) = [ ((fst f), e):y | e <- [0..(snd f)], y <- divisors fs ]

divide n 2 = [ [d, div_f n d] | d <- divisors n ]
divide n m = foldl (++) []
(map (\(p:ps) -> [ [d, div_f p d] ++ ps | d <- divisors p ])
(divide n (m - 1)))

value f = product (map (\(p,e) -> p^e) f)

is_valid (l:m:mn:[]) = mod mn 2 == 1
&& m < mn && mn < m * 2 && (gcd m mn) == 1
mult (l:m:mn:[]) = 2 * l^3 * (m^4 - (mn - m)^4) * m * (mn - m)

divs = [ [ value f | f <- c ] | c <- (divide (factorize m primes) 3) ]
main = print(map mult (filter is_valid divs))