numpyを使った特殊な行列演算

はじめに

chainerで新しいconnectionのファイルを作成する際、Numpyの演算で行き詰まった。そこで、Numpyの特殊な行列演算を実行する際にどうするか検討してみた。

問題設定

Figure 1のように3次元(0次元目〜2次元目)行列2つの行列演算を考える。 f:id:webfarmer:20161012143537p:plain 図左上が(a,b,c)行列、図右上が(a,c,d)行列。0次元目が一致する(b,c)と(c,d)のドット積を求め、図下(a,b,d)の3次元行列にしたい。

通常のdot()で計算すると4次元行列が生成される。

>>> a = np.arange(24).reshape(2,3,4)
>>> a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> b = np.arange(40).reshape(2,4,5)
>>> b
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]]])
>>> a.dot(b)
array([[[[  70,   76,   82,   88,   94],
         [ 190,  196,  202,  208,  214]],

        [[ 190,  212,  234,  256,  278],
         [ 630,  652,  674,  696,  718]],

        [[ 310,  348,  386,  424,  462],
         [1070, 1108, 1146, 1184, 1222]]],


       [[[ 430,  484,  538,  592,  646],
         [1510, 1564, 1618, 1672, 1726]],

        [[ 550,  620,  690,  760,  830],
         [1950, 2020, 2090, 2160, 2230]],

        [[ 670,  756,  842,  928, 1014],
         [2390, 2476, 2562, 2648, 2734]]]])

求めたいのは(a,b,d)の3次元行列。そこで例えばfor文を使ってみる。

>>> c = np.zeros(30).reshape(2,3,5)
>>> c
array([[[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]],

       [[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]]])
>>> for i in range(2):
...     c[i] = a[i].dot(b[i])
...
>>> c
array([[[   70.,    76.,    82.,    88.,    94.],
        [  190.,   212.,   234.,   256.,   278.],
        [  310.,   348.,   386.,   424.,   462.]],

       [[ 1510.,  1564.,  1618.,  1672.,  1726.],
        [ 1950.,  2020.,  2090.,  2160.,  2230.],
        [ 2390.,  2476.,  2562.,  2648.,  2734.]]])

確かにうまくいってる。