NumPyでエラトステネスのふるいを完全理解する

エラトステネスのようなリストを使う計算はPythonでは非常に遅いですが、Numpyを使うと速くなります。http://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-pythonに書かれてあるprimesfrom2toという関数です。

import numpy

def primesfrom2to(n):
    """ Input n>=6, Returns a array of primes, 2 <= p < n """
    sieve = numpy.ones(n/3 + (n%6==2), dtype=numpy.bool)
    for i in xrange(1,int(n**0.5)/3+1):
        if sieve[i]:
            k=3*i+1|1
            sieve[       k*k/3     ::2*k] = False
            sieve[k*(k-2*(i&1)+4)/3::2*k] = False
    return numpy.r_[2,3,((3*numpy.nonzero(sieve)[0][1:]+1)|1)]

これと、サラっと書いたコードと比較してみましょう。

from itertools import *

def make_primes(N):
    a = [ True ] * (N + 1)
    for p in takewhile(lambda p: p * p <= N, (n for n in count(2) if a[n])):
        for k in xrange(p * p, N + 1, p):
            a[k] = False
    return [ n for n in xrange(2, N + 1) if a[n] ]

下のコードは35秒、上のNumPyのコードは0.97秒でした。下のコードで2と3の倍数を無視すると3倍くらい速くなります。

def make_primes3(N):
    a = [ False, True, False, False, False, True ] * (N / 6 + 1)
    for p in takewhile(lambda p: p * p <= N, (n for n in count(5) if a[n])):
        for k in xrange(p * p, N + 1, p * 6):
            a[k] = False
        for k in xrange(p * (p + (p + 3) % 6), N + 1, p * 6):
            a[k] = False
    return [ 2 ] + [ n for n in xrange(3, N + 1, 2) if a[n] ]

さて、NumPyのコードを見てもよくわからないので、上から理解していきます。

    sieve = numpy.ones(n/3 + (n%6==2), dtype=numpy.bool)

numpy.onesは指定した長さの1だけの配列を作成します。長さは、n = 100なら33です。6の剰余が1, 5だけを素数かどうかの対象としているからnの1/3くらいの配列になります。要素はデータタイプnumpy.boolだからTrueです。

k=3*i+1|1

i = 0, 1, 2, 3, ...とすると、k = 1, 5, 7, 11, ...となります。

sieve[k*k/3::2*k] = False

これは、k = 5なら、[ 5 * 5 / 3 ] = 8から10ごとにFalseにします。分かりやすい例だと、

a = numpy.ones(10, dtype = numpy.bool)
a[3::2] = False
print a
[ True  True  True False  True False  True False  True False]

インデックス3, 5, 7, 9をFalseにしています。
この行は例えばk = 5なら、25, 55, 85, ...を非素数にしていますが、次の行は35, 65, 95, ...を非素数にしています。

最後の行、numpy.r_は配列を構築します。

print numpy.r_[0,1,4]   # [0 1 4]
print numpy.r_[1,4:7]   # [1 4 5 6]

numpy.nonzeroはゼロでない要素のインデックスを配列にします。

a = numpy.ones(10, dtype = numpy.bool)
a[3::2] = False
print numpy.nonzero(a)  # (array([0, 1, 2, 4, 6, 8]),)

タプルになっているのは、恐らく多次元配列を前提としているからでしょう。
配列を3倍すると要素が全て3倍になります。

print 3 * numpy.r_[0,2,3]   # [0 6 9]

+も|も同じですね。

print (3 * numpy.r_[1:8] + 1) | 1

は、

[ 5  7 11 13 17 19 23]

となります。細かいところは省きましたが、これで完全理解できました。