MojoでProject Euler 30

https://projecteuler.net/problem=30

2から順に等式が成り立つかどうか調べていっても十分間に合うのですが、これは重複組み合わせを使うと計算量が減ります。1634を分解して4乗和を取ると1634になりますが、1346や6431も1634になります。4桁で重複が無いから24個の数がが同じになることが分かります。6桁だと重複組み合わせは _{10}H_6 = 5005通りしかありません。なので、重複組合せをだして、[1, 3, 4, 6]のとき、4乗和が1634になって、それを数字に分解してソートすると、[1, 3, 4, 6]で元と同じになるので、1634が該当する数になります。この方法で17乗まで計算できました。

最新のバージョンで、DynamicVectorはListになって、Pythonと同じような便利なメソッドも追加されました。メソッドではないですが、ソートもできるんですね。

https://docs.modular.com/mojo/stdlib/collections/list#list

from algorithm.sort import sort
from math import min
import sys


#################### library ####################

fn print_vector(v: List[Int]):
    for k in range((v.size + 9) // 10):
        var s = str("")
        for i in range(k*10, min(k*10+10, v.size)):
            s += str(v[i]) + " "
        print(s)

fn digits(n: Int) -> List[Int]:
    var m = n
    var ds = List[Int]()
    while m > 0:
        var d = m % 10
        m //= 10
        ds.append(d)
    return ds


#################### process ####################

fn duplicate_combinations(a: List[Int], n: Int) -> List[List[Int]]:
    var v = List[List[Int]]()
    if n == 1:
        for e in a:
            var w = List[Int](e[])
            v.append(w)
    elif n % 2 == 1:
        var v1 = duplicate_combinations(a, n-1)
        for e in a:
            for w1 in v1:
                if e[] > w1[][0]:
                    continue
                var w = List[Int](e[])
                w.extend(w1[])
                v.append(w)
    else:
        var m = n // 2
        var v1 = duplicate_combinations(a, m)
        for w1 in v1:
            for w2 in v1:
                if w1[][m-1] > w2[][0]:
                    continue
                var w = w1[]
                w.extend(w2[])
                v.append(w)
    return v

fn upper_num_digits(E: Int) -> Int:
    for n in range(1, 20):
        if n * 9**E < 10**(n-1):
            return n - 1
    return 0

fn is_coincident(w: List[Int], v: List[Int]) -> Bool:
    var delta = v.size - w.size
    for i in range(delta):
        if v[i] != 0:
            return False
    for i in range(w.size):
        if w[i] != v[i+delta]:
            return False
    return True

fn valid_sum_pows(v: List[Int], E: Int) -> Int:
    var s = 0
    for e in v:
        s += e[]**E
    
    var w = digits(s)
    sort(w)
    
    if is_coincident(w, v):
        return s
    else:
        return 0

fn f(E: Int) -> Int:
    var N = upper_num_digits(E)
    var ds = List[Int]()
    for d in range(10):
        ds.append(d)
    var v = duplicate_combinations(ds, N)
    var s = 0
    for w in v:
        s += valid_sum_pows(w[], E)
    return s - 1

fn main() raises:
    var args = sys.argv()
    var E = atol(args[1])
    print(f(E))