質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

2回答

755閲覧

module importできない

chgrios

総合スコア70

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/08/03 08:54

前提・実現したいこと

はじめてPyTorchで書籍を参考にしながら、画像認識のプログラムを勉強しています。
プログラムを関数化するのがはじめてでとても苦労しています。
ほかのファイルからプログラムをimortするにあたって次のようなエラーができました。

発生している問題・エラーメッセージ

cannot import name 'train_model' from 'dataloader_image_classification' (/Users/**/**/dataloader_image_classification.py)

該当のソースコード

python

1from dataloader_image_classification import ImageTransform, make_datapath_list 2from dataloader_image_classification import HymenopteraDataset, train_model

以下がpyファイルdataloader_image_classificationの中身です。作業ファイルと同列のディレクトリに存在しています。train_model以外のモジュールは正常にimportできていました。

python

1import glob 2import os.path as osp 3import torch.utils.data as data 4from torchvision import models, transforms 5from PIL import Image 6 7 8class ImageTransform(): 9 #中略 10def make_datapath_list(phase="train"): 11 #中略 12class HymenopteraDataset(data.Dataset): 13 #中略 14def train_model(net, dataloader_dict, criterion. optimzer, num_epochs): 15 16 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 17 print(device) 18 19 net.top(device) 20 21 torch.backends.cudnn.benchmark =True 22 23 for epoch in range(num_epochs): 24 print('epoch: {}/{}'.format(epoch, num_epochs)) 25 print('-------------') 26 27 for phase in ['train', 'val']: 28 if phase == 'train': 29 net.train() 30 else: 31 net.eval() 32 33 epoch_loss = 0 34 epoch_acc = 0 35 36 if (epoch === 0 ) and (phase == 'train'): 37 continue 38 for inputs, labels in tqdm(dataloader_dict[phase]): 39 inputs = inputs.to(device) 40 labels = labels.to(device) 41 42 optimzer.zero_grad() 43 44 with torch.set_grad_enabled(phase == 'train'): 45 outputs = net(inputs) 46 loss = criterion(outputs, labels) 47 _, pred = torch.max(outputs, 1) 48 49 if phase == 'train': 50 loss.backward() 51 optimzier.step() 52 53 epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset) 54 epoch_acc = epoch_acc / len(dataloader_dict[phase].dataset) 55 56 print(phase, epoch_loss, epoch_acc)
解決したいです。よろしくお願いします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

meg_

2019/08/03 10:22 編集

「作業ファイルと同列のディレクトリ」とは「dataloader_image_classification.py」とこのファイルをインポートしているソースコードが同じディレクトリにあるということで良いでしょうか?
chgrios

2019/08/03 10:57

そうです!
meg_

2019/08/03 14:38 編集

from dataloader_image_classification import ImageTransform, make_datapath_list from dataloader_image_classification import HymenopteraDataset, train_model 上記ソースコードのファイル名は何ですか? 2つのソースコード内でインポートし合っているときにも出るエラーですが、そんなことはしていないですよね?
guest

回答2

0

python

1from dataloader_image_classification import ImageTransform, make_datapath_list

より上で

python

1import dataloader_image_classification 2print(dataloader_image_classification.__file__)

としてみてdataloader_image_classificationが自分が思っているファイルを指し示すか確認しましょう。

投稿2019/08/03 11:20

quickquip

総合スコア11038

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

0

ベストアンサー

お疲れ様です。

何かと重複していると考えるのが普通です。
仮に、
train_modelを train_modelzz に変更してみて下さい。呼ぶ側、呼ばれる側を。

で、解決するとしたら、
さて、どこと重複しているかを探す作業になるかと。

投稿2019/08/03 09:14

0kcal

総合スコア275

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

chgrios

2019/08/03 09:20

呼ぶ側呼ばれる側どちらも名前を変えて見ましたが、解決しませんでした。基本的なところなんですが、 __init__.pyファイルというものをつくって from dataloader_image_classification import * としないといけないんでしょうか?(試してもうまくいきませんでした)
0kcal

2019/08/03 13:02

上記は、関係ないと思います。 呼ばれる側が、★印など、間違えています。 まず、呼ばれる側が単独で、エラーがない状態にされるのがいいと思います。 (いまのエラーの出方は納得いきませんが、そういうエラーの出し方のあるのかもしれまません。) def train_model(net, dataloader_dict, criterion.★  optimzer, num_epochs): if (epoch ===★ 0 ) and (phase == 'train'):
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問