PyTorchを利用し、DQNでゲームのソルバーを作っております。
ミニバッチの値だけ対局データを取り出し、局面に分解し、torch.utils.data.Datalodaerを使ってモデルにデータを投げています。
損失関数はnn.CrossEntropyLoss()を使用しています。
入力データは、12×12のボードサイズの局面を7680個です。
ラベルは、局面に対応しているので、勝敗データ7680個です。
これを、バッチサイズ32で学習させようとしています。
モデルの定義は、今のところ適当です。
いろいろ値を変えてみても、何かしらのエラーが返ってきます。
今回の場合はnn.CrossEntropyLoss()でつまずいているようですが、エラーメッセージを読んでも何を言いたいのかよくわかりません。
また、公式ドキュメントやその他の質問サイトなども見ましたが、あまり情報が充実していないようです。
どうすれば、一通り動くなるようになるのでしょうか。
ご教示をお願い致します。
ここの値を見たいなどありましたら、対応いたします。
Output
1> python .\train.py 2epoch: 1 3epoch:1 record:0 4(7680, 12, 12) 5(7680, 1, 12, 12) 6torch.Size([32, 1, 12, 12]) 7torch.Size([32, 1]) 8torch.Size([32, 1, 1, 1]) 9Traceback (most recent call last): 10 File ".\train.py", line 98, in <module> 11 loss = criterion(y, t) 12 File "C:\Users\meron\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 477, in __call__ 13 result = self.forward(*input, **kwargs) 14 File "C:\Users\meron\Anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 862, in forward 15 ignore_index=self.ignore_index, reduction=self.reduction) 16 File "C:\Users\meron\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1550, in cross_entropy 17 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) 18 File "C:\Users\meron\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1407, in nll_loss 19 return torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) 20RuntimeError: multi-target not supported at c:\new-builder_2\win-wheel\pytorch\aten\src\thnn\generic/ClassNLLCriterion.c:21
train.py
Python
1# coding: utf-8 2 3import os 4import glob 5import random 6import numpy as np 7import torch 8import torch.nn as nn 9import torch.optim as optim 10import torch.nn.functional as F 11import torch.utils.data 12import network 13import train_data_creator 14import game 15 16GAMMA = 0.97 #割引率 17BATCH_SIZE = 32 #一度に学習する局面数 18EPOCH = 3 #1つの訓練データを何回学習させるか 19TURN = 32 20 21MODEL_PATH = os.path.join(os.path.dirname(__file__), "./output/model.pth") #モデルの保存パス 22OPTIMIZER_PATH = os.path.join(os.path.dirname(__file__), "./output/optimizer.pth") #オプティマイザの保存パス 23RECORD_LIST_PATH = os.path.join(os.path.dirname(__file__), "./recordlist_train") #対局データ一覧表(学習用)の保存パス 24TEST_RECORD_LIST_PATH = os.path.join(os.path.dirname(__file__), "./recordlist_test") #対局データ一覧表(テスト用)の保存パス 25 26#モデルが保存されていれば読み込み、なければ新規作成 27if os.path.exists(MODEL_PATH): 28 fine_tune = True 29 model = network.Network() 30 model.load_state_dict(torch.load(MODEL_PATH)) 31else: 32 fine_tune = False 33 model = network.Network() 34 35#オプティマイザが保存されていれば読み込み、なければ新規作成 36if os.path.exists(MODEL_PATH): 37 fine_tune = True 38 optimizer = optim.Adam(model.parameters(), lr=0.1) 39 optimizer.load_state_dict(torch.load(OPTIMIZER_PATH)) 40else: 41 fine_tune = False 42 optimizer = optim.Adam(model.parameters(), lr=0.1) 43 44criterion = nn.CrossEntropyLoss() #推論値と理論値の差を計算 45 46 47for i in range(1, EPOCH+1): #エポックを回す 48 print("epoch:", i) #現在のエポック数(何回目のループか) 49 50 record_index = 0 51 52 while True: #全学習データを扱う 53 print("epoch:{0} record:{1}".format(i, record_index)) 54 55 #バッチサイズ分の訓練データ(陣形とタイルの点数)と正解ラベルを取得 56 #datasetは、[value, state, player, won] それぞれの要素は、バッチサイズ分の対局の全局面でシャッフルなし 57 dataset = train_data_creator.get_dataset(RECORD_LIST_PATH, BATCH_SIZE, record_index) 58 if dataset is None: #学習データがなくなった 59 break 60 61 dataset[1] = dataset[1].reshape(len(dataset[1]), 1, game.MAX_BOARD_SIZE, game.MAX_BOARD_SIZE) 62 dataset[3] = dataset[3].reshape(len(dataset[3]), 1, 1, 1) 63 64 train = torch.utils.data.TensorDataset(torch.from_numpy(dataset[1]).float(), torch.from_numpy(dataset[3]).long()) 65 train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True) 66 67 total_loss = 0 68 for i, data in enumerate(train_loader): 69 x, t = data 70 x, t = torch.autograd.Variable(x), torch.autograd.Variable(t) 71 optimizer.zero_grad() 72 y = model(x) 73 74 print(x.shape) 75 print(y.shape) 76 print(t.shape) 77 78 loss = criterion(y, t) 79 total_loss += loss.data[0] 80 loss.backward() 81 optimizer.step() 82 83 84 85 record_index += BATCH_SIZE
network.py
Python
1# coding: utf-8 2 3import numpy as np 4import torch 5import torch.nn as nn 6import torch.optim as optim 7import torch.nn.functional as F 8import game 9 10ch = 192 #中間層のフィルター枚数 11 12class Network(nn.Module): 13 14 def __init__(self): 15 #ニューラルネットワークを作成 <= 畳み込み(Conv2d)を使った方がいい(画像の特徴を抽出できる)と思うけれど、後回し 16 super(Network, self).__init__() 17 self.conv1 = nn.Conv2d(in_channels=1, out_channels=ch, kernel_size=(4, 4), stride=1, padding=1) 18 self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=1, padding=1) 19 self.conv2 = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=(4, 4), stride=1, padding=1) 20 self.pool2 = nn.MaxPool2d(kernel_size=[2, 2], stride=1, padding=1) 21 self.conv3 = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=(4, 4), stride=1, padding=1) 22 self.pool3 = nn.MaxPool2d(kernel_size=[2, 2], stride=1, padding=1) 23 self.fc1 = nn.Linear(ch * 12 * 12, 100) 24 self.fc2 = nn.Linear(100, 100) 25 self.fc3 = nn.Linear(100, 1) #出力は、勝ち負けの値1コ 26 27 def forward(self, x): 28 x = self.pool1(F.relu(self.conv1(x))) 29 x = self.pool2(F.relu(self.conv2(x))) 30 x = self.pool3(F.relu(self.conv3(x))) 31 x = x.view(-1, ch * 12 * 12) 32 x = F.relu(self.fc1(x)) 33 x = F.relu(self.fc2(x)) 34 x = self.fc3(x) 35 return x
環境
Python
PyTorch

回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2018/09/16 09:12
2018/09/16 12:22