Project Euler 29

http://projecteuler.net/index.php?section=problems&id=29

この問題は、解くだけなら非常に易しいです。Pythonなら1行で書けるレベルです。
しかし、工夫をすれば大きな数でも解けることに昔気づきました。フォーラムに今までそういう書き込みは無かったのでそのような解法を書き込むと、Kudos(いいねみたいなもの)がたくさん付くも、なぜか委員会に認められずに、自動消去されるという憂き目に。

さて、おさらいしておきましょう。
ab(2 <= a, b <= N)は(N-1)2個ありますが、同じ値を取る(a, b)の組み合わせがあるために問題が成り立つのですね。ここでは、より小さい底(a)で同じ値になる組合せをカウントして、(N-1)2から引きます。
例えば、N=10で考えます。より小さい底になるのは、

a=4のとき、

42 = 24, 43 = 26, 44 = 28, 45 = 210

ここまでの4つです。次は、46 = 212となって、指数が10を超えてしまうからです。

a=8のときは、2つの底2, 4を考えられて、

82 = 26, 83 = 29
82 = 43, 84 = 46, 86 = 49

の5つですが、82がダブっているので、4つですね。

a=9のとき、

92 = 34, 93 = 36, 94 = 38, 95 = 310

の4つです。
ここで、分かるのは、4と9は全く同じということです。つまり、同じ指数なら1回計算すればよく、あとは同じ指数がいくつあるか数えればいいだけです。例えば、指数2は、int(sqrt(N))-1からべき乗数を引いた数なので、簡単に再帰的に求められます。

それから、底が小さくなるのは、もちろんべき乗数のときです。このようにすれば、O(N(logN)^2)程度の計算量になると思います。

ここまでは復習で、ここからが本番です。さきほどのa=8のとき、上の段と下の段のそれぞれの個数は簡単に求まります。しかし、重複があるので難しくなるわけです。もう少し見やすくするために、N=20で見て見ましょう。

82 = 26, 83 = 29, 84 = 212, 85 = 215, 86 = 218
82 = 43, 84 = 46, 86 = 49, 88 = 412, 810 = 415, 812 = 418

8の指数は、上の段は1刻みで[2, 6]の範囲、下の段は2刻みで[2,12]なので、どちらも5つあります。そして、どちらにも含まれるのは、2刻みで[2,6]の範囲なので、3つです。つまり、指数3のべき乗数が底のとき、5+5-3=7つの指数で底が小さくなり得ることになります。

これを一般化すると包除原理を使うことになります。そして、いつものように分割統治法を使います。しかし、このままだと分割統治法を使う意味があまりないので、ある程度まとめた計算をしなければなりません。このとき、同じ刻み幅でまとめることができます。上で見たように、範囲があるので、これをまとめると階段関数になります。階段が上下する点も少ないので、これである程度速くなります。
最後に、刻み幅は例えば、4と6だったら、共通部分は12、すなわち最小公倍数になります。これが意外と遅い。なので、分割統治法の最初の分割で互いに共通因子が無いようにします。例えば、1〜13なら、7,11,13とそれ以外に分けます。そうすると、最後で最小公倍数の代わりに掛け算を使うことができ、さらにほかの高速化もできます。

結果はこうです。

F(10^12) = 999999494802129045868607 0.375sec
F(10^13) = 99999984077613685612998021 0.750sec
F(10^14) = 9999999497624726793960874036 0.609sec
F(10^15) = 999999984137697761708961227516 1.484sec
F(10^16) = 99999999498907434707071928610549 3.547sec
F(10^17) = 9999999984165144778636197325836865 3.672sec
#!/usr/bin/python
# coding: utf-8

from itertools import *
from collections import defaultdict
from fractions import gcd
import sys
import time

def accumulate(iterable):
    dic = defaultdict(int)
    for k, n in iterable:
        dic[k] += n
    return dic

def lcm(n, m):
    return n / gcd(n, m) * m

def int_log(b, n):
    m = 1
    for e in count(1):
        m *= b
        if m > n:
            return e - 1

def int_root(n, e):
    r0 = int(n ** (1./e))
    if r0 ** e <= n:
        for r in count(r0 + 1):
            if r ** e > n:
                return r - 1
    else:
        for r in count(r0 - 1, -1):
            if r ** e <= n:
                return r

def divisors(n):
    for d in takewhile(lambda d: d * d <= n, count(1)):
        if n % d == 0:
            yield d
            if d * d < n:
                yield n / d

def is_prime(n):
    if n < 2:
        return False
    return all(n % d != 0 for d in takewhile(lambda d: d * d <= n, count(2)))

def F_naive(N):
    def make_pows(N):
        a = [ (n, 1) for n in range(N + 1) ]
        for e in range(2, int_log(2, N) + 1):
            for n in range(2, int_root(N, e) + 1):
                a[n**e] = (n, e)
        return a
    
    pows = make_pows(N)     # [(0, 1), (1, 1), (2, 1), (3, 1), (2, 2), ... ]
    s = set()
    for a in range(2, N + 1):
        n, e = pows[a]
        for b in range(2, N + 1):
            s.add((n, e * b))
    return len(s)

def num_not_pows(N):
    s = N - 1
    for e in range(2, int_log(2, N) + 1):
        s -= num_not_pows(int_root(N, e))
    return s

def num_pows(U):
    a = [ 0, 0 ]
    for e in takewhile(lambda e: 2 ** e <= U, count(2)):
        num = int_root(U, e)    # int(N**(1/e))
        a.append(num - 1 - sum(num_pows(num)))
    return a

