AIプログラミング初心者がChatGPTにpytorchで簡単なAIを実装させてみた。学習編

はじめに

前回はChatGPTにpytorchで簡単なAIを実装させ、モデルについて簡単に調査しました。

今回は前回のコードの中で学習部分について調査してみたいと思います。

pytorchの学習について調べてみた

ChatGPTに聞いたコード全体

まずは前回のコードの全体から

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset

# Iris データセットの読み込み
iris = load_iris()
X, y = iris.data, iris.target

# データの前処理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# PyTorchのテンソルに変換
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)

# データローダーの作成
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# ニューラルネットワークの定義
class IrisNet(nn.Module):
    def __init__(self):
        super(IrisNet, self).__init__()
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

model = IrisNet()

# 損失関数とオプティマイザ
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# トレーニングプロセス
def train(model, criterion, optimizer, train_loader, epochs=100):
    for epoch in range(epochs):
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")

train(model, criterion, optimizer, train_loader)

# テストデータで評価
with torch.no_grad():
    outputs = model(X_test)
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == y_test).sum().item() / y_test.size(0)

accuracy

この中から今回は学習に焦点を当ててみたいと思います。該当箇所がこちら

model = IrisNet()

# 損失関数とオプティマイザ
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# トレーニングプロセス
def train(model, criterion, optimizer, train_loader, epochs=100):
    for epoch in range(epochs):
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")

train(model, criterion, optimizer, train_loader)

AIの学習方法の復習

AIの学習について復習です。

学習では入力層から出力層にかけて順伝播を行い、損失関数を計算、損失関数を基に逆伝播を行いモデルを最適化という流れでした。

pytorchではどうでしょうか?確認していきます。

順伝播

まずは順伝播。

順伝播は前回のモデルでforwardという言葉が合った通り、

model = IrisNet()

で順伝播を行っているのでしょう。ただ、入力層に何も入力されていないので、

outputs = model(inputs)

ここで本格的に順伝播を行っているかもしれません。

損失関数

次に損失関数の計算です。これは

criterion = nn.CrossEntropyLoss()

でCrossEntropyLossという損失関数を定義していますね。クロスエントロピーと呼ぶらしいですが、その正体が分かりません。ちょっと調べてみます。

数式がゴリゴリ書いてある記事がたくさんありますが、こちらが比較的分かりやすかったです。

このクロスエントロピーは分類問題に適した関数らしいので、今回のあやめの分類などには向いているのでしょう。

実際の計算は

loss = criterion(outputs, targets)

ここで行っているようです。

outputsは順伝播で計算された結果でしょう。targetsは正解データだと思われます。この二つを使って損失関数を計算しています。

逆伝播

そして、この損失関数を基に逆伝播を行いますが、コードでは

loss.backward()

これで行っていそうです。なぜならbackwardと書いてあるから…。

最適化

そして、最適化ですね。それはきっと

optimizer = optim.Adam(model.parameters(), lr=0.001)

が最適化関数で、実際に実行しているのは

optimizer.step()

ここでしょう。

pytorchにおける学習はもう理論通りというかまんまでしたね。

未解決なものを調べる

あと未解決なのは

for inputs, targets in train_loader:

optimizer.zero_grad()

ぐらいでしょうか?epochはミニバッチ学習法で紹介しました。

では、さて一つ目のtrain_loaderから何か二つ取り出してますね。train_loaderは前々回の前処理で調べました。入力層に代入するデータと正解データのセットと作っていました。

なので、取り出しているのは入力層に代入するデータと正解データなんでしょう。

実際にprintで見てみると、

print結果
tensor([[ 0.4950,  0.8060,  1.0038,  1.5132],
        [ 0.1407, -1.9840,  0.6662,  0.3475],
        [-0.0954, -0.8215,  0.7225,  0.8656],
        [-1.0401,  1.0385, -1.2466, -0.8181],
        [ 0.6130, -0.5890,  1.0038,  1.1246],
        [-0.9220,  1.7360, -1.3029, -1.2066],
        [ 0.4950, -1.7515,  0.3287,  0.0885],
        [-0.2135, -0.5890,  0.3849,  0.0885]]) tensor([2, 2, 2, 0, 2, 0, 1, 1])
