fp16からtorch.flort32に変えれば下のエラーが治ると思うのですが
調べてもnumpyからtorch など違う物しかでてきません
fp16からtorch.flort32に変える方法がわかる人教えてください
もし変えても治らないなら 原因を教えてください
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-1-00ebb105f79f> in <module> 396 total_reward_vec = np.hstack((total_reward_vec[1:], episode_reward)) # 報酬を記録 397 if batch_size<len(memory.buffer*4): --> 398 memory_TDerror.update_TDerror(gamma,multireward_steps) 399 for _ in range(t): 400 trin.pioritized_experience_replay(batch_size, gamma,step=episode,state_size=state_,action_size=acthon,multireward_steps=multireward_steps) <ipython-input-1-00ebb105f79f> in update_TDerror(self, gamma, multireward_steps) 294 295 next_state=memory.buffer[i][0] --> 296 target = memory.buffer[i][2] + (gamma**multireward_steps) * targetQN.forward(next_state,"net_v")[0] 297 self.buffer[i] =target - mainQN.forward(inpp,"net_q")[0] 298 <ipython-input-1-00ebb105f79f> in forward(self, inputs, net) 188 if net=="net_v": 189 V=self.V(x)-self.V_q(x) --> 190 V=V.to("cpu")+0 191 return V 192 if net=="net_a": ~\Anaconda3\envs\pyflan\lib\site-packages\apex\amp\wrap.py in wrapper(*args, **kwargs) 51 52 if len(types) <= 1: ---> 53 return orig_fn(*args, **kwargs) 54 elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']): 55 new_args = utils.casted_args(cast_fn, RuntimeError: "add_cpu/sub_cpu" not implemented for 'Half'
回答1件
あなたの回答
tips
プレビュー