例えば、C クラスの画像分類モデルを考えると、サンプル1つを推論した場合、形状が (1, C) のテンソルが出力として得られますが、これを squeeze() して 形状 (C,) のテンソルにするとかでしょうか。
unsqueeze() というサイズ1の軸を追加する関数があるので、それとは逆に、サイズ1の軸を削除する関数があるというのは API 設計としては自然だと思います。
そうしてこんな事をしなければならいのでしょうか。
しなければならないというわけでないです。
コードを書いている際にサイズが1の次元を削除したい場合がなんらかの理由が出てきたら、squeeze() があることを思い出して使えばいいというだけです。
先程の1枚だけ推論する例の場合、別に squeeze() を使わなくても書けるのでどっちでもいいです。
squeeze() を使う場合
python
1pred = model(x) # 推論する。 (1, 1000) (バッチサイズ, クラス数) のテンソル
2pred = pred.squeeze() # (1000, ) (クラス数,) のテンソル バッチサイズの次元を削除
3print("予測したクラス", pred.argmax())
squeeze() を使わない場合
python
1pred = model(x) # 推論する。 (1, 1000)
2print("予測したクラス", pred[0].argmax())
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2020/01/12 02:47
2020/01/12 06:28 編集
2020/01/12 13:17
2020/01/12 14:12
2020/01/12 21:17 編集
2020/01/14 03:54 編集
2020/01/14 08:13