社畜エンジニア発掘戦線

駆けだしAIエンジニア

MNIST手書き数字の画像識別(学習の評価)

前回はシンプルなニューラルネットワークを構築し、MNISTの訓練データを使って学習を行いました。その正解率(精度)は80%を超える結果となり、おおすげぇ、と満足しています。しかしながら、ここまでは訓練データを使った正解率です。7万枚の画像のうち、6万枚を使って訓練を行ったので、次は残りの1万枚でどれだけの正解率を出せるのか評価してみようと思います。

CONTENTS

パラメータの保存(pickle形式)

前回の問題でイテレーションにより更新されたパラメータ:WとBがこのニューラルネットワークの「賢さ」になります。つまり、学習(イテレーション)が終了した時点でこのWとBを保存しておかなければいけません。幸いにも、MNISTの画像をダウンロードするときにpickle形式で保存する方法を勉強しました。このpickle形式で、学習が完了したWとBを保存したいと思います。

前回のプログラムのコードの最後に、パラメータを保存するコードを追加します。

#Save parameters
parameters = [W,B]

dataset_dir = '/Users/FUTOSHI/Desktop/MNIST_test'
save_file = dataset_dir + '/perceptron.pkl'

with open(save_file, 'wb') as f:
    pickle.dump(parameters, f) 

MNISTのデータダウンロードで一度やっているのでサラッといきましょう。まずは保存したいパラメータ:WとBをリストで保存します。それから、保存先ディレクトリと保存ファイル名でパスを指定して、保存を実行する、という流れです。

詳しくは以前の記事をご参考。
MNIST手書き数字の画像識別(データ・セット) - 社畜エンジニア発掘戦線

プログラムを実行し、これで指定したディレクトリに「perceptron.pkl」という、学習済みのWとBが詰まったファイルが新しく追加されました。perceptronという名前に深い意味はないです、なんかperceptronっぽいってだけでこの名前にしただけです笑

全体コード
github.com


データセット、及びパラメータのロード

データが保存できたので、評価用のプログラムファイルを新しく作成します。ここから、評価用のプログラムコードを書いていきます。まずはテスト用のMNISTデータセットと上で保存したパラメータWとBをロードするコードを書きます。

#load dataset
import pickle

#MNIST_dataset
save_file = '/Users/FUTOSHI/Desktop/MNIST_test/mnist.pkl'

with open(save_file, 'rb') as f:
    dataset = pickle.load(f)

train_img,train_label,test_img,test_label = dataset

#Parameters
save_file = '/Users/FUTOSHI/Desktop/MNIST_test/perceptron.pkl'

with open(save_file, 'rb') as f:
    parameters = pickle.load(f)

W, B = parameters

こちらもこれまでと同じ内容でデータをロードしています。MNISTのデータについては、「save_file」でロードしたいディレクトリとファイルをパス指定して、ロードを実行、「train_img,train_label,test_img,test_label = dataset」でデータをパラメータに展開しています。Parameters(WとB)のデータについても同様です。

これでMNISTのテストデータと学習済のパラメータをロードできました。

評価プログラムの実装

では具体的に評価を実行するコードを書いていきたいと思います。何はともあれ、まずはライブラリのインポートからスタート。特に新しいライブラリは使わないので、前回分をコピペです。

#library import
import numpy as np
import matplotlib.pyplot as plt
import time

次に、MNISTのデータセットを規格化したり、One-Hotにしたりの前処理をします。今回はテストデータだけでいいんですが、まぁコピペしたってことで訓練用も入ってます。

#pixel normalization
X_train, X_test = train_img/255, test_img/255

#transform_OneHot
T_train = np.eye(10)[list(map(int,train_label))] 
T_test = np.eye(10)[list(map(int,test_label))]

次にニューラルネットの計算に使う活性化関数をコードに落としておきます。今回、学習はすでに終えているので逆伝搬計算は不要です。そのため、順伝搬計算に使うソフトマックス関数だけで大丈夫です。

#function
def softmax(z):
    return np.exp(z) / np.sum(np.exp(z), axis=1, keepdims=True)

あと、計算した画像が「正解グループ」と「間違いグループ」にそれぞれ振り分けられるように空の配列を準備しておきます。

correct_img = np.empty((0,784))
error_img = np.empty((0,784))

計算結果が合っていれば「correct_img」に、間違っていれば「error_img」にその画像が保存されるようにします。

では、画像をニューラルネットに突っ込む処理を書いていきます。

start_time = time.time()
num_of_test = X_test.shape[0]
for i in range(num_of_test):

1万枚の画像を繰り返し処理するのでfor文を使います。繰り返し回数には直接10000って入れてもいいんですが、なんとなく汎用性を持たせたくて「X_test.shape[0] #10000」にしています。

計算する内容は順伝搬計算だけです

    #forward prop
    Y = softmax(np.dot(X_test[i], W)+B)

さて、ニューラルネットは頑張って計算してくれるので、その結果を判定する内容を書いていきます。

    #set accuracy
    Y_accuracy = np.argmax(Y)
    T_accuracy = np.argmax(T_test[i])

    if Y_accuracy == T_accuracy:
        correct_img = np.vstack((correct_img,X_test[i]))

    else:
        error_img = np.vstack((error_img,X_test[i]))

計算結果(Y)と正解ラベル(T)の最大値となる番号を比較して、一致していれば「correct_img」、不一致なら「error_img 」に追加する、という処理内容です。

あと、計算時間を測定するコードも貼り付けておきます。

end_time = time.time()

これで計算すべき内容は完了です。後は欲しい情報を出力する処理を書いておきます。このテストで欲しい情報はこんな感じ。
・正解数(correct_imgの型)
・間違い数(error_imgの型)
・正解率(正解数/10000 [%])
・計算時間

特にグラフで出力したい内容はないので全部printでターミナル上に出力させます。

print(correct_img.shape)
print(error_img.shape)
print(100*correct_img.shape[0]/num_of_test)
print(end_time - start_time)

これでコードの記述はすべて完了です。

プログラムの実行

コードも書けたことなので、このプログラムを実行してみます。

> (8664, 784)
> (1336, 784)
> 86.64
> 211.79083514213562

おお、うまくいきました、まぁそんなに複雑な計算でもないし。10000枚中、正解は8664枚ということで、正解率は86%となりました。だいたい訓練データと一致していますね。せっかくピックアップしてるので、間違い画像がどんな数字なのか確認してみたいんですが、まだ1336枚もあるのでもうちょっと正解率が上がってきたら具体的にチェックしてみたいと思います。

ただ、計算時間が200秒を超えてきました、だいたい3分半。けっこう時間かかったな、学習は2秒ぐらいで終わるのに、そのテストは100倍かかるのか、ふーん、そんなもんなのか。

まとめ

今回は学習が完了した時点のパラメータを保存して、そのパラメータを使って実際にこのニューラルネットワークがどの程度の実力を持っているのかを評価しました。その結果は概ね訓練データと一致しており、MNISTを識別するニューラルネットワークの評価ベースができあがった、というところでしょうか。ここから、ニューラルネットワークをいろいろと工夫して、どのように正解率が向上していくのか見ていきたいと思います。

次回はこのニューラルネットワークに隠れ層を追加してみます。

全体コード
github.com

元の記事
週末のディープラーニング - 社畜エンジニア発掘戦線

Twitter
世界の社畜 (@sekai_syachiku) | Twitter