Pythonで書かれたPyTorchのコードからNumpyに変換可能な部分をNumpyに入れ替える作業をしていたところ、いくつか対応が分からない関数があったので質問をさせてください。
まず一つ目は、"x = pred[:, 4:6].amax(1) > conf" です。
predは1/38/8000ほどのTorch配列なのですが、この処理を実行するとx.shapeが(38,18~23)となります。
amax(1)はaxis=1で最大値を返す関数だと思うのですが、それだとshape[1]が変動する理由が理解できません。
もう一つは、”box, cls, mask = x.split((4, nc, nm), 1)” です。
こちらはxを(4,nc,nm)の形にaxis=1で分割する関数で正しいでしょうか。
PytorchとNumpyだと分割の仕方も違うようなので、Numpyのsplitで言うところのどの様な操作になるのか
ご存じ方が居られれば、教えていただけると幸いです。
よろしくお願いいたします。
回答1件
あなたの回答
tips
プレビュー