はじめに
前回はChatGPTにpytorchで簡単なAIを実装させ、モデルについて簡単に調査しました。
今回は前回のコードの中で学習部分について調査してみたいと思います。
pytorchの学習について調べてみた
ChatGPTに聞いたコード全体
まずは前回のコードの全体から
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
# Iris データセットの読み込み
iris = load_iris()
X, y = iris.data, iris.target
# データの前処理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# PyTorchのテンソルに変換
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)
# データローダーの作成
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# ニューラルネットワークの定義
class IrisNet(nn.Module):
def __init__(self):
super(IrisNet, self).__init__()
self.fc1 = nn.Linear(4, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 3)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
model = IrisNet()
# 損失関数とオプティマイザ
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# トレーニングプロセス
def train(model, criterion, optimizer, train_loader, epochs=100):
for epoch in range(epochs):
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")
train(model, criterion, optimizer, train_loader)
# テストデータで評価
with torch.no_grad():
outputs = model(X_test)
_, predicted = torch.max(outputs, 1)
accuracy = (predicted == y_test).sum().item() / y_test.size(0)
accuracy
この中から今回は学習に焦点を当ててみたいと思います。該当箇所がこちら
model = IrisNet()
# 損失関数とオプティマイザ
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# トレーニングプロセス
def train(model, criterion, optimizer, train_loader, epochs=100):
for epoch in range(epochs):
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")
train(model, criterion, optimizer, train_loader)
AIの学習方法の復習
AIの学習について復習です。
学習では入力層から出力層にかけて順伝播を行い、損失関数を計算、損失関数を基に逆伝播を行いモデルを最適化という流れでした。
pytorchではどうでしょうか?確認していきます。
順伝播
まずは順伝播。
順伝播は前回のモデルでforwardという言葉が合った通り、
model = IrisNet()
で順伝播を行っているのでしょう。ただ、入力層に何も入力されていないので、
outputs = model(inputs)
ここで本格的に順伝播を行っているかもしれません。
損失関数
次に損失関数の計算です。これは
criterion = nn.CrossEntropyLoss()
でCrossEntropyLossという損失関数を定義していますね。クロスエントロピーと呼ぶらしいですが、その正体が分かりません。ちょっと調べてみます。
数式がゴリゴリ書いてある記事がたくさんありますが、こちらが比較的分かりやすかったです。
このクロスエントロピーは分類問題に適した関数らしいので、今回のあやめの分類などには向いているのでしょう。
実際の計算は
loss = criterion(outputs, targets)
ここで行っているようです。
outputsは順伝播で計算された結果でしょう。targetsは正解データだと思われます。この二つを使って損失関数を計算しています。
逆伝播
そして、この損失関数を基に逆伝播を行いますが、コードでは
loss.backward()
これで行っていそうです。なぜならbackwardと書いてあるから…。
最適化
そして、最適化ですね。それはきっと
optimizer = optim.Adam(model.parameters(), lr=0.001)
が最適化関数で、実際に実行しているのは
optimizer.step()
ここでしょう。
pytorchにおける学習はもう理論通りというかまんまでしたね。
未解決なものを調べる
あと未解決なのは
for inputs, targets in train_loader:
と
optimizer.zero_grad()
ぐらいでしょうか?epochはミニバッチ学習法で紹介しました。
では、さて一つ目のtrain_loaderから何か二つ取り出してますね。train_loaderは前々回の前処理で調べました。入力層に代入するデータと正解データのセットと作っていました。
なので、取り出しているのは入力層に代入するデータと正解データなんでしょう。
実際にprintで見てみると、
print結果
tensor([[ 0.4950, 0.8060, 1.0038, 1.5132],
[ 0.1407, -1.9840, 0.6662, 0.3475],
[-0.0954, -0.8215, 0.7225, 0.8656],
[-1.0401, 1.0385, -1.2466, -0.8181],
[ 0.6130, -0.5890, 1.0038, 1.1246],
[-0.9220, 1.7360, -1.3029, -1.2066],
[ 0.4950, -1.7515, 0.3287, 0.0885],
[-0.2135, -0.5890, 0.3849, 0.0885]]) tensor([2, 2, 2, 0, 2, 0, 1, 1])
tensor([[-1.3943, 0.3410, -1.2466, -1.3362],
[-0.5678, -0.1240, 0.3849, 0.3475],
[-0.9220, -1.2865, -0.4590, -0.1705],
[ 0.9673, -1.2865, 1.1163, 0.7361],
[ 1.0853, -0.1240, 0.9475, 1.1246],
[ 0.7311, -0.1240, 1.1163, 1.2541],
[ 0.4950, -0.5890, 0.7225, 0.3475],
[ 0.0226, -0.1240, 0.7225, 0.7361],
[-0.5678, 0.8060, -1.1904, -1.3362],
[-1.8666, -0.1240, -1.5279, -1.4657],
[-1.0401, 1.2710, -1.3592, -1.3362],
[ 0.7311, -0.1240, 0.9475, 0.7361],
[-0.8039, 1.0385, -1.3029, -1.3362],
[ 0.2588, -0.1240, 0.4412, 0.2180],
[ 0.1407, -0.1240, 0.5537, 0.7361],
[-0.2135, -1.0540, -0.1777, -0.3000]]) tensor([0, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 2, 0, 1, 2, 1])
tensor([[-0.9220, 0.5735, -1.1904, -0.9476],
[ 0.6130, -0.5890, 1.0038, 1.2541],
[ 1.2034, 0.1085, 0.8913, 1.1246],
[-0.3316, -0.1240, 0.1599, 0.0885],
[-1.0401, 0.5735, -1.3592, -1.3362],
[-1.2762, -0.1240, -1.3592, -1.4657],
[ 0.4950, -0.8215, 0.6100, 0.7361],
[-0.9220, 0.8060, -1.3029, -1.3362],
[-0.0954, -0.8215, 0.1599, -0.3000],
[-1.0401, 1.0385, -1.2466, -0.8181],
[ 2.0300, -0.1240, 1.5664, 1.1246],
[ 1.2034, 0.1085, 0.6100, 0.3475],
[-0.2135, -0.3565, 0.2161, 0.0885],
[-0.9220, 1.7360, -1.3029, -1.2066],
[ 1.5577, 1.2710, 1.2851, 1.6427],
[-0.8039, 2.4335, -1.3029, -1.4657]]) tensor([0, 2, 2, 1, 0, 0, 2, 0, 1, 0, 2, 1, 1, 0, 2, 0])
tensor([[ 0.9673, 0.5735, 1.0600, 1.6427],
[-0.2135, -1.2865, 0.6662, 0.9951],
[ 0.3769, -1.9840, 0.3849, 0.3475],
[-0.4497, -1.5190, -0.0652, -0.3000],
[-0.4497, -1.5190, -0.0089, -0.1705],
[-1.7486, -0.1240, -1.4154, -1.3362],
[ 0.6130, -0.5890, 1.0038, 1.1246],
[-1.0401, -0.1240, -1.2466, -1.3362],
[-1.2762, 0.1085, -1.2466, -1.3362],
[ 0.6130, -0.3565, 0.2724, 0.0885],
[-0.8039, 0.8060, -1.3592, -1.3362],
[ 0.1407, -1.9840, 0.6662, 0.3475],
[ 0.1407, -0.8215, 0.7225, 0.4770],
[-0.6858, 1.5035, -1.3029, -1.3362],
[ 1.5577, -0.1240, 1.1163, 0.4770],
[-0.2135, -0.1240, 0.2161, -0.0410]]) tensor([2, 2, 1, 1, 1, 0, 2, 0, 0, 1, 0, 2, 1, 0, 2, 1])
tensor([[ 0.9673, 0.1085, 1.0038, 1.5132],
[ 0.8492, -0.1240, 0.3287, 0.2180],
[-0.4497, 1.0385, -1.4154, -1.3362],
[ 0.9673, -0.1240, 0.6662, 0.6066],
[ 2.1481, -0.5890, 1.6226, 0.9951],
[-1.1582, -0.1240, -1.3592, -1.3362],
[-0.3316, -0.3565, -0.1214, 0.0885],
[-0.0954, -1.0540, 0.1036, -0.0410],
[ 0.4950, -0.3565, 1.0038, 0.7361],
[-0.9220, 1.0385, -1.3592, -1.3362],
[-1.1582, 0.1085, -1.3029, -1.4657],
[ 0.4950, -1.2865, 0.6100, 0.3475],
[-0.9220, 1.5035, -1.3029, -1.0771],
[-0.0954, -0.8215, 0.0474, -0.0410],
[ 0.4950, -1.2865, 0.6662, 0.8656],
[-1.1582, -1.2865, 0.3849, 0.6066]]) tensor([2, 1, 0, 1, 2, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 2])
tensor([[-0.3316, -0.8215, 0.2161, 0.0885],
[-1.1582, -1.5190, -0.2902, -0.3000],
[-0.2135, 1.7360, -1.1904, -1.2066],
[ 0.8492, -0.3565, 0.4412, 0.0885],
[ 0.7311, 0.3410, 0.7225, 0.9951],
[-0.0954, -0.8215, 0.7225, 0.8656],
[ 0.9673, 0.5735, 1.0600, 1.1246],
[-0.3316, -1.2865, 0.0474, -0.1705],
[ 0.6130, 0.1085, 0.9475, 0.7361],
[-1.5124, 0.3410, -1.3592, -1.3362],
[-0.4497, -1.7515, 0.1036, 0.0885],
[ 0.6130, -0.8215, 0.8350, 0.8656],
[ 0.4950, -1.7515, 0.3287, 0.0885],
[ 0.9673, 0.1085, 0.3287, 0.2180],
[-1.5124, 0.8060, -1.3592, -1.2066],
[ 1.2034, 0.3410, 1.0600, 1.3836]]) tensor([1, 1, 0, 1, 2, 2, 2, 1, 2, 0, 1, 2, 1, 1, 0, 2])
tensor([[-1.0401, -2.4490, -0.1777, -0.3000],
[ 0.0226, -0.1240, 0.2161, 0.3475],
[-0.5678, 0.8060, -1.3029, -1.0771],
[-0.0954, 2.2010, -1.4717, -1.3362],
[-0.0954, -0.8215, 0.7225, 0.8656],
[ 1.2034, 0.1085, 0.7225, 1.3836],
[-0.4497, -1.2865, 0.1036, 0.0885],
[ 1.4396, -0.1240, 1.1726, 1.1246],
[-1.2762, 0.8060, -1.2466, -1.3362],
[ 1.0853, 0.3410, 1.1726, 1.3836],
[ 0.9673, -0.1240, 0.7787, 1.3836],
[-1.5124, 0.1085, -1.3029, -1.3362],
[-0.2135, -0.5890, 0.1599, 0.0885],
[-1.6305, -1.7515, -1.4154, -1.2066],
[ 2.1481, -0.1240, 1.2851, 1.3836],
[ 1.3215, 0.3410, 0.4974, 0.2180]]) tensor([1, 1, 0, 0, 2, 2, 1, 2, 0, 2, 2, 0, 1, 0, 2, 1])
tensor([[ 0.7311, -0.1240, 0.7787, 0.9951],
[ 0.2588, -0.1240, 0.6100, 0.7361],
[ 1.5577, 0.3410, 1.2288, 0.7361],
[ 0.6130, 0.3410, 0.8350, 1.3836],
[-1.1582, 0.1085, -1.3029, -1.3362],
[ 0.4950, 0.8060, 1.0038, 1.5132],
[-0.2135, 3.1310, -1.3029, -1.0771],
[-1.0401, 0.3410, -1.4717, -1.3362],
[-0.8039, -0.8215, 0.0474, 0.2180],
[ 2.1481, -1.0540, 1.7352, 1.3836],
[ 0.3769, 0.8060, 0.8913, 1.3836],
[-0.5678, 1.5035, -1.3029, -1.3362],
[ 2.1481, 1.7360, 1.6226, 1.2541],
[ 1.7938, -0.5890, 1.2851, 0.8656],
[ 2.3842, 1.7360, 1.4539, 0.9951],
[-1.7486, 0.3410, -1.4154, -1.3362]]) tensor([2, 2, 2, 2, 0, 2, 0, 0, 1, 2, 2, 0, 2, 2, 2, 0])
tensor([[ 0.0226, 0.3410, 0.5537, 0.7361],
[ 0.4950, 0.5735, 0.4974, 0.4770],
[ 0.1407, 0.8060, 0.3849, 0.4770],
[-0.9220, 1.0385, -1.3592, -1.2066],
[-0.9220, 1.7360, -1.2466, -1.3362],
[-0.5678, 1.9685, -1.1904, -1.0771],
[-1.3943, 0.3410, -1.4154, -1.3362],
[-0.2135, -0.5890, 0.3849, 0.0885]]) tensor([1, 1, 1, 0, 0, 0, 0, 1])
めちゃめちゃ長かったので、一部省力してますが、思惑通りの結果になってます。pytorchはミニバッチ学習法はすごく直観的で分かりやすいですね。
最後の疑問です。
optimizer.zero_grad()
こちらに非常に詳しく書かれていました。最適化するときに正しく計算するためのようですね。
これだけはない方が分かりやすかったな。と感想。
これでpytorchで学習する方法が分かりました。
最後に
とりあえず、pytorchの流れが分かり、研究開発を進めることができそうです。次回から本格的なAIを作っていくぞー。
コメント