実現したいこと
Pytorchで3次元のtensorがあり、tensor中のdim=1の中で値が一番大きい場所を1に、一番大きい値以外を0にしたいのですがどうすればいいでしょうか?
ご教授よろしくお願いします。
torch.argmaxやtorch.max、スライス処理など色々考えてみたんですがわからなかったです、、、
該当のソースコード
import torch x = torch.randn(4, 3, 1) print(x) tensor([[[ 0.4082], [ 2.0627], [ 0.7252]], [[ 0.7946], [ 0.2679], [-0.4184]], [[ 0.3380], [ 0.8403], [-1.7227]], [[-1.1250], [-1.8144], [ 1.4441]]])
以下のようにしたい
print(x) tensor([[[ 0], [ 1], [ 0]], [[ 1], [ 0], [0]], [[ 0], [ 1], [0]], [[0], [0], [ 1]]])