ビットごとのエラトステネスのふるい

Pythonで何も考えずにエラトステネスのふるいのコードを書くとこんな感じでしょうか。

def sieve(max_n):
    a = [ True ] * L
    for p in takewhile(lambda n: n * n < L,
                    (n for n in xrange(2, L) if a[n])):
        for k in xrange(p * 2, L, p):
            a[k] = False

Booleanの配列を作ってふるいにかけます。このBooleanをビットにしようというものです。こうするとメモリが節約できます。さらに2の倍数、3の倍数は最初から無視する、すなわち6で割って余り1または5の数しか考えないと、1/3にメモリを節約できます。整数1個で32ビットなので96までの素数性を表すことができます。1億までで4MBくらいで済むはずですね。この配列をそのまま保持します。こうすると、メモリが節約できるだけでなく次のようなメリットがあります。

比較的速く素数を列挙できる
素数性の判定がO(1)でできる

最初のほうは、N付近で平均logN / 3個に1個素数があるはずなので、N = 10000くらいだと3個に1個ということになります。
ただし、この方法は6で割って余りが1と5ということで、きれいにはいかない部分があります。該当する数が5 7 11 13…と差が2 4 2 4…となるからです。1 5 7…とそのインデックスで変換する方法がひつようです。

k 0 1 2  3  4  5  6  7  8 …
n 1 5 7 11 13 17 19 23 25 …

下から上は簡単です。3で割るだけですね。

k = [n / 3]

上から下は場合分けして考えましょう。kが偶数のとき、

n = 3k + 1

kが奇数のとき、

n = 3k + 2

これは次のようにまとめられます。

n = k * 3 + 1 + (k & 1)

これでコードが組めます。

from itertools import *

class cPrimes:
    def __init__(self, N):
        self.N = N
        self.L = N / 96 + 1
        self.sieve()
    
    def sieve(self):
        self.bits = [ -1 ] * self.L
        for p in takewhile(lambda p: p * p < N,
                            islice(self.enumerate(), 2, None)):
            k0 = p / 3
            for k in xrange(k0 + p * 2, N / 3 + 1, p * 2):
                self.erase_bit(k)
            for k in xrange(-k0 - 1 + p * 2, N / 3 + 1, p * 2):
                self.erase_bit(k)
    
    def enumerate(self):
        yield 2
        yield 3
        k0 = 1
        for m, B in enumerate(self.bits):
            for k in xrange(k0, 32):
                if ((B >> k) & 1) == 1:
                    yield m * 96 + k * 3 + 1 + (k & 1)
            k0 = 0
    
    def is_prime(self, n):
        if n < 4:
            return n >= 2
        else:
            q, r = divmod(n, 6)
            if r == 1:
                k = q >> 4
                l = (q & 15) << 1
                return ((self.bits[q>>4] >> l) & 1) == 1
            elif r == 5:
                k = q >> 4
                l = ((q & 15) << 1) + 1
                return ((self.bits[q>>4] >> l) & 1) == 1
            else:
                return False
    
    def erase_bit(self, k):
        m = k >> 5
        l = k & 31
        if l != 31:
            self.bits[m] &= ~(1 << l)
        else:
            self.bits[m] &= ~(-1 << 31)

N = 10 ** 7
primes = cPrimes(N)
print sum(1 for p in takewhile(lambda n: n < N, primes.enumerate()))