もう一つ、これはC/C++で行列の掛け算を書いたことがある人にはよく知られていると思いますが、乗数の方の行列を転置すると速くなります。
もう一度C++のコードを見てみましょう。
Matrix C(H, Vec(W));
for(size_t i = 0U; i < H; ++i) {
for(size_t j = 0U; j < W; ++j) {
__int128 e = 0;
for(size_t k = 0U; k < M; ++k)
e += A[i][k] * B[k][j];
C[i][j] = (long)(e % D);
}
}
一番内側のループでkを振っているので、行列Bを縦にスキャンしていくことになります。しかし、メモリは横に並んでいるので、飛び飛びに値を拾うことになり、キャッシュの中身が次々と入れ替わることになります。Bを転置すると、
const Matrix B_trans = transpose(B);
Matrix C(H, Vec(W));
for(size_t i = 0U; i < H; ++i) {
for(size_t j = 0U; j < W; ++j) {
__int128 e = 0;
for(size_t k = 0U; k < M; ++k)
e += A[i][k] * B_trans[j][k];
C[i][j] = (long)(e % D);
}
}
B_transは横にスキャンしています。これでキャッシュの入れ替わりが少なくなり、高速化されます。
実行してみましょう。
1134
256330968
12.208092451095581
これまでのバージョンの時間を並べてみましょう。
method |
time |
ratio |
Python |
1756.86s |
225.8 |
PyPy |
277.08s |
22.7 |
C++ |
206.85s |
16.9 |
C++128 |
30.57s |
2.5 |
C++transpose |
12.21s |
1 |
最後に、コードを表示します。
from distutils.core import setup, Extension
setup(name = 'mymatrix', version = '1.0.0', \
ext_modules = [Extension('mymatrix', ['matrix.cpp'])])
#include <Python.h>
#include <vector>
typedef std::vector<long> Vec;
typedef std::vector<Vec> Matrix;
static Vec make_vector(PyObject *obj_list);
static Matrix make_matrix(PyObject *args);
static PyObject *mul(PyObject *self, PyObject *args);
static Matrix mul_core(const Matrix& A, const Matrix& B, long D);
static Matrix mul_core128(const Matrix& A, const Matrix& B, long D);
static Matrix mul_core_trans(const Matrix& A, const Matrix& B, long D);
static Matrix transpose(const Matrix& A);
static PyObject *make_python_matrix(const Matrix& A);
static PyObject *make_python_vector(const Vec& v);
#include <iostream>
#include "matrix.h"
using namespace std;
Vec make_vector(PyObject *obj_list) {
if(!PyList_Check(obj_list))
return Vec(0);
Vec v;
const int M = PyList_Size(obj_list);
for(int j = 0; j < M; ++j) {
PyObject *obj = PyList_GetItem(obj_list, (Py_ssize_t)j);
if(!PyLong_Check(obj)) {
printf("vector[%d] is not long.\n", j);
return Vec(0);
}
const long d = PyLong_AsLong(obj);
v.push_back(d);
}
return v;
}
Matrix make_matrix(PyObject *obj_table) {
if(!PyList_Check(obj_table))
return Matrix(0);
const int L = PyList_Size(obj_table);
Matrix X(L);
for(int i = 0; i < L; ++i) {
PyObject *obj = PyList_GetItem(obj_table, (Py_ssize_t)i);
Vec v = make_vector(obj);
if(v.empty())
return Matrix(0);
X[i] = v;
}
return X;
}
PyObject *mul(PyObject *self, PyObject *args) {
PyObject *obj1, *obj2;
long D;
if(!PyArg_ParseTuple(args, "OOl", &obj1, &obj2, &D))
return make_python_matrix(Matrix(0));
const Matrix E = make_matrix(obj2);
const Matrix A = make_matrix(obj1);
const Matrix B = make_matrix(obj2);
Matrix C = mul_core_trans(A, B, D);
return make_python_matrix(C);
}
Matrix mul_core(const Matrix& A, const Matrix& B, long D) {
const size_t H = A.size();
const size_t M = B.size();
const size_t W = B.front().size();
Matrix C(H, Vec(W));
for(size_t i = 0U; i < H; ++i) {
for(size_t j = 0U; j < W; ++j) {
for(size_t k = 0U; k < M; ++k)
C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % D;
}
}
return C;
}
Matrix mul_core128(const Matrix& A, const Matrix& B, long D) {
const size_t H = A.size();
const size_t M = B.size();
const size_t W = B.front().size();
Matrix C(H, Vec(W));
for(size_t i = 0U; i < H; ++i) {
for(size_t j = 0U; j < W; ++j) {
__int128 e = 0;
for(size_t k = 0U; k < M; ++k)
e += A[i][k] * B[k][j];
C[i][j] = (long)(e % D);
}
}
return C;
}
Matrix mul_core_trans(const Matrix& A, const Matrix& B, long D) {
const size_t H = A.size();
const size_t M = B.size();
const size_t W = B.front().size();
const Matrix B_trans = transpose(B);
Matrix C(H, Vec(W));
for(size_t i = 0U; i < H; ++i) {
for(size_t j = 0U; j < W; ++j) {
__int128 e = 0;
for(size_t k = 0U; k < M; ++k)
e += A[i][k] * B_trans[j][k];
C[i][j] = (long)(e % D);
}
}
return C;
}
Matrix transpose(const Matrix& A) {
const size_t H = A.size();
const size_t W = A.front().size();
Matrix B(W, vector<long>(H));
for(size_t i = 0U; i < W; ++i) {
for(size_t j = 0U; j < H; ++j)
B[i][j] = A[j][i];
}
return B;
}
PyObject *make_python_matrix(const Matrix& A) {
const size_t H = A.size();
PyObject *obj_mat = PyList_New(H);
for(size_t i = 0U; i < H; ++i) {
PyObject *obj_list = make_python_vector(A[i]);
PyList_SET_ITEM(obj_mat, i, obj_list);
}
return obj_mat;
}
PyObject *make_python_vector(const Vec& v) {
const size_t L = v.size();
PyObject *obj_list = PyList_New(L);
for(size_t i = 0U; i < L; ++i)
PyList_SET_ITEM(obj_list, i, PyLong_FromLong(v[i]));
return obj_list;
}
static PyObject *print_matrix(PyObject *self, PyObject *args) {
PyObject *obj_table;
if(!PyArg_ParseTuple(args, "O", &obj_table))
return make_python_matrix(Matrix(0));
Matrix X = make_matrix(obj_table);
for(auto p = X.begin(); p != X.end(); ++p) {
for(auto q = p->begin(); q != p->end(); ++q)
cout << *q << ' ';
cout << endl;
}
return PyLong_FromLong(0);
}
static PyMethodDef mymatrixMethods[] = {
{ "print_matrix", print_matrix, METH_VARARGS },
{ "mul", mul, METH_VARARGS },
{ NULL }
};
static struct PyModuleDef mymatrix = {
PyModuleDef_HEAD_INIT, "mymatrix", "Python3 C API Module",
-1, mymatrixMethods
};
PyMODINIT_FUNC PyInit_mymatrix() {
return PyModule_Create(&mymatrix);
}
from itertools import *
from collections import Counter
import sys
import time
import mymatrix
def mat_pow(M, e, D):
if e == 1:
return M
elif e % 2 == 1:
return mymatrix.mul(M, mat_pow(M, e-1, D), D)
else:
A = mat_pow(M, e//2, D)
return mymatrix.mul(A, A, D)
def f(H, W, D):
def normalize(diagram):
min_l = min(diagram)
return tuple(row - min_l for row in diagram)
def nexts(diagram):
y = next(y for y, row in enumerate(diagram) if row == 0)
yield diagram[:y] + (2,) + diagram[y+1:]
if y+1 < H and diagram[y+1] == 0:
yield diagram[:y] + (1, 1) + diagram[y+2:]
def walk(diagram):
for diagram1 in nexts(diagram):
diagram2 = normalize(diagram1)
if diagram2 not in set_diagram:
set_diagram.add(diagram2)
walk(diagram2)
def make_graph(H):
graph = Counter()
for diagram in set_diagram:
for diagram1 in nexts(diagram):
graph[(diagram, normalize(diagram1))] += 1
return graph
def make_matrix(H):
graph = make_graph(H)
M = [ [ graph[(orig, dest)] for orig in set_diagram ]
for dest in set_diagram ]
row_to_index = dict((diagram, i) for i, diagram in enumerate(set_diagram))
return (M, row_to_index)
init_diagram = (0,)*H
set_diagram = set([init_diagram])
walk(init_diagram)
M, row_to_index = make_matrix(H)
print(len(M))
M_pow = mat_pow(M, W*H//2, D)
i = row_to_index[init_diagram]
return M_pow[i][i]
t0 = time.time()
D = 10**9+7
H = int(sys.argv[1])
W = int(sys.argv[2])
print(f(H, W, D))
print(time.time() - t0)