MojoでProject Euler 44

https://projecteuler.net/problem=44

 P_i + P_j = P_kとすると、
 \displaystyle \frac{i(3i-1)}{2} + \frac{j(3j-1)}{2} = \frac{k(3k-1)}{2}
両辺を24倍して2を足すと、
 (6i-1)^2 + (6j-1)^2 = (6k-1)^2 + 1
 x \equiv 6i-1 \ \ y \equiv 6j-1 \ \ z \equiv 6k-1
とすると、
 (x-1)(x+1) = (z-y)(z+y)
だから、ふるいで素因数分解をしておいて、xの小さいほうから(x-1)(x+1)を素因数分解を求めます。そうすると約数が簡単に求められるので、そのときに、yとzが6で割って5余るならとりあえず一方の条件は満たします。そして足しても五角数かを調べます。

import sys

#################### library ####################

fn int_sqrt(n: Int) -> Int:
    var x = 1
    x = (x + n // x) // 2
    while True:
        var x1 = (x + n // x) // 2
        if x1 >= x:
            return x
        x = x1

fn div_pow(n: Int, d: Int) -> Tuple[Int, Int]:
    var m = n
    var e = 0
    while m % d == 0:
        e += 1
        m //= d
    return (e, m)

fn make_prime_table(N: Int) -> List[Int]:
    var a = initialize_list(N+1, True)
    for p in range(2, N+1):
        if p * p > N:
            break
        elif not a[p]:
            continue
        
        for n in range(p*p, N+1, p):
            a[n] = False
    
    var ps = List[Int]()
    for n in range(2, N+1):
        if a[n]:
            ps.append(n)
    return ps


#################### List ####################

trait Printable(CollectionElement, Stringable):
    pass

fn initialize_list[T: CollectionElement](N: Int, init: T) -> List[T]:
    var a = List[T](capacity=N)
    for n in range(N):
        a.append(init)
    return a

fn print_list[T: Printable](a: List[T]):
    if a.size > 0:
        var s = "[" + str(a[0])
        for i in range(1, a.size):
            s += ", " + str(a[i])
        s += "]"
        print(s)
    else:
        print("[]")

fn copy_list[T: CollectionElement](a: List[T]) -> List[T]:
    return sublist(a, 0, len(a))

fn sublist[T: CollectionElement](a: List[T], first: Int, last: Int) -> List[T]:
    var b = List[T]()
    for i in range(first, last):
        b.append(a[i])
    return b

fn add_list[T: CollectionElement](a: List[T], b: List[T]) -> List[T]:
    var c = List[T]()
    for e in a:
        c.append(e[])
    for e in b:
        c.append(e[])
    return c

fn extend_list[T: CollectionElement](inout a: List[T], b: List[T]):
    for e in b:
        a.append(e[])

fn reverse_list[T: CollectionElement](a: List[T]) -> List[T]:
    var rev = List[T](capacity=len(a))
    for i in range(len(a)-1, -1, -1):
        rev.append(a[i])
    return rev


#################### Factors ####################

struct Factors(ComparableCollectionElement):
    var ps: List[Int]
    var es: List[Int]
    var value: Int
    
    fn __init__(inout self, ps: List[Int], es: List[Int]):
        self.ps = ps
        self.es = es
        self.value = 1
        self.value = self.calc_value()
    
    fn __copyinit__(inout self, other: Factors):
        self.ps = other.ps
        self.es = other.es
        self.value = other.value
    
    fn __moveinit__(inout self, owned other: Factors):
        self.ps = other.ps^
        self.es = other.es^
        self.value = other.value
    
    fn __len__(self) -> Int:
        return len(self.ps)
    
    fn __lt__(self, other: Self) -> Bool:
        return self.value < other.value
    
    fn __le__(self, other: Self) -> Bool:
        return self.value <= other.value
    
    fn __eq__(self, other: Self) -> Bool:
        return self.value == other.value
    
    fn __ne__(self, other: Self) -> Bool:
        return self.value != other.value
    
    fn __gt__(self, other: Self) -> Bool:
        return self.value > other.value
    
    fn __ge__(self, other: Self) -> Bool:
        return self.value >= other.value
    
    fn __mul__(self, other: Factors) -> Factors:
        var ps = List[Int]()
        var es = List[Int]()
        var L1 = self.ps.size
        var L2 = other.ps.size
        var k = 0
        var l = 0
        while k < L1 and l < L2:
            var p1 = self.ps[k]
            var e1 = self.es[k]
            var p2 = other.ps[l]
            var e2 = other.es[l]
            if p1 == p2:
                ps.append(p1)
                es.append(e1+e2)
                k += 1
                l += 1
            elif p1 < p2:
                ps.append(p1)
                es.append(e1)
                k += 1
            else:
                ps.append(p2)
                es.append(e2)
                l += 1
        
        for k1 in range(k, L1):
            ps.append(self.ps[k1])
            es.append(self.es[k1])
        for l1 in range(l, L2):
            ps.append(other.ps[l1])
            es.append(other.es[l1])
        return Factors(ps, es)
    
    fn __imul__(inout self, other: Factors):
        for i in range(len(other.ps)):
            var p = other.ps[i]
            var e = other.es[i]
            for j in range(len(self.ps)):
                if self.ps[j] == p:
                    self.es[j] += e
                    break
            else:
                self.ps.append(p)
                self.es.append(e)
            self.value *= p**e
    
    fn __str__(self) -> String:
        if len(self) == 0:
            return "1"
        
        var s = str(self.ps[0])
        if self.es[0] > 1:
            s += "^" + str(self.es[0])
        for i in range(1, len(self)):
            s += " * " + str(self.ps[i])
            if self.es[i] > 1:
                s += "^" + str(self.es[i])
        return s
    
    fn calc_value(self) -> Int:
        var n = 1
        for i in range(len(self.ps)):
            var p = self.ps[i]
            var e = self.es[i]
            n *= p**e
        return n
    
    @staticmethod
    fn create(n: Int) -> Factors:
        var ps = List[Int]()
        var es = List[Int]()
        var m = n
        for p in range(2, n+1):
            if p*p > m:
                break
            elif m%p == 0:
                var a = div_pow(m, p)
                var e = a.get[0, Int]()
                m = a.get[1, Int]()
                ps.append(p)
                es.append(e)
        if m > 1:
            ps.append(m)
            es.append(1)
        return Factors(ps, es)
    
    @staticmethod
    fn make_partial_factors_table(first: Int, last: Int,
                                    primes: List[Int]) -> List[Factors]:
        fn mul(inout fs: Factors, p: Int, e: Int):
            fs.ps.append(p)
            fs.es.append(e)
            fs.value *= p**e
        
        var N = last - first
        var a = List[Int](capacity=N)
        for n in range(first, last):
            a.append(n)
        var b = initialize_list(N, Factors.create(1))
        for p in primes:
            if p[] * p[] >= last:
                break
            var n0 = (first + p[] - 1) // p[] * p[]
            for n in range(n0, last, p[]):
                var t = div_pow(a[n-first], p[])
                var e = t.get[0, Int]()
                a[n-first] = t.get[1, Int]()
                mul(b[n-first], p[], e)
        
        for k in range(last - first):
            if a[k] != 1:
                mul(b[k], a[k], 1)
        return b
    
    @staticmethod
    fn divisors_core(ps: List[Int], es: List[Int]) -> List[Int]:
        if len(es) == 1:
            var ds = List[Int](1)
            var p = ps[0]
            var n = 1
            for _ in range(es[0]):
                n *= p
                ds.append(n)
            return ds
        else:
            var mid = len(ps) // 2
            var ps1 = sublist(ps, 0, mid)
            var es1 = sublist(es, 0, mid)
            var ps2 = sublist(ps, mid, len(ps))
            var es2 = sublist(es, mid, len(ps))
            var ds1 = Factors.divisors_core(ps1, es1)
            var ds2 = Factors.divisors_core(ps2, es2)
            var ds = List[Int](capacity=len(ds1)*len(ds2))
            for d1 in ds1:
                for d2 in ds2:
                    ds.append(d1[] * d2[])
            return ds
    
    @staticmethod
    fn divisors(fs: Factors) -> List[Int]:
        if len(fs) == 0:
            return List[Int](1)
        else:
            return Factors.divisors_core(fs.ps, fs.es)


#################### process ####################

fn is_pentagonal(n: Int) -> Bool:
    var m = 1 + 24 * n
    var r = int_sqrt(m)
    if r * r != m:
        return False
    else:
        return (1 + r) % 6 == 0

# n(3n-1)/2 * 24 + 1 = (6n-1)^2
# x^2 + y^2 = z^2 + 1
# x^2 - 1 = z^2 - y^2
fn f() -> Int:
    var primes = make_prime_table(10000)
    var M = 6000
    var first = 4
    while True:
        var fss = Factors.make_partial_factors_table(first, first+M+1, primes)
        for x in range(first+1, first+M-1, 6):
            var n = (x - 1) * (x + 1)
            var fs = fss[x-first-1] * fss[x-first+1]
            var ds = Factors.divisors(fs)
            for d1 in ds:
                var d2 = n // d1[]
                if d2 <= d1[] or (d2 - d1[]) % 2 == 1:
                    continue
                var y = (d2 - d1[]) // 2
                var z = (d1[] + d2) // 2
                if y % 6 == 5 and z % 6 == 5:
                    var P2 = (y * y - 1) // 24
                    var P3 = (z * z - 1) // 24
                    var P4 = P2 + P3
                    if is_pentagonal(P4):
                        return P3 - P2
        first += M

fn main():
    print(f())