前提・実現したいこと
はじめてPyTorchで書籍を参考にしながら、画像認識のプログラムを勉強しています。
プログラムを関数化するのがはじめてでとても苦労しています。
ほかのファイルからプログラムをimortするにあたって次のようなエラーができました。
発生している問題・エラーメッセージ
cannot import name 'train_model' from 'dataloader_image_classification' (/Users/**/**/dataloader_image_classification.py)
該当のソースコード
python
1from dataloader_image_classification import ImageTransform, make_datapath_list 2from dataloader_image_classification import HymenopteraDataset, train_model
以下がpyファイルdataloader_image_classificationの中身です。作業ファイルと同列のディレクトリに存在しています。train_model以外のモジュールは正常にimportできていました。
python
1import glob 2import os.path as osp 3import torch.utils.data as data 4from torchvision import models, transforms 5from PIL import Image 6 7 8class ImageTransform(): 9 #中略 10def make_datapath_list(phase="train"): 11 #中略 12class HymenopteraDataset(data.Dataset): 13 #中略 14def train_model(net, dataloader_dict, criterion. optimzer, num_epochs): 15 16 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 17 print(device) 18 19 net.top(device) 20 21 torch.backends.cudnn.benchmark =True 22 23 for epoch in range(num_epochs): 24 print('epoch: {}/{}'.format(epoch, num_epochs)) 25 print('-------------') 26 27 for phase in ['train', 'val']: 28 if phase == 'train': 29 net.train() 30 else: 31 net.eval() 32 33 epoch_loss = 0 34 epoch_acc = 0 35 36 if (epoch === 0 ) and (phase == 'train'): 37 continue 38 for inputs, labels in tqdm(dataloader_dict[phase]): 39 inputs = inputs.to(device) 40 labels = labels.to(device) 41 42 optimzer.zero_grad() 43 44 with torch.set_grad_enabled(phase == 'train'): 45 outputs = net(inputs) 46 loss = criterion(outputs, labels) 47 _, pred = torch.max(outputs, 1) 48 49 if phase == 'train': 50 loss.backward() 51 optimzier.step() 52 53 epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset) 54 epoch_acc = epoch_acc / len(dataloader_dict[phase].dataset) 55 56 print(phase, epoch_loss, epoch_acc)
解決したいです。よろしくお願いします。
「作業ファイルと同列のディレクトリ」とは「dataloader_image_classification.py」とこのファイルをインポートしているソースコードが同じディレクトリにあるということで良いでしょうか?
そうです!
from dataloader_image_classification import ImageTransform, make_datapath_list
from dataloader_image_classification import HymenopteraDataset, train_model
上記ソースコードのファイル名は何ですか?
2つのソースコード内でインポートし合っているときにも出るエラーですが、そんなことはしていないですよね?
回答2件
あなたの回答
tips
プレビュー