ScalaでProject Euler(69)

Problem 39

めんどうなので、前回n'と書いていたものをnとします。
もうちょっとなんとかならないでしょうか。
例えば、p / 2 = q素数だったとしましょう。そうすると、

(l, m, n) = (q, 1, 1), (1, q, 1), (1, 1, q)

はどれも直角三角形ではないので、0個です。p / 2 = qrqrの2つの素数なる場合、q < r < 2qを満たせば、

(l, m, n) = (1, q, r)

の1個が直角三角形になります。すなわち、2つの素数の積で表されるなら直角三角形の個数は高々1個ということになります。
このように直角三角形の個数を素因数分解の形から概算して上から押さえることはできないでしょうか。
例えば、p / 2 = 360 = 23・32・5ならどうでしょう。まず、lが2のみから成る場合を考えます。そうすると、2を除いて、

(m, n) = (32, 5), (5, 32)

のどちからです。ここに2の因子を加えます。といっても加えられるのはmのほうだけなので、(32, 5) = (9, 5)はどう2の因子を加えても(2e・9, 5)でm < nになりえません。(5, 32)の方は、これでm < n < 2mになっています。これに2の因子を加えるとこの条件を満たさなくなります。一般に2の因子が十分多くあれば、mに2の因子をe個加えてこの条件を満たせるようなeがただ一つ存在します。mnに3と5の因子が一つもなければ条件を満たしません。3の因子のみまたは5の因子のみなら一通り、3の因子も5の因子もあればmnのどちらを3と5が占めるかで4通りあり、mnのどちらが大きいかでペアになっていてそのうちnが大きい方が条件を満たしうるので2通りあります。具体的には、p / 2 = 23・3・5として、

(l, m, n) = (23・5, 2, 3), (24・3, 22, 5), (23, 3, 5), (1, 23, 3・5)

となります。p / 2 = 2e13e25e3e1が十分に大きければ、2e2e3 + e2 + e3通りで、

(2e2 + 1)(2e3 + 1) / 2

となります(1/2は切り捨てます)。p / 2 = 2e13e25e37e4e1が十分に大きければ、

(2e2 + 1)(2e3 + 1)(2e4 + 1) / 2

となります。
今までの議論は2の因子が十分にあるときの話です。2が少ないときは直角三角形の個数をかなり過剰に見積もっていることになります。2の代わりに3が多い場合、たとえば、p / 2 = 34・7・11なら、

(l, m, n) = (32・11, 7, 32), (32・7, 32, 11), (34, 7, 11), (33, 11, 3・7), (1, 7・11, 34)

となります。すなわち、7, 11という分配があったときに3の調整で2通りの直角三角形ができます。だから、2で調整したときの2倍の個数の可能性があるということです。5以上で調整するときは高々1個です。だから、3の指数が他の指数の最大の2倍以下なら3で調整し、そうでなければ最大の指数を持つ素数で調整することになります。例えば、

2・33・52 -> (2 + 1)(6 + 1) / 2 = 10
2・35・52 -> (2 + 1)(4 + 1) - 1 = 14

このように概算して、それまでの実際の最大の個数以下ならちゃんと直角三角形の個数を計算せずに捨てます。そうでなければ計算します。
N = 106で計算したところ、前回の20倍程度の速度でした。

import scala.testing.Benchmark

def pow(n :Int, e :Int) = (1 to e).foldLeft(1)((x, y) => x * n)

type Facts = List[(Int,Int)]

def value(fs :Facts) =
    fs.foldLeft(1)((x, y) => x * pow(y._1, y._2))

def divide(fs1 :Facts, fs2 :Facts) :Facts = (fs1, fs2) match {
    case (_, Nil) => fs1
    case ((p1, e1) :: t1, (p2, e2) :: t2) if p1 == p2 =>
                if(e1 == e2) divide(t1, t2) else (p1, e1 - e2) :: divide(t1, t2)
    case (h1 :: t1, _) => h1 :: divide(t1, fs2)
    case _ => Nil
}

def sieve(N :Int) = {
    val a = Array.range(0, N)
    for(p <- Iterator.from(2).takeWhile(n => n * n < N).filter(n => a(n) == n);
                    k <- Iterator.range(p * 2, N, p))
        a(k) = p
    a
}

def divs(fs :Facts) :Iterator[Facts] = fs match {
    case Nil => Iterator(Nil)
    case (p, e) :: fst => for(fs2 <- divs(fst); e1 <- Iterator.range(0, e + 1))
                            yield if(e1 == 0) fs2 else (p, e1) :: fs2
}

def divs2(fs :Facts) :Iterator[Facts] = fs match {
    case Nil => Iterator(Nil)
    case (2, e) :: fst => divs2(fst).map((2, e) :: _)
    case f :: fst => divs2(fst).flatMap(x => Iterator(x, f :: x))
}

def divs3(fs :Facts) :Iterator[(Int,Int,Int)] =
    for(fs1 <- divs(fs); fs2 = divide(fs, fs1); fs3 <- divs2(fs2))
        yield (value(fs1), value(fs3), value(divide(fs2, fs3)))

def div_pow(n :Int, d :Int) :(Int,Int) =
    if(n % d != 0)
        (0, n)
    else {
        val (e, m) = div_pow(n / d, d)
        (e + 1, m)
    }

def is_right_triangle(x :(Int,Int,Int)) = {
    val (l, m, n) = x
    m < n && n < m * 2 && n % 2 == 1
}

def solve() = {
    val a = sieve(N / 2 + 1)
    
    def factorize(n :Int) :Facts =
        if(n == 1)
            Nil
        else {
            val p = a(n)
            val (e, m) = div_pow(n, p)
            (p, e) :: factorize(m)
        }
    
    def max(x :(Int,Int), y :(Int, Int)) = (x, y) match {
        case ((n1, m1), (n2, m2)) if m1 > m2 => (n1, m1)
        case ((n1, m1), (n2, m2))            => (n2, m2)
    }
    
    def estimate(fs :Facts) :Int = {
        def greater(m :Int, p :Int, mx :Int, p0 :Int) :Boolean =
            if(p == 3)
                m > mx * 2
            else if(p0 == 3)
                m * 2 > mx
            else
                m > mx
        
        def f(fs :Facts) :(Int,Int,Int) =
            fs match {
                case Nil => (1, 1, 1)
                case (p, e) :: fst => {
                    val m = 2 * e + 1
                    val (mx, pr, p0) = f(fst)
                    if(greater(m, p, mx, p0)) (m, pr * mx, p)
                    else (mx, pr * m, p0)
                }
            }
        
        val (_, m, p) = f(fs)
        if(p == 3) m - 1 else m / 2
    }
    
    def next(s :(Int,Int), n :Int) :(Int,Int) = {
        val (m0, n0) = s
        val fs = factorize(n / 2)
        val m_est = estimate(fs)
        if(m0 >= m_est)
            s
        else {
            val m = divs3(fs).filter(is_right_triangle).size
            if(m0 > m)
                s
            else
                (m, n)
        }
    }
    
    println (Iterator.range(2, N + 1, 2).foldLeft((1, 0))(next))
}

object Test extends Benchmark {
    def run() = solve
}

val N = 1e6.toInt
println (Test.runBenchmark(5))