Project Euler 74

プロジェクトオイラー
http://projecteuler.net/index.php?section=problems&id=74


この問題もちゃんと考えると面白い。
これも以前解いたときは、題意をそのままコードにしただけだった。


from math import factorial

def next(n):
s = 0
while n:
m = n % 10
s += factorial(m)
n /= 10
return s

def get_length(n):
s = set()
while n not in s:
s.add(n)
n = next(n)
return len(s)

N = 10 ** 5
LEN = 2
print sum(map(lambda n: n == LEN, map(get_length, xrange(1, N))))

この問題は、数学的には特に工夫の余地はないと思うが、アルゴリズムで速くできる。
まず、例えば、123と231は次が同じ数になるので、このような数は数字が昇順のものを代表して計算すればよい。すなわち、123だけ計算すればよい。また、0!と1!は等しいので、0は考えなくてもよい。
ただし、例えば、178 → 45361 → 871 → 45361なので、178に代表される数は長さ3となるが、871だけは2となるところを注意しなければならない。
代表した数と同類項の数は、同じ数字の個数を、

n1, …, nm
n = n1 + … + nm

とすると、

n! / n1!…nm!

ただし、最初の数字が1だと、0か1のどちらかを選択することになるから、2n1倍する必要がある。さらに、最初が0は無いので、その分は差し引く。

2n1n! / n1!…nm! - 2n1-1(n-1)! / (n1-1)!n2!…nm!

これで、3分近くかかっていたものが、0.8sくらいになった。さらに辞書に長さを記録して計算の重複を避けたところ、0.2sくらいになった。



from itertools import imap
from math import factorial

def next(n):
s = 0
while n:
m = n % 10
s += factorial(m)
n /= 10
return s

def encode(n):
a = [ ]
while n:
d = n % 10
if d == 0:
d = 1
a.append(d)
n /= 10
return a

def decode(a):
return reduce(lambda x, y: x * 10 + y, a, 0)

def gen_sorted_number(d1, d2, pos, length):
if length == 1:
if pos == 0:
for d in xrange(d1, d2 + 1):
yield d
else:
yield d1
else:
l = length / 2
for d3 in xrange(d1, d2 + 1):
for n1 in gen_sorted_number(d1, d3, pos, l):
for n2 in gen_sorted_number(d3, d2, pos + l, length - l):
yield n1 * 10 ** (length - l) + n2

def gen_number(n):
for k in range(1, n + 1):
for m in gen_sorted_number(1, 9, 0, k):
yield m

def calc_length_by_cache(n, a):
d = dic[n]
if type(d) == int:
return d + len(a)
else:
return d[0] + len(a)

def add_chain_core(n, a):
d = dic[n]
if type(d) == int:
l = d + 2
else:
l = d[0] + 1
for m in reversed(a):
dic[m] = l
l += 1

def get_loop(n, a):
pos = a.index(n)
len_loop = len(a) - pos
dic[n] = (len_loop, a[-1])
for k in xrange(pos + 1, len(a)):
dic[a[k]] = (len_loop, a[k-1])
return len_loop

def add_chain(n):
if n in dic:
return

s = set()
a = [ ]
s.add(n)
a.append(n)

n = next(n)
while n not in s:
s.add(n)
if n in dic:
add_chain_core(n, a)
return

a.append(n)
prev = n
n = next(n)

# new loop
len_loop = get_loop(n, a)
len_queue = len(a)
for k in xrange(len_queue - len_loop):
dic[a[k]] = len_queue - k

def make_chains():
for n in gen_number(N):
add_chain(n)

def is_same_digits(n1, n2):
a = encode(n2)
a.sort()
return decode(a) == n1

def get_length(n):
d = dic[n]
if type(d) == int:
n2 = next(n)
d2 = dic[n2]
if type(d2) == int:
return d, False
else:
if is_same_digits(n, d2[1]):
return d - 1, True
else:
return d, False
else:
return d[0], True

def pack(a):
b = [ ]
c = [ ]
prev = 0
for e in reversed(a):
if e == prev:
c[-1] += 1
else:
b.append(e)
c.append(1)
prev = e
return b, c

def combination(c):
fn = factorial(sum(c))
return reduce(lambda x, y: x / y, map(factorial, c), fn)

def num_of_cases(n):
d, c = pack(encode(n))
m = combination(c)
if d[0] == 1:
m *= (sum(c) * 2 - c[0]) << (c[0] - 1)
m /= sum(c)
return m

def weight(n, e):
length, loop = e
if loop:
if length == LEN:
return 1
elif length == LEN - 1:
return num_of_cases(n) - 1
else:
return 0
else:
if length == LEN:
return num_of_cases(n)
else:
return 0

dic = { }
N = 6
LEN = 60
make_chains()
print sum(imap(lambda n: weight(n, get_length(n)), gen_number(N)))