前提
現在理解を深めるためにpytorchのチュートリアルを基に自分でDQNの実装を行っているのですが
env.render()を実行した際にエラーが発生していました。
どなたか教えていただきたいです。
よろしくお願いいたします。
注 プログラムは余計な部分を省いています。
必要がありましたら追記いたします。
実現したいこと
ここに実現したいことを箇条書きで書いてください。
DQNに画像を入力させたいのでenv.render()を動作させたい。
発生している問題・エラーメッセージ
エラーメッセージ File "C:\Users\zxyv_\reinforcement Learning\dqn_original.py", line 109, in <module> init_screen=get_screen(env) File "C:\Users\zxyv_\reinforcement Learning\dqn_original.py", line 57, in get_screen screen=env.render().transpose((2,0,1)) File "C:\Users\zxyv_\anaconda3\lib\site-packages\gym\envs\classic_control\cartpole.py", line 179, in render from gym.envs.classic_control import rendering File "C:\Users\zxyv_\anaconda3\lib\site-packages\gym\envs\classic_control\rendering.py", line 27, in <module> from pyglet.gl import * File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\gl\__init__.py", line 236, in <module> import pyglet.window File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\window\__init__.py", line 1816, in <module> gl._create_shadow_window() File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\gl\__init__.py", line 205, in _create_shadow_window _shadow_window = Window(width=1, height=1, visible=False) File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\window\win32\__init__.py", line 131, in __init__ super(Win32Window, self).__init__(*args, **kwargs) File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\window\__init__.py", line 493, in __init__ display = get_platform().get_default_display() File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\window\__init__.py", line 1765, in get_default_display return pyglet.canvas.get_display() File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\canvas\__init__.py", line 77, in get_display from pyglet.app import displays File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\app\__init__.py", line 177, in <module> event_loop = EventLoop() File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\app\base.py", line 116, in __init__ self.clock = clock.get_default() File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\__init__.py", line 357, in __getattr__ __import__(import_name) File "C:\Users\zxyv_\anaconda3\lib\site-packages\pyglet\clock.py", line 165, in <module> _default_time_function = time.clock AttributeError: module 'time' has no attribute 'clock'
該当のソースコード
Python
1ソースコード 2from configparser import Interpolation 3from pickletools import optimize 4from statistics import mean 5import gym 6import math 7import random 8import numpy as np 9import matplotlib 10import matplotlib.pyplot as plt 11from collections import namedtuple, deque 12from itertools import count 13from PIL import Image 14 15import torch 16import torch.nn as nn 17import torch.optim as optim 18import torch.nn.functional as F 19from torchvision import datasets,transforms#データの前処理に必要なモジュール 20import torchvision.transforms as T 21 22env = gym.make('CartPole-v0').unwrapped 23 24 25# if gpu is to be used 26device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27#環境内の単一の遷移を表す名前付きタプル 28Transition=namedtuple('Transition',('state','action','nextstate','reward')) 29#経験を保存するためのバッファ。バッチを取得するためのサンプルメソッドも実装 30class ReplayMemory(object): 31 #capacityはバッファの容量 32 def __init__(self,capacity): 33 self.memory=deque([],maxlen=capacity) 34 def push(self,*args): 35 #Save a taransition 36 self.memory.append(Transition(*args)) 37 def sample(self,batch_size): 38 return random.sample(self.memory,batch_size) 39 def __lem__(self): 40 return len(self.memory) 41 42def get_cart_location(screen_width): 43 world_width=env.x_threshold*2 44 scale=screen_width/world_width 45 return int(env.state[0]*scale+screen_width/2.0)#Middle of cart 46 47def get_screen(env): 48 # ジムからリクエストされた返却画面は 400x600x3 ですが、 49 # 800x1200x3 のように大きい場合もあります。それをトーチ オーダー (HWCからCHW) に置き換えます。 50 screen=env.render().transpose((2,0,1)) 51 # カートは下半分にあるので、画面の上下をはがします 52 _,screen_height,screen_width=screen.shape 53 screen=screen[:,int(screen_height*0.4):int(screen_height*0.8)] 54 view_width=int(screen_width*0.6) 55 cart_location=get_cart_location(screen_width) 56 if cart_location<view_width//2: 57 slice_range=slice(view_width) 58 elif cart_location>(screen_width-view_width//2): 59 slice_range=slice(-view_width,None) 60 else: 61 slice_range=slice(cart_location-view_width//2,cart_location+view_width//2) 62 63 # カートを中心とした正方形の画像になるように、端を取り除きます 64 screen=screen[:, :,slice_range] 65 # float への変換、再スケーリング、torch tensor への変換 66 screen=np.ascontiguousarray(screen,dtype=np.float32)/255 67 screen=torch.from_numpy(screen) 68 # サイズを変更し、バッチ ディメンションを追加します (BCHW) 69 return resize(screen).unsqueeze(0) 70 71 72class DQN(nn.Module): 73 74 def __init__(self,h,w,outputs): 75 super(DQN,self).__init__() 76 self.conv1=nn.Conv2d(3,16,kernel_size=5,stride=2) 77 self.bn1=nn.BatchNorm2d(16) 78 self.conv2=nn.Conv2d(16,32,kernel_size=5,stride=2) 79 self.bn2=nn.BatchNorm2d(32) 80 self.conv3=nn.Conv2d(32,32,kernel_size=5,stride=2) 81 self.bn3=nn.BatchNorm2d(32) 82 83 #畳み込み層からの出力を全結合層に入力するためのサイズを計算している 84 def conv2d_size_out(size,kernel_size=5,stride=2): 85 return (size-(kernel_size-1)-1)//stride +1 86 convw=conv2d_size_out(conv2d_size_out(conv2d_size_out(w))) 87 convh=conv2d_size_out(conv2d_size_out(conv2d_size_out(h))) 88 linear_input_size=convw*convh*32 89 self.head=nn.Linear(linear_input_size,outputs) 90 91 def forward(self,x): 92 x=x.to(device) 93 x=F.relu(self.bn1(self.conv1(x))) 94 x=F.relu(self.bn2(self.conv2(x))) 95 x=F.relu(self.bn3(self.conv3(x))) 96 return self.head(x.view(x.size(0),-1))#わからない 97 98 99memory=ReplayMemory(10000) 100 101steps_done=0 102init_screen=get_screen(env) 103_,_,screen_height,screen_width=init_screen.shape 104
試したこと
ここに問題に対して試したことを記載してください。
env = gym.make('CartPole-v0')をenv = gym.make('CartPole-v0').unwrappedに変更してみたり
自分で調べたりしたのですが、わからず質問させていただきました。

回答2件
あなたの回答
tips
プレビュー