Project Euler 229

プロジェクトオイラー
http://projecteuler.net/

Q229.
n = a12 + b12
n = a22 + 2b22
n = a32 + 3b32
n = a42 + 7b42
と表せる20億以下のnの個数。

まず、オイラーが証明したところによると、2つの平方の和で表される整数の条件は、素因数分解して、奇数乗の素数は4の剰余が1または2となる。同様の条件が他の3つにもあって、これらを総合すると、奇数乗の素数の168の剰余が1,25,121が条件となる。ただし、これはどちらかの平方が0である場合も含まれるため、次のように考える。まず、上の条件を満たす素数を列挙する。これにはエラトステネスのふるいの変形版を用いた。ここではなるべくメモリを消費しないような工夫を凝らした。そして、これらの異なる素数の積(一つでもよい)と平方(1でもよい)の積の個数を数えた。次に、平方について考えた。x2 + ky2 = z2を満たすx,y,zは、(m2 - kn2, 2mn, m2 + kn2)(とこれの整数倍)である。これら4つをすべて満たすzをエラトステネスのふるい的に求める。ただし、mnが互いに素でも、xyは互いに素とは限らないので注意が必要。



from itertools import count
from math import sqrt
import fractions

def is_prime(n):
for p in primes:
if p * p > n:
return True
elif n % p == 0:
return False
return True

def make_primes(n):
m = int(sqrt(n + 0.5))
for p in xrange(3, m + 1, 2):
if is_prime(p):
primes.append(p)

a = [ 0 ] * ((n + 15) / 16)
for p in primes:
if p * p > n:
break
if p == 2:
continue
for k in xrange(3 * p, n, 2 * p):
a[k>>4] |= 1 << (k & 15)

for k in xrange((m + 1) / 2 * 2 + 1, n + 1, 2):
if (a[k>>4] & (1 << (k & 15))) == 0:
primes.append(k)

def gen_special_number(n):
if n == 1:
yield 1
else:
yield 2

def calc_inverse():
a = [ ]
m = 168
for k in range(m):
if k % 2 == 0 or k % 3 == 0 or k % 7 == 0:
a.append(0)
else:
for l in range(m):
if k * l % m == 1:
a.append(l)
break
return a

def encode(n):
m = n % 168
l = n / 168 * 3
if m == 1:
return l
elif m == 25:
return l + 1
elif m == 121:
return l + 2
else:
return -1

def decode(k):
m = k % 3
l = k / 3 * 168
if m == 0:
return l + 1
elif m == 1:
return l + 25
else:
return l + 121

def sieve_core(a, n, p, a_inv, m):
r = a_inv[p%168]
x0 = r * m % 168
if x0 == 1:
x0 += 168
for k in xrange(p * x0, n + 1, p * 168):
l = encode(k)
a[l>>4] |= 1 << (l & 15)

# get an array of prime numbers s.t. p % 168 == 1, 25, 121
def sieve(n):
a_inv = calc_inverse()
a = [ 0 ] * (n * 3 / 168 / 16 + 1)
for p in primes:
if p == 2:
continue
sieve_core(a, n, p, a_inv, 1)
sieve_core(a, n, p, a_inv, 25)
sieve_core(a, n, p, a_inv, 121)

b = [ ]
for k in xrange(1, n * 3 / 168 + 1):
if (a[k>>4] & (1 << (k & 15))) == 0:
b.append(decode(k))
return b

# count p1 * p2 ... * pn * n^2 s.t. pk % 168 == 1, 25, 121
def count_normal(n, k = 0):
if n < primes4[k]:
return 0

s = 0
for l in xrange(k, len(primes4)):
p = primes4[l]
m = n / p
s += int(sqrt(m + 0.5))
if l < len(primes4) - 1:
s += count_normal(m, l + 1)
return s

def count_square(limit):
limit2 = int( (limit + 0.5) ** 0.5)
a = [ 0 ] * (limit2 + 1)
for k in [ 1, 2, 3, 7 ]:
for m in count(1):
m2 = m * m
if m2 >= limit2 * k:
break
for n in count(1):
if fractions.gcd(m, n) != 1:
continue
x = m2 - k * n * n
if x <= 0:
break
y = 2 * m * n
z = m2 + k * n * n
if z > limit2 * k:
break
d = fractions.gcd(z, y)
if d > 1:
z /= d
for p in xrange(z, limit2 + 1, z):
a[p] |= 1 << k
return len(filter(lambda e: e == 0x8e, a))

N = 2 * 10 ** 9
primes = [ 2 ]
make_primes(int(N ** 0.5 + 0.5))
primes4 = sieve(N)
print count_normal(N) + count_square(N)