質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
CUDA

CUDAは並列計算プラットフォームであり、Nvidia GPU(Graphics Processing Units)向けのプログラミングモデルです。CUDAは様々なプログラミング言語、ライブラリ、APIを通してNvidiaにインターフェイスを提供します。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

Q&A

解決済

1回答

2586閲覧

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to

kuniima

総合スコア2

CUDA

CUDAは並列計算プラットフォームであり、Nvidia GPU(Graphics Processing Units)向けのプログラミングモデルです。CUDAは様々なプログラミング言語、ライブラリ、APIを通してNvidiaにインターフェイスを提供します。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

0グッド

0クリップ

投稿2021/07/30 04:42

下記の関数を実行するとタイトルのエラーが発生します。Google collabo使用中はエラーが発生しなかったのですJupitor Notebookで同一の内容を実行するとエラーが発生します。ライブラリー等のバージョンの違いからでしょうか。

何かアドバイスを頂けますと幸いでございます。

def train_model(model, criterion, optimizer, scheduler, num_epochs):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
since = time.time()

best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 #エポックごとの正解率、損失関数情報取得 acc_history_ft = {'train': [], 'valid': []} loss_history_ft= {'train': [], 'valid': []} #エポックループ for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) #それぞれのエポックで訓練データで訓練→検証データで検証 for phase in ['train', 'valid']: if phase == 'train': model.train() # モデルを訓練モードに else: model.eval() # モデルを検証モードに #各種変数の初期化 running_loss = 0.0 running_corrects = 0 #データローダーからミニバッチを読み込むループ for inputs, labels in dataloaders[phase]: inputs = inputs.to(device) labels = labels.to(device) #勾配を初期化 optimizer.zero_grad() #順伝播 #※1:訓練データの時はテンソルの勾配を求める with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) #訓練データの時に逆伝播+重みの更新を行う if phase == 'train': loss.backward() optimizer.step() #損失等を計算する running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) #学習率の更新を行う if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] #エポックごとの正解率、損失関数情報取得 acc_history_ft[phase].append(epoch_acc) loss_history_ft[phase].append(epoch_loss) print('{} Loss: {:.4f} Acc: {:.4f}'.format( phase, epoch_loss, epoch_acc)) #検証データで精度が以前より高ければモデルをdeepcopyする if phase == 'valid' and epoch_acc > best_acc: #ベストの精度を更新 best_acc = epoch_acc #モデルのコピーを保存して「best_model_wts」に格納 best_model_wts = copy.deepcopy(model.state_dict()) print() #1エポック終了毎の表示 time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) # Accuracyのグラフ plt.figure() plt.plot(range(num_epochs), acc_history_ft['train'], color='blue', linestyle='-', label='train_acc') plt.plot(range(num_epochs), acc_history_ft['valid'], color='green', linestyle='--', label='val_acc') plt.legend() plt.xlabel('epoch') plt.ylabel('acc') plt.title('Training and validation accuracy') plt.grid() # Lossのグラフ plt.figure() plt.plot(range(num_epochs), loss_history_ft['train'], color='blue', linestyle='-', label='train_loss') plt.plot(range(num_epochs), loss_history_ft['valid'], color='green', linestyle='--', label='val_loss') plt.legend() plt.xlabel('epoch') plt.ylabel('loss') plt.title('Training and validation loss') plt.grid() #最も良いモデルを読みだして返す model.load_state_dict(best_model_wts) return model

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

attakei

2021/07/30 12:46

質問文にコードやエラーを載せる場合は、コードブロック機能を利用してください。 インデントなどがカットされて、どういったコードなのかが極めてわかりづらくなっています。 特にPythonのようなインデントが文法の一部にになっているプログラミング言語での質問の場合、コードブロックが利用されてないと、コード内容の把握が非常に困難になり、結果として正しい回答が付きにくくなります。
guest

回答1

0

ベストアンサー

Tensor.cpu()を使って、最初にテンソルをホストメモリにコピーします。

投稿2021/07/30 11:11

退会済みユーザー

退会済みユーザー

総合スコア0

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問