前提・実現したいこと
学習済みモデルでの未知データに対しての推論を行い、
データを格納する処理を行いたい。
inputs = inputs.to(DEVICE) で
AttributeError: 'tuple' object has no attribute 'to'が発生...
どなたか詳しいかた抜け道を教えて下さい。
発生している問題・エラーメッセージ
AttributeError: 'tuple' object has no attribute 'to'
該当のソースコード
python
1DEVICE='cpu' 2pred = [] 3 4# データの取り出し 5for i,(inputs, labels) in enumerate(test_dataloader): 6 7 inputs = inputs.to(DEVICE) <-----ここでerror発生 8 9 # 学習済みモデルを推論モードに設定 10 loaded_model.eval() 11 12 # 学習済みモデルにデータをインプットし、推論をさせる 13 outputs = loaded_model(inputs) 14 15 # アウトプットから推定されたラベルを取得 16 _, preds = torch.max(outputs, 1) 17 18 # 事前に用意したリストに推論結果(0 or 1)を格納 19 pred.append(preds.item()) 20 21df_test['pred'] = pred
test_dataloaderのイテレーション要素であるinputsがtensorになっておらずtupleになっている、ということですね。原因追求には、test_dataloaderを定義しているコード、さらにDatasetを定義しているコードまで遡ることが必要です。その部分を提示願います。
少し前に回答のコメントでinputsを表示していただいていますが、やはりtensorになっていないですね。
リスト(タプル、tensor)という変則な形になっています。Datasetのコードが怪しいです。
下記がDataSet作成とDataloader部分のコードになります。
何かおかしな箇所はありますでしょうか?
-----------------------------------------------------------------------------------------------------------
# dataset作成
class Test_Datasets(Dataset):
def __init__(self, data_transform):
#csvファイルの指定カラムを読み取ってDataFrameに変換
#self.df = pd.read_csv('./df_test.csv',names=['filename'])
self.df = pd.read_csv('./df_test.csv')
self.data_transform = data_transform
def __len__(self):
return len(self.df)
def __getitem__(self, index):
#file = self.df['filename'][index]
#↓ここの処理が分からない
file = self.df['filename'][index]
image = Image.open('./test_data/' + file)
image = self.data_transform(image)
return file,image
# datasetのインスタンス作成
#推論用データのアノテーションを行っている。
test_dataset = Test_Datasets(data_transform=test_transforms)
# dataloader作成
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size= 1 ,
shuffle= False ,
num_workers= 0 ,
drop_last= True )
def __getitem__(self, index):
の
return file,image
を
return image
にしてください。ファイル名は学習データには不要ですよね。このファイル名が邪魔で、tensorが崩れています。
また、もとソースコードの
for i,(inputs, labels) in enumerate(test_dataloader):
は
for i, inputs in enumerate(test_dataloader):
にしましょう。未知データなのでDataset側にlabelが無いため、こっちもlabel無しで受ける必要があります。
上記で修正されることを動作確認しましたので、回答として記載します。
回答2件
あなたの回答
tips
プレビュー