numpyを使った特殊な行列演算
はじめに
chainerで新しいconnectionのファイルを作成する際、Numpyの演算で行き詰まった。そこで、Numpyの特殊な行列演算を実行する際にどうするか検討してみた。
問題設定
Figure 1のように3次元(0次元目〜2次元目)行列2つの行列演算を考える。 図左上が(a,b,c)行列、図右上が(a,c,d)行列。0次元目が一致する(b,c)と(c,d)のドット積を求め、図下(a,b,d)の3次元行列にしたい。
通常のdot()で計算すると4次元行列が生成される。
>>> a = np.arange(24).reshape(2,3,4) >>> a array([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]) >>> b = np.arange(40).reshape(2,4,5) >>> b array([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]], [[20, 21, 22, 23, 24], [25, 26, 27, 28, 29], [30, 31, 32, 33, 34], [35, 36, 37, 38, 39]]]) >>> a.dot(b) array([[[[ 70, 76, 82, 88, 94], [ 190, 196, 202, 208, 214]], [[ 190, 212, 234, 256, 278], [ 630, 652, 674, 696, 718]], [[ 310, 348, 386, 424, 462], [1070, 1108, 1146, 1184, 1222]]], [[[ 430, 484, 538, 592, 646], [1510, 1564, 1618, 1672, 1726]], [[ 550, 620, 690, 760, 830], [1950, 2020, 2090, 2160, 2230]], [[ 670, 756, 842, 928, 1014], [2390, 2476, 2562, 2648, 2734]]]])
求めたいのは(a,b,d)の3次元行列。そこで例えばfor文を使ってみる。
>>> c = np.zeros(30).reshape(2,3,5) >>> c array([[[ 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0.]], [[ 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0.]]]) >>> for i in range(2): ... c[i] = a[i].dot(b[i]) ... >>> c array([[[ 70., 76., 82., 88., 94.], [ 190., 212., 234., 256., 278.], [ 310., 348., 386., 424., 462.]], [[ 1510., 1564., 1618., 1672., 1726.], [ 1950., 2020., 2090., 2160., 2230.], [ 2390., 2476., 2562., 2648., 2734.]]])
確かにうまくいってる。