tensor([[-1.3943,  0.3410, -1.2466, -1.3362],
        [-0.5678, -0.1240,  0.3849,  0.3475],
        [-0.9220, -1.2865, -0.4590, -0.1705],
        [ 0.9673, -1.2865,  1.1163,  0.7361],
        [ 1.0853, -0.1240,  0.9475,  1.1246],
        [ 0.7311, -0.1240,  1.1163,  1.2541],
        [ 0.4950, -0.5890,  0.7225,  0.3475],
        [ 0.0226, -0.1240,  0.7225,  0.7361],
        [-0.5678,  0.8060, -1.1904, -1.3362],
        [-1.8666, -0.1240, -1.5279, -1.4657],
        [-1.0401,  1.2710, -1.3592, -1.3362],
        [ 0.7311, -0.1240,  0.9475,  0.7361],
        [-0.8039,  1.0385, -1.3029, -1.3362],
        [ 0.2588, -0.1240,  0.4412,  0.2180],
        [ 0.1407, -0.1240,  0.5537,  0.7361],
        [-0.2135, -1.0540, -0.1777, -0.3000]]) tensor([0, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 2, 0, 1, 2, 1])
tensor([[-0.9220,  0.5735, -1.1904, -0.9476],
        [ 0.6130, -0.5890,  1.0038,  1.2541],
        [ 1.2034,  0.1085,  0.8913,  1.1246],
        [-0.3316, -0.1240,  0.1599,  0.0885],
        [-1.0401,  0.5735, -1.3592, -1.3362],
        [-1.2762, -0.1240, -1.3592, -1.4657],
        [ 0.4950, -0.8215,  0.6100,  0.7361],
        [-0.9220,  0.8060, -1.3029, -1.3362],
        [-0.0954, -0.8215,  0.1599, -0.3000],
        [-1.0401,  1.0385, -1.2466, -0.8181],
        [ 2.0300, -0.1240,  1.5664,  1.1246],
        [ 1.2034,  0.1085,  0.6100,  0.3475],
        [-0.2135, -0.3565,  0.2161,  0.0885],
        [-0.9220,  1.7360, -1.3029, -1.2066],
        [ 1.5577,  1.2710,  1.2851,  1.6427],
        [-0.8039,  2.4335, -1.3029, -1.4657]]) tensor([0, 2, 2, 1, 0, 0, 2, 0, 1, 0, 2, 1, 1, 0, 2, 0])
tensor([[ 0.9673,  0.5735,  1.0600,  1.6427],
        [-0.2135, -1.2865,  0.6662,  0.9951],
        [ 0.3769, -1.9840,  0.3849,  0.3475],
        [-0.4497, -1.5190, -0.0652, -0.3000],
        [-0.4497, -1.5190, -0.0089, -0.1705],
        [-1.7486, -0.1240, -1.4154, -1.3362],
        [ 0.6130, -0.5890,  1.0038,  1.1246],
        [-1.0401, -0.1240, -1.2466, -1.3362],
        [-1.2762,  0.1085, -1.2466, -1.3362],
        [ 0.6130, -0.3565,  0.2724,  0.0885],
        [-0.8039,  0.8060, -1.3592, -1.3362],
        [ 0.1407, -1.9840,  0.6662,  0.3475],
        [ 0.1407, -0.8215,  0.7225,  0.4770],
        [-0.6858,  1.5035, -1.3029, -1.3362],
        [ 1.5577, -0.1240,  1.1163,  0.4770],
        [-0.2135, -0.1240,  0.2161, -0.0410]]) tensor([2, 2, 1, 1, 1, 0, 2, 0, 0, 1, 0, 2, 1, 0, 2, 1])
tensor([[ 0.9673,  0.1085,  1.0038,  1.5132],
        [ 0.8492, -0.1240,  0.3287,  0.2180],
        [-0.4497,  1.0385, -1.4154, -1.3362],
        [ 0.9673, -0.1240,  0.6662,  0.6066],
        [ 2.1481, -0.5890,  1.6226,  0.9951],
        [-1.1582, -0.1240, -1.3592, -1.3362],
        [-0.3316, -0.3565, -0.1214,  0.0885],
        [-0.0954, -1.0540,  0.1036, -0.0410],
        [ 0.4950, -0.3565,  1.0038,  0.7361],
        [-0.9220,  1.0385, -1.3592, -1.3362],
        [-1.1582,  0.1085, -1.3029, -1.4657],
        [ 0.4950, -1.2865,  0.6100,  0.3475],
        [-0.9220,  1.5035, -1.3029, -1.0771],
        [-0.0954, -0.8215,  0.0474, -0.0410],
        [ 0.4950, -1.2865,  0.6662,  0.8656],
        [-1.1582, -1.2865,  0.3849,  0.6066]]) tensor([2, 1, 0, 1, 2, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 2])
