MojoでProject Euler 32

https://projecteuler.net/problem=32

問題をB進法に拡張します。
単にしらみつぶししてもいいのですが、それだけだと面白くないのでもう少し工夫します。B-1の剰余を考えると、各桁を足した和の剰余と同じになります。そうすると剰余の組み合わせが限られます。被乗数の剰余をx、乗数の剰余をyとすると、10進なら、

 x + y + xy \equiv 0

なので、

 (x - 1)(y - 1) \equiv 1

となって、x=0ならy=0に限られて、x=1なら対応するyはありません。
ただし、これはBが偶数の時で、奇数だと次のようになります。

 \displaystyle x + y + xy \equiv \frac{B-1}{2}

また、Bが偶数のときは、例えばB=10なら、2桁×3桁=4桁みたいなので、被乗数を決めると乗数は上から抑えられます。逆にBが奇数のときは下から抑えられます。なので、偶数と奇数でコードを分けるとよいです。

from collections import Set
import sys


#################### library ####################

def digits(n: Int, b: Int) -> List[Int]:
    var m = n;
    var ds = List[Int]()
    while m > 0:
        var d = m % b
        ds.append(d)
        m //= b
    return ds

# n以上でdの剰余がrの最小の整数
def ceil_mod(n: Int, r: Int, d: Int) -> Int:
    return (n+d-r-1) // d * d + r


#################### process ####################

fn find_num_digits_combs(B: Int) -> List[Tuple[Int, Int]]:
    var v = List[Tuple[Int, Int]]()
    for i in range(1, B//4+1):
        v.append((i, B//2-i))
    return v

fn find_mod_combs(B: Int) -> List[Tuple[Int, Int]]:
    # x + y + xy \equiv 0 (mod B-1)
    var v = List[Tuple[Int, Int]]()
    var sum_r = 0 if B % 2 == 0 else (B-1)//2
    for x in range(B-1):
        for y in range(B-1):
            if (x + y + x*y) % (B-1) == sum_r:
                v.append((x, y))
    return v

fn set_digits(n: Int, s: Int, B: Int) -> Int:
    var s1 = s
    try:
        var ds = digits(n, B)
        for d in ds:
            if ((s1 >> d[]) & 1) == 1:
                return -1
            else:
                s1 |= 1 << d[]
        return s1
    except:
        return -1

fn sum(s: Set[Int]) -> Int:
    var s1 = 0
    for n in s:
        s1 += n[]
    return s1

fn f_even(B: Int) raises -> Int:
    var s = Set[Int]()
    var v = find_num_digits_combs(B)
    var w = find_mod_combs(B)
    for t in v:
        var nd1 = t[].get[0, Int]()
        var nd2 = t[].get[1, Int]()
        var nd3 = B // 2 - 1
        for u in w:
            var r1 = u[].get[0, Int]()
            var r2 = u[].get[1, Int]()
            var first1 = ceil_mod(B**(nd1-1), r1, B-1)
            var first2 = ceil_mod(B**(nd2-1), r2, B-1)
            for n1 in range(first1, B**nd1, B-1):
                var s1 = set_digits(n1, 0, B)
                if s1 == -1:
                    continue
                # 100
                var last2 = B**nd3 // n1 + 1
                for n2 in range(first2, last2):
                    var n3 = n1 * n2
                    var s2 = set_digits(n2, s1, B)
                    if s2 == -1:
                        continue
                    var s3 = set_digits(n3, s2, B)
                    if s3 == (1 << B) - 2:
                        s.add(n3)
                        print(n1, n2, n3)
    return sum(s)

fn f_odd(B: Int) raises -> Int:
    var s = Set[Int]()
    var v = find_num_digits_combs(B)
    var w = find_mod_combs(B)
    for t in v:
        var nd1 = t[].get[0, Int]()
        var nd2 = t[].get[1, Int]()
        var nd3 = B // 2
        for u in w:
            var r1 = u[].get[0, Int]()
            var r2 = u[].get[1, Int]()
            var first1 = ceil_mod(B**(nd1-1), r1, B-1)
            for n1 in range(first1, B**nd1, B-1):
                var s1 = set_digits(n1, 0, B)
                if s1 == -1:
                    continue
                var first2 = ceil_mod(10**(nd3-1)//n1, r2, B-1)
                for n2 in range(first2, B**nd2):
                    var n3 = n1 * n2
                    var s2 = set_digits(n2, s1, B)
                    if s2 == -1:
                        continue
                    var s3 = set_digits(n3, s2, B)
                    if s3 == (1 << B) - 2:
                        s.add(n3)
                        print(n1, n2, n3)
    return sum(s)

fn f(B: Int) raises -> Int:
    if B % 2 == 0:
        return f_even(B)
    else:
        return f_odd(B)

fn main() raises:
    var args = sys.argv()
    var B = atol(args[1])
    print(f(B))