AIプログラミング初心者がChatGPTにpytorchで簡単なAIを実装させてみた。前処理編

pytorch

はじめに

前回はChatGPTにpytorchで簡単なAIを実装させ、ライブラリについて簡単に調査しました。

今回は前回のコードの中で前処理部分について調査してみたいと思います。

pytorchの前処理について調べてみた

ChatGPTに聞いたコード全体

まずは前回のコードの全体から

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset

# Iris データセットの読み込み
iris = load_iris()
X, y = iris.data, iris.target

# データの前処理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# PyTorchのテンソルに変換
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)

# データローダーの作成
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# ニューラルネットワークの定義
class IrisNet(nn.Module):
    def __init__(self):
        super(IrisNet, self).__init__()
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

model = IrisNet()

# 損失関数とオプティマイザ
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# トレーニングプロセス
def train(model, criterion, optimizer, train_loader, epochs=100):
    for epoch in range(epochs):
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")

train(model, criterion, optimizer, train_loader)

# テストデータで評価
with torch.no_grad():
    outputs = model(X_test)
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == y_test).sum().item() / y_test.size(0)

accuracy

前処理について調べてみた

その中でも前処理に当たる部分は以下のコードになるのでしょうか

# Iris データセットの読み込み
iris = load_iris()
X, y = iris.data, iris.target

# データの前処理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# PyTorchのテンソルに変換
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)

# データローダーの作成
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

データセットについて調べてみた

まずはどんなデータセットを使っているか確認してみたいと思います。

前回のtensorflowではmnistという手書き文字を扱ってましたが、今回の場合はIrisというデータセットを扱っているようです。

データセットの中身はこちらでかなり詳しく説明されています。

Irisという単語が「あやめ」という花の意味らしく、このデータベースでは3種類のあやめの情報が含まれているようです。

このデータセットをprint(iris)で表示させてみます。

pint(iris)
{'data': array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.6, 1.4, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]]), 
'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 
'frame': None, 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

おーなんかたくさん数字がでました。ここには表示されていないですがまだまだ下にも続いていました。今回は重要な部分だけ取り出してみました。

データセットの構造としてはpythonの辞書型になっているようです。

一番初めのdataというキーはニューラルネットワークの入力層に入るデータだと思われます。各行があやめの各種類に当たり、各列が花の萼(がく)片の長さと幅、花弁の長さと幅に当たるようです。(がく片と花弁なんて中学以来聞いたことないよ…笑)

がく片と花弁の長さと幅であやめの種類が分類されているようです。

そして、このデータセットのtargetというキーがdataの各行に対応するあやめの種類になるようです。0がsetosa、1がversicolor、2がvirginicaとなっているようです。

とりあえず、データセットの中身が分かりました。このデータセットで解析する目的はがく片と花弁の長さと幅を与えれたときにどの種類のあやめに相当するのかを分類することだと思います。

そのためにニューラルネットワークを使って学習させ、モデルを作っていくことになります。

次にこのデータセットから入力となるXと教師データとなるyに分けていきます。これはデータセットが辞書型なので簡単にでき、

X, y = iris.data, iris.target

これで一発です。

次にこのデータセットを学習用データと検証データに分けていますね。それが以下のコードになりますかね。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

学習用データ(train)と検証データ(test)に分けて、検証データが全体の内の2割になるようにしていますね。また、ランダムに抜き出していますが、毎回同じものが得られるようにrandom_stateを0にしていますね。

正規化について調べてみた

次に行っていることは学習用データの正規化ですね。それがこちら

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

正規化はデータの幅を揃えて、学習をより効率的かつ予測精度の高くするための方法みたいな感じでしたね。

ただfit_transformとtransformの違いがよく分かっていません。なのでprintで何が変わったか確認してみます。

