PyTorch を用いて、手書き数字の分類を行ってみます。サポートベクターマシンを用いた場合は HOG などの特徴量を考える必要がありましたが、ディープラーニングでは十分な質の良いデータがあればその必要がありません。
MNIST データの読み込み
手書き数字のデータとして、MNIST データをダウンロードして利用することにします。Matplotlibで描画する例は以下のようになります。
# -*- coding: utf-8 -*-
import gzip
import pickle
import matplotlib.pyplot as plt
import torch
def Main():
# pickle 形式で保存されています。
with gzip.open('mnist.pkl.gz', 'rb') as f:
((xTrain, yTrain), (xValid, yValid), _) = pickle.load(f, encoding='latin-1')
# 28x28 ピクセルの画像データが 50000 枚分あります。
print(xTrain.shape) #=> (50000, 784)
# 描画してみます。
print(yTrain[0]) #=> 5
plt.imshow(xTrain[0].reshape(28, 28), cmap='gray')
plt.show()
# pytorch で利用するためには torch.tensor に変換します。
print(type(xTrain[0])) #=> <class 'numpy.ndarray'>
xTrain, yTrain, xValid, yValid = map(
torch.tensor, (xTrain, yTrain, xValid, yValid)
)
print(type(xTrain[0])) #=> <class 'torch.Tensor'>
if __name__ == '__main__':
Main()
ニューラルネットワークの定義
手書き数字の描かれた画像を分類するニューラルネットワークとして、ディープラーニングでよく利用される「畳み込みニューラルネットワーク (CNN; Convolutional Neural Network)」を用いてみます。torch.nn.Module
を継承したクラスを利用してネットワークを定義できます。
MNIST データの分類を考えたときには、以下のようなネットワーク定義となります。ただしこれは CNN の一つの例であり、一般形ではありません。
#!/usr/bin/python
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
def Main():
# MNIST データは RGB ではなくグレースケールです。
inChannels = 1
# 0 から 9 までの数字への分類を考えます。
outFeatures = 10
# MNIST データは 28x28 の画像です。
inputSize = 28
# ネットワークを定義します。
cnn = CNN(inChannels, outFeatures, inputSize)
# 一つのミニバッチに含まれるデータの個数
bs = 1
# 乱数で MNIST と同じサイズのデータを用意してみます。
x = torch.randn(bs, inChannels, inputSize, inputSize)
yPred = cnn(x)
print(cnn)
print(yPred.shape)
class CNN(nn.Module):
def __init__(self, inChannels, outFeatures3, inputSize):
super(CNN, self).__init__()
# 隠れ層の次元数など
outChannels = 6
kernelSize = 3
outChannels2 = 16
outFeatures = 120
outFeatures2 = 84
poolingStride = 2
sz = inputSize - kernelSize + 1
sz = sz // poolingStride
sz = sz - kernelSize + 1
sz = sz // poolingStride
self.__poolingStride = poolingStride
self.__conv1 = nn.Conv2d(inChannels, outChannels, kernelSize)
self.__conv2 = nn.Conv2d(outChannels, outChannels2, kernelSize)
self.__fc1 = nn.Linear(outChannels2 * sz * sz, outFeatures)
self.__fc2 = nn.Linear(outFeatures, outFeatures2)
self.__fc3 = nn.Linear(outFeatures2, outFeatures3)
def forward(self, x):
# 1 x 1 x 28 x 28
x = self.__conv1(x) #=> 1 x 6 x 26 x 26
x = F.relu(x) #=> 1 x 6 x 26 x 26
x = F.max_pool2d(x, self.__poolingStride) #=> 1 x 6 x 13 x 13
x = self.__conv2(x) #=> 1 x 16 x 11 x 11
x = F.relu(x) #=> 1 x 16 x 11 x 11
x = F.max_pool2d(x, self.__poolingStride) #=> 1 x 16 x 5 x 5
# note: 第一引数を -1 とすることで、第二引数の値から形状を推定させることができます。
x = x.reshape(-1, self.__GetNumFlatFeatures(x)) #=> 1 x 400
x = F.relu(self.__fc1(x)) #=> 1 x 120
x = F.relu(self.__fc2(x)) #=> 1 x 84
x = self.__fc3(x) #=> 1 x 10
return x
def __GetNumFlatFeatures(self, x):
size = x.size()[1:] # ミニバッチの個数の次元を除く、すべての次元
numFeatures = 1
for sz in size:
numFeatures *= sz
return numFeatures
if __name__ == '__main__':
Main()
Conv2d(inChannels, outChannels, kernelSize)
inChannels
方向には動かさず、画像の平面内で畳み込みを行います。この畳み込みを独立に outChannels
個のフィルタで行い、結果を一つのテンソルとしてまとめます。kernelSize
は OpenCV での畳み込み処理におけるカーネルと同じ概念です。
Linear(inFeatures, outFeatures)
線形変換です。重みとバイアスをパラメータとして持ちます。
max_pool2d(x, stride)
stride x stride
において最大となる値をフィルタします。
出力例
CNN(
(_CNN__conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
(_CNN__conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
(_CNN__fc1): Linear(in_features=400, out_features=120, bias=True)
(_CNN__fc2): Linear(in_features=120, out_features=84, bias=True)
(_CNN__fc3): Linear(in_features=84, out_features=10, bias=True)
)
torch.Size([1, 10])
分類問題で利用する損失関数について
回帰問題を扱う際に利用した平均二乗誤差は、分類問題ではそのまま利用できません。そのため、ここでは手書き数字の分類を行うために、以下の式で表される、交差エントロピー誤差という損失関数を用いてニューラルネットワークを学習します。
MNIST データには学習用のデータが 50000 枚あります。ディープラーニングでパラメータの学習のためにループを回す際に、利用可能なすべての学習用のデータを分割して、小さなバッチデータ毎にループを回す手法があります。本ページではミニバッチのサイズ を 64 として学習することにします。
ある一つの画像データをニューラルネットワークに入力として与えると、入力画像が 0-9 の数字である確率が、長さ 10 のベクトルとして出力されます。実際には 個のデータを一度に入力するため、このベクトルが 個出力されます。
交差エントロピー誤差では、10 の長さのベクトルのうち、例えば入力画像が 0 という数字であった場合は、最初の要素だけを取り出して対数を取ります。10 個の確率から一つのデータを取り出せるように は 0 または 1 の値を取ります。 個のデータについて同様の処理を行い、平均を計算したものが誤差となります。
例えば は 0 となるため、正しい分類ができている場合の誤差は 0 となります。
ソフトマックス関数について
交差エントロピー誤差を計算するためには、ニューラルネットワークの出力を確率として扱えるように変換する必要があります。PyTorch では交差エントロピー誤差を計算する関数 nn.CrossEntropyLoss の内部で、ソフトマックス関数 nn.Softmax を利用して出力を確率として扱えるように変換しています。
nn.CrossEntropyLoss の利用例
以下では nn.CrossEntropyLoss
で計算した誤差と、定義に基いて手動計算した誤差が一致することを確認しています。
import torch
import torch.nn as nn
lossFn = nn.CrossEntropyLoss()
N = 64
y = torch.empty(N, dtype=torch.long).random_(10)
yPred = torch.randn(N, 10)
loss = lossFn(yPred, y)
loss2 = 0.0
for j in range(N):
loss2 += -torch.log(torch.exp(yPred[j][y[j]]) / sum(torch.exp(yPred[j])))
loss2 /= N
print(loss) #=> 2.8512
print(loss2) #=> 2.8512
ニューラルネットワークの学習
上述の CNN と交差エントロピー誤差を用いて MNIST データの分類を試してみます。
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import gzip
import pickle
import matplotlib.pyplot as plt
def Main():
# MNIST データ
xTrain, yTrain, xValid, yValid = GetMnistData()
# CNN モデル
inChannels = 1
outFeatures = 10
inputSize = 28
model = CNN(inChannels, outFeatures, inputSize)
# 交差エントロピー誤差
lossFn = F.cross_entropy
# 学習率、学習の反復回数
learningRate = 0.001
iters = 10
# 最適化関数
optimizer = torch.optim.Adam(model.parameters(), lr=learningRate)
# ミニバッチのサイズ
bs = 64
trainDs = TensorDataset(xTrain, yTrain)
trainDl = DataLoader(trainDs, batch_size=bs, shuffle=True)
# 全体のループ
for t in range(iters):
# ミニバッチ毎のループ
for x, y in trainDl:
yPred = model(x)
loss = lossFn(yPred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 誤差の出力
print(t, loss.item())
# 学習済みモデルの検証
CheckOutput(model, xTrain, yTrain)
CheckOutput(model, xValid, yValid)
def CheckOutput(model, x, y):
yPred = map(lambda xx: xx.max(0).indices.item(), model(x))
wrong = 0
for xx, yy, yyPred in zip(x, y, yPred):
if yy.item() == yyPred:
continue
# print('{} != {}'.format(yy.item(), yyPred))
# plt.imshow(xx.reshape(28, 28), cmap='gray')
# plt.show()
wrong += 1
print('Accuracy: {}'.format(100 - wrong / len(x) * 100))
def GetMnistData():
with gzip.open('mnist.pkl.gz', 'rb') as f:
(xTrain, yTrain), (xValid, yValid), _ = pickle.load(f, encoding='latin-1')
xTrain = list(map(lambda x: x.reshape(1, 28, 28), xTrain))
xValid = list(map(lambda x: x.reshape(1, 28, 28), xValid))
return map(torch.tensor, (xTrain, yTrain, xValid, yValid))
class CNN(nn.Module):
def __init__(self, inChannels, outFeatures3, inputSize):
super(CNN, self).__init__()
outChannels = 6
kernelSize = 3
outChannels2 = 16
outFeatures = 120
outFeatures2 = 84
poolingStride = 2
sz = inputSize - kernelSize + 1
sz = sz // poolingStride
sz = sz - kernelSize + 1
sz = sz // poolingStride
self.__poolingStride = poolingStride
self.__conv1 = nn.Conv2d(inChannels, outChannels, kernelSize)
self.__conv2 = nn.Conv2d(outChannels, outChannels2, kernelSize)
self.__fc1 = nn.Linear(outChannels2 * sz * sz, outFeatures)
self.__fc2 = nn.Linear(outFeatures, outFeatures2)
self.__fc3 = nn.Linear(outFeatures2, outFeatures3)
def forward(self, x):
x = F.max_pool2d(F.relu(self.__conv1(x)), self.__poolingStride)
x = F.max_pool2d(F.relu(self.__conv2(x)), self.__poolingStride)
x = x.reshape(-1, self.__GetNumFlatFeatures(x))
x = F.relu(self.__fc1(x))
x = F.relu(self.__fc2(x))
x = self.__fc3(x)
return x
def __GetNumFlatFeatures(self, x):
size = x.size()[1:]
numFeatures = 1
for sz in size:
numFeatures *= sz
return numFeatures
if __name__ == '__main__':
Main()
実行例
0 0.10681144148111343
1 0.06327979266643524
2 0.040145404636859894
3 0.015086745843291283
4 0.005156606901437044
5 0.0018728474387899041
6 0.000745357247069478
7 0.0005938038229942322
8 5.65591617487371e-05
9 0.0003749439201783389
Accuracy: 99.456
Accuracy: 98.58
訓練用のデータで 99.456%、未知のデータで 98.58% となりました。分類に失敗したデータの例としては以下のようなものがあります。
2 と認識 (正しくは 3)
8 と認識 (正しくは 3)
6 と認識 (正しくは 5)
関連記事
- Python コードスニペット (条件分岐)if-elif-else sample.py #!/usr/bin/python # -*- coding: utf-8 -*- # コメント内であっても、ASCII外の文字が含まれる場合はエンコーディング情報が必須 x = 1 # 一行スタイル if x==0: print 'a' # 参考: and,or,notが使用可能 (&&,||はエラー) elif x==1: p...
- Python コードスニペット (リスト、タプル、ディクショナリ)リスト range 「0から10まで」といった範囲をリスト形式で生成します。 sample.py print range(10) # for(int i=0; i<10; ++i) ← C言語などのfor文と比較 print range(5,10) # for(int i=5; i<10; ++i) print range(5,10,2) # for(int i=5; i<10;...
- ZeroMQ (zmq) の Python サンプルコードZeroMQ を Python から利用する場合のサンプルコードを記載します。 Fixing the World To fix the world, we needed to do two things. One, to solve the general problem of "how to connect any code to any code, anywhere". Two, to wra...
- Matplotlib/SciPy/pandas/NumPy サンプルコードPython で数学的なことを試すときに利用される Matplotlib/SciPy/pandas/NumPy についてサンプルコードを記載します。 Matplotlib SciPy pandas [NumPy](https://www.numpy
- pytest の基本的な使い方pytest の基本的な使い方を記載します。 適宜参照するための公式ドキュメントページ Full pytest documentation API Reference インストール 適当なパッケージ