MojoでProject Euler 39

https://projecteuler.net/problem=39

直角三角形の各辺は
 a = k(m^2-n^2)
 b = 2kmn
 c = k(m^2+n^2)
 (m>n、mとnは互いに素、mとnは偶奇が逆)なので、 p = 2km(m+n)となります。なので、約数が多い方が同じpで違う(k, m, n)の組が多そうです。ただ、そんなに簡単ではありません。
まず、m+nは奇数です。mは奇数でも偶数でも可能ですが、 m \lt m+n \lt 2mを満たさなければなりません。またmとm+nは互いに素です。kは任意です。例えば、p = 60で考えると、奇素数は3と5なので、3と5をkとmとm+nに分配すると、 (k, m, m+n) = (1, 1, 15), (1, 3, 5), (5, 1, 3), (3, 1, 5), (15, 1, 1)が考えられます。条件を満たすのは (k, m, m+n) = (1, 3, 5)のみですが、mには2を1回だけ掛けることができるので(2は2個あるが、 p = 2km(m+n)の頭に一つ取られるので一つだけ掛けることができる)、 (k, m, m+n) = (5, 2, 3)も条件を満たします。
ただ、2が十分たくさんあって例えば4つあれば、 (k, m, m+n) = (1, 8, 15), (8, 3, 5), (20, 2, 3), (6, 4, 5)が条件を満たします。 (k, m, m+n) = (15, 1, 1)はいくつ2を掛けても満たしません。3と5を分配するときにそれぞれ k, m, m+nのどこかに所属するので、3通りずつあって3 * 3 = 9通りあるのですが、 m \le m+nなので(1, 1)以外はm < m+nになるほうだけ取るので半分になって、(3*3-1)/2 = 4通りあります。
一般に奇素数のべき乗をeとすると、mに1~e個、m+nに1~e個またはmとm+nには一つも分配しないと2e+1通りあります。なので、2が十分にあれば、 p = 2^{e}\prod_i{q_i^{e_i}}とすれば \prod_i{(2e_i+1)}個組合せがあります。
つまり、pの素因数分解の指数が分かれば(k, m, n)の組の個数は上から押さえることができるので、ある個数のpがあると分かっていれば、指数を見たらそれを超えられないことが分かるかもしれません。実際に p \le 1000とすると、たぶん使われる素数が小さい方が個数が多くなるので、例えば p = 2^3 * 3 * 5 * 7 = 840の個数を数えてみると8つあるので、例えば p = 2^e * q^2 * rなら(5 * 3 - 1) / 2 = 7だから、この形の素因数分解ならもう実際に数えなくても8を上回ることができないことが分かります。
なので、各指数の組合せで1000を超えないのが可能なものを列挙して、その中で素数がなるべく小さいものについて実際に(k, m, n)の組合せを数えて、これを超える可能性がある指数の組合せだけを考えます。指数の順番を変えたものと素数を大きくしたものも同じように個数を上から押さえることができます。例えば、 p = 2^e * 3^2 * 5なら p = 2^e * 3 * 5^2も同じですし、 p = 2^e * 5 * 11^2も同じです。
これで計算してみると、 p \le 10^{12}で少し時間オーバーするくらいでした。
ただ、この2つ目のような例は大きな素数も試さないといけないです。実際のところはなるべく素数が小さいほうが組合せの個数が多くなるのですが、必ずしもそうとは言えないのが難しいところです。ただ、2が少なくなると個数が少なくなることは分かります。例えば、 p = 2 * q * r * sとすると、 2 \lt qrs,\ 2q \lt rs,\ 2r \lt qs,\ 2 \lt qr,\ 2 \lt qs,\ 2 \lt rs,\ 2 \lt p,\ 2 \lt r,\ 2 \lt sとなるので、 (k, m, m+n) = (1, 1, qrs), (1, q, rs), (1, r, qs), (s, 1, qr), (r, 1, qs), (q, 1, rs), (rs, 1, q), (qs, 1, r), (qr, 1, s)の9通りは成り立たないことが分かって、元々13通りなので、最大でも4個の組合せしかないことが分かって、8個を超えられないから、2が1つしかないときは最初から考えなくてもよく、素数の組合せを減らすことができます。いつかできたらと思います。

from collections import Set
import sys

#################### library ####################

fn div_pow(n: Int, d: Int) -> Tuple[Int, Int]:
    var m = n
    var e = 0
    while m % d == 0:
        e += 1
        m //= d
    return (e, m)