まずは大本のX
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]
 [5.4 3.9 1.7 0.4]
 [4.6 3.4 1.4 0.3]
 [5.  3.4 1.5 0.2]
 [4.4 2.9 1.4 0.2]
 [4.9 3.1 1.5 0.1]
 [5.4 3.7 1.5 0.2]
 [4.8 3.4 1.6 0.2]
 [4.8 3.  1.4 0.1]
 [4.3 3.  1.1 0.1]
 [5.8 4.  1.2 0.2]
 [5.7 4.4 1.5 0.4]
 [5.4 3.9 1.3 0.4]
 [5.1 3.5 1.4 0.3]
 [5.7 3.8 1.7 0.3]
 [5.1 3.8 1.5 0.3]
 [5.4 3.4 1.7 0.2]
 [5.1 3.7 1.5 0.4]
 [4.6 3.6 1.  0.2]
 [5.1 3.3 1.7 0.5]
 [4.8 3.4 1.9 0.2]
 [5.  3.  1.6 0.2]
 [5.  3.4 1.6 0.4]
 [5.2 3.5 1.5 0.2]
 [5.2 3.4 1.4 0.2]
 [4.7 3.2 1.6 0.2]
 [4.8 3.1 1.6 0.2]
 [5.4 3.4 1.5 0.4]
 [5.2 4.1 1.5 0.1]
 [5.5 4.2 1.4 0.2]
 [4.9 3.1 1.5 0.2]
 [5.  3.2 1.2 0.2]
 [5.5 3.5 1.3 0.2]
 [4.9 3.6 1.4 0.1]
 [4.4 3.  1.3 0.2]
 [5.1 3.4 1.5 0.2]
 [5.  3.5 1.3 0.3]
 [4.5 2.3 1.3 0.3]
 [4.4 3.2 1.3 0.2]
 [5.  3.5 1.6 0.6]
 [5.1 3.8 1.9 0.4]
 [4.8 3.  1.4 0.3]
 [5.1 3.8 1.6 0.2]
 [4.6 3.2 1.4 0.2]
 [5.3 3.7 1.5 0.2]
 [5.  3.3 1.4 0.2]
 [7.  3.2 4.7 1.4]
 [6.4 3.2 4.5 1.5]
 [6.9 3.1 4.9 1.5]
 [5.5 2.3 4.  1.3]
 [6.5 2.8 4.6 1.5]
 [5.7 2.8 4.5 1.3]
 [6.3 3.3 4.7 1.6]
 [4.9 2.4 3.3 1. ]
 [6.6 2.9 4.6 1.3]
 [5.2 2.7 3.9 1.4]
 [5.  2.  3.5 1. ]
 [5.9 3.  4.2 1.5]
 [6.  2.2 4.  1. ]
 [6.1 2.9 4.7 1.4]
 [5.6 2.9 3.6 1.3]
 [6.7 3.1 4.4 1.4]
 [5.6 3.  4.5 1.5]
 [5.8 2.7 4.1 1. ]
 [6.2 2.2 4.5 1.5]
 [5.6 2.5 3.9 1.1]
 [5.9 3.2 4.8 1.8]
 [6.1 2.8 4.  1.3]
 [6.3 2.5 4.9 1.5]
 [6.1 2.8 4.7 1.2]
 [6.4 2.9 4.3 1.3]
 [6.6 3.  4.4 1.4]
 [6.8 2.8 4.8 1.4]
 [6.7 3.  5.  1.7]
 [6.  2.9 4.5 1.5]
 [5.7 2.6 3.5 1. ]
 [5.5 2.4 3.8 1.1]
 [5.5 2.4 3.7 1. ]
 [5.8 2.7 3.9 1.2]
 [6.  2.7 5.1 1.6]
 [5.4 3.  4.5 1.5]
 [6.  3.4 4.5 1.6]
 [6.7 3.1 4.7 1.5]
 [6.3 2.3 4.4 1.3]
 [5.6 3.  4.1 1.3]
 [5.5 2.5 4.  1.3]
 [5.5 2.6 4.4 1.2]
 [6.1 3.  4.6 1.4]
 [5.8 2.6 4.  1.2]
 [5.  2.3 3.3 1. ]
 [5.6 2.7 4.2 1.3]
 [5.7 3.  4.2 1.2]
 [5.7 2.9 4.2 1.3]
 [6.2 2.9 4.3 1.3]
 [5.1 2.5 3.  1.1]
 [5.7 2.8 4.1 1.3]
 [6.3 3.3 6.  2.5]
 [5.8 2.7 5.1 1.9]
 [7.1 3.  5.9 2.1]
 [6.3 2.9 5.6 1.8]
 [6.5 3.  5.8 2.2]
 [7.6 3.  6.6 2.1]
 [4.9 2.5 4.5 1.7]
 [7.3 2.9 6.3 1.8]
 [6.7 2.5 5.8 1.8]
 [7.2 3.6 6.1 2.5]
 [6.5 3.2 5.1 2. ]
 [6.4 2.7 5.3 1.9]
 [6.8 3.  5.5 2.1]
 [5.7 2.5 5.  2. ]
 [5.8 2.8 5.1 2.4]
 [6.4 3.2 5.3 2.3]
 [6.5 3.  5.5 1.8]
 [7.7 3.8 6.7 2.2]
 [7.7 2.6 6.9 2.3]
 [6.  2.2 5.  1.5]
 [6.9 3.2 5.7 2.3]
 [5.6 2.8 4.9 2. ]
 [7.7 2.8 6.7 2. ]
 [6.3 2.7 4.9 1.8]
 [6.7 3.3 5.7 2.1]
 [7.2 3.2 6.  1.8]
 [6.2 2.8 4.8 1.8]
 [6.1 3.  4.9 1.8]
 [6.4 2.8 5.6 2.1]
 [7.2 3.  5.8 1.6]
 [7.4 2.8 6.1 1.9]
 [7.9 3.8 6.4 2. ]
 [6.4 2.8 5.6 2.2]
 [6.3 2.8 5.1 1.5]
 [6.1 2.6 5.6 1.4]
 [7.7 3.  6.1 2.3]
 [6.3 3.4 5.6 2.4]
 [6.4 3.1 5.5 1.8]
 [6.  3.  4.8 1.8]
 [6.9 3.1 5.4 2.1]
 [6.7 3.1 5.6 2.4]
 [6.9 3.1 5.1 2.3]
 [5.8 2.7 5.1 1.9]
 [6.8 3.2 5.9 2.3]
 [6.7 3.3 5.7 2.5]
 [6.7 3.  5.2 2.3]
 [6.3 2.5 5.  1.9]
 [6.5 3.  5.2 2. ]
 [6.2 3.4 5.4 2.3]
 [5.9 3.  5.1 1.8]]
