Scalaで素数判定(2)

エラトステネスのふるいです。ふつうに書いてみます。

def sieve1(N :Int) :List[Int] = {
    val a = new Array[Boolean](N + 1)
    for(n <- 2 to N) a(n) = true
    for(p <- Iterator.range(2, N + 1).takeWhile(p => p * p <= N) if a(p);
        k <- Iterator.range(p * 2, N + 1, p))
            a(k) = false
    List.range(2, N + 1).filter(n => a(n))
}

45msでした。Pythonの5倍程度の速さですね。イマイチですね。forを二重にすると、

def sieve1a(N :Int) :List[Int] = {
    val a = new Array[Boolean](N + 1)
    for(p <- Iterator.range(2, N + 1).takeWhile(p => p * p <= N) if !a(p))
        for(k <- p * 2 to N by p)
            a(k) = true
    List.range(2, N + 1).filter(n => !a(n))
}

32ms、少し効果ありますね。
2と3の倍数を無視すると、

def sieve3(N :Int) :List[Int] = {
    val M = if((N - 1) % 6 < 4) N / 6 * 2 + 1 else (N + 1) / 3
    val a = new Array[Boolean](M)
    for(n <- 1 to M - 1) a(n) = true
    for(n <- Iterator.range(1, N + 1).takeWhile(p => p * p <= N) if a(n)) {
        val p = n * 3 + (n & 1) + 1
        for(k <- Iterator.range(n + p * 2, M, p * 2))
            a(k) = false
        for(k <- Iterator.range(-n - 1 + p * 2, M, p * 2))
            a(k) = false
    }
    2 :: 3 :: List.range(1, M).filter(n => a(n)).map(n => n * 3 + (n & 1) + 1)
}

23msでした。Pythonの3倍程度ですね。たいしたことないですね。ちなみに、C++では3msとかでした。1000万までで5倍くらい速いです。

#include <iostream>
#include <vector>
#include <windows.h>
#pragma comment(lib, "winmm.lib")

using namespace std;

void sieve3(int N, vector<int>& primes) {
    const int   M = (N - 1) % 6 < 4 ? N / 6 * 2 + 1 : (N + 1) / 3;
    vector<bool>    a(M, true);
    for(int n = 1; n < M; n++) a[n] = true;
    for(int n = 1; n * n <= N; n++) {
        if(a[n]) {
            const int   p = n * 3 + (n & 1) + 1;
            for(int k = n + p * 2; k < M; k += p * 2)
                a[k] = false;
            for(int k = -n - 1 + p * 2; k < M; k += p * 2)
                a[k] = false;
        }
    }
    
    primes.push_back(2);
    primes.push_back(3);
    for(int n = 1; n < M; n++) {
        if(a[n])
            primes.push_back(n * 3 + (n & 1) + 1);
    }
}

int main() {
    const int   t0 = timeGetTime();
    const int   N = (int)1e7;
    vector<int> primes;
    sieve3(N, primes);
    cout << primes.size() << endl;
    cout << timeGetTime() - t0 << endl;
}