Pytorchで1次元のtensorがあり、その中で最も値が高いインデックスをargmaxで求めました。
ここで、tensorの中で値が一番大きい場所(ここではargmaxで求めた2番目)以外の値を0にしたいのですがどうすればいいでしょうか?
2番目の値はそのままにしたいです。
ご教授よろしくお願いします。
inp1 = torch.rand((1, 5), dtype=torch.float32) idx = torch.argmax(inp1) print(inp1, idx) tensor([[0.6886, 0.0348, 0.9323, 0.4553, 0.6799]]) tensor(2)
以下のようにしたい
tensor([[0, 0, 0.9323, 0, 0]])

回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2022/06/11 03:33