fn make_prime_table(N: Int) -> List[Int]:
    var a = initialize_list(N+1, True)
    for p in range(2, N+1):
        if p * p > N:
            break
        elif not a[p]:
            continue
        
        for n in range(p*p, N+1, p):
            a[n] = False
    
    var ps = List[Int]()
    for n in range(2, N+1):
        if a[n]:
            ps.append(n)
    return ps


#################### List ####################

trait Printable(CollectionElement, Stringable):
    pass

fn initialize_list[T: CollectionElement](N: Int, init: T) -> List[T]:
    var a = List[T](capacity=N)
    for n in range(N):
        a.append(init)
    return a

fn print_list[T: Printable](a: List[T]):
    if a.size > 0:
        var s = "[" + str(a[0])
        for i in range(1, a.size):
            s += ", " + str(a[i])
        s += "]"
        print(s)
    else:
        print("[]")

fn copy_list[T: CollectionElement](a: List[T]) -> List[T]:
    return sublist(a, 0, len(a))

fn sublist[T: CollectionElement](a: List[T], first: Int, last: Int) -> List[T]:
    var b = List[T]()
    for i in range(first, last):
        b.append(a[i])
    return b

fn add_list[T: CollectionElement](a: List[T], b: List[T]) -> List[T]:
    var c = List[T]()
    for e in a:
        c.append(e[])
    for e in b:
        c.append(e[])
    return c

fn extend_list[T: CollectionElement](inout a: List[T], b: List[T]):
    for e in b:
        a.append(e[])

fn reverse_list[T: CollectionElement](a: List[T]) -> List[T]:
    var rev = List[T](capacity=len(a))
    for i in range(len(a)-1, -1, -1):
        rev.append(a[i])
    return rev


#################### Factors ####################

struct Factors(ComparableCollectionElement):
    var ps: List[Int]
    var es: List[Int]
    var value: Int
    
    fn __init__(inout self, ps: List[Int], es: List[Int]):
        self.ps = ps
        self.es = es
        self.value = 1
        self.value = self.calc_value()
    
    fn __copyinit__(inout self, other: Factors):
        self.ps = other.ps
        self.es = other.es
        self.value = other.value
    
    fn __moveinit__(inout self, owned other: Factors):
        self.ps = other.ps^
        self.es = other.es^
        self.value = other.value
    
    fn __len__(self) -> Int:
        return len(self.ps)
    
    fn __lt__(self, other: Self) -> Bool:
        return self.value < other.value
    
    fn __le__(self, other: Self) -> Bool:
        return self.value <= other.value
    
    fn __eq__(self, other: Self) -> Bool:
        return self.value == other.value
    
    fn __ne__(self, other: Self) -> Bool:
        return self.value != other.value
    
    fn __gt__(self, other: Self) -> Bool:
        return self.value > other.value
    
    fn __ge__(self, other: Self) -> Bool:
        return self.value >= other.value
    
    fn __mul__(self, other: Factors) -> Factors:
        var ps = List[Int]()
        var es = List[Int]()
        var L1 = self.ps.size
        var L2 = other.ps.size
        var k = 0
        var l = 0
        while k < L1 and l < L2:
            var p1 = self.ps[k]
            var e1 = self.es[k]
            var p2 = other.ps[l]
            var e2 = other.es[l]
            if p1 == p2:
                ps.append(p1)
                es.append(e1+e2)
                k += 1
                l += 1
            elif p1 < p2:
                ps.append(p1)
                es.append(e1)
                k += 1
            else:
                ps.append(p2)
                es.append(e2)
                l += 1
        
        for k1 in range(k, L1):
            ps.append(self.ps[k1])
            es.append(self.es[k1])
        for l1 in range(l, L2):
            ps.append(other.ps[l1])
            es.append(other.es[l1])
        return Factors(ps, es)
    
    fn __str__(self) -> String:
        if len(self) == 0:
            return "1"
        
        var s = str(self.ps[0])
        if self.es[0] > 1:
            s += "^" + str(self.es[0])
        for i in range(1, len(self)):
            s += " * " + str(self.ps[i])
            if self.es[i] > 1:
                s += "^" + str(self.es[i])
        return s
    
    fn calc_value(self) -> Int:
        var n = 1
        for i in range(len(self.ps)):
            var p = self.ps[i]
            var e = self.es[i]
            n *= p**e
        return n
    
    @staticmethod
    fn create(n: Int) -> Factors:
        var ps = List[Int]()
        var es = List[Int]()
        var m = n
        for p in range(2, n+1):
            if p*p > m:
                break
            elif m%p == 0:
                var a = div_pow(m, p)
                var e = a.get[0, Int]()
                m = a.get[1, Int]()
                ps.append(p)
                es.append(e)
        if m > 1:
            ps.append(m)
            es.append(1)
        return Factors(ps, es)