次に学習データの変換後のX
[[ 0.61303014  0.10850105  0.94751783  0.736072  ]
 [-0.56776627 -0.12400121  0.38491447  0.34752959]
 [-0.80392556  1.03851009 -1.30289562 -1.33615415]
 [ 0.25879121 -0.12400121  0.60995581  0.736072  ]
 [ 0.61303014 -0.58900572  1.00377816  1.25412853]
 [-0.80392556 -0.82150798  0.04735245  0.21801546]
 [-0.21352735  1.73601687 -1.19037495 -1.20664002]
 [ 0.14071157 -0.82150798  0.72247648  0.47704373]
 [ 0.02263193 -0.12400121  0.21613346  0.34752959]
 [-0.09544771 -1.05401024  0.10361279 -0.04101281]
 [ 1.0853487  -0.12400121  0.94751783  1.1246144 ]
 [-1.39432376  0.34100331 -1.41541629 -1.33615415]
 [ 1.20342834  0.10850105  0.72247648  1.38364267]
 [-1.04008484  1.03851009 -1.24663528 -0.81809761]
 [-0.56776627  1.50351461 -1.30289562 -1.33615415]
 [-1.04008484 -2.4490238  -0.1776889  -0.30004108]
 [ 0.73110978 -0.12400121  0.94751783  0.736072  ]
 [ 0.96726906  0.57350557  1.0600385   1.64267094]
 [ 0.14071157 -1.98401928  0.66621615  0.34752959]
 [ 0.96726906 -1.2865125   1.11629884  0.736072  ]
 [-0.33160699 -1.2865125   0.04735245 -0.17052694]
 [ 2.14806547 -0.12400121  1.28507985  1.38364267]
 [ 0.49495049  0.57350557  0.49743514  0.47704373]
 [-0.44968663 -1.51901476 -0.00890789 -0.17052694]
 [ 0.49495049 -0.82150798  0.60995581  0.736072  ]
 [ 0.49495049 -0.58900572  0.72247648  0.34752959]
 [-1.15816448 -1.2865125   0.38491447  0.60655786]
 [ 0.49495049 -1.2865125   0.66621615  0.86558613]
 [ 1.32150798  0.34100331  0.49743514  0.21801546]
 [ 0.73110978 -0.12400121  0.77873682  0.99510027]
 [ 0.14071157  0.80600783  0.38491447  0.47704373]
 [-1.27624412  0.10850105 -1.24663528 -1.33615415]
 [-0.09544771 -0.82150798  0.72247648  0.86558613]
 [-0.33160699 -0.82150798  0.21613346  0.08850133]
 [-0.33160699 -0.35650346 -0.12142856  0.08850133]
 [-0.44968663 -1.2865125   0.10361279  0.08850133]
 [ 0.25879121 -0.12400121  0.4411748   0.21801546]
 [ 1.55766726  0.34100331  1.22881951  0.736072  ]
 [-0.68584591  1.50351461 -1.30289562 -1.33615415]
 [-1.86664232 -0.12400121 -1.52793696 -1.46566829]
 [ 0.61303014 -0.82150798  0.83499716  0.86558613]
 [-0.21352735 -0.12400121  0.21613346 -0.04101281]
 [-0.56776627  0.80600783 -1.19037495 -1.33615415]
 [-0.21352735  3.13103043 -1.30289562 -1.07712588]
 [ 1.20342834  0.10850105  0.60995581  0.34752959]
 [-1.5124034   0.10850105 -1.30289562 -1.33615415]
 [ 0.02263193 -0.12400121  0.72247648  0.736072  ]
 [-0.9220052  -1.2865125  -0.45899058 -0.17052694]
 [-1.5124034   0.80600783 -1.35915595 -1.20664002]
 [ 0.37687085 -1.98401928  0.38491447  0.34752959]
 [ 1.55766726  1.27101235  1.28507985  1.64267094]
 [-0.21352735 -0.35650346  0.21613346  0.08850133]
 [-1.27624412 -0.12400121 -1.35915595 -1.46566829]
 [ 1.43958762 -0.12400121  1.17255917  1.1246144 ]
 [ 1.20342834  0.34100331  1.0600385   1.38364267]
 [ 0.73110978 -0.12400121  1.11629884  1.25412853]
 [ 0.61303014 -0.58900572  1.00377816  1.1246144 ]
 [-0.9220052   1.73601687 -1.24663528 -1.33615415]
 [-1.27624412  0.80600783 -1.24663528 -1.33615415]
 [ 0.73110978  0.34100331  0.72247648  0.99510027]
 [ 0.96726906  0.57350557  1.0600385   1.1246144 ]
 [-1.63048304 -1.75151702 -1.41541629 -1.20664002]
 [ 0.37687085  0.80600783  0.89125749  1.38364267]
 [-1.15816448 -0.12400121 -1.35915595 -1.33615415]
 [-0.21352735 -1.2865125   0.66621615  0.99510027]
 [ 1.20342834  0.10850105  0.89125749  1.1246144 ]
 [-1.74856268  0.34100331 -1.41541629 -1.33615415]
 [-1.04008484  1.27101235 -1.35915595 -1.33615415]
 [ 1.55766726 -0.12400121  1.11629884  0.47704373]
 [-0.9220052   1.03851009 -1.35915595 -1.20664002]
 [-1.74856268 -0.12400121 -1.41541629 -1.33615415]
 [-0.56776627  1.96851913 -1.19037495 -1.07712588]
 [-0.44968663 -1.75151702  0.10361279  0.08850133]
 [ 1.0853487   0.34100331  1.17255917  1.38364267]
 [ 2.02998583 -0.12400121  1.56638153  1.1246144 ]
 [-0.9220052   1.03851009 -1.35915595 -1.33615415]
 [-1.15816448  0.10850105 -1.30289562 -1.33615415]
 [-0.80392556  0.80600783 -1.35915595 -1.33615415]
 [-0.21352735 -0.58900572  0.38491447  0.08850133]
 [ 0.84918942 -0.12400121  0.32865413  0.21801546]
 [-1.04008484  0.34100331 -1.47167663 -1.33615415]
 [-0.9220052   0.57350557 -1.19037495 -0.94761175]
 [ 0.61303014 -0.35650346  0.27239379  0.08850133]
 [-0.56776627  0.80600783 -1.30289562 -1.07712588]
 [ 2.14806547 -1.05401024  1.73516253  1.38364267]
 [-1.15816448 -1.51901476 -0.29020957 -0.30004108]
 [ 2.38422475  1.73601687  1.45386085  0.99510027]
 [ 0.96726906  0.10850105  0.32865413  0.21801546]
 [-0.80392556  2.43352365 -1.30289562 -1.46566829]
 [ 0.14071157 -0.12400121  0.55369548  0.736072  ]
 [-0.09544771  2.20102139 -1.47167663 -1.33615415]
 [ 2.14806547 -0.58900572  1.62264186  0.99510027]
 [-0.9220052   1.73601687 -1.30289562 -1.20664002]
 [-1.39432376  0.34100331 -1.24663528 -1.33615415]
 [ 1.79382654 -0.58900572  1.28507985  0.86558613]
 [-1.04008484  0.57350557 -1.35915595 -1.33615415]
 [ 0.49495049  0.80600783  1.00377816  1.5131568 ]
 [-0.21352735 -0.58900572  0.15987312  0.08850133]
 [-0.09544771 -0.82150798  0.04735245 -0.04101281]
 [-0.21352735 -1.05401024 -0.1776889  -0.30004108]
 [ 0.61303014  0.34100331  0.83499716  1.38364267]
 [ 0.96726906 -0.12400121  0.77873682  1.38364267]
 [ 0.49495049 -1.2865125   0.60995581  0.34752959]
 [ 0.96726906 -0.12400121  0.66621615  0.60655786]
 [-1.04008484 -0.12400121 -1.24663528 -1.33615415]
 [-0.44968663 -1.51901476 -0.06516822 -0.30004108]
 [ 0.96726906  0.10850105  1.00377816  1.5131568 ]
 [-0.09544771 -0.82150798  0.72247648  0.86558613]
 [-0.9220052   0.80600783 -1.30289562 -1.33615415]
 [ 0.84918942 -0.35650346  0.4411748   0.08850133]
 [-0.33160699 -0.12400121  0.15987312  0.08850133]
 [ 0.02263193  0.34100331  0.55369548  0.736072  ]
 [ 0.49495049 -1.75151702  0.32865413  0.08850133]
 [-0.44968663  1.03851009 -1.41541629 -1.33615415]
 [-0.9220052   1.50351461 -1.30289562 -1.07712588]
 [-1.15816448  0.10850105 -1.30289562 -1.46566829]
 [ 0.49495049 -0.35650346  1.00377816  0.736072  ]
 [-0.09544771 -0.82150798  0.15987312 -0.30004108]
 [ 2.14806547  1.73601687  1.62264186  1.25412853]
 [-1.5124034   0.34100331 -1.35915595 -1.33615415]]
