MojoでProject Euler 66

https://projecteuler.net/problem=66

Pell方程式は連分数と密接に関連しています。Pell方程式の解は \sqrt{D}の連分数を途中で打ち切った分数の分子と分母になっています。なので、漸化式で分子と分母を求めて、それがPell方程式の解になっていればそれが最小解です。
例えば、 D=2とすると、 \sqrt{2}の整数部分は1、 \sqrt{2} - 1の逆数を取ると、 \sqrt{2} + 1、これの整数部分は2、あとは繰り返しなので、連分数は[1; 2]と表されます。連分数を途中で打ち切った分数の分子分母をそれぞれ a_n,\ b_nとすると、
 a_0 = 1,\ b_0 = 0,\ a_1 = 1,\ b_1 = 1
として、
 a_1^2 - Db_1^2 = -1
なのでPell方程式を満たしません。
 a_2 = 2a_1 + a_0 = 3
 b_2 = 2b_1 + b_0 = 2
 a_2^2 - 2b_2^2 = 1
なので、これがPell方程式の最小解です。

import sys


#################### List ####################

fn initialize_list[T: CollectionElement](N: Int, init: T) -> List[T]:
    var a = List[T](capacity=N)
    for n in range(N):
        a.append(init)
    return a

trait Printable(CollectionElement, Stringable):
    pass

fn print_list[T: Printable](a: List[T]):
    if a.size > 0:
        var s = "[" + str(a[0])
        for i in range(1, a.size):
            s += ", " + str(a[i])
        s += "]"
        print(s)
    else:
        print("[]")


#################### BigInteger ####################

@value
struct BigInteger(Stringable):
    var v: List[Int]
    
    fn __eq__(self, other: BigInteger) -> Bool:
        if len(self.v) != len(other.v):
            return False
        for i in range(len(self.v)):
            if self.v[i] != other.v[i]:
                return False
        return True
    
    fn __ne__(self, other: BigInteger) -> Bool:
        return not self.__eq__(other)
    
    fn __lt__(self, other: BigInteger) -> Bool:
        if len(self.v) != len(other.v):
            return len(self.v) < len(other.v)
        for i in range(len(self.v)-1, -1, -1):
            if self.v[i] != other.v[i]:
                return self.v[i] < other.v[i]
        return False
    
    fn __gt__(self, other: BigInteger) -> Bool:
        if len(self.v) != len(other.v):
            return len(self.v) > len(other.v)
        for i in range(len(self.v)-1, -1, -1):
            if self.v[i] != other.v[i]:
                return self.v[i] > other.v[i]
        return False
    
    fn __ge__(self, other: BigInteger) -> Bool:
        if len(self.v) != len(other.v):
            return len(self.v) > len(other.v)
        for i in range(len(self.v)-1, -1, -1):
            if self.v[i] != other.v[i]:
                return self.v[i] > other.v[i]
        return True
    
    fn __add__(self, other: BigInteger) -> BigInteger:
        var v = List[Int]()
        var carry = 0
        for i in range(max(self.v.size, other.v.size)):
            var d1 = self.v[i] if i < self.v.size else 0
            var d2 = other.v[i] if i < other.v.size else 0
            var n = d1 + d2 + carry
            v.append(n % 10)
            carry = n // 10
        if carry > 0:
            v.append(carry)
        return BigInteger(v)
    
    # 非負になる前提
    fn __sub__(self, other: BigInteger) -> BigInteger:
        var v = List[Int]()
        var carry = 0
        for i in range(max(self.v.size, other.v.size)):
            var d1 = self.v[i] if i < self.v.size else 0
            var d2 = other.v[i] if i < other.v.size else 0
            var n = d1 - d2 + carry
            v.append(n % 10)
            carry = n // 10
        if v.size > 1 and v[v.size-1] == 0:
            var tmp = v.pop()   # 受けないとwarning
        
        return BigInteger(v)
    
    fn __mul__(self, other: Int) -> BigInteger:
        var v = List[Int]()
        var carry = 0
        for d in self.v:
            var n = d[] * other + carry
            v.append(n % 10)
            carry = n // 10
        while carry > 0:
            var r = carry % 10
            carry //= 10
            v.append(r)
        return BigInteger(v)
    
    fn __rmul__(self, other: Int) -> BigInteger:
        return self * other
    
    fn __mul__(self, other: BigInteger) -> BigInteger:
        var L1 = len(self.v)
        var L2 = len(other.v)
        var v = initialize_list(L1 + L2 - 1, 0)
        for i in range(L1):
            for j in range(L2):
                v[i+j] += self.v[i] * other.v[j]
        
        var carry = 0
        for i in range(L1 + L2 - 1):
            v[i] += carry
            carry = v[i] // 10
            v[i] = v[i] % 10
        while carry > 0:
            v.append(carry % 10)
            carry //= 10
        return BigInteger(v)
    
    fn __str__(self) -> String:
        if len(self.v) == 0:
            return "0"
        
        var s: String = ""
        for i in range(self.v.size-1, -1, -1):
            if self.v[i] < 10:
                s += chr(self.v[i] + 48)
            else:
                s += chr(self.v[i] + 87)
        return s
    
    @staticmethod
    fn create(owned n: Int) -> BigInteger:
        var v = List[Int]()
        while n > 0:
            var d = n % 10
            n //= 10
            v.append(d)
        return BigInteger(v)


#################### process ####################

fn proceed(p: Int, q: Int, r: Int, s: Int) -> Tuple[Int, Int, Int]:
    # (sqrt(n) - p) / q
    var n = r * r + s
    var q1 = (n - p * p) // q
    var f = (r + p) // q1
    var p1 = -p + f * q1
    return (f, p1, q1)

fn minimal_Pell_solution(r: Int, s: Int, D: Int) -> BigInteger:
    var n = r * r + s
    var one = BigInteger.create(1)
    var a0 = one
    var b0 = BigInteger.create(0)
    var a1 = BigInteger.create(r)
    var b1 = one
    var p = r
    var q = 1
    while a1*a1 != D*b1*b1 + one:
        var t = proceed(p, q, r, s)
        var f = t.get[0, Int]()
        p = t.get[1, Int]()
        q = t.get[2, Int]()
        var a2 = f * a1 + a0
        var b2 = f * b1 + b0
        a0 = a1
        b0 = b1
        a1 = a2
        b1 = b2
    return a1

fn f(N: Int) -> Int:
    var max_x = BigInteger.create(1)
    var max_D = 1
    var r = 1
    while True:
        for s in range(1, 2*r+1):
            var D = r * r + s
            if D > N:
                return max_D
            var x = minimal_Pell_solution(r, s, D)
            print(D, x)
            if x > max_x:
                max_x = x
                max_D = D
        r += 1

fn main() raises:
    var args = sys.argv()
    var N = atol(args[1])
    print(f(N))