#################### process ####################

fn divide_into_two(fs: Factors) -> Tuple[List[Int], List[Int]]:
    if len(fs) == 1:
        var p = fs.ps[0]
        var e = fs.es[0]
        var ds1 = List[Int](1)
        var ds2 = List[Int](1)
        for e1 in range(1, e+1):
            ds1.append(1)
            ds2.append(p**e1)
            ds1.append(p**e1)
            ds2.append(1)
        return (ds1, ds2)
    else:
        var N = len(fs)
        var mid = N // 2
        var fs1 = Factors(sublist(fs.ps, 0, mid), sublist(fs.es, 0, mid))
        var fs2 = Factors(sublist(fs.ps, mid, N), sublist(fs.es, mid, N))
        var pair1 = divide_into_two(fs1)
        var ds11 = pair1.get[0, List[Int]]()
        var ds12 = pair1.get[1, List[Int]]()
        var pair2 = divide_into_two(fs2)
        var ds21 = pair2.get[0, List[Int]]()
        var ds22 = pair2.get[1, List[Int]]()
        var ds1 = List[Int]()
        var ds2 = List[Int]()
        for i in range(len(ds11)):
            for j in range(len(ds21)):
                ds1.append(ds11[i] * ds21[j])
                ds2.append(ds12[i] * ds22[j])
        return (ds1, ds2)

fn num_right_angle_triangles(fs: Factors) -> Int:
    var e2 = fs.es[0]
    var n2 = 1 << e2
    
    # 2を削除
    var ps = List[Int]()
    var es = List[Int]()
    for i in range(1, len(fs)):
        ps.append(fs.ps[i])
        es.append(fs.es[i])
    var fs1 = Factors(ps, es)
    
    var ds = divide_into_two(fs1)
    var ds1 = ds.get[0, List[Int]]()
    var ds2 = ds.get[1, List[Int]]()
    var counter = 0
    for i in range(len(ds1)):
        var d1 = ds1[i]
        var d2 = ds2[i]
        if d1 < d2 and d1 * n2 > d2:
            counter += 1
    return counter

fn next_factors(fs: Factors, primes: List[Int]) -> List[Factors]:
    var fss = List[Factors]()
    
    # 左端は常に積める
    var e3 = fs.es[0]
    var ps1 = copy_list(fs.ps)
    var es1 = copy_list(fs.es)
    es1[0] += 1
    fss.append(Factors(ps1, es1))
    
    # 左端と違う指数が前と一つ差なら積める
    for i in range(1, len(fs)):
        var e = fs.es[i]
        if e == e3 - 1:
            var ps2 = copy_list(fs.ps)
            var es2 = copy_list(fs.es)
            es2[i] += 1
            fss.append(Factors(ps2, es2))
            break
        elif e != e3:
            break
    else:
        # 全て1なら右に伸ばせる
        if e3 == 1:
            var ps3 = copy_list(fs.ps)
            var es3 = copy_list(fs.es)
            ps3.append(primes[len(fs)+1])
            es3.append(1)
            fss.append(Factors(ps3, es3))
    return fss

fn create_shapes(N: Int, primes: List[Int]) -> List[List[Int]]:
    var ess = List[List[Int]]()
    var stack = List[Factors]()
    stack.append(Factors.create(3))
    while len(stack) > 0:
        var fs = stack.pop()
        ess.append(fs.es)
        for fs1 in next_factors(fs, primes):
            if fs1[].value * 2 <= N:
                stack.append(fs1[])
    return ess

# Nを超えないように2を掛ける
fn mul_two_pow(fs: Factors, N: Int) -> Factors:
    var ps = List[Int](2)
    var es = List[Int](0)
    extend_list(ps, fs.ps)
    extend_list(es, fs.es)
    var value = fs.value
    while value * 2 <= N:
        es[0] += 1
        value *= 2
    return Factors(ps, es)

# esの形で左から詰まった素因数分解を作る
# es: [2, 1], N: 1000 => 2^4*3^2*5
fn create_minimal_factors(es1: List[Int], N: Int, primes: List[Int]) -> Factors:
    var ps = List[Int]()
    var es = List[Int]()
    for i in range(len(es1)):
        if es1[i] != 0:
            ps.append(primes[i+1])  # 2は飛ばす
            es.append(es1[i])
    var fs = Factors(ps, es)
    return mul_two_pow(fs, N)

