質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

9060閲覧

Pytorchでメモリを異常に消費する原因

wildgeece96

総合スコア8

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2018/08/20 15:22

編集2018/08/20 15:34

このサイトで質問するのが適切か判断しかねますが詰まってしまったので。

今、Pytorchを使ってデータセットの事前学習をさせようという段階です。

そのとき、メモリが消去されずに残っているのか30GBあるメモリが1バッチ回すだけでメモリが満杯になり次のバッチが何故か回せないという状況になっています。

原因となりそうな部分があったら教えていただけると嬉しいです。

いか、一部ですがコードになります。

generatorがモデルで今回はU-Netが内包されています。
dataloaderの部分
事前学習の部分です。入力された画像(256×256に変形済み)とそれをグレースケールに変換したもへと学習させています。

python

1for _, (images, img_masks) in enumerate(pre_batch): 2 gen_optimizer.zero_grad() 3 4 x = Variable(images) 5 y = generator.forward(x).cuda(0) 6 img_arr = x.data.cpu().numpy() 7 img_arr = img_arr.transpose(0, 2, 3, 1) 8 img_gray = np.zeros((img_arr.shape[0], img_size, img_size)) 9 for k in range(img_arr.shape[0]): 10 temp = cv2.cvtColor(img_arr[k], cv2.COLOR_BGR2GRAY) 11 img_gray[k] = temp.copy() 12 print(k) 13 del temp 14 gc.collect() 15 img_gray += img_gray.min(axis=(1,2), keepdims=True) 16 img_gray /= img_gray.max(axis=(1,2), keepdims=True) - img_gray.min(axis=(1,2), keepdims=True) 17 img_gray = img_gray.reshape(-1, 1, img_size, img_size) 18 img_tensor = torch.from_numpy(img_gray) 19 y_ = Variable(img_tensor.float()).cuda(0) 20 21 loss = recon_loss_func(y, y_) 22 loss.backward() 23 gen_optimizer.step() 24 if _ % 100 == 0: 25 print(_ , "\ttime loss:\t", loss.data[0]) 26 torch.save(generator.state_dict(),"pre_trained_model.pkl") 27 if _ >= 1000: 28 break

エラーメッセージを追記します。
出力

0 1 2 3 4 5 6 7 0 time loss: 0.8070976138114929 (エラーメッセージ) RuntimeError: $ Torch: not enough memory: you tried to allocate 0GB. Buy new RAM! at /opt/conda/conda-bld/pytorch_1512387374934/work/torch/lib/TH/THGeneral.c:246

ちなみになのですが、以下のコードは同じ環境で実行しても問題なく学習が進んでいます。

python

1for _, (images,img_masks) in enumerate(train_batch): 2 gen_optimizer.zero_grad() 3 4 x = Variable(images).cuda(0) 5 y_ = Variable(img_masks.mean(dim=1,keepdim=True)).cuda(0) 6 y = generator.forward(x).mean(dim=1,keepdim=True).cuda(0) 7 8 loss = recon_loss_func(y,y_) # nn.MSELoss()です。 9 file.write(str(loss.data[0])+"\n") 10 loss_record.append(loss.data[0]) 11 loss.backward() 12 gen_optimizer.step()

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

純粋にバッチサイズの問題だったようです。

おさわがせしました。

投稿2018/08/20 15:37

wildgeece96

総合スコア8

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問