行列の偏微分

2021年2月1日

ディープラーニングの勉強をしていると行列の偏微分が出てきました。内容は次のような感じです。

\begin{equation*}
Y = XW + B
\end{equation*}

を$X$で偏微分すると

\begin{equation*}
\frac{\partial Y}{\partial X} = W^T
\end{equation*}

となるそうですが私にはよくわかりませんでした。なのでわからなかった所とそれなりに分かった所を書き記しておきます。

ニューラルネットワークの式

ニューラルネットワークの重みとバイアスを式で表すと、一般的に次のような式になります。

\begin{equation}
Y = XW + B
\end{equation}

ここでは$W$は重み、$B$はバイアス、$X$はベクトルを表しています。

それでは上記の式$Y$を$X$で偏微分したいと思いますが、その前に転置を軽く復習。

転置とは

まずはあるベクトルを行列で表してみます。

一般的にベクトルと言えば列ベクトルを表すようなので、それに則ってここでは列ベクトルで表します。

\begin{equation*}
X =
\begin{pmatrix}
a \\
b \\
c
\end{pmatrix}
\end{equation*}

この列ベクトルを転置してみると次のようになります。

\begin{equation*}
X^T =
\begin{pmatrix}
a & b & c
\end{pmatrix}
\end{equation*}

$X^T$は$X$の転置を表します。

次に2行2列の例を見てみましょう。

\begin{equation*}
X =
\begin{pmatrix}
a & b \\
c & d \\
\end{pmatrix}
\end{equation*}

これを転置すると

\begin{equation*}
X^T =
\begin{pmatrix}
a & c \\
b & d \\
\end{pmatrix}
\end{equation*}

となります。

ニューラルネットワークの式を偏微分

それでは上記(1)のニューラルネットワークの式を偏微分したいと思います。もう一度書いてみます。

\begin{equation*}
Y = XW + B
\end{equation*}

ここで$X$、$W$、$B$を次のように仮定します。

\begin{eqnarray*}
X &=&
\begin{pmatrix}
x_1 \\
x_2 \\
\end{pmatrix}
\\
W &=&
\begin{pmatrix}
a & c & e \\
b & d & f \\
\end{pmatrix}
\\
B &:&
定数
\end{eqnarray*}

この式を$X$で偏微分すると

\begin{equation}
\frac{\partial Y}{\partial X} = \frac{\partial XW}{\partial X} + \frac{\partial B}{\partial X}
\end{equation}

と書けます。ここで$X$は2行1列、$W$は2行3列で$XW$の計算を可能にするため$X$を転置します。

\begin{equation*}
X^T =
\begin{pmatrix}
x_1 & x_2
\end {pmatrix}
\end{equation*}

そして式(2)は

\begin{equation}
\frac{\partial Y}{\partial X} = \frac{\partial X^TW}{\partial X} + \frac{\partial B}{\partial X}
\end{equation}

となります。

ここで$X^TW$を整理しますと

\begin{eqnarray}
X^TW &=&
\begin{pmatrix}
x_1 & x_2
\end{pmatrix}
\begin{pmatrix}
a & c & e \\
b & d & f
\end{pmatrix}
\\
&=&
\begin{pmatrix}
ax_1+bx_2 & cx_1+bx_2 & ex_1+fx_2
\end{pmatrix}
\end{eqnarray}

となり$X^TW$を$X$で偏微分すると

\begin{eqnarray}
\frac{\partial X^TW}{\partial X}
&=&
\begin{pmatrix}
\frac{\partial (ax_1+bx_2)}{\partial x_1} & \frac{\partial (ax_1+bx_2)}{\partial x_2} \\
\frac{\partial (cx_1+dx_2)}{\partial x_1} & \frac{\partial (cx_1+dx_2)}{\partial x_2} \\
\frac{\partial (ex_1+fx_2)}{\partial x_1} & \frac{\partial (ex_1+fx_2)}{\partial x_2} \\
\end{pmatrix}
\\
&=&
\begin{pmatrix}
a & b \\
c & d \\
e & f
\end{pmatrix}
\\
&=&
W^T
\end{eqnarray}

そして$\frac{\partial B}{\partial X}$は$B$が定数なので0となり、式(3)は次のようになります。

\begin{equation}
\frac{\partial Y}{\partial X}
=
W^T
\end{equation}

めでたしめでたし。

でも自信はありません。

式(2)を式(3)と解釈していいものなのか、また(5)の式を偏微分すると(6)になるのがどうも自信がありません。でもこうすると正解っぽい。