fn normalize_es(fs: Factors, primes: List[Int]) -> List[Int]:
    var es = List[Int]()
    # 2は飛ばす
    var i = 1
    var j = 1
    var max_p = fs.ps[len(fs)-1]
    while primes[i] <= max_p:
        if primes[i] == fs.ps[j]:
            es.append(fs.es[j])
            j += 1
        else:
            es.append(0)
        i += 1
    return es

fn swap_es(fs: Factors, i: Int) -> Factors:
    var ps = copy_list(fs.ps)
    var es = copy_list(fs.es)
    var tmp = es[i]
    es[i] = es[i+1]
    es[i+1] = tmp
    return Factors(ps, es)

fn possible_max_num_right_triangles(es: List[Int]) -> Int:
    var n = 1
    for e in es:
        n *= e[]*2 + 1
    return n // 2

fn collect_all_permutations(fs0: Factors, N: Int) -> List[Factors]:
    var fss = List[Factors]()
    var s = Set[Int](fs0.value)
    var stack = List[Factors](fs0)
    while len(stack) > 0:
        var fs = stack.pop()
        fss.append(fs)
        for i in range(len(fs)-1):
            if fs.es[i] > fs.es[i+1]:
                var fs1 = swap_es(fs, i)
                if fs1.value <= N and fs1.value not in s:
                    stack.append(fs1)
                    s.add(fs1.value)
    return fss

fn move_states(es: List[Int], N: Int, primes: List[Int]) -> List[List[Int]]:
    # 指数を移動できる
    var ess = List[List[Int]]()
    
    # 最後の0の手前を右に一つ移動できる
    var found_zero = False
    for i in range(len(es)-1, -1, -1):
        if es[i] == 0:
            found_zero = True
        elif found_zero:
            var es1 = copy_list(es)
            es1[i+1] = es1[i]
            es1[i] = 0
            ess.append(es1)
            break
    
    # 最後が0以外でその前が0の0回以上の連続でその前は全て0以外のときのみ
    # 最後を一つ右に進められる
    var mode = 0
    for i in range(len(es)-1, -1, -1):
        if mode == 0:
            mode = 1
        elif mode == 1:
            if es[i] != 0:
                mode = 2
        elif mode == 2:
            if es[i] == 0:
                mode = 3
                break
    
    if mode == 2:
        var es2 = copy_list(es)
        es2.append(es[len(es)-1])
        es2[len(es)-1] = 0
        ess.append(es2)
    
    return ess

fn make_factors(es0: List[Int], primes: List[Int]) -> Factors:
    var ps = List[Int]()
    var es = List[Int]()
    for i in range(len(es0)):
        if es0[i] != 0:
            ps.append(primes[i])
            es.append(es0[i])
    return Factors(ps, es)

fn max_num_right_triangles(es0: List[Int], N: Int,
                                    primes: List[Int]) -> Tuple[Int, Int]:
    var upper_num = possible_max_num_right_triangles(es0)
    var fs0 = make_factors(es0, primes)
    var max_num = num_right_angle_triangles(fs0)
    var max_p = fs0.value
    for fs in collect_all_permutations(fs0, N):
        var stack = List[List[Int]](fs[].es)
        while len(stack) > 0:
            var es = stack.pop()
            var fs = create_minimal_factors(es, N, primes)
            if fs.es[0] == 0:   # 2が無い
                continue
            if fs.value > N:
                continue
            var num = num_right_angle_triangles(fs)
            if num > max_num:
                max_num = num
                max_p = fs.value
            var ess = move_states(es, N, primes)
            for es1 in ess:
                stack.append(es1[])
    return (max_num, max_p)

fn f(N: Int) -> Int:
    var M = 1000
    var primes = make_prime_table(M)
    var ess = create_shapes(N, primes)
    var max_num_triangles = 0
    var max_p = 0
    for es in ess:
        var fs = create_minimal_factors(es[], N, primes)
        var n = num_right_angle_triangles(fs)
        if n > max_num_triangles:
            max_p = fs.value
            max_num_triangles = n
    
    # esを絞る
    var ess1 = List[List[Int]]()
    for es in ess:
        if possible_max_num_right_triangles(es[]) > max_num_triangles:
            ess1.append(es[])
    
    for es in ess1:
        var max_num = possible_max_num_right_triangles(es[])
        if max_num <= max_num_triangles:
            continue
        var fs = make_factors(es[], primes)
        var pair = max_num_right_triangles(es[], N, primes)
        var num = pair.get[0, Int]()
        var p = pair.get[1, Int]()
        if num > max_num_triangles:
            max_num_triangles = num
            max_p = p
    
    print(max_num_triangles)
    return max_p

fn main() raises:
    var args = sys.argv()
    var N = atol(args[1])
    print(f(N))