実現したいこと
pytorchの学習をしております。
学習前と学習後のモデルの重みの変化を確認してたいです。
発生している問題・エラーメッセージ
モデルをMNISTを用いて一定回数学習した後の1つ目の全結合層の重みを学習前と学習後で確認しているのですが、変化が見られません。
Lossは落ちているので学習はしているのですが、重みが変化しないのはなぜでしょうか?
該当のソースコード
python
1import torch 2from torch import nn 3from torchvision import datasets, models, transforms 4 5import torch.nn.functional as F 6 7num_epochs = 2 # 学習を繰り返す回数 8num_batch = 100 # 一度に処理する画像の枚数 9learning_rate = 0.1 # 学習率 10image_size = 28*28 # 画像の画素数(幅x高さ) 11 12# GPU(CUDA)が使えるかどうか? 13device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 15 16class Net(nn.Module): 17 def __init__(self, input_size, output_size): 18 super(Net, self).__init__() 19 20 # 各クラスのインスタンス(入出力サイズなどの設定) 21 self.fc1 = nn.Linear(input_size, 100) 22 self.fc2 = nn.Linear(100, output_size) 23 24 def forward(self, x): 25 # 順伝播の設定(インスタンスしたクラスの特殊メソッド(__call__)を実行) 26 x = self.fc1(x) 27 x = torch.sigmoid(x) 28 x = self.fc2(x) 29 return F.log_softmax(x, dim=1) 30 31 32# データセット 33trainset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) 34trainloader = torch.utils.data.DataLoader(trainset, batch_size = 100) 35 36# モデル 37model = Net(image_size, 10).to(device) 38 39# 学習前 40print("before") 41print(model.state_dict()['fc1.weight']) 42 43# 適当に学習 44optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) 45criterion = nn.CrossEntropyLoss() 46 47for epoch in range(num_epochs): # 学習を繰り返し行う 48 loss_sum = 0 49 model.train() 50 51 for inputs, labels in trainloader: 52 53 # GPUが使えるならGPUにデータを送る 54 inputs = inputs.to(device) 55 labels = labels.to(device) 56 57 # optimizerを初期化 58 optimizer.zero_grad() 59 60 # ニューラルネットワークの処理を行う 61 inputs = inputs.view(-1, image_size) # 画像データ部分を一次元へ並び変える 62 63 outputs = model(inputs) 64 65 # 損失(出力とラベルとの誤差)の計算 66 loss = criterion(outputs, labels) 67 loss_sum += loss 68 69 # 勾配の計算 70 loss.backward() 71 72 # 重みの更新 73 optimizer.step() 74 print(f"Epoch: {epoch+1}/{num_epochs}, Loss: {loss_sum.item() / len(trainloader)}") 75 76model.eval() 77# 学習後 78print("after") 79print(model.state_dict()['fc1.weight'])``` 80

回答1件
あなたの回答
tips
プレビュー