def F(N):
    def count_duplicated(e1):
        v = [ 0 ] * (N + 1)
        for e2 in range(1, e1):
            d = e2 / gcd(e2, e1)
            for n in xrange(max(2, d), N*e2/e1 + 1, d):
                v[n] = 1
        return sum(v)
    
    def count_duplicated(e1):
        def g(v):
            if len(v) == 1:
                d, L = v[0]
                yield (1, 1, N)
                yield (-1, d, L)
                return
            
            result1 = list(g(v[::2]))
            result2 = list(g(v[1::2]))
            for sign1, d1, L1 in result1:
                for sign2, d2, L2 in result2:
                    sign = sign1 * sign2
                    d = lcm(d1, d2)
                    L = min(L1, L2)
                    if d <= L:
                        yield (sign, d, L)
        
        dic = { }
        for e2 in range(1, e1):
            d = gcd(e2, e1)
            e2r = e2 / d
            dic[e2r] = N * e2 / e1
        
        s = N - 1
        for sign, d, L in g(dic.items()):
            s -= sign * (L / d)
        return s
    
    def count_duplicated2(e1):
        def g(v):
            if len(v) == 1:
                d, L = v[0]
                yield ((1, N), 1)
                yield ((d, L), -1)
                return
            
            result1 = accumulate(g(v[::2])).items()
            result2 = accumulate(g(v[1::2])).items()
            for (d1, L1), sign1 in result1:
                if sign1 == 0:
                    continue
                for (d2, L2), sign2 in result2:
                    if sign2 == 0:
                        continue
                    sign = sign1 * sign2
                    d = lcm(d1, d2)
                    L = min(L1, L2)
                    if d <= L:
                        yield ((d, L), sign)
        
        dic = { }
        for e2 in range(1, e1):
            d = gcd(e2, e1)
            dic[e2/d] = N * e2 / e1
        s = N - 1
        for (d, L), sign in g(dic.items()):
            s -= sign * (L / d)
        return s
    
    def count_duplicated3(e1):
        def mul_step(f, g):
            h = []
            L1, L2 = len(f), len(g)
            k, l = 0, 0
            prev_v = None
            while k < L1 and l < L2:
                first1, last1, v1 = f[k]
                first2, last2, v2 = g[l]
                v = v1 * v2
                if prev_v is None:
                    first = max(first1, first2)
                    last  = min(last1, last2)
                    prev_v = v
                elif v == prev_v:
                    last  = min(last1, last2)
                else:
                    h.append((first, last, prev_v))
                    first = max(first1, first2)
                    last  = min(last1, last2)
                    prev_v = v
                
                if last1 <= last2:
                    k += 1
                if last1 >= last2:
                    l += 1
            h.append((first, last, prev_v))
            return h
        
        def add_step(f, g):
            h = []
            L1, L2 = len(f), len(g)
            k, l = 0, 0
            prev_v = None
            while k < L1 and l < L2:
                first1, last1, v1 = f[k]
                first2, last2, v2 = g[l]
                v = v1 + v2
                if prev_v is None:
                    first = max(first1, first2)
                    last  = min(last1, last2)
                    prev_v = v
                elif v == prev_v:
                    last  = min(last1, last2)
                else:
                    h.append((first, last, prev_v))
                    first = max(first1, first2)
                    last  = min(last1, last2)
                    prev_v = v
                
                if last1 <= last2:
                    k += 1
                if last1 >= last2:
                    l += 1
            h.append((first, last, prev_v))
            return h
        
        def collect(iterable):
            dic = { }
            for d, f in iterable:
                if d in dic:
                    dic[d] = add_step(dic[d], f)
                else:
                    dic[d] = f
            return [ (d, f) for d, f in dic.items()
                                        if len(f) != 1 or f[0][2] != 0 ]
            return dic.items()
        
        def g(v):
            if len(v) == 1:
                d, L = v[0]
                if d == 1:
                    yield (1, [(1, L, 0), (L, N, 1)])
                else:
                    yield (1, [(1, N, 1)])
                    yield (d, [(1, L, -1), (L, N, 0)])
                return
            
            result1 = collect(g(v[::2]))
            result2 = collect(g(v[1::2]))
            for d1, f1 in result1:
                for d2, f2 in result2:
                    d = lcm(d1, d2)
                    if d > N:
                        continue
                    
                    f = mul_step(f1, f2)
                    if len(f) != 1 or f[0][2] != 0:
                        yield (d, f)
        
        if e1 == 2:
            return N / 2 - 1
        
        dic = { }
        for e2 in range(1, e1):
            d = gcd(e2, e1)
            dic[e2/d] = N * e2 / e1
        for e, L in dic.items():
            if any(d in dic and dic[d] >= L for d in divisors(e) if 1 < d < e):
                del dic[e]
        s = N - 1
        isolated_ps = [ d for d in dic.keys() if d * 2 > e1 and is_prime(d) ]
        b = sorted(g([ (d, L) for d, L in dic.items() if d in isolated_ps ]))
        for d1, f1 in g([ (d, L) for d, L in dic.items()
                                                if d not in isolated_ps ]):
            for first, last, v in f1:
                if v == 0:
                    continue
                for d2, f2 in b:
                    _, L2, c = f2[0]
                    d = d1 * d2
                    if last < d:
                        break
                    if L2 >= last:
                        s -= (last / d - first / d) * v * c
                    elif L2 > first:
                        s -= (L2 / d - first / d) * v * c
        return s
    
    s = (N - 1) ** 2
    a = num_pows(N)
    for e, n in enumerate(a[2:], 2):
        c = count_duplicated3(e)
        s -= c * n
    return s

for E in range(12, 18):
    t0 = time.clock()
    N = 10 ** E
    print "F(10^%d) = %d %.3fsec" % (E, F(N), time.clock() - t0)