MojoでProject Euler 38

https://projecteuler.net/problem=38

一般的にB進法で考えます。単純に実行するとすぐにオーバーフローするので、B進のBigIntegerを作ります。
n=2のときが調べなければならない被乗数の個数が多いのでここが問題になります。攻めて考えると、10進のとき、被除数は大きいほうから考えて、9xxxですが、99xxはもちろんダメで、95xx~98xxも2倍すると19xxxだからダメで、94xxは2倍すると188xxか189xxだからダメで、93xxではじめてだいじょうぶな数が出てきます。このように絞り込みをすれば高速に処理できます。ただ、時間内に答えが出たのは、B=20まででした。

Mojoは、Stringableというトレイトを実装すると、printで出力できるようになります。これをprint_listをGenericで書けるようになりました。ただ、複数のトレイトを実装する方法がわからなかったので、複数のトレイトをまとめた別のトレイトを作って、それを実装するようにしました。

#################### library ####################

import sys

#################### library ####################

fn initialize_list[T: CollectionElement](N: Int, init: T) -> List[T]:
    var a = List[T](capacity=N)
    for n in range(N):
        a.append(init)
    return a

trait Printable(CollectionElement, Stringable):
    pass

fn print_list[T: Printable](a: List[T]):
    if a.size > 0:
        var s = str(a[0])
        for i in range(1, a.size):
            s += " " + str(a[i])
        print(s)
    else:
        print()

fn reverse_list[T: CollectionElement](a: List[T]) -> List[T]:
    var rev = List[T](capacity=len(a))
    for i in range(len(a)-1, -1, -1):
        rev.append(a[i])
    return rev


#################### BigInteger ####################

struct BigInteger(CollectionElement, Stringable):
    var v: List[Int]
    var B: Int
    
    fn __init__(inout self, v: List[Int], B: Int):
        self.v = v
        self.B = B
    
    fn __copyinit__(inout self, other: BigInteger):
        self.v = other.v
        self.B = other.B
    
    fn __moveinit__(inout self, owned other: BigInteger):
        self.v = other.v^
        self.B = other.B
    
    fn __eq__(self, other: BigInteger) -> Bool:
        if len(self.v) != len(other.v):
            return False
        for i in range(len(self.v)):
            if self.v[i] != other.v[i]:
                return False
        return True
    
    fn __lt__(self, other: BigInteger) -> Bool:
        if len(self.v) != len(other.v):
            return len(self.v) < len(other.v)
        for i in range(len(self.v)-1, -1, -1):
            if self.v[i] != other.v[i]:
                return self.v[i] < other.v[i]
        return False
    
    fn __gt__(self, other: BigInteger) -> Bool:
        if len(self.v) != len(other.v):
            return len(self.v) > len(other.v)
        for i in range(len(self.v)-1, -1, -1):
            if self.v[i] != other.v[i]:
                return self.v[i] > other.v[i]
        return False
    
    fn __ge__(self, other: BigInteger) -> Bool:
        if len(self.v) != len(other.v):
            return len(self.v) > len(other.v)
        for i in range(len(self.v)-1, -1, -1):
            if self.v[i] != other.v[i]:
                return self.v[i] > other.v[i]
        return True
    
    fn __add__(self, other: BigInteger) -> BigInteger:
        var v = List[Int]()
        var carry = 0
        for i in range(max(self.v.size, other.v.size)):
            var d1 = self.v[i] if i < self.v.size else 0
            var d2 = other.v[i] if i < other.v.size else 0
            var n = d1 + d2 + carry
            v.append(n % self.B)
            carry = n // self.B
        if carry > 0:
            v.append(carry)
        return BigInteger(v, self.B)
    
    # 非負になる前提
    fn __sub__(self, other: BigInteger) -> BigInteger:
        var v = List[Int]()
        var carry = 0
        for i in range(max(self.v.size, other.v.size)):
            var d1 = self.v[i] if i < self.v.size else 0
            var d2 = other.v[i] if i < other.v.size else 0
            var n = d1 - d2 + carry
            v.append(n % self.B)
            carry = n // self.B
        if v.size > 1 and v[v.size-1] == 0:
            var tmp = v.pop()   # 受けないとwarning
        
        return BigInteger(v, self.B)
    
    fn __mul__(self, other: Int) -> BigInteger:
        var v = List[Int]()
        var carry = 0
        for d in self.v:
            var n = d[] * other + carry
            v.append(n % self.B)
            carry = n // self.B
        while carry > 0:
            var r = carry % self.B
            carry //= self.B
            v.append(r)
        return BigInteger(v, self.B)
    
    fn __floordiv__(self, d: Int) -> BigInteger:
        var rev = List[Int](capacity=len(self.v))
        var carry = 0
        for i in range(len(self.v)-1, -1, -1):
            carry = carry * 10 + self.v[i]
            var q = carry // d
            carry %= d
            if len(rev) != 0 or q != 0:
                rev.append(q)
        
        var a = reverse_list(rev)
        return BigInteger(a, self.B)
    
    fn is_zero(self) -> Bool:
        return len(self.v) == 0 or (len(self.v) == 1 and self.v[0] == 0)
    
    fn decrement(inout self):
        self.decrement_core(0)
    
    fn decrement_core(inout self, i: Int):
        if i == len(self.v) - 1 and self.v[i] == 1:
            _ = self.v.pop()
        elif self.v[i] == 0:
            self.v[i] = self.B - 1
            self.decrement_core(i + 1)
        else:
            self.v[i] -= 1
    
    fn __str__(self) -> String:
        if len(self.v) == 0:
            return "0"
        
        var s: String = ""
        for i in range(self.v.size-1, -1, -1):
            if self.v[i] < 10:
                s += chr(self.v[i] + 48)
            else:
                s += chr(self.v[i] + 87)
        return s
    
    @staticmethod
    fn create(n: Int, B: Int) -> BigInteger:
        var v = List[Int]()
        v.append(n)
        return BigInteger(v, B)


