Project Euler 30

http://projecteuler.net/index.php?section=problems&id=30

以前のと違う方法をふと思いついたので書きます。

例えば、6桁として456789を考えます。

45 + 55 + 65 + 75 + 85 + 95 - 456789

が0ならいいわけです。ここで3桁ずつにわけて考えましょう。456000と789に分けます。そうすると、

45 + 55 + 65 - 456000 = 789 - (75 + 85 + 95)

ならいいわけです。上限を354294とすると、上3桁は0, 1000, ... , 354000、下3桁は0〜999です。この範囲で上のそれぞれ左右の辺を計算します。そうすると整数の集合ができます。それぞれの集合から一つずつ取ってきて両方が等しければいいわけです。すなわちマージ法が使えます。ソートしておいて比較していけばいいわけです。同じ値になる可能性もあるので、groupbyを使います。

マージ法は単純で誰にでも思いつきそうなアルゴリズムにもかかわらず非常に強力なので是非覚えましょう。

この方法は以前の方法に対して9乗までは速かったです。

# 123 -> 1 1234 -> 2
def num_half_digits(n):
    return len(list(digits(n))) / 2

def divide(E):
    def group(iterable):
        return [ (d, list(a)) for d, a in
                    groupby(sorted(iterable), key = lambda x: x[0]) ]
    
    limit = calc_limit(E)
    half = num_half_digits(limit)
    D = 10 ** half
    lowers = group((sum_pows(n) - n, n) for n in xrange(D))
    uppers = group((n * D - sum_pows(n), n * D) for n in xrange(limit / D + 1))
    return lowers, uppers

def merge(a, b):
    k, l = 0, 0
    m, n = a[k][0], b[l][0]
    while True:
        if m == n:
            u, v = a[k][1], b[l][1]
            yield sum(n for _, n in u) * len(v) + sum(n for _, n in v) * len(u)
            k, l = k + 1, l + 1
            if k == len(a) or l == len(b):
                break
            m, n = a[k][0], b[l][0]
        elif m < n:
            k += 1
            if k == len(a):
                break
            m = a[k][0]
        else:
            l += 1
            if l == len(b):
                break
            n = b[l][0]

t0 = time.clock()
E = 9
pows = [ d ** E for d in xrange(10) ]
lowers, uppers = divide(E)
print sum(merge(lowers, uppers)) - 1
print time.clock() - t0