最後に検証データの変換後のX
[[-0.09544771 -0.58900572  0.72247648  1.5131568 ]
 [ 0.14071157 -1.98401928  0.10361279 -0.30004108]
 [-0.44968663  2.66602591 -1.35915595 -1.33615415]
 [ 1.6757469  -0.35650346  1.39760052  0.736072  ]
 [-1.04008484  0.80600783 -1.30289562 -1.33615415]
 [ 0.49495049  0.57350557  1.22881951  1.64267094]
 [-1.04008484  1.03851009 -1.41541629 -1.20664002]
 [ 0.96726906  0.10850105  0.49743514  0.34752959]
 [ 1.0853487  -0.58900572  0.55369548  0.21801546]
 [ 0.25879121 -0.58900572  0.10361279  0.08850133]
 [ 0.25879121 -1.05401024  1.00377816  0.21801546]
 [ 0.61303014  0.34100331  0.38491447  0.34752959]
 [ 0.25879121 -0.58900572  0.49743514 -0.04101281]
 [ 0.73110978 -0.58900572  0.4411748   0.34752959]
 [ 0.25879121 -0.35650346  0.49743514  0.21801546]
 [-1.15816448  1.27101235 -1.35915595 -1.46566829]
 [ 0.14071157 -0.35650346  0.38491447  0.34752959]
 [-0.44968663 -1.05401024  0.32865413 -0.04101281]
 [-1.27624412 -0.12400121 -1.35915595 -1.20664002]
 [-0.56776627  1.96851913 -1.41541629 -1.07712588]
 [-0.33160699 -0.58900572  0.60995581  0.99510027]
 [-0.33160699 -0.12400121  0.38491447  0.34752959]
 [-1.27624412  0.80600783 -1.07785427 -1.33615415]
 [-1.74856268 -0.35650346 -1.35915595 -1.33615415]
 [ 0.37687085 -0.58900572  0.55369548  0.736072  ]
 [-1.5124034   1.27101235 -1.5841973  -1.33615415]
 [-0.9220052   1.73601687 -1.07785427 -1.07712588]
 [ 0.37687085 -0.35650346  0.27239379  0.08850133]
 [-1.04008484 -1.75151702 -0.29020957 -0.30004108]
 [-1.04008484  0.80600783 -1.24663528 -1.07712588]]

