Project Euler 128

http://projecteuler.net/index.php?section=problems&id=128


素直に回って座標を出します。ここでは、1を原点、2を(1, 0)、3を(0, 1)とする空間(平面)で考えます。今考えている座標より1層多く計算している素数の個数を計算できるので、1層分早いのと2つジェネレータを使います。しかし、これではなかなか素数が3つになる数が出てきません。
よく見ると、素数の個数が3つなのは限られた場所になります。まず、六角形の辺上だと隣との差が1で、内側の層の隣は2つあって並んでいるので、差が偶数と奇数になって、どちらかしか素数になりえません。外側の層も同じなので最大でも素数は2つです。同様に考えると、層の境目、1の真上か1つ右にずれた数しか3つにならないことがわかります。

from itertools import *

bit_primes = []
L = 32000
M = L / 32

def get_bit(a, n):
    m = n & 31
    if m < 31:
        return ((a[n>>5] >> m) & 1) == 1
    else:
        return (a[n>>5] & ~0x7fffffff) == ~0x7fffffff

def unset_bit(a, n):
    m = n & 31
    if m < 31:
        a[n>>5] &= ~(1 << m)
    else:
        a[n>>5] &= 0x7fffffff

def is_prime(n):
    if n >= len(bit_primes) * 32:
        sieve(n)
    
    return get_bit(bit_primes, n)

def sieve(n):
    m0 = len(bit_primes) / M
    m1 = n / L
    for m in xrange(m0, m1 + 1):
        a = [ ~0 ] * M
        start_n = L * m
        end_n = start_n + L
        g = takewhile(lambda p: p * p < end_n, count(2))
        if m == 0:
            a[0] = ~3
            for p in ifilter(lambda n: get_bit(a, n), g):
                for n in xrange(p * 2, L, p):
                    unset_bit(a, n)
        else:
            for p in ifilter(lambda n: is_prime, g):
                for n in xrange((start_n + p - 1) / p * p, end_n, p):
                    unset_bit(a, n - start_n)
        
        bit_primes.extend(a)

nei = ((-1, 0), (0, -1), (1, -1), (1, 0), (0, 1), (-1, 1))

def add(p, q):
    return (p[0] + q[0], p[1] + q[1])

def pt2v(pt):
    x, y = pt
    if x == 0 and y == 0:
        return 1
    elif x <= 0:
        if -x < y:
            r, l, t = y, 0, -x
        elif y > 0:
            r, l, t = -x, 1, -x - y
        else:
            r, l, t = -x - y, 2, -y
    else:
        if x < -y:
            r, l, t = -y, 3, x
        elif y < 0:
            r, l, t = x, 4, x + y
        else:
            r, l, t = x + y, 5, y
    
    return 3 * r * (r - 1) + r * l + t + 2

def calc_PD(pt):
    def gen_neighbors(pt0):
        return imap(pt2v, (add(pt0, v) for v in nei))
    
    n = pt2v(pt)
    return sum(is_prime(abs(m - n)) for m in gen_neighbors(pt))

def gen_pt():
    yield (0, 0)
    for r in count(1):
        yield (0, r)
        yield (1, r - 1)

def gen_PD():
    return ((pt2v(pt), calc_PD(pt)) for pt in gen_pt())

N = 2000
print next(n for k, n in izip(count(1),
            (n for n, PD in gen_PD() if PD == 3)) if k == N)