CNN:逆伝搬計算
おつかれさまです。
前回はCNN(畳み込みニューラルネットワーク)についてアルゴリズムと順伝搬計算を行うコードを書いていきました。今回は、その誤差から微分値を計算し、逆伝搬でパラメータを更新してみたいと思います。
CNN逆伝搬計算の概要
今回のニューラルネットワークの構成をもう一度確認してみます。
ここの計算は、今まで隠れ層の出力HがCNN層のCに置き換わっただけで、特に問題はないですね。
次に、CNN層の微分を見てみます。考え方は今までと同じく、3段連鎖律で骨格となる微分式を求めます。
後ろ2つの項はsoftmax関数とクロスエントロピーの合成微分なので、今まで通りΔを使います。
一番前の項ですが、式をそのまま書き下してみます。
合成関数の微分ってことで、とりあえず中の「X_batch*F+Bc」をF微分したX_batchは外に出てきそうです。
んで、ReLU関数の微分です。もともと、入力が0以下であれば全部「0」、ずっと同じ値ということは傾きはゼロなので微分しても「0」。0より大きい値は入力値をそのまま返すので、「f(x)=x」となり、これを微分すれば全部1です。
つまり、「ReLU(X_batch*F+Bc)」の「X_batch*F+Bc」がマイナスであれば「0」、プラスなら「1」と処理してあげれば良さそうです。(ただ、正直この辺あんまよく分からないんですよね〜合ってんのかな?)
とりあえず、微分項の処理内容が見えてきたのでまとめてみます。
ReLUの微分はどう書けばいいのか分からんかったので「ReLU'」としています、0と1のやつです。
順伝搬を思い出してみるとFは畳み込みの計算に用いました。今はその逆をたどっているので、逆畳み込み演算が必要です。といっても、「X_batch」と計算された「ReLU’*Wo*Δ」で順伝搬のときに計算した畳み込み計算を行っているだけですが。
次にバイアスです、こちらも同じように微分を行います。
後ろの2項はFと同じですね、前の項もBcでの微分なので1です。ただしReLUの微分は残っています。
つまり、結果をまとめると、
確かに順伝搬のとき、Bcは畳み込み演算には関与していないので、Fのようにややこしい逆畳み込みはありません。なんか、よくできていますね、数学的にあたりまえなんでしょうけど。
コーディング(逆伝搬:出力層Fの微分)
さてさて、逆伝搬の処理が分かってきたので、この処理をコードに落としていきたいと思います。
前回の続きから、まずは出力層の微分値を求めます。
(forの中なので、その場所に応じて各コードはインテンドされています)
dWo = np.dot(FM_batch.T,(Y-T_batch)) dBo = np.reshape(np.sum(Y-T_batch, axis=0),(1,10))
ここは今まで通り、上の式をコードに落としただけです。CのところはFM_batchになっています。
コーディング(逆伝搬:CNN層の微分)
次に、Fの微分です。まずは「Wo*Δ」の部分をコードで書いておきます。
np.dot(Y-T_batch,Wo.T)
ReLU関数の部分の微分も書いてみます。ReLU関数の中身は「X_batch*F+Bc」でした。前回のコードで畳み込み計算をしたときにこの値を求めました。しかし、その値を保存せずに次の処理に進んでしまったのでこのままでは使えません。
もうあと付け感がハンパないですが、前回のコードに「X_batch*F+Bc」を保存するコードを加えます。
FM_bias_batch = np.empty((0,FM_size**2))
バッチごとに保存したいのでこれは「for M…」の外に記述します。欲しい値は「FM_bias = FM_storage+Bc」で計算されているので、この値を「FM_bias_batch」に保存するコードを足します。
FM_bias_batch = np.vstack((FM_bias_batch,FM_bias))
まー自分で書いててなんですが、言葉書きだとわけわからんので、また後で全体コード確認して下さい。。
で、この「FM_bias_batch」を使ってReLUの微分を書きます。ReLUの微分は、0以下が「0」で、0より大きければ「1」となります。ここもwhereを使って記述します。
np.where(FM_bias_batch<=0,0,1)
ここまで書ければ「ReLU'*Wo*Δ」をまとめてdeltaとして書いてしまいます。
delta = np.where(FM_bias_batch<=0,0,1)*np.dot(Y-T_batch,Wo.T)
これで「dF = X_batch*ReLU'*Wo*Δ」を「X_batch*delta」とまとめることができました。あとはこの2つの畳込み計算です。
畳み込み計算の流れは前回とほとんど同じです、まぁ逆方向なのでちょっと違いますが。まずは画像ごとに計算されたdFを保存するハコを準備します。このハコには各画像のdFがバッチ数だけ保存されます。
dF_batch = np.empty((0,F_size**2))
次はバッチごとの処理です。
for M in range (batch_size): img = np.reshape(X_batch[M],(IMG_size,IMG_size)) img_delta = np.reshape(delta[M],(FM_size,FM_size)) dF_storage = []
ここの処理で畳み込むのはX_batchとdeltaなので、この2つを2次元にreshapeします。
そして、画像の要素ごとのかけ算です。forのところで縦と横の計算回数が「F_size(=5)」になっているのと、X_batchからピックアップする画像サイズは「FM_size」になっているところだけ注意、あとは一緒かな。
for i in range(F_size): for j in range(F_size): pick_img = img[i:i+FM_size, j:j+FM_size] dF_storage = np.append(dF_storage,np.tensordot(img_delta,pick_img))
配列で保存されたdFを最初に設定したハコ「dF_batch」にnp.vstack詰め込みます。
dF_batch = np.vstack((dF_batch,dF_storage))
この時点でdF_batchの型は(batch_size, 25)になっています。最後にFを更新するためには(5, 5)のdFが必要なので、dF_batchを(5, 5)にreshapeします。ただ、dF_batchにはバッチ数だけ計算結果が詰まっているのでこれらを平均化してdFとします。reshapeとaverageをまとめて書きます。
dF = np.reshape(np.average(dF_batch,axis=0),(F_size,F_size))
これでようやくdFが求まりました。
コーディング(逆伝搬:出力層Bcの微分)
dBcは「1*ReLU'*Wo*Δ」になるのでdeltaがそのまま求める値になります。ただし、こちらもバッチ数だけデータが詰まっているので平均化してあげる必要があります。
dBc = np.reshape(np.average(delta,axis=0),(1,FM_size**2))
これでBcも更新できる形に整いました。
コーディング(パラメータの更新)
これですべてのパラメータの微分がそろったので、勾配降下法でパラメータを更新します。
Wo = function1.update(Wo,learning_rate,dWo) Bo = function1.update(Bo,learning_rate,dBo) F = function1.update(F,learning_rate,dF) Bc = function1.update(Bc,learning_rate,dBc)
微分さえ求められればなんてことないですね。
コーディング(残りのモロモロ)
まずは学習にかかる時間の計算、
end_time = time.time()
total_time = end_time - start_time
print(total_time)
ほんでからグラフ化、
#show graph
function1.plot_acc(accuracy_save)
function1.plot_loss(E_save)
最後に学習後のパラメータ保存、
#パラメータ保存 CNN_parameters = [Wo,Bo,F,Bc] path = '/Users/FUTOSHI/Desktop/cnn_test' name = '/CNN_ver6' function1.save(CNN_parameters,name,path)
ここは外部ファイルの関数を使ってちょちょっと書いてしまいます。とりあえずここまででコーディングは完了です。はぁ、おつかれさまでした。
まとめ
前回に引き続いてCNNのアルゴリズムとコードをまとめていきました。
なんかいろいろ調べてみたりもしたんですが、逆伝搬についてはあんまり詳しく書かれているサイトとか書籍とかなかったんで、自己流?というか今までのニューラルネットワークを構築した流れに沿って進めていきました。なので、正直この処理が正しいのかよく分かりません。
ミスってたり、抜けがあれば即効修正しますお。
次回はこのニューラルネットワークで学習、あとその精度を評価してみます。学習、時間かかりそうやなぁ。
全体コード:CNN
github.com
全体コード:function
github.com