パット見、何が変わったか分かりません。

この違いについてこちらで解説してくださっています。

やっていることは変わらないですが、検証データを変換時に学習データのパラメータを共通にして、変換しているようでうね。

pytorch特有の前処理について調べてみた

ようやく前処理の半分ですね。残りの前処理を改めて

# PyTorchのテンソルに変換
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)

# データローダーの作成
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
torch.tensorについて調べてみた

まずはtensorってやつですね。X_trainをprintで確認すると

print結果
tensor([[ 0.6130,  0.1085,  0.9475,  0.7361],
        [-0.5678, -0.1240,  0.3849,  0.3475],
        [-0.8039,  1.0385, -1.3029, -1.3362],
        [ 0.2588, -0.1240,  0.6100,  0.7361],
        [ 0.6130, -0.5890,  1.0038,  1.2541],
        [-0.8039, -0.8215,  0.0474,  0.2180],
        [-0.2135,  1.7360, -1.1904, -1.2066],
        [ 0.1407, -0.8215,  0.7225,  0.4770],
        [ 0.0226, -0.1240,  0.2161,  0.3475],
        [-0.0954, -1.0540,  0.1036, -0.0410],
        [ 1.0853, -0.1240,  0.9475,  1.1246],
        [-1.3943,  0.3410, -1.4154, -1.3362],
        [ 1.2034,  0.1085,  0.7225,  1.3836],
        [-1.0401,  1.0385, -1.2466, -0.8181],
        [-0.5678,  1.5035, -1.3029, -1.3362],
        [-1.0401, -2.4490, -0.1777, -0.3000],
        [ 0.7311, -0.1240,  0.9475,  0.7361],
        [ 0.9673,  0.5735,  1.0600,  1.6427],
        [ 0.1407, -1.9840,  0.6662,  0.3475],
        [ 0.9673, -1.2865,  1.1163,  0.7361],
        [-0.3316, -1.2865,  0.0474, -0.1705],
        [ 2.1481, -0.1240,  1.2851,  1.3836],
        [ 0.4950,  0.5735,  0.4974,  0.4770],
        [-0.4497, -1.5190, -0.0089, -0.1705],
        [ 0.4950, -0.8215,  0.6100,  0.7361],
        [ 0.4950, -0.5890,  0.7225,  0.3475],
        [-1.1582, -1.2865,  0.3849,  0.6066],
        [ 0.4950, -1.2865,  0.6662,  0.8656],
        [ 1.3215,  0.3410,  0.4974,  0.2180],
        [ 0.7311, -0.1240,  0.7787,  0.9951],
        [ 0.1407,  0.8060,  0.3849,  0.4770],
        [-1.2762,  0.1085, -1.2466, -1.3362],
        [-0.0954, -0.8215,  0.7225,  0.8656],
        [-0.3316, -0.8215,  0.2161,  0.0885],
        [-0.3316, -0.3565, -0.1214,  0.0885],
        [-0.4497, -1.2865,  0.1036,  0.0885],
        [ 0.2588, -0.1240,  0.4412,  0.2180],
        [ 1.5577,  0.3410,  1.2288,  0.7361],
        [-0.6858,  1.5035, -1.3029, -1.3362],
        [-1.8666, -0.1240, -1.5279, -1.4657],
        [ 0.6130, -0.8215,  0.8350,  0.8656],
        [-0.2135, -0.1240,  0.2161, -0.0410],
        [-0.5678,  0.8060, -1.1904, -1.3362],
        [-0.2135,  3.1310, -1.3029, -1.0771],
        [ 1.2034,  0.1085,  0.6100,  0.3475],
        [-1.5124,  0.1085, -1.3029, -1.3362],
        [ 0.0226, -0.1240,  0.7225,  0.7361],
        [-0.9220, -1.2865, -0.4590, -0.1705],
        [-1.5124,  0.8060, -1.3592, -1.2066],
        [ 0.3769, -1.9840,  0.3849,  0.3475],
        [ 1.5577,  1.2710,  1.2851,  1.6427],
        [-0.2135, -0.3565,  0.2161,  0.0885],
        [-1.2762, -0.1240, -1.3592, -1.4657],
        [ 1.4396, -0.1240,  1.1726,  1.1246],
        [ 1.2034,  0.3410,  1.0600,  1.3836],
        [ 0.7311, -0.1240,  1.1163,  1.2541],
        [ 0.6130, -0.5890,  1.0038,  1.1246],
        [-0.9220,  1.7360, -1.2466, -1.3362],
        [-1.2762,  0.8060, -1.2466, -1.3362],
        [ 0.7311,  0.3410,  0.7225,  0.9951],
        [ 0.9673,  0.5735,  1.0600,  1.1246],
        [-1.6305, -1.7515, -1.4154, -1.2066],
        [ 0.3769,  0.8060,  0.8913,  1.3836],
        [-1.1582, -0.1240, -1.3592, -1.3362],
        [-0.2135, -1.2865,  0.6662,  0.9951],
        [ 1.2034,  0.1085,  0.8913,  1.1246],
        [-1.7486,  0.3410, -1.4154, -1.3362],
        [-1.0401,  1.2710, -1.3592, -1.3362],
        [ 1.5577, -0.1240,  1.1163,  0.4770],
        [-0.9220,  1.0385, -1.3592, -1.2066],
        [-1.7486, -0.1240, -1.4154, -1.3362],
        [-0.5678,  1.9685, -1.1904, -1.0771],
        [-0.4497, -1.7515,  0.1036,  0.0885],
        [ 1.0853,  0.3410,  1.1726,  1.3836],
        [ 2.0300, -0.1240,  1.5664,  1.1246],
        [-0.9220,  1.0385, -1.3592, -1.3362],
        [-1.1582,  0.1085, -1.3029, -1.3362],
        [-0.8039,  0.8060, -1.3592, -1.3362],
        [-0.2135, -0.5890,  0.3849,  0.0885],
        [ 0.8492, -0.1240,  0.3287,  0.2180],
        [-1.0401,  0.3410, -1.4717, -1.3362],
        [-0.9220,  0.5735, -1.1904, -0.9476],
        [ 0.6130, -0.3565,  0.2724,  0.0885],
        [-0.5678,  0.8060, -1.3029, -1.0771],
        [ 2.1481, -1.0540,  1.7352,  1.3836],
        [-1.1582, -1.5190, -0.2902, -0.3000],
        [ 2.3842,  1.7360,  1.4539,  0.9951],
        [ 0.9673,  0.1085,  0.3287,  0.2180],
        [-0.8039,  2.4335, -1.3029, -1.4657],
        [ 0.1407, -0.1240,  0.5537,  0.7361],
        [-0.0954,  2.2010, -1.4717, -1.3362],
        [ 2.1481, -0.5890,  1.6226,  0.9951],
        [-0.9220,  1.7360, -1.3029, -1.2066],
        [-1.3943,  0.3410, -1.2466, -1.3362],
        [ 1.7938, -0.5890,  1.2851,  0.8656],
        [-1.0401,  0.5735, -1.3592, -1.3362],
        [ 0.4950,  0.8060,  1.0038,  1.5132],
        [-0.2135, -0.5890,  0.1599,  0.0885],
        [-0.0954, -0.8215,  0.0474, -0.0410],
        [-0.2135, -1.0540, -0.1777, -0.3000],
        [ 0.6130,  0.3410,  0.8350,  1.3836],
        [ 0.9673, -0.1240,  0.7787,  1.3836],
        [ 0.4950, -1.2865,  0.6100,  0.3475],
        [ 0.9673, -0.1240,  0.6662,  0.6066],
        [-1.0401, -0.1240, -1.2466, -1.3362],
        [-0.4497, -1.5190, -0.0652, -0.3000],
        [ 0.9673,  0.1085,  1.0038,  1.5132],
        [-0.0954, -0.8215,  0.7225,  0.8656],
        [-0.9220,  0.8060, -1.3029, -1.3362],
        [ 0.8492, -0.3565,  0.4412,  0.0885],
        [-0.3316, -0.1240,  0.1599,  0.0885],
        [ 0.0226,  0.3410,  0.5537,  0.7361],
        [ 0.4950, -1.7515,  0.3287,  0.0885],
        [-0.4497,  1.0385, -1.4154, -1.3362],
        [-0.9220,  1.5035, -1.3029, -1.0771],
        [-1.1582,  0.1085, -1.3029, -1.4657],
        [ 0.4950, -0.3565,  1.0038,  0.7361],
        [-0.0954, -0.8215,  0.1599, -0.3000],
        [ 2.1481,  1.7360,  1.6226,  1.2541],
        [-1.5124,  0.3410, -1.3592, -1.3362]])

