どうしても素数判定に時間がかかりますが、どうにかならないのでしょうか。素数判定を高速に行うといえばエラトステネスのふるいですが、ふるい的な手法は使えないでしょうか。実は使えます。
対角線上の数を分けて考えましょう。まず、右下は平方数なので考える必要がありません。左上は4n2 + 1です。これは、n2 + 1でnが偶数のときを取ったものです。n2 + 1はふるい的手法が使えます。
さて、右上と左下ですが、右上は4n2 - 2n + 1で、左下は4n2 + 2n + 1です。なのですが、n2 + n + 1を考えると、3, 7, 13, 21, 31, 43, ...で右上と左下が交互になっています。なので、この2次式について考えればよいです。これはn2 + 1と同様にふるい的に素数判定ができます。やってみましょう。まず、n2 + n + 1を並べます。
1 3 7 13 21 31 43 57 73 91 111
1は無視して3が素数です。n = 1のときにp = 3ですが、このnをkとおくと、k + plでpで割り切れることになります。
(k + pl)(k + pl + 1) + 1 ≡ k2 + k + 1 ≡ 0 (mod p)
また、p - k - 1 + plでもpで割り切れます。
(p - k - 1 + pl)(p - k + pl) + 1 ≡ k2 + k + 1 ≡ 0 (mod p)
3の場合はどちらも1, 4, 7, ...になるので最初だけ残してその位置を3で割ります。
1 3 7 13 7 31 43 19 73 91 37
次に7は1でないので素数です。素数でなかったらより小さい素数で割り切れることになりますが、それはkより小さいものがあるはずなので、すでに出てきているはずです。k = 2なので、k = 2, 9, 16, ...とk = 4, 11, 18, ...で7で割り切れます。
1 3 7 13 1 31 43 19 73 13 37
同様に13も素数で、
1 3 7 13 1 31 43 19 73 1 37
で、全く割っていない3, 7, 31, 73が素数ということになります。
実装は、めんどうなのでふるいの範囲を決め打ちして、解が出てこなかったら範囲を倍にしています。これでも前の解法よりかなり速くなって、10%で50ms、9%で200ms、8%で460ms程度でした。
import scala.testing.Benchmark def div_pow(n :Long, d :Long) :Long = if(n % d != 0) n else div_pow(n / d, d) def long_range(first :Long, last :Long, diff :Long) :Iterator[Long] = if(diff > 0) Iterator.iterate(first)(diff +).takeWhile(last >) else if(diff < 0) Iterator.iterate(first)(diff +).takeWhile(last <) else Iterator() // prime numbers of n^2 + 1 // 1 2 5 10 17 26 37 50 def sieve1(max_n :Int) :Iterator[Int] = { val a = Array.range(0, max_n + 1).map(n => n.toLong * n + 1) for(k <- 3 to max_n by 2) a(k) /= 2 for(k <- Iterator.range(2, max_n + 1); p = a(k) if a(k) != 1L && p <= 2000000000; init = if(a(k) == k.toLong * k + 1) p + k else k; m <- long_range(init, max_n + 1, p) ++ long_range(p - k, max_n + 1, p)) a(m.toInt) = div_pow(a(m.toInt), p) Iterator.range(1, max_n + 1).map( k => if(a(k) == k.toLong * k + 1) 1 else 0) } // prime numbers of n^2 + n + 1 // 3 7 13 21 31 43 57 73 91 111 def sieve2(max_n :Int) :Iterator[Int] = { val a = Array.range(0, max_n + 1).map(n => n.toLong * n + n + 1) for(k <- 4 to max_n by 3) a(k) /= 3 for(k <- Iterator.range(2, max_n + 1); p = a(k) if a(k) != 1L; init = if(a(k) == k.toLong * k + k + 1) p + k else k; m <- long_range(init, max_n + 1, p) ++ long_range(p - k - 1, max_n + 1, p)) a(m.toInt) = div_pow(a(m.toInt), p) Iterator.range(1, max_n + 1).map( k => if(a(k) == k.toLong * k + k + 1) 1 else 0) } def calc_solution(size :Int) :Option[Int] = { def next(s :(Int,Int,Int,Iterator[Int],Iterator[Int])) = { val (len, num_p, num, it1, it2) = s it1.next val np = List(it1.next, it2.next, it2.next).sum (len + 2, num_p + np, num + 4, it1, it2) } val N = 8 val it1 = sieve1(size) val it2 = sieve2(size) val it = Iterator.iterate((1, 0, 1, it1, it2))(next).drop(1). take(size / 2 - 1).filter(s => s._2 * 100 < s._3 * N).map(_._1) if(it.hasNext) Some(it.next) else None } def solve() = { val it = Iterator.iterate((1000, calc_solution(1000)))( s => (s._1 * 2, calc_solution(s._1 * 2))) println (it.map(s => s._2).filter(None !=).next) } object Test extends Benchmark { def run() = solve } println (Test.runBenchmark(10))