質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

7631閲覧

.zero_grad()を使う場合と使わない場合?

OOZAWA

総合スコア45

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

1グッド

0クリップ

投稿2020/01/15 02:22

PyTorchの.zero_grad()メソッドはどんな場合に使わなければならないのでしょうか

良く見るコードパターンとして:
self.optim.zero_grad()
loss.backward()
self.optim.step()

ところが、 .zero_grad() 抜きで下記の二行だけ実行する実例も見たことがあります。
loss.backward()
self.optim.step()

.zero_grad() 入りと.zero_grad() 抜きとは何が違うのでしょうか。

是非ご教授お願い致します

insecticide👍を押しています

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

Backpropするときにgradientがたまります。たまったgradientを消さないとそのまま残るから普段は学習loopの始まりで消します。消さないと、前のgradient情報も残ってgradientの方向が最小値へ向きません。
複数のバッチのためてからパラメータ更新をしたい場合は、loss.backward()を数回呼んでからoptim.step()を呼びます。そのあとはまたzero_grad()します。
あと、RNNを使うときにgradientをためることもあります。

zero_grad() を呼ばないとこういうふうにたまります

import torch w = torch.rand(5) w.requires_grad_() print(w) s = w.sum() s.backward() print(w.grad) # tensor([1., 1., 1., 1., 1.]) s.backward() print(w.grad) # tensor([2., 2., 2., 2., 2.]) s.backward() print(w.grad) # tensor([3., 3., 3., 3., 3.]) s.backward() print(w.grad) # tensor([4., 4., 4., 4., 4.])

投稿2020/01/23 22:24

kurapan

総合スコア79

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

OOZAWA

2020/01/23 23:16

本当にありがとうございます! ???????????? 凄い事を知りました!  @.backward()が(関数としての)変数@の個々自変数に対する微分を計算するのみだと思っていましたが、 『微分値の累算』をも行いますね。 どうしてでしょうか。 ご説明いただけませんでしょうか。
kurapan

2020/01/24 01:08

累算はRNNを学習するとき便利です。BPTT (backpropagation through time)とかに使えます。 毎回zero_grad()呼ぶとgradientがたまらないし、逆に累算してほしいときは何も工夫しなくてもそのまま累算されます。だから、累算したほうが便利だと判断されてそう実装されたかもしれません。
OOZAWA

2020/01/24 03:11

kurapan様 本当に有難うございました! これほど熟知される方は世の中に多くないな。。。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問