ScalaでProject Euler(58)

Problem 29

少し戻って、ScalaでProject Euler(51)の続きです。

5乗のときで考えてみます。例えば32です。前回は、

U - (S1 ∪ S2 ∪ S3 ∪ S4) = (U - S1) ∩ (U - S2) ∩ (U - S3) ∩ (U - S4)

を利用しましたが、今回は包除原理そのもの、

|S1 ∪ S2 ∪ S3 ∪ S4| = |S1| + |S2| + |S3| + |S4| - |S1 ∩ S2| - |S1 ∩ S3| - |S1 ∩ S4| - |S2 ∩ S3| - |S2 ∩ S4| - |S3 ∩ S4| + |S1 ∩ S2 ∩ S3| + |S1 ∩ S3 ∩ S4| + |S1 ∩ S2 ∩ S4| + |S2 ∩ S3 ∩ S4| - |S1 ∩ S2 ∩ S3 ∩ S4|

を使います。4つの集合の各交わりを計算して奇数個の交わりなら足して、偶数個なら引くのですね。ただし少し工夫します。この交わりを公差は別にして範囲が同じものに分類します。すなわち、この問題の場合は、使われている集合の添え字が最も小さいもので分類します。その添え字をkとして、それらの集合をQkとすると、

Q4 = { S4 }
Q3 = { S3, S3 ∩ S4 }
Q2 = { S2, S2 ∩ S3, S2 ∩ S3 ∩ S4, S2 ∩ S4 }
Q_1 : 残り

となります。ここで、Q1はS1のみと、Q2〜Q4の要素とS1の交わりから成ります。というように計算すると速いです。なぜかというと、例えばQ1にはS1 ∩ S4とS1 ∩ S2 ∩ S4が含まれますが、ともに範囲はN/5までで同じで公差も4、しかし交わりの個数が2個と3個でなので符号が逆なのでキャンセルします。そうするとこのあとさらに狭い範囲を考えるときにもうこの集合は考えなくてもよくなります。範囲は同じなので、公差だけ考えればよくなります。公差が同じものはまとめて、

Q_4 : [(4, 1)]
Q_3 : [(3, 1),(12, -1)]
Q_2 : [(2, 1),(6, -1),(12, 1),(4, -1)]
Q_1 : [(1, 1),(2, -1),(6, 1),(3, -1)]

となります。タプルの第1項が公差で、第2項が同じ公差の数で1と-1があるとキャンセルします。
こうした考えで組んでみました。そうすると、N = 109で前回のコードより3ケタ速く、N = 6×1012でも1秒を切っていました。ただ、そのあとおかしな値になるのですが。

(追記)
N = 7 × 1012で早くもLongの範囲を超えていました。少し直したら1017までそれらしい答えが出ました。13秒くらいです。

import scala.math
import scala.collection.mutable.Map
import scala.testing.Benchmark

def gcd(n :Long, m :Long) :Long =
    if(m == 0) n else gcd(m, n % m)

def lcm(n :Long, m :Long) = n / gcd(n, m) * m

def pow(n :Long, e :Long) :Long =
    if(e == 0)
        1
    else {
        val m = pow(n, e / 2)
        if(e % 2 == 1)
            m * m * n
        else
            m * m
    }

// [n^(1/e)]
def int_root(n :Long, e :Long) = {
    def dec(m :Long) :Long =
        if(pow(m - 1, e) < n) m else dec(m - 1)
    
    def inc(m :Long) :Long =
        if(pow(m + 1, e) > n) m else inc(m + 1)
    
    val m = math.pow(n, 1. / e).toLong
    if(pow(m, e) > n) dec(m) else inc(m)
}

def rangeLong(a :Long, b :Long) =
    Iterator.iterate(a)(1L +).takeWhile(b >)

def divs(n :Long) =
    rangeLong(1L, n).filter(n % _ == 0)

def num_pows() = {
    val a = Iterator.from(2).map(int_root(N, _) - 1).takeWhile(0 <).toArray
    for(e <- Iterator.range(a.size + 1, 3, -1); e2 <- divs(e).drop(1))
        a(e2.toInt - 2) -= a(e.toInt - 2)
    Iterator.range(0, a.size).map(a(_))
}

def fraction(num :Long, den :Long) = {
    val d = gcd(num, den)
    (num / d, den / d)
}

def count_overlap(e :Int) :Long = {
    def add_value(m :Map[Long,Long], key :Long, value :Long) = {
        if(m.contains(key))
            m(key) += value
        else
            m(key) = value
    }
    
    def enum_terms(f :(Long,Long), as :List[List[(Long,Long)]]) = {
        val m = Map[Long,Long]()
        m(f._1) = 1
        val upper = N * f._1 / f._2
        for(a <- as; (d, v) <- a; d1 = lcm(d, f._1) if d1 <= upper)
            add_value(m, lcm(d, d1), -v)
        m.toList.filter(_._2 != 0)
    }
    
    def enum_t(fs :List[(Long,Long)], ms :List[List[(Long,Long)]] = Nil)
                                                :List[List[(Long,Long)]] =
        if(fs == Nil)
            Nil
        else {
            val m = enum_terms(fs.head, ms)
            m :: enum_t(fs.tail, m :: ms)
        }
    
    val fs = List.range(1, e.toLong).map(fraction(_, e.toLong)).
                    toMap.toList.sort((x, y) => x._1 * y._2 < x._2 * y._1)
    val ms = enum_t(fs.reverse).reverse
    val a = for(((d, n), b) <- fs.zip(ms); (d1, n1) <- b)
                yield N / n / (d1 / d) * n1
    a.sum - 1
}

val N = 6e12.toLong
def solve() = {
    val a = num_pows()
    val it = a.zip(Iterator.from(2).map(count_overlap))
    val Nsq = BigInt(N - 1) * (N - 1)
    println (Nsq - it.map(x => x._1.toLong * x._2).sum)
}

object Test extends Benchmark {
    def run() = solve
}

println (Test.runBenchmark(10))