エラトステネスのようなリストを使う計算はPythonでは非常に遅いですが、Numpyを使うと速くなります。http://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-pythonに書かれてあるprimesfrom2toという関数です。
import numpy def primesfrom2to(n): """ Input n>=6, Returns a array of primes, 2 <= p < n """ sieve = numpy.ones(n/3 + (n%6==2), dtype=numpy.bool) for i in xrange(1,int(n**0.5)/3+1): if sieve[i]: k=3*i+1|1 sieve[ k*k/3 ::2*k] = False sieve[k*(k-2*(i&1)+4)/3::2*k] = False return numpy.r_[2,3,((3*numpy.nonzero(sieve)[0][1:]+1)|1)]
これと、サラっと書いたコードと比較してみましょう。
from itertools import * def make_primes(N): a = [ True ] * (N + 1) for p in takewhile(lambda p: p * p <= N, (n for n in count(2) if a[n])): for k in xrange(p * p, N + 1, p): a[k] = False return [ n for n in xrange(2, N + 1) if a[n] ]
下のコードは35秒、上のNumPyのコードは0.97秒でした。下のコードで2と3の倍数を無視すると3倍くらい速くなります。
def make_primes3(N): a = [ False, True, False, False, False, True ] * (N / 6 + 1) for p in takewhile(lambda p: p * p <= N, (n for n in count(5) if a[n])): for k in xrange(p * p, N + 1, p * 6): a[k] = False for k in xrange(p * (p + (p + 3) % 6), N + 1, p * 6): a[k] = False return [ 2 ] + [ n for n in xrange(3, N + 1, 2) if a[n] ]
さて、NumPyのコードを見てもよくわからないので、上から理解していきます。
sieve = numpy.ones(n/3 + (n%6==2), dtype=numpy.bool)
numpy.onesは指定した長さの1だけの配列を作成します。長さは、n = 100なら33です。6の剰余が1, 5だけを素数かどうかの対象としているからnの1/3くらいの配列になります。要素はデータタイプnumpy.boolだからTrueです。
k=3*i+1|1
i = 0, 1, 2, 3, ...とすると、k = 1, 5, 7, 11, ...となります。
sieve[k*k/3::2*k] = False
これは、k = 5なら、[ 5 * 5 / 3 ] = 8から10ごとにFalseにします。分かりやすい例だと、
a = numpy.ones(10, dtype = numpy.bool) a[3::2] = False print a
[ True True True False True False True False True False]
インデックス3, 5, 7, 9をFalseにしています。
この行は例えばk = 5なら、25, 55, 85, ...を非素数にしていますが、次の行は35, 65, 95, ...を非素数にしています。
最後の行、numpy.r_は配列を構築します。
print numpy.r_[0,1,4] # [0 1 4] print numpy.r_[1,4:7] # [1 4 5 6]
numpy.nonzeroはゼロでない要素のインデックスを配列にします。
a = numpy.ones(10, dtype = numpy.bool) a[3::2] = False print numpy.nonzero(a) # (array([0, 1, 2, 4, 6, 8]),)
タプルになっているのは、恐らく多次元配列を前提としているからでしょう。
配列を3倍すると要素が全て3倍になります。
print 3 * numpy.r_[0,2,3] # [0 6 9]
+も|も同じですね。
print (3 * numpy.r_[1:8] + 1) | 1
は、
[ 5 7 11 13 17 19 23]
となります。細かいところは省きましたが、これで完全理解できました。