MojoでProject Euler 60

https://projecteuler.net/problem=60

各エッジについて、両端と共通のノードを探して、ということを繰り返して完全グラフを作ります。
しかし、Graphを

alias Node = Int
struct Graph:
    var g: Dict[Node, Set[Node]]

としようとしても、SetがCollectionElementでないためできません。仕方がないので、

struct Graph:
    var g: Dict[Node, List[Node]]

として、Set[Edge]を作っておいて、ノードとノードがエッジであることを判定します。このとき、EdgeはCollectionElementでKeyElementを実装していないといけません。

import sys


#################### List ####################

fn initialize_list[T: CollectionElement](N: Int, init: T) -> List[T]:
    var a = List[T](capacity=N)
    for n in range(N):
        a.append(init)
    return a

trait Printable(CollectionElement, Stringable):
    pass

fn print_list[T: Printable](a: List[T]):
    if a.size > 0:
        var s = "[" + str(a[0])
        for i in range(1, a.size):
            s += ", " + str(a[i])
        s += "]"
        print(s)
    else:
        print("[]")


#################### library ####################

fn make_prime_table(N: Int) -> List[Int]:
    var a = initialize_list(N+1, True)
    for p in range(2, N+1):
        if p * p > N:
            break
        elif not a[p]:
            continue
        
        for n in range(p*p, N+1, p):
            a[n] = False
    
    var ps = List[Int]()
    for n in range(2, N+1):
        if a[n]:
            ps.append(n)
    return ps

fn is_prime(n: Int) -> Bool:
    for p in range(2, n):
        if p * p > n:
            return True
        elif n % p == 0:
            return False
    return True


#################### Graph ####################

alias Node = Int

@value
struct Edge(KeyElement, Stringable):
    var v1: Node
    var v2: Node
    
    fn __eq__(self, other: Edge) -> Bool:
        return self.v1 == other.v1 and self.v2 == other.v2
    
    fn __ne__(self, other: Edge) -> Bool:
        return self.v1 != other.v1 or self.v2 != other.v2
    
    fn __hash__(self) -> Int:
        return self.v1 * 10000 + self.v2
    
    fn __str__(self) -> String:
        return "(" + str(self.v1) + ", " + str(self.v2) + ")"

struct Graph:
    var g: Dict[Node, List[Node]]
    
    fn __init__(inout self, g: Dict[Node, List[Node]]):
        self.g = g
    
    fn __contains__(self, v: Node) -> Bool:
        return v in self.g
    
    @staticmethod
    fn create_from_edges(edges: List[Edge]) -> Graph:
        var g = Dict[Node, List[Node]]()
        try:
            for edge in edges:
                var v1 = edge[].v1
                var v2 = edge[].v2
                if v1 in g:
                    g[v1].append(v2)
                else:
                    g[v1] = List[Node](v2)
                if v2 in g:
                    g[v2].append(v1)
                else:
                    g[v2] = List[Node](v1)
        except:
            pass
        return Graph(g)


#################### process ####################

fn connect(owned p: Int, owned q: Int) -> Int:
    var ds = List[Int]()
    while q > 0:
        var d = q % 10
        q //= 10
        ds.append(d)
    
    for i in range(len(ds)-1, -1, -1):
        p = p * 10 + ds[i]
    return p

fn is_connected(p: Int, q: Int) -> Bool:
    var n1 = connect(p, q)
    var n2 = connect(q, p)
    return is_prime(n1) and is_prime(n2)

fn collect_edges(nodes: List[Int], U: Int) -> List[Edge]:
    var edges = List[Edge]()
    var L = len(nodes)
    for i in range(L-1):
        var u = nodes[i]
        for j in range(i+1, L):
            var v = nodes[j]
            if u + v > U:
                break
            if is_connected(u, v):
                edges.append(Edge(u, v))
    return edges

# 桁数で分ける
fn divide_by_digits(ns: List[Int]) -> List[List[Int]]:
    var nss = List[List[Int]](List[Int]())
    var E = 0
    for n in ns:
        while n[] > 10**E:
            E += 1
            nss.append(List[Int]())
        nss[E].append(n[])
    return nss

fn find_common_neighbors(vs: List[Node], graph: Graph,
                                    set_edges: Set[Edge]) -> List[Node]:
    var common_neighbors = List[Node]()
    var v0 = vs[len(vs)-1]
    var vs1 = graph.g.get(v0, List[Node]())
    for v1 in vs1:
        if v1[] in vs:
            continue
        for v in vs:
            if (Edge(v[], v1[]) not in set_edges and
                            Edge(v1[], v[]) not in set_edges):
                break
        else:
            common_neighbors.append(v1[])
    return common_neighbors

fn expand_complete_graph(vs: List[Node], graph: Graph,
                                    set_edges: Set[Edge]) -> List[List[Node]]:
    var vss = List[List[Node]]()
    var neighbors = find_common_neighbors(vs, graph, set_edges)
    for v1 in neighbors:
        var vs1 = List[Node](vs)
        vs1.append(v1[])
        vss.append(vs1)
    return vss

fn find_complete_graphs(edge: Edge, graph: Graph,
                            set_edges: Set[Edge], N: Int) -> List[List[Node]]:
    var vss = List[List[Int]](List[Int](edge.v1, edge.v2))
    for _ in range(3, N+1):
        var expanded_vss = List[List[Int]]()
        for vs in vss:
            var vss1 = expand_complete_graph(vs[], graph, set_edges)
            expanded_vss.extend(vss1)
        vss = expanded_vss
    return vss

fn sum(v: List[Int]) -> Int:
    var s = 0
    for e in v:
        s += e[]
    return s

fn f_each(N: Int, L: Int, U: Int) -> Int:
    var primes = make_prime_table(L)
    var edges = collect_edges(primes, U)
    var set_edges = Set[Edge](edges)
    var graph = Graph.create_from_edges(edges)
    var INF = 10**18
    var min_sum = INF
    for edge in edges:
        var vss = find_complete_graphs(edge[], graph, set_edges, N)
        for vs in vss:
            var s = sum(vs[])
            if s < min_sum:
                min_sum = s
    return min_sum

fn f(N: Int) -> Int:
    var L = 1000
    var U = 2000
    var INF = 10**18
    var min_s = INF
    while True:
        var s = f_each(N, L, U)
        if s < min_s:
            min_s = s
        if min_s <= L:
            return s
        elif min_s < INF:
            L = min_s
            U = min_s
        else:
            L *= 10
            U *= 10

fn main() raises:
    var args = sys.argv()
    var N = atol(args[1])
    print(f(N))