Project Euler 237

プロジェクトオイラー
http://projecteuler.net/

Q237.
4×1012の格子の上にルートを作る。そのルートは左上と左下がスタート・ゴールになっている。上下左右に動ける。どのマスも必ず一度だけ通る。ルートはいくつあるか。

左端と中間と右端に分けて、バイナリ法的に計算する。最初、n列の部分の入り口と出口のどこが外と繋がっているかだけ分かればいいと思っていたので、4ビットのタプルを状態としていたが、それだとループになる場合を排除できないことに気がついた。もう少しややこしく、そこを通るルートはどこから入ってどこに出るのかを入り口と出口番号のペアのタプルにした。



def next_node(n1, s):
for e2 in s:
if n1 >= 4:
if n1 == e2[0] + 4:
s.remove(e2)
return e2[1]
if n1 == e2[1] + 4:
s.remove(e2)
return e2[0]
else:
if n1 == e2[0] - 4:
s.remove(e2)
return e2[1]
if n1 == e2[1] - 4:
s.remove(e2)
return e2[0]
return -3

def end_node(n1, s1, s2, turn):
while True:
if turn == 1:
if n1 < 4:
break
n1 = next_node(n1, s2)
else:
if n1 >= 4:
break
n1 = next_node(n1, s1)
if n1 == -3:
return n1
turn = 3 - turn # 1 <=> 2
return n1

def connect(s1, s2):
a = [ ]
set1 = set(s1)
set2 = set(s2)
for e1 in s1:
if e1 not in set1:
continue
set1.remove(e1)
n1 = end_node(e1[0], set1, set2, 1)
if n1 == -3:
return False
n2 = end_node(e1[1], set1, set2, 1)
if n2 == -3:
return False
a.append( (n1, n2) )

if len(set2) != 0:
s2 = tuple(set2)
for e2 in s2:
if e2 not in set2:
continue
set2.remove(e2)
n1 = end_node(e2[0], set2, set1, 2)
if n1 == -3:
return False
n2 = end_node(e2[1], set2, set1, 2)
if n2 == -3:
return False
a.append( (n1, n2) )

return tuple(a)

def add_dic(d, k, n):
if k in d:
d[k] = (d[k] + n) % M
else:
d[k] = n % M

def mul(S1, S2):
d = { }
for s1, n1 in S1.iteritems():
for s2, n2 in S2.iteritems():
s = connect(s1, s2)
if s:
add_dic(d, s, n1 * n2)
return d

def num_mul(S1, S2):
n = 0
for s1, n1 in S1.iteritems():
for s2, n2 in S2.iteritems():
s = connect(s1, s2)
if s:
n += n1 * n2
return n % M

def T(n):
A = calc_A(n / 2)
C = calc_C(n - n / 2)
return num_mul(A, C)

def calc_A(n):
if n == 1:
return { ( (-1, 4), (5, 6), (-2, 7) ): 1,
( (-1, 4), (-2, 5) ): 1,
( (-1, 5), (-2, 6) ): 1,
( (-1, 6), (-2, 7) ): 1
}

A = calc_A(n / 2)
B = calc_B(n - n / 2)
return mul(A, B)

def calc_B(n):
if n == 1:
return { ( (0, 4), (1, 5), (2, 6), (3, 7) ): 1,
( (0, 4), (1, 5), (2, 3) ): 1,
( (0, 4), (1, 5), (6, 7) ): 1,
( (0, 1), (2, 6), (3, 7) ): 1,
( (4, 5), (2, 6), (3, 7) ): 1,
( (0, 4), (1, 7) ): 1,
( (0, 4), (3, 5) ): 1,
( (3, 7), (0, 6) ): 1,
( (3, 7), (2, 4) ): 1,
( (0, 4), (1, 2), (3, 7) ): 1,
( (0, 4), (5, 6), (3, 7) ): 1,
( (0, 5), (3, 6) ): 1,
( (1, 4), (2, 7) ): 1
}

if n in dicB:
return dicB[n]
B1 = calc_B(n / 2)
B = mul(B1, B1)
if n % 2 == 1:
B = mul(calc_B(1), B)

dicB[n] = B
return B

def calc_C(n):
if n == 1:
return { ( (0, 1), (2, 3) ): 1, ( (0, 3), ): 1 }

B = calc_B(n / 2)
C = calc_C(n - n / 2)
return mul(B, C)

dicB = { }
N = 10 ** 12
M = 10 ** 8
print T(N)