Project Euler 103(2)

http://projecteuler.net/index.php?section=problems&id=103


このシリーズの最後に、この問題をチェックする集合のペアを減らして解きましょう。
n = 7なら要素数2つずつまたは3つずつの集合のペアについてチェックすればよいです。チェックするペアはあらかじめ決めておけます。
また、2番目のルールを満たすような和が決まった集合Aを生成するようにしました。
これで、23秒程度でした。

from itertools import *

def get_pairs_need_to_check(n):
    def gen_pairs(S, diff = 1, rev = False):
        if S == ():
            if rev:
                yield (), ()
        else:
            e = S[0]
            S1 = S[1:]
            if diff == len(S):
                if rev:
                    for A, B in gen_pairs(S1, diff - 1, rev):
                        yield A, (e,) + B
            elif diff == -len(S):
                if rev:
                    for A, B in gen_pairs(S1, diff + 1, rev):
                        yield (e,) + A, B
            else:
                for A, B in gen_pairs(S1, diff - 1, rev or diff == 0):
                    yield A, (e,) + B
                for A, B in gen_pairs(S1, diff + 1, rev):
                    yield (e,) + A, B
    
    return [ (S[:1] + A, B) for m in range(4, n + 1, 2)
                    for S in combinations(range(n), m)
                        for A, B in gen_pairs(S[1:]) ]

def gen_sets(n, s, k = 0, d0 = 0, min_a = 1):
    if k == n:
        if s == 0:
            yield ()
    else:
        nB = (n + 1) / 2
        nC = (n - 1) / 2
        def sum_seq(begin, l):  # begin + ... + (begin + l - 1)
            return (begin * 2 + l - 1) * l / 2
        
        def lower_limit(a):
            if k < nB:
                return d0 + sum_seq(a, nB - k) <= \
                                sum_seq(a + n - nC - k, nC)
            elif k < n - nC:
                return d0 <= sum_seq(a + n - nC - k, nC)
            else:
                return d0 <= sum_seq(a, n - k)
        
        def upper_limit(a):
            return sum_seq(a, n - k) <= s
        
        if k != n - nC or s < d0:
            for a in dropwhile(lower_limit,
                        takewhile(upper_limit, count(min_a))):
                d = d0 + a if k < nB else d0 if k < n - nC else d0 - a
                for seq in gen_sets(n, s - a, k + 1, d, a + 1):
                    yield (a,) + seq

def condition1(a):
    def c1(S):
        B, C = S
        return sum(B) != sum(C)
    
    def select(a, e):
        return (tuple([ a[k] for k in e[0] ]),
                tuple([ a[k] for k in e[1] ]))
    
    return all(c1(select(a, e)) for e in pairs)

def join(a):
    S = list(a)
    S.sort()
    return "".join(map(str, S))

N = 7
pairs = list(get_pairs_need_to_check(N))
print (join(a) for n in count(1) for a in gen_sets(N, n)
                                    if condition1(a)).next()