tensorで変換する前とした後で結果がほとんど変化していないですが、一番初めにtensorが付きました。

tensorって日本語でテンソルって意味だと思うのですが、数学や物理のテンソルは行列とかベクトルに近いものだと思っています。何が違うのか調べてみます。

こちらにnumpyのndarrayとの違いを書かれています。以下に引用を書かせていただきます。

  • GPUによる高速化(Google Colab向き)
  • 自動微分(Autograd)機能により勾配演算を自動化できる(順伝播 / 逆伝播)

なるほど。とりあえず、高速化、自動化ができるようです。この辺りもいずれ勉強していきたいと思います。

TensorDatasetとDataLoaderについて調べてみた

最後の前処理ですね。それがこちら

# データローダーの作成
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

さっそくprintで確認。

torch.utils.data.dataset.TensorDataset object at 0x0000025BDCAA2970

torch.utils.data.dataloader.DataLoader object at 0x0000025BDCAA2130

なんかよくわからない結果が出てきました。何を行っているか調べてみます。

こちらが分かりやすかったですね。これを基に図を作りました。

TensorDatasetはX_trainとy_trainを結合して一個のデータのかたまりを作っているイメージをしました。

DataLoaderはTensorDatasetで作ったデータからミニバッチ学習法を行うための準備だと思いました。ミニバッチ学習法というのはデータセットからバッチサイズ分だけ取り出し、学習させ、またバッチサイズ分だけ取り出し、学習という繰り返しと行うことでデータ全体を学習させようという方法です。

それを簡単にするための準備だと思います。なので、今回のDataLoaderは作成したデータセットからランダムに16バッチサイズ選んで学習させる準備と思っています。

コメント

タイトルとURLをコピーしました