Project Euler 34

http://projecteuler.net/index.php?section=problems&id=34


これも暗黙の上限がある問題。
そこまでの整数をしらみつぶしにしても十分速いが、例えば123と213は同じ数に変換されるので、重複組合せを出すクラスを使うと非常に速い。これを自作した。

#include <iostream>
#include "itertools.h"

using namespace std;
using namespace itertools;

int factorial(int n) {
    return n == 0 ? 1 : n * factorial(n - 1);
}

int is_valid(const vector<int>& v) {
    auto    g1 = iterable(v);
    auto    g2 = sorted(list(digits(sum(map(factorial, iterable(v))))));
    return all([] (tuple<int,int> x) {
                return fst(x) == 1 ? snd(x) <= 1 : fst(x) == snd(x);
            }, zip(g1, g2))
            && !g1.exists_next() && !g2.exists_next();
}

template<typename T>
class cRepeatedCombination {
    const vector<T> vmap;
    vector<T>   ret;
    vector<int> ref;
    int         length;
    bool        first;
    
public:
    cRepeatedCombination(const vector<T>& v, int n) : vmap(v), length(n) {
        ret = vector<int>(length, v.front());
        ref = vector<int>(length, 0);
        first = true;
    }
    vector<T> next() {
        if(!first) {
            const int   p = pos(length - 1);
            ref[p]++;
            ret[p] = vmap[ref[p]];
            reset(ref[p], p + 1);
        }
        
        first = false;
        return ret;
    }
    bool exists_next() {
        return exists();
    }
    
private:
    int pos(int p) {
        if(ref[p] != vmap.size() - 1) return p;
        else return pos(p - 1);
    }
    void reset(int n, int p) {
        if(p == length) return;
        ref[p] = n;
        ret[p] = vmap[n];
        reset(n, p + 1);
    }
    bool exists(int p = 0) {
        if(p == length) return false;
        else if(ref[p] != vmap.size() - 1) return true;
        else return exists(p + 1);
    }
};

template<typename T>
cRepeatedCombination<T> repeated_combination(const vector<T>& v, int n) {
    return cRepeatedCombination<T>(v, n);
}

int sum_by_length(int n) {
    return sum(map([] (const vector<int>& v)
                    { return sum(map(factorial, iterable(v))); },
            filter(is_valid,
            repeated_combination(list(range<>(1, 10)), n))));
}

int main() {
    cout << sum(map(sum_by_length, range<>(2, 8))) << endl;
}