#################### process ####################

fn distribute(owned e: Int, n: Int) -> List[Int]:
    var ds = List[Int]()
    var q = e // n
    var r = e % n
    if r == 0:
        for _ in range(n):
            ds.append(q)
    else:
        for k in range(n):
            if k < r:
                ds.append(q)
            else:
                ds.append(q+1)
    return ds

fn concatenate(m: BigInteger, n: Int) -> BigInteger:
    var a = List[Int]()
    for k in range(n, 0, -1):
        var m1 = m * k
        for d in m1.v:
            a.append(d[])
    if n == 2:
        print_list(a)
    return BigInteger(a, m.B)

fn is_pandigital(n: BigInteger) -> Bool:
    var s = 0
    for d in n.v:
        if d[] == 0 or (s & (1 << d[])) != 0:
            return False
        s |= 1 << d[]
    return s == (1 << n.B) - 2

# B**e
fn power(e: Int, B: Int) -> BigInteger:
    var a = initialize_list(e+1, 0)
    a[e] = 1
    return BigInteger(a, B)

fn calc_range(n: Int, B: Int) -> Tuple[BigInteger, BigInteger]:
    var q = (B-1) // n
    var r = (B-1) % n
    var first = 0
    var last = 0
    var Bq = power(q, B)
    if r == 0:
        var first = power(q-1, B)
        Bq.decrement()
        var last = Bq // n
        return (first, last)
    else:
        var first = Bq // (n-r+1)
        Bq.decrement()
        var last = Bq // (n-r)
        if first.is_zero():
            first = BigInteger.create(1, B)
        return (first, last)

fn next_digits(first_half: Bool, B: Int) -> List[Int]:
    var v = List[Int](capacity=B)
    if B % 2 == 0:
        if first_half:
            for d in range(B//2-1, 0, -1):
                v.append(d*2+1)
                v.append(d*2)
        else:
            for d in range(B-1, B//2, -1):
                v.append(d*2+1)
                v.append(d*2)
            v.append(B+1)
    else:
        if first_half:
            v.append(B//2*2)
            for d in range(B//2-1, 0, -1):
                v.append(d*2+1)
                v.append(d*2)
        else:
            for d in range(B-1, B//2, -1):
                v.append(d*2+1-B)
                v.append(d*2-B)
    return v

fn find_max_pandigital(inout v: List[Int], s0: Int,
                        first_half: Bool, B: Int) -> BigInteger:
    if s0 == (1 << B) - 2:  # 0を除いて全ての数字が揃っている
        if first_half:
            var w = reverse_list(v)
            return BigInteger(w, B)
        else:
            var w = List[Int]()
            return BigInteger(w, B)
    
    var ds = next_digits(first_half, B)
    for pair in ds:
        var d = pair[] >> 1
        var b = (pair[] & 1) == 0
        var s = s0
        if ((s >> d) & 1) == 1:
            continue
        s |= 1 << d
        var d2 = d * 2 if b else d * 2 + 1
        if not first_half:
            d2 -= B
        if ((s >> d2) & 1) == 0:
            v.append(d)
            s |= 1 << d2
            var n = find_max_pandigital(v, s, b, B)
            s ^= 1 << d2
            _ = v.pop()
            if not n.is_zero():
                return n
    return BigInteger(List[Int](), B)

fn f2(B: Int) -> BigInteger:
    var v = List[Int]()
    if B % 2 == 0:
        var m = find_max_pandigital(v, 2, False, B)
        return concatenate(m, 2)
    else:
        var m = find_max_pandigital(v, 0, True, B)
        return concatenate(m, 2)

fn f_each(n: Int, B: Int) -> BigInteger:
    if n == 2:
        return f2(B)
    
    var rng = calc_range(n, B)
    var first = rng.get[0, BigInteger]()
    var last = rng.get[1, BigInteger]()
    var max_pan = BigInteger.create(0, B)
    while last >= first:
        var c = concatenate(last, n)
        if is_pandigital(c):
            if c > max_pan:
                max_pan = c
            break
        last.decrement()
    return max_pan

fn f(B: Int) -> BigInteger:
    var max_pan = BigInteger.create(0, B)
    for n in range(2, B):
        var pan = f_each(n, B)
        print(n, pan)
        if pan > max_pan:
            max_pan = pan
    return max_pan

fn main() raises:
    var args = sys.argv()
    var B = atol(args[1])
    print(f(B))