Project Euler 92

Problem 92
前に出てきた数になるまで新しい数を作るよう数のチェーンが数の各桁の平方和を連続的に加えることにより作られる。
例えば、
44 → 32 → 13 → 10 → 11
85 → 89 → 145 → 42 → 20 → 4 → 16 → 37 → 58 → 89
ゆえに1か89に到達するどのチェーンもループに捉われる。驚くべきことにどの数からはじめても最後には1か89に到達する。
1000万より小さい開始数が89に到達するものはいくつあるか。
http://projecteuler.net/index.php?section=problems&id=92

nに対して次のステップの数をf(n)と書くことにします。
まず、4桁の数を考えましょう。次のステップで最大になるのは9999で、f(9999) = 92*4 = 324だから、どの4桁の数も3桁以下になります。5桁以降の数も1ステップで桁数が小さくなります。また、3桁で次のステップが最大の999はf(999) = 243になるので、244以上のnについて

f(n) < n

となることが簡単にわかります。すなわちループになる可能性があるのは243以下の数だけなので、これらの数についてのみ調べれば必ず1か89に到達するかどうかわかります。

さて、全てのnについてfを次々に作用させていき、1か89になるのを待ってもよいです(Code1)。ただし遅いです。

メモ化すると速くなります。ここでは243以下の結果を格納することにしましょう。それ以上の数は急速に小さくなるのでメモ化しないことにします(Code2)。これでもまだ遅いです。

どうすれば速くなるでしょうか。例えば、

f(112) = f(121) = 6

です。このように数字の順番が入れ替わっても次のステップは同じです。これはProblem 74と同じです。89に到達したら、同じ数字の組合せの個数を数えます(Code3)。

# Code1
def gen_digits(n):
    while n:
        yield n % 10
        n /= 10

def next(n):
    return sum(x * x for x in gen_digits(n))

def is_arrived_at_89(n):
    if n == 89:
        return True
    elif n == 1:
        return False
    else:
        return is_arrived_at_89(next(n))

N = 10 ** 7
print sum(is_arrived_at_89(n) for n in xrange(1, N))
# Code2
def gen_digits(n):
    while n:
        yield n % 10
        n /= 10

def next(n):
    return sum(x * x for x in gen_digits(n))

def is_arrived_at_89(n):
    if n <= M:
        if a[n]:
            return a[n] == 89
        else:
            if n == 89:
                b = True
            elif n <= 1:
                b = False
            else:
                b = is_arrived_at_89(next(n))
            a[n] = 89 if b else 1
            return b
    else:
        return is_arrived_at_89(next(n))

N = 10 ** 7
M = 9 ** 2 * 3
a = [ 0 ] * (M + 1)
print sum(is_arrived_at_89(n) for n in xrange(1, N))
# Code3
def gen_repeated_combination(a, n, k0 = 0):
    if n == 0:
        yield ()
    else:
        for k in xrange(k0, len(a)):
            for rc in gen_repeated_combination(a, n - 1, k):
                yield (k,) + rc

def gen_digits(n):
    while n:
        yield n % 10
        n /= 10

def numerize(a):
    return reduce(lambda x, y: x * 10 + y, a)

def next(n):
    return sum(x * x for x in gen_digits(n))

def is_arrived_at_89(n):
    if n <= M:
        if a[n]:
            return a[n] == 89
        else:
            if n == 89:
                b = True
            elif n <= 1:
                b = False
            else:
                b = is_arrived_at_89(next(n))
            a[n] = 89 if b else 1
            return b
    else:
        return is_arrived_at_89(next(n))

N = 7
M = 9 ** 2 * 3
f = reduce(lambda x, y: x * y, range(1, N + 1)) # N!
def num(rc, n = f, k = 0, m = 1):
    if k == N:
        return n
    
    m = m + 1 if k > 0 and rc[k] == rc[k-1] else 1
    return num(rc, n, k + 1, m) / m

a = [ 0 ] * (M + 1)
print sum(is_arrived_at_89(numerize(rc)) * num(rc)
                for rc in gen_repeated_combination(range(10), N))