メルセンヌ数を10進表示する

「これまでで最大の素数」を発見 << WIRED.jp

久しぶりに新たなメルセンヌ素数が発見されたようです。メルセンヌ数は、

Mn = 2n - 1

という形をしているのですが、なぜこの形の素数を求めようとするかというと、非常に高速な素数判定法があるからです。

さて、新たに得られたメルセンヌ素数を仮にMとしましょう。これは

M = 257885161 - 1

と表されます。2進数で表すのは簡単ですが、実は10進表示するのは難しいです。例えばPythonで、

print 2 ** 57885 - 1

とすると0.064秒、

print 2 ** 578851 - 1

とすると6.4秒かかりました。一桁指数が増えると100倍の時間がかかるということです。この調子だと64000秒かかりそうです。

この10進表示するというのは多項式の計算のようなものと考えることができます。2桁なら2項の多項式と考えます。例えば16なら、

6 + x

でこれを2乗すると、

(6 + x)2 = 36 + 12x + x2

1項で1桁なのでそれぞれ繰り上がって、

6 + 5x + 2x2

となって、256に対応しています。2^nを10進表示するというのは、このような多項式の乗算を繰り返すことになります。

これを実装してみましょう。実際には1桁ずつでなく4桁とします。多項式の乗算は単純に組むとn項同士ならO(n2)となりますが、Karatsuba法を使うとO(n1.6)くらいになります。Pythonで実装してみました。

確かに指数が2倍になると時間が3倍になりますが、元々が遅いので23000秒くらいかかりそうです。

多項式の演算というとNumpyを使えば速そうなので組んでみましたが、指数2倍になると4倍かかります。どうやらNumpyの多項式の乗算はKaratsuba法を使っていないようですね。5800秒くらいかかりそうです。

これでもいいかなと思って走らせていると、こんな記事が話題になっていました。

Rails Hub情報局: 本家の5倍速? Pythonで実装したRuby処理系の「Topaz」が登場

そういえば、PyPyってありましたね。ただ、前に試したらうまく動かなかったような。いちおういろいろ試してみると、どんなコードでも動いて変わらない値を出してくれます。

速度はいろいろです。Project EulerのProblem 413は1.4倍くらいにしかなりませんでした。理由はdictのキーでしょうね。Problem 412はほとんど変わらず。どうも時間が短すぎる場合には速くならないようです。しかし、最初の遅いコードだと5倍速くなりました。

上のKaratsuba法のコードを試してみるとかなり速くなりそうです。結局2913秒かかりました。リダイレクトしたファイルのプロパティを見てみると、17425172バイト。改行で2バイト増えているので、17425170桁であってますね。下50桁くらいなら一瞬で計算できます。

print (pow(2, 57885161, 10 ** 50) - 1)

こうですね。

10942833323095203705645658725746141988071724285951

たぶんあってました。
最後にコードを示します。

from itertools import *
import time
import sys

def add(f, g):
    if len(f) < len(g):
        return add(g, f)
    h = f[:]
    for k in xrange(len(g)):
        h[k] += g[k]
    return h

def sub(f, g):
    h = f[:]
    for k, a in enumerate(g):
        if k < len(h):
            h[k] -= g[k]
        else:
            h.append(-a)
    return h

def mul(f, g):
    h = [ 0 ] * (len(f) + len(g) - 1)
    if len(f) < 20 or len(g) < 20:
        # normal multiply
        for k in xrange(len(f)):
            for l in xrange(len(g)):
                h[k+l] += f[k] * g[l]
    else:
        # Karatsuba algorithm
        mid = min(len(f) / 2, len(g) / 2)
        f1 = f[:mid]
        f2 = f[mid:]
        g1 = g[:mid]
        g2 = g[mid:]
        h1 = mul(f1, g1)
        h2 = mul(f2, g2)
        h3 = sub(add(h2, h1), mul(sub(f2, f1), sub(g2, g1)))
        for k, a in enumerate(h1):
            h[k] += a
        for k, a in enumerate(h3, mid):
            h[k] += a
        for k, a in enumerate(h2, mid * 2):
            if k < len(h):
                h[k] += a
            else:
                h.append(a)
    
    normalize(h)
    return h

def pow_dec(f, e):
    if e == 1:
        return f
    elif e % 2 == 1:
        return mul(pow_dec(f, e - 1), f)
    else:
        g = pow_dec(f, e / 2)
        h = mul(g, g)
        print >>sys.stderr, e, time.clock() - t0
        return h

def normalize(f):
    c = 0
    for k, a in enumerate(f):
        c, r = divmod(a + c, 10000)
        f[k] = int(r)
    while c > 0:
        c, r = divmod(c, 10000)
        f.append(int(r))

t0 = time.clock()
E = 57885161
f = pow_dec([ 2 ], E)
f[0] -= 1
g = chain(("%d" % (f[-1],),), (("%04d" % a) for a in reversed(f[:-1])))
print "".join(g)
print >>sys.stderr, time.clock() - t0