前回の質問の続きです。
PyTorchを利用し、DQNでゲームのソルバーを作っております。
ミニバッチの値だけ対局データを取り出し、局面に分解し、torch.utils.data.Datalodaerを使ってモデルにデータを投げています。
損失関数はnn.CrossEntropyLoss()を使用しています。
入力データは、12×12のボードサイズの局面を7680個です。
ラベルは、局面に対応しているので、勝敗データ7680個です。
これを、バッチサイズ32で学習させようとしています。
モデルの定義は、今のところ適当です。
教師ラベルを1次元にしてみましたが、エラーが発生しました。
エラーを見ても、具体的な問題点が得られず、ネット上で調べてみても、似たような事例はあるようですが、解決方法を見てもわかりませんでした。
とりあえず一通り動けば良いと思っております。
ご教示をお願い致します。
ここの値を見たいなどありましたら、対応いたします。
Output
1> python .\train.py 2epoch: 1 3epoch:1 record:0 4(7680, 12, 12) 5(7680, 1, 12, 12) 6torch.Size([32, 1, 12, 12]) 7torch.Size([32, 1]) 8torch.Size([32]) 9Traceback (most recent call last): 10 File ".\train.py", line 102, in <module> 11 loss = criterion(y, t) 12 File "C:\Users\meron\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 477, in __call__ 13 result = self.forward(*input, **kwargs) 14 File "C:\Users\meron\Anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 862, in forward 15 ignore_index=self.ignore_index, reduction=self.reduction) 16 File "C:\Users\meron\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1550, in cross_entropy 17 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) 18 File "C:\Users\meron\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1407, in nll_loss 19 return torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) 20RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at c:\new-builder_2\win-wheel\pytorch\aten\src\thnn\generic/ClassNLLCriterion.c:93
train.py
Python
1# coding: utf-8 2 3import os 4import glob 5import random 6import numpy as np 7import torch 8import torch.nn as nn 9import torch.optim as optim 10import torch.nn.functional as F 11import torch.utils.data 12import network 13import train_data_creator 14import game 15 16GAMMA = 0.97 #割引率 17BATCH_SIZE = 32 #一度に学習する局面数 18EPOCH = 3 #1つの訓練データを何回学習させるか 19TURN = 32 20 21MODEL_PATH = os.path.join(os.path.dirname(__file__), "./output/model.pth") #モデルの保存パス 22OPTIMIZER_PATH = os.path.join(os.path.dirname(__file__), "./output/optimizer.pth") #オプティマイザの保存パス 23RECORD_LIST_PATH = os.path.join(os.path.dirname(__file__), "./recordlist_train") #対局データ一覧表(学習用)の保存パス 24TEST_RECORD_LIST_PATH = os.path.join(os.path.dirname(__file__), "./recordlist_test") #対局データ一覧表(テスト用)の保存パス 25 26#モデルが保存されていれば読み込み、なければ新規作成 27if os.path.exists(MODEL_PATH): 28 fine_tune = True 29 model = network.Network() 30 model.load_state_dict(torch.load(MODEL_PATH)) 31else: 32 fine_tune = False 33 model = network.Network() 34 35#オプティマイザが保存されていれば読み込み、なければ新規作成 36if os.path.exists(MODEL_PATH): 37 fine_tune = True 38 optimizer = optim.Adam(model.parameters(), lr=0.1) 39 optimizer.load_state_dict(torch.load(OPTIMIZER_PATH)) 40else: 41 fine_tune = False 42 optimizer = optim.Adam(model.parameters(), lr=0.1) 43 44criterion = nn.CrossEntropyLoss() #推論値と理論値の差を計算 45 46 47for i in range(1, EPOCH+1): #エポックを回す 48 print("epoch:", i) #現在のエポック数(何回目のループか) 49 50 record_index = 0 51 52 while True: #全学習データを扱う 53 print("epoch:{0} record:{1}".format(i, record_index)) 54 55 #バッチサイズ分の訓練データ(陣形とタイルの点数)と正解ラベルを取得 56 #datasetは、[value, state, player, won] それぞれの要素は、バッチサイズ分の対局の全局面でシャッフルなし 57 dataset = train_data_creator.get_dataset(RECORD_LIST_PATH, BATCH_SIZE, record_index) 58 if dataset is None: #学習データがなくなった 59 break 60 61 dataset[1] = dataset[1].reshape(len(dataset[1]), 1, game.MAX_BOARD_SIZE, game.MAX_BOARD_SIZE) 62 dataset[3] = dataset[3].reshape(len(dataset[3]), 1, 1, 1) 63 64 train = torch.utils.data.TensorDataset(torch.from_numpy(dataset[1]).float(), torch.from_numpy(dataset[3]).long()) 65 train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True) 66 67 total_loss = 0 68 for i, data in enumerate(train_loader): 69 x, t = data 70 x, t = torch.autograd.Variable(x), torch.autograd.Variable(t) 71 optimizer.zero_grad() 72 y = model(x) 73 74 # *** 追加 *** 1次元にする 75 t = torch.squeeze(t) 76 77 print(x.shape) 78 print(y.shape) 79 print(t.shape) 80 81 loss = criterion(y, t) 82 total_loss += loss.data[0] 83 loss.backward() 84 optimizer.step() 85 86 record_index += BATCH_SIZE
環境
Python
PyTorch
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2018/09/16 12:17