MojoでProject Euler 4

https://projecteuler.net/problem=4

この問題はPriorityQueueを速いので、Pythonでこんな感じに書きます。

# e004.py
from __future__ import annotations
from typing import Iterator
import heapq
import sys

def digits(n: int) -> Iterator[int]:
    while n != 0:
        n, d = divmod(n, 10)
        yield d

def is_palindromic(n: int) -> bool:
    ds = list(digits(n))
    L = len(ds)
    return all(ds[i] == ds[L-i-1] for i in range(L//2))

def nexts(p: int, q: int, max_n: int) -> Iterator[tuple[int, int]]:
    if q == max_n:
        yield (p - 1, q)
    if p != q:
        yield (p, q - 1)

def products(e: int) -> Iterator[int]:
    max_n = 10**e-1
    a = [(-max_n*max_n, max_n, max_n)]
    heapq.heapify(a)
    while True:
        n, p, q = heapq.heappop(a)
        yield -n
        for p1, q1 in nexts(p, q, max_n):
            heapq.heappush(a, (-p1*q1, p1, q1))

def f(e: int) -> int:
    for n in products(e):
        if is_palindromic(n):
            return n
    return -1

E = int(sys.argv[1])
print(f(E))

MojoでもPythonのライブラリが使えるようなのですが、使い方が分からなかったので、ヒープ木を実装して代用しました。

# e004.mojo
import sys

alias Item = Tuple[Int, Int, Int]
alias HeapQueue = DynamicVector[Item]

fn heappop(inout a: HeapQueue) -> Item:
    let top = a[0]
    a[0] = a.pop_back()
    let N = a.size
    var i = 0
    while i < N-1:
        let j1 = i*2+1
        let j2 = i*2+2
        var j: Int = 0
        if j1 >= N:
            break
        elif j2 == N:
            j = j1
        else:
            if a[j1].get[0, Int]() >= a[j2].get[0, Int]():
                j = j1
            else:
                j = j2
        if a[j].get[0, Int]() > a[i].get[0, Int]():
            a[i], a[j] = a[j], a[i]
            i = j
        else:
            break
    
    return top

fn heappush(inout a: HeapQueue, e: Item):
    a.push_back(e)
    var i = a.size - 1
    while i > 0:
        let j = (i-1) // 2
        if a[j].get[0, Int]() >= a[i].get[0, Int]():
            break
        else:
            a[i], a[j] = a[j], a[i]
            i = j

fn digits(n: Int) -> DynamicVector[Int]:
    var ds = DynamicVector[Int]()
    var m = n
    while m != 0:
        let d = m % 10
        m //= 10
#       m, d = divmod(m, 10)
        ds.push_back(d)
    return ds

fn is_palindromic(n: Int) -> Bool:
    let ds = digits(n)
    let L = ds.size
    for i in range(L//2):
        if ds[i] != ds[L-i-1]:
            return False
    return True
#   return all(ds[i] == ds[L-i-1] for i in range(L//2))

fn nexts(p: Int, q: Int, max_n: Int) -> DynamicVector[Tuple[Int, Int]]:
    var neighbors = DynamicVector[Tuple[Int, Int]]()
    if q == max_n:
        neighbors.push_back((p - 1, q))
    if p != q:
        neighbors.push_back((p, q - 1))
    return neighbors

fn f(e: Int) -> Int:
    let max_n = 10**e-1
    var a = HeapQueue()
    a.push_back((max_n*max_n, max_n, max_n))
    while True:
        let t = heappop(a)
        let n = t.get[0, Int]()
        let p = t.get[1, Int]()
        let q = t.get[2, Int]()
        if is_palindromic(n):
            return n
        let neighbors = nexts(p, q, 10**e-1)
        for i in range(neighbors.size):
            let t1 = neighbors[i]
            let p1 = t1.get[0, Int]()
            let q1 = t1.get[1, Int]()
            heappush(a, (p1*q1, p1, q1))

fn main():
    let args = sys.argv()
    try:
        let E = atol(args[1])
        print(f(E))
    except:
        pass

時間測定すると、

$ time python e004.py 7
99956644665999

real    0m13.420s
user    0m13.398s
sys     0m0.010s

$ time pypy3 e004.py 7
99956644665999

real    0m5.577s
user    0m5.405s
sys     0m0.161s

$ time mojo e004.mojo 7
99956644665999

real    0m1.305s
user    0m0.918s
sys     0m0.399s


ビルドすると、

$ mojo build e004.mojo -o e004_mojo
$ time ./e004_mojo 7
99956644665999

real    0m1.073s
user    0m0.770s
sys     0m0.300s

ちょっと速いですね。
C++とRustでも試してみましょう。

$ g++ -O3 e004.cpp -o e004_cpp
$ time ./e004_cpp 7
99956644665999

real    0m0.694s
user    0m0.690s
sys     0m0.000s

$ rustc -C opt-level=3 e004.rs -o e004_rust
$ time ./e004_rust 7
99956644665999

real    0m0.611s
user    0m0.589s
sys     0m0.010s

今回はRustのほうが速いですね。

// e004.cpp
#include <iostream>
#include <tuple>
#include <queue>

typedef long long   ll;

using namespace std;

vector<ll> digits(ll n) {
    vector<ll> ds;
    while(n != 0) {
        const ll    d = n % 10;
        n /= 10;
        ds.push_back(d);
    }
    return ds;
}

bool is_palindromic(ll n) {
    const auto  ds = digits(n);
    const size_t    L = ds.size();
    for(size_t i = 0; i < L/2; ++i) {
        if(ds[i] != ds[L-i-1])
            return false;
    }
    return true;
}

vector<pair<ll, ll>> nexts(ll p, ll q, ll max_n) {
    vector<pair<ll, ll>>    neighbors;
    if(q == max_n)
        neighbors.push_back(make_pair(p - 1, q));
    if(p != q)
        neighbors.push_back(make_pair(p, q - 1));
    return neighbors;
}

ll pow(ll b, int e) {
    return e == 0 ? 1 : b * pow(b, e-1);
}

ll f(int e) {
    const ll    max_n = pow(10, e) - 1;
    priority_queue<tuple<ll, ll, ll>>   pq;
    pq.push(make_tuple(max_n*max_n, max_n, max_n));
    while(!pq.empty()) {
        const auto  t = pq.top();
        const ll    n = get<0>(t);
        const ll    p = get<1>(t);
        const ll    q = get<2>(t);
        if(is_palindromic(n))
            return n;
        pq.pop();
        const auto  neighbors = nexts(p, q, max_n);
        for(auto it = neighbors.begin(); it != neighbors.end(); ++it) {
            const ll    p1 = it->first;
            const ll    q1 = it->second;
            pq.push(make_tuple(p1*q1, p1, q1));
        }
    }
    return 0;
}

int main(int argc, char **argv) {
    const int   E = atoi(argv[1]);
    cout << f(E) << endl;
}
// e004.rs
#![allow(non_snake_case)]

use std::env;
use std::collections::BinaryHeap;

fn digits(mut n: i64) -> Vec<i64> {
    let mut ds: Vec<i64> = vec![];
    while n != 0 {
        let d = n % 10;
        n /= 10;
        ds.push(d)
    }
    ds
}

fn is_palindromic(n: i64) -> bool {
    let ds = digits(n);
    let L = ds.len();
    (0..L/2).all(|i| ds[i] == ds[L-i-1])
}

fn nexts(p: i64, q: i64, max_n: i64) -> Vec<(i64, i64)> {
    let mut neighbors: Vec<(i64, i64)> = vec![];
    if q == max_n {
        neighbors.push((p - 1, q))
    }
    if p != q {
        neighbors.push((p, q - 1))
    }
    neighbors
}

fn f(e: u32) -> i64 {
    let max_n: i64 = 10i64.pow(e)-1;
    let mut heap = BinaryHeap::new();
    heap.push((max_n*max_n, max_n, max_n));
    while let Some((n, p, q)) = heap.pop() {
        if is_palindromic(n) {
            return n
        }
        let neighbors = nexts(p, q, max_n);
        for (p1, q1) in neighbors.into_iter() {
            heap.push((p1*q1, p1, q1))
        }
    }
    0
}

fn main() {
    let args: Vec<String> = env::args().collect();
    let N: u32 = args[1].parse().unwrap();
    println!("{}", f(N))
}