Project Euler 74

Problem 74
145という数は各桁の階乗の和が145に等しいという性質がよく知られている:
1! + 4! + 5! = 1 + 24 + 120 = 145
恐らく169はあまり知られていないが、169に戻ってくる最も長い数のチェーンを生成する。このようなループは3つしか存在しないことがわかっている:
169 → 363601 → 1454 → 169
871 → 45361 → 871
872 → 45362 → 872
どの数も最後にはループに捕まることを証明するのは難しくない。例えば、
69 → 363600 → 1454 → 169 → 363601 (→ 1454)
78 → 45360 → 871 → 45361 → (→ 871)
540 → 145 (→ 145)
69で始まると5つの繰り返しでない項から成るチェーンを生成するが、最も長い繰り返しのない100万より小さい数から始まるチェーンは60項である。
100万より小さい開始数で60の繰り返しのない項を持つチェーンはいくつあるか。
http://projecteuler.net/index.php?section=problems&id=74

個々の開始数について単純に繰り返しが出てくるまで単純にステップを繰り返しても答えは出ます。ただし、遅いです。メモ化をすれば速くなります。すなわち、1454は3項であることが分かっていれば、69が開始数なら

69 → 363600 → 1454

と1454にぶつかれば、さらに2項足して5項であることがわかります(Code1)。
もっと速くならないでしょうか。69から始めると5項のチェーンになりましたが、96でも同じことです。

96 → 363600 → …

つまり、数字を並べ替えただけなら次は同じ数になるということです。また、0! = 1!なので、例えば169は、196,619,609,691,690,916,906,961,960と同じです。これらを類として、類のチェーンを考えましょう。代表元として0を使わない最小の数を使います。上の類なら169が代表元です。このときこの類を[169]と表すことにします。[169]で始まるチェーンは、

[169] → [113366] → [1445] (→ [169])

3タイプのチェーンが考えられます。

1: [169] → [113366] → [1445] (→ [169])
2: [23366] → [1445] → [169] → [113366] (→ [1445])
3: [69] → [113366] → [1445] → [169](→ [113366])

タイプ1は開始類がループ内になります。このとき169が開始数のチェーンの長さは3ですが、619は

619 → 363601 → 1454 → 169 (→ 363601)

と、チェーンの長さは1長くなります。[169]の中で169だけが長さ3で、他は4となります。
タイプ2は、ループの外からやってきてループに到達したところの数が一周してきたときにまた同じになります。[23366]のどの元が開始数でも長さ4になります。
タイプ3は、逆に一周して違う数になるタイプです。[69]のどの元でも長さ5になります。

ですから、長さが59か60になる開始数の類を探して個数を数えればよいです(Code2)。

# Code1
from itertools import imap

def gen_digits(n):
    while n:
        yield n % 10
        n /= 10

def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n - 1)

def step(n):
    return sum(imap(lambda m: factorial(m), gen_digits(n)))

def calc_length(n):
    if n <= N:
        if memo[n] != 0:
            return memo[n]
    
    def f(chain, s, n):
        chain.append(n)
        s.add(n)
        m = step(n)
        if m <= N and memo[m] != 0:
            return memo[m]
        elif m in s:
            return -m
        else:
            return f(chain, s, m)
    
    chain = [ ]
    s = set()
    l = f(chain, s, n)
    if l > 0:
        for k in reversed(chain):
            l += 1
            if k <= N:
                memo[k] = l
    else:
        m = -l
        l = len(chain)
        loop = False
        for k in chain:
            if not loop:
                if k <= N:
                    memo[k] = l
                if k == m:
                    loop = True
                else:
                    l -= 1
            else:
                if k <= N:
                    memo[k] = l
    
    return memo[n]

N = 10 ** 6
M = 60
memo = [ 0 ] * (N + 1)
print len(filter(lambda n: calc_length(n) == M, xrange(1, N + 1)))
# Code2
from itertools import imap

def gen_digits(n):
    while n:
        yield n % 10
        n /= 10

def numerize(a):
    return reduce(lambda x, y: x * 10 + y, a, 0)

def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n - 1)

def gen_rep_comb(a, l, k0 = 0):
    if l == 0:
        yield [ ]
    else:
        for k in xrange(k0, len(a)):
            e = a[k]
            if l > 1:
                yield [ e ]
            for b in gen_rep_comb(a, l - 1, k):
                yield [ e ] + b

def step(a):
    return sum(imap(lambda m: factorial(m), a))

def sorted_list(n):
    a = map(lambda d: 1 if d == 0 else d, gen_digits(n))
    a.sort()
    return a

def calc_length(a):
    def f(chain, s, a):
        n = numerize(a)
        s.add(n)
        m0 = step(a)
        a1 = sorted_list(m0)
        m = numerize(a1)
        chain.append((m, m0))
        if m in memo:
            return chain, memo[m]
        
        if m in s:
            return chain, (-1, 0)
        else:
            return f(chain, s, a1)
    
    n = numerize(a)
    chain = [ (n, n) ]
    return f(chain, set(), a)

def count_length(a):
    def count_perm(a, head = True):
        if len(a) == 0:
            return 1
        else:
            s = 0
            for d in range(1, 10):
                try:
                    k = a.index(d)
                except:
                    continue
                n = count_perm(a[:k] + a[k+1:], False)
                if d == 1 and not head:
                    s += n * 2
                else:
                    s += n
            return s
    
    def index(m, k = 0):
        if chain[k][0] == m:
            return k
        return index(m, k + 1)
    
    def count_(a, l, t):
        if l == M:
            if t == 1:
                return 1
            elif t == 2:
                return count_perm(a)
            else:
                return 0
        elif l == M - 1:
            if t == 1:
                return count_perm(a) - 1
            elif t == 2:
                return 0
            else:
                return count_perm(a)
        else:
            return 0
    
    n = numerize(a)
    if n in memo:
        return count_(a, memo[n][0], memo[n][1])
    
    chain, memo_x = calc_length(a)
    m, m0 = chain[-1]
    memo_l, memo_t = memo_x[:2]
    if memo_l == -1:
        l = len(chain) - 1
        p = index(m)
        if p == 0:
            t = 1
        else:
            t = 2 if m0 == m else 3
        
        if t == 1:
            for k in xrange(1, len(chain)):
                m, m0 = chain[k]
                memo[m] = (l, t, m0)
        else:
            for k in xrange(p):
                memo[chain[k][0]] = (l - k, t)
            for k in xrange(p + 1, l + 1):
                m, m0 = chain[k]
                memo[m] = (l - p, 1, m0)
    else:
        l = memo_l + len(chain) - 1
        if memo_t == 1:
            t = 2 if m0 == memo_x[2] else 3
        else:
            t = memo_t
        
        for k in xrange(len(chain) - 1):
            memo[chain[k][0]] = (l - k, t)
    
    return count_(a, l, t)

def f(n):
    return count_length(list(sorted(gen_digits(n))))

N = 6
M = 60
memo = { }
print sum(imap(count_length, gen_rep_comb(range(1, 10), N)))