結合則が成り立たない場合のfold系の高速化

Modegramming Style: Scalazの型クラス

これを読んでいてふと思ったこと。
結合則が成り立つ場合、畳込み(fold)を行うときリストを分割して並列で計算させれば速くなります。例えばsumですね。1〜100を足し合わせるときに、1〜50と51〜100に分けてそれぞれを別プロセスで計算して、最後にその2つを足し合わせます。

from multiprocessing import Pool
import functools
import time

NUM_PROCESS = 2

def sum_range(first, last, k):
    return sum(xrange(first + (last - first) * k / NUM_PROCESS,
                        first + (last - first) * (k + 1) / NUM_PROCESS))

def sum_parallel(first, last):
    p = Pool(NUM_PROCESS)
    f = functools.partial(sum_range, first, last)
    return sum(p.map(f, range(NUM_PROCESS)))

if __name__ == '__main__':
    N = 10 ** 8
    t0 = time.clock()
    print sum(xrange(N))        # 4999999950000000
    print time.clock() - t0     # 9.09830021404
    t0 = time.clock()
    print sum_parallel(0, N)    # 4999999950000000
    print time.clock() - t0     # 5.12969264883

場合によっては並列化しなくても速くなります。例えば階乗です。100 !を計算するのに、半分に分けて1 * ... * 50と51 * ... * 100を計算します。これを再帰的に行います。

from multiprocessing import Pool
import functools
import time

def factorial(n):
    return reduce(lambda x, y: x * y, xrange(1, n + 1), 1)

def factorial2(n):
    def f(first, last):
        if first == last - 1:
            return first
        
        mid = (first + last) / 2
        return f(first, mid) * f(mid, last)
    
    return f(1, n + 1)

for N in map(lambda n: n * 10000, (1, 2, 5, 10, 20)):
    t0 = time.clock()
    a = factorial(N)
    print N, time.clock() - t0
    t0 = time.clock()
    b = factorial2(N)
    print N, time.clock() - t0
    print a == b

n !を1 * 2 * ...と順に計算すると、k番目の掛け算の結果がO(k log k)桁になります。kは小さいので掛け算の計算量は桁数に比例するとして、全体の計算量はO(n2 log n)となります。一方、分割した場合、掛け算にはKaratsuba法を使用したとしてO(k)桁同士の掛け算の計算量はO(k1.6)とすると、最初のn/2分割した領域での掛け算の計算量はO(n (log n)1.6)、n/4分割した領域での計算量はO(n/4 (2log n)1.6) = O(n20.6(log n)1.6)、最後はO(n n0.6(log n)1.6)で、これを足し合わせて、O((n log n)1.6)となります。実際に計算すると、

順番に掛ける方法はO(n2.27log n)、分割する方法はO(n1.57log n)でした。


さて、本題です。演算の結合則が成り立たない場合はどうすればいいでしょう。fold系で使う演算でなじみがあるのは数字のリストを10進数に直すときのものです。

lambda x, y: x * 10 + y

これで、[ 3, 1, 4 ]を畳込むと314となります。この演算を⊗とすると、

(3 ⊗ 1) ⊗ 4 = 314
3 ⊗ (1 ⊗ 4) = 44

となってしまいます。どうすればいいでしょう。少し考えるとこうすればよいことがわかります。

lambda (p, x), (q, y): (p * q, q * x + y)

という演算を考えてこれを⊗と書くと、

( (10, 3) ⊗ (10, 1)) ⊗ (10, 4) = (100, 31) ⊗ (10, 4) = (1000, 314)
(10, 3) ⊗ ( (10, 1) ⊗ (10, 4)) = (10, 3) ⊗ (100, 41) = (1000, 314)

とうまくいきます。これをコードにしてみましょう。

from multiprocessing import Pool
import time

def number(ds):
    return reduce(lambda x, y: x * 10 + y, ds, 0)

def number2(ds):
    def mul((p, x), (q, y)):
        return (p * q, x * q + y)
    
    def f(first, last):
        if first == last - 1:
            return (10, ds[first])
        
        mid = (first + last) / 2
        return mul(f(first, mid), f(mid, last))
    
    return f(0, len(ds))[1]

def g(N):
    for _ in xrange(N):
        yield 1
        yield 2

for N in map(lambda n: n * 10000, (1, 2, 5, 10, 20)):
    t0 = time.clock()
    a = number(g(N))
    print N, time.clock() - t0
    t0 = time.clock()
    b = number2(list(g(N)))
    print N, time.clock() - t0
    print a == b

40万桁で60倍くらい速くなりました。
しかし、一般的にはどうしたらこのように結合的でない演算を結合的な演算にすることができるのでしょうか。また、できない演算はどのようなものでしょうか。一般的にはほとんどはできないような気がします。例えば、

lambda x, y: x * y + 1

これなど結合則が成り立ちませんが、どうすれば結合則が成り立つようにできるかわかりません。