detach() の用途
Pytorch では、勾配情報を保持しているテンソルは、演算した際に計算グラフを構築するようになっています。
例えば、次のコードを実行した場合、以下の計算グラフが作成されます。
import torch
in1 = torch.ones(2, 2, requires_grad=True)
in2 = torch.ones(2, 2, requires_grad=True)
y = in1 + in2
out = y ** 2
ここで、テンソル out
を Tensor.numpy()
で numpy 配列に変換しようとすると、エラーになります。計算グラフの一部になっているテンソルは numpy 配列に変換できない仕様のためです。
out.numpy()
# RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
そのため、detach() を使って、計算グラフから切り離されたテンソルを作成します。
すると、Tensor.numpy()
で numpy 配列に変換できるようになります。
arr = out.detach().numpy()
print(type(arr)) # <class 'numpy.ndarray'>
copy() の用途
テンソルが載っているデバイスが CPU の場合、Tensor.numpy() で numpy に変換しても、元のテンソルとデータは共有しています。
なので、元のテンソルが変更されると、Tensor.numpy() で取得した numpy 配列の中身も変わります。
python
1x = torch.ones(2, 2)
2arr = x.numpy()
3print(arr)
4# [[1. 1.]
5# [1. 1.]]
6
7x[:] = 4 # 元のテンソルを変えると、numpy 配列も変わってしまう
8print(arr)
9# [[4. 4.]
10# [4. 4.]]
それを防ぐために、ndarray.copy() でディープコピーを行い、元のテンソルとは独立した numpy 配列を作成します。
まとめると、
- Tensor.detach() でテンソルを計算グラフから切り離す
- Tensor.numpy() で numpy 配列に変換する
- ndarray.copy() で numpy 配列をディープコピーする
Pytorch の情報元について
一覧や まとめサイトなど知ってる人がいたら教えてください
関数の使い方は公式ドキュメント、個別の QA は Google 検索すれば PyTorch Forums または Stack Overflow でほぼ答えが見つかります。
例えば、今回の detach() はどういうときに使うの?という疑問も
「pytorch how detach() work?」と検索すれば、検索結果の1ページ目に答えが出てきます。
Pytorch は英語圏のライブラリなので、日本語情報は少ないです。
なので、最初から英語で検索したほうがほしい情報がすぐ手に入ります。
それに使える関数があるのに 知らずにどうしようかと悩んでしまう
暇なときに torch モジュール以下の関数一覧を見て、引き出しを増やすということをやっておけばいいと思います。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。