ScalaでProject Euler(91)

Problem 58

どうしても素数判定に時間がかかりますが、どうにかならないのでしょうか。素数判定を高速に行うといえばエラトステネスのふるいですが、ふるい的な手法は使えないでしょうか。実は使えます。
対角線上の数を分けて考えましょう。まず、右下は平方数なので考える必要がありません。左上は4n2 + 1です。これは、n2 + 1でnが偶数のときを取ったものです。n2 + 1はふるい的手法が使えます
さて、右上と左下ですが、右上は4n2 - 2n + 1で、左下は4n2 + 2n + 1です。なのですが、n2 + n + 1を考えると、3, 7, 13, 21, 31, 43, ...で右上と左下が交互になっています。なので、この2次式について考えればよいです。これはn2 + 1と同様にふるい的に素数判定ができます。やってみましょう。まず、n2 + n + 1を並べます。

1 3 7 13 21 31 43 57 73 91 111

1は無視して3が素数です。n = 1のときにp = 3ですが、このnkとおくと、k + plpで割り切れることになります。

(k + pl)(k + pl + 1) + 1 ≡ k2 + k + 1 ≡ 0 (mod p)

また、p - k - 1 + plでもpで割り切れます。

(p - k - 1 + pl)(p - k + pl) + 1 ≡ k2 + k + 1 ≡ 0 (mod p)

3の場合はどちらも1, 4, 7, ...になるので最初だけ残してその位置を3で割ります。

1 3 7 13 7 31 43 19 73 91 37

次に7は1でないので素数です。素数でなかったらより小さい素数で割り切れることになりますが、それはkより小さいものがあるはずなので、すでに出てきているはずです。k = 2なので、k = 2, 9, 16, ...とk = 4, 11, 18, ...で7で割り切れます。

1 3 7 13 1 31 43 19 73 13 37

同様に13も素数で、

1 3 7 13 1 31 43 19 73 1 37

で、全く割っていない3, 7, 31, 73が素数ということになります。

実装は、めんどうなのでふるいの範囲を決め打ちして、解が出てこなかったら範囲を倍にしています。これでも前の解法よりかなり速くなって、10%で50ms、9%で200ms、8%で460ms程度でした。

import scala.testing.Benchmark

def div_pow(n :Long, d :Long) :Long =
    if(n % d != 0) n else div_pow(n / d, d)

def long_range(first :Long, last :Long, diff :Long) :Iterator[Long] =
    if(diff > 0)
        Iterator.iterate(first)(diff +).takeWhile(last >)
    else if(diff < 0)
        Iterator.iterate(first)(diff +).takeWhile(last <)
    else
        Iterator()

// prime numbers of n^2 + 1
// 1 2 5 10 17 26 37 50
def sieve1(max_n :Int) :Iterator[Int] = {
    val a = Array.range(0, max_n + 1).map(n => n.toLong * n + 1)
    for(k <- 3 to max_n by 2) a(k) /= 2
    for(k <- Iterator.range(2, max_n + 1);
        p = a(k) if a(k) != 1L && p <= 2000000000;
        init = if(a(k) == k.toLong * k + 1) p + k else k;
        m <- long_range(init, max_n + 1, p) ++
             long_range(p - k, max_n + 1, p))
                a(m.toInt) = div_pow(a(m.toInt), p)
    
    Iterator.range(1, max_n + 1).map(
                k => if(a(k) == k.toLong * k + 1) 1 else 0)
}

// prime numbers of n^2 + n + 1
// 3 7 13 21 31 43 57 73 91 111
def sieve2(max_n :Int) :Iterator[Int] = {
    val a = Array.range(0, max_n + 1).map(n => n.toLong * n + n + 1)
    for(k <- 4 to max_n by 3) a(k) /= 3
    for(k <- Iterator.range(2, max_n + 1);
        p = a(k) if a(k) != 1L;
        init = if(a(k) == k.toLong * k + k + 1) p + k else k;
        m <- long_range(init, max_n + 1, p) ++
             long_range(p - k - 1, max_n + 1, p))
                a(m.toInt) = div_pow(a(m.toInt), p)
    
    Iterator.range(1, max_n + 1).map(
                k => if(a(k) == k.toLong * k + k + 1) 1 else 0)
}

def calc_solution(size :Int) :Option[Int] = {
    def next(s :(Int,Int,Int,Iterator[Int],Iterator[Int])) = {
        val (len, num_p, num, it1, it2) = s
        it1.next
        val np = List(it1.next, it2.next, it2.next).sum
        (len + 2, num_p + np, num + 4, it1, it2)
    }
    
    val N = 8
    val it1 = sieve1(size)
    val it2 = sieve2(size)
    val it = Iterator.iterate((1, 0, 1, it1, it2))(next).drop(1).
            take(size / 2 - 1).filter(s => s._2 * 100 < s._3 * N).map(_._1)
    if(it.hasNext)
        Some(it.next)
    else
        None
}

def solve() = {
    val it = Iterator.iterate((1000, calc_solution(1000)))(
                            s => (s._1 * 2, calc_solution(s._1 * 2)))
    println (it.map(s => s._2).filter(None !=).next)
}

object Test extends Benchmark {
    def run() = solve
}

println (Test.runBenchmark(10))