tensor([[-0.3316, -0.8215,  0.2161,  0.0885],
        [-1.1582, -1.5190, -0.2902, -0.3000],
        [-0.2135,  1.7360, -1.1904, -1.2066],
        [ 0.8492, -0.3565,  0.4412,  0.0885],
        [ 0.7311,  0.3410,  0.7225,  0.9951],
        [-0.0954, -0.8215,  0.7225,  0.8656],
        [ 0.9673,  0.5735,  1.0600,  1.1246],
        [-0.3316, -1.2865,  0.0474, -0.1705],
        [ 0.6130,  0.1085,  0.9475,  0.7361],
        [-1.5124,  0.3410, -1.3592, -1.3362],
        [-0.4497, -1.7515,  0.1036,  0.0885],
        [ 0.6130, -0.8215,  0.8350,  0.8656],
        [ 0.4950, -1.7515,  0.3287,  0.0885],
        [ 0.9673,  0.1085,  0.3287,  0.2180],
        [-1.5124,  0.8060, -1.3592, -1.2066],
        [ 1.2034,  0.3410,  1.0600,  1.3836]]) tensor([1, 1, 0, 1, 2, 2, 2, 1, 2, 0, 1, 2, 1, 1, 0, 2])
tensor([[-1.0401, -2.4490, -0.1777, -0.3000],
        [ 0.0226, -0.1240,  0.2161,  0.3475],
        [-0.5678,  0.8060, -1.3029, -1.0771],
        [-0.0954,  2.2010, -1.4717, -1.3362],
        [-0.0954, -0.8215,  0.7225,  0.8656],
        [ 1.2034,  0.1085,  0.7225,  1.3836],
        [-0.4497, -1.2865,  0.1036,  0.0885],
        [ 1.4396, -0.1240,  1.1726,  1.1246],
        [-1.2762,  0.8060, -1.2466, -1.3362],
        [ 1.0853,  0.3410,  1.1726,  1.3836],
        [ 0.9673, -0.1240,  0.7787,  1.3836],
        [-1.5124,  0.1085, -1.3029, -1.3362],
        [-0.2135, -0.5890,  0.1599,  0.0885],
        [-1.6305, -1.7515, -1.4154, -1.2066],
        [ 2.1481, -0.1240,  1.2851,  1.3836],
        [ 1.3215,  0.3410,  0.4974,  0.2180]]) tensor([1, 1, 0, 0, 2, 2, 1, 2, 0, 2, 2, 0, 1, 0, 2, 1])
tensor([[ 0.7311, -0.1240,  0.7787,  0.9951],
        [ 0.2588, -0.1240,  0.6100,  0.7361],
        [ 1.5577,  0.3410,  1.2288,  0.7361],
        [ 0.6130,  0.3410,  0.8350,  1.3836],
        [-1.1582,  0.1085, -1.3029, -1.3362],
        [ 0.4950,  0.8060,  1.0038,  1.5132],
        [-0.2135,  3.1310, -1.3029, -1.0771],
        [-1.0401,  0.3410, -1.4717, -1.3362],
        [-0.8039, -0.8215,  0.0474,  0.2180],
        [ 2.1481, -1.0540,  1.7352,  1.3836],
        [ 0.3769,  0.8060,  0.8913,  1.3836],
        [-0.5678,  1.5035, -1.3029, -1.3362],
        [ 2.1481,  1.7360,  1.6226,  1.2541],
        [ 1.7938, -0.5890,  1.2851,  0.8656],
        [ 2.3842,  1.7360,  1.4539,  0.9951],
        [-1.7486,  0.3410, -1.4154, -1.3362]]) tensor([2, 2, 2, 2, 0, 2, 0, 0, 1, 2, 2, 0, 2, 2, 2, 0])
tensor([[ 0.0226,  0.3410,  0.5537,  0.7361],
        [ 0.4950,  0.5735,  0.4974,  0.4770],
        [ 0.1407,  0.8060,  0.3849,  0.4770],
        [-0.9220,  1.0385, -1.3592, -1.2066],
        [-0.9220,  1.7360, -1.2466, -1.3362],
        [-0.5678,  1.9685, -1.1904, -1.0771],
        [-1.3943,  0.3410, -1.4154, -1.3362],
        [-0.2135, -0.5890,  0.3849,  0.0885]]) tensor([1, 1, 1, 0, 0, 0, 0, 1])

めちゃめちゃ長かったので、一部省力してますが、思惑通りの結果になってます。pytorchはミニバッチ学習法はすごく直観的で分かりやすいですね。

最後の疑問です。

optimizer.zero_grad()

こちらに非常に詳しく書かれていました。最適化するときに正しく計算するためのようですね。

これだけはない方が分かりやすかったな。と感想。

これでpytorchで学習する方法が分かりました。

最後に

とりあえず、pytorchの流れが分かり、研究開発を進めることができそうです。次回から本格的なAIを作っていくぞー。

コメント

タイトルとURLをコピーしました