前提・実現したいこと
python module Rayを用いた関数の並列化処理を行いたいのですが、”TypeError: can't pickle _thread.lock objects”とエラーが出てしまいます。
エラーの原因はメソッドが直列化できないからであることはわかりました。
インスタンス変数を総て外に出して引数にしたクラスメソッドを作成し、そのクラスメソッドを並列化処理すればうまくいくことはわかったのですが、プログラムが非常に膨大であるため、関数構造を変えられません。
また、cloudpickleやpickleを使うとうまくいく?ような記述をwebで見つけたのですが、そもそもpickleがよくわからないため、うまくコードに起こせない状態です。
multiprocessingモジュールを使えば動いたので、並列処理できないプログラムでないことは確かなのですが、、、
もし何か解決策等あれば、ご教授よろしくおねがいします。
参考にしたwebページ
https://docs.ray.io/en/releases-0.8.5/serialization.html
https://docs.ray.io/en/master/serialization.html
発生している問題・エラーメッセージ
TypeError: can't pickle _thread.lock objects --------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-13-9de0171f325e> in <module> 41 42 ---> 43 results=[output.remote(EbNodB) for EbNodB in range(-3,5)] 44 print(ray.get(results)) 45 <ipython-input-13-9de0171f325e> in <listcomp>(.0) 41 42 ---> 43 results=[output.remote(EbNodB) for EbNodB in range(-3,5)] 44 print(ray.get(results)) 45 ~/.pyenv/versions/3.7.10/lib/python3.7/site-packages/ray/remote_function.py in _remote_proxy(*args, **kwargs) 97 @wraps(function) 98 def _remote_proxy(*args, **kwargs): ---> 99 return self._remote(args=args, kwargs=kwargs) 100 101 self.remote = _remote_proxy ~/.pyenv/versions/3.7.10/lib/python3.7/site-packages/ray/remote_function.py in _remote(self, args, kwargs, num_returns, num_cpus, num_gpus, memory, object_store_memory, accelerator_type, resources, max_retries, placement_group, placement_group_bundle_index, name) 175 # first driver. This is an argument for repickling the function, 176 # which we do here. --> 177 self._pickled_function = pickle.dumps(self._function) 178 179 self._function_descriptor = PythonFunctionDescriptor.from_function( ~/.pyenv/versions/3.7.10/lib/python3.7/site-packages/ray/cloudpickle/cloudpickle_fast.py in dumps(obj, protocol, buffer_callback) 68 with io.BytesIO() as file: 69 cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback) ---> 70 cp.dump(obj) 71 return file.getvalue() 72 ~/.pyenv/versions/3.7.10/lib/python3.7/site-packages/ray/cloudpickle/cloudpickle_fast.py in dump(self, obj) 654 def dump(self, obj): 655 try: --> 656 return Pickler.dump(self, obj) 657 except RuntimeError as e: 658 if "recursion" in e.args[0]: TypeError: can't pickle _thread.lock objects
該当のソースコード
python
1#必要なライブラリ、定数 2import sys 3sys.path.append("../channel") 4from AWGN import _AWGN 5import numpy as np 6import random 7import time 8import math 9from decimal import * 10import ray 11random.seed(time.time()) 12 13ch=_AWGN() 14 15class coding(): 16 def __init__(self): 17 #super().__init__() 18 19 self.N=1024 20 self.R=0.5 21 self.K=math.floor(self.R*self.N) 22 self.design_SNR=4 23 24 #prepere constants 25 tmp2=np.log2(self.N) 26 self.itr_num=tmp2.astype(int) 27 self.frozen_bits,self.info_bits=\ 28 self.Bhattacharyya_bounds() 29 30 self.Gres=self.make_H() 31 32 self.filename="polar_code_{}_{}".format(self.N,self.K) 33 34class coding(coding): 35 #frozen_bitの選択 36 def Bhattacharyya_bounds(self): 37 E=np.zeros(1,dtype=np.float128) 38 E =Decimal('10') ** (Decimal(str(self.design_SNR)) / Decimal('10')) 39 itr_num=np.log2(N) 40 itr_num=itr_num.astype(int) 41 z=np.zeros(self.N,dtype=np.float128) 42 43 #10^10かけて計算する 44 45 z[0]=math.exp(Decimal('-1')*Decimal(str(E))) 46 47 #print("E=",np.exp(-E)) 48 49 for j in range(1,itr_num+1): 50 tmp=2**(j)//2 51 52 for t in range(tmp): 53 T=z[t] 54 z[t]=Decimal('2')*Decimal(str(T))-Decimal(str(T))**Decimal('2') 55 z[tmp+t]=Decimal(str(T))**Decimal('2') 56 #print(z) 57 #np.savetxt("z",z) 58 tmp=self.indices_of_elements(z,N) 59 frozen_bits=tmp[:self.N-self.K] 60 info_bits=tmp[self.N-self.K:] 61 return np.sort(frozen_bits),np.sort(info_bits) 62 63 @staticmethod 64 def indices_of_elements(v,l): 65 tmp=np.argsort(v)[::-1] 66 #print(tmp) 67 res=tmp[0:l] 68 return res 69 70class coding(coding): 71 @staticmethod 72 def tensordot(A): 73 tmp0=np.zeros((A.shape[0],A.shape[1]),dtype=np.int) 74 tmp1=np.append(A,tmp0,axis=1) 75 #print(tmp1) 76 tmp2=np.append(A,A,axis=1) 77 #print(tmp2) 78 tmp3=np.append(tmp1,tmp2,axis=0) 79 #print(tmp3) 80 return tmp3 81 82 def make_H(self): 83 G2=np.array([[1,0],[1,1]],dtype=np.int) 84 Gres=G2 85 for _ in range(self.itr_num-1): 86 #print(i) 87 Gres=self.tensordot(Gres) 88 return Gres 89 90class encoding(coding): 91 #def __init__(self,N): 92 #super().__init__(N) 93 94 def generate_information(self): 95 #generate information 96 information=np.random.randint(0,2,self.K) 97 return information 98 99class encoding(encoding): 100 def generate_U(self,information): 101 u_message=np.zeros(self.N) 102 u_message[self.info_bits]=information 103 return u_message 104 105lass encoding(encoding): 106 def polar_encode(self): 107 information=self.generate_information() 108 u_message=self.generate_U(information) 109 codeword=(u_message@self.Gres)%2 110 return information,codeword 111 112class decoding(coding): 113 n=0 114 EST_information=np.array([]) 115 116 #def __init__(self,N): 117 #super().__init__(N) 118 119 @staticmethod 120 def chk(llr_1,llr_2): 121 CHECK_NODE_TANH_THRES=30 122 res=np.zeros(len(llr_1)) 123 for i in range(len(res)): 124 125 if abs(llr_1[i]) > CHECK_NODE_TANH_THRES and abs(llr_2[i]) > CHECK_NODE_TANH_THRES: 126 if llr_1[i] * llr_2[i] > 0: 127 # If both LLRs are of one sign, we return the minimum of their absolute values. 128 res[i]=min(abs(llr_1[i]), abs(llr_2[i])) 129 else: 130 # Otherwise, we return an opposite to the minimum of their absolute values. 131 res[i]=-1 * min(abs(llr_1[i]), abs(llr_2[i])) 132 else: 133 res[i]= 2 * np.arctanh(np.tanh(llr_1[i] / 2, ) * np.tanh(llr_2[i] / 2)) 134 return res 135 136 def SC_decoding(self,a): 137 #interior node operation 138 if a.shape[0]==1: 139 #frozen_bit or not 140 if np.any(self.frozen_bits==decoding.n): 141 tmp0=np.zeros(1) 142 elif a>=0: 143 tmp0=np.zeros(1) 144 elif a<0: 145 tmp0=np.ones(1) 146 else: 147 print("err!") 148 exit() 149 150 if np.any(self.info_bits==decoding.n): 151 decoding.EST_information=np.append(decoding.EST_information,a) 152 #print(decoding.n) 153 #print(t) 154 decoding.n+=1 155 #if t>=N: 156 #exit() 157 return tmp0 158 159 #step1 left input a output u1_hat 160 161 tmp1=np.split(a,2) 162 f_half_a=self.chk(tmp1[0],tmp1[1]) 163 u1=self.SC_decoding(f_half_a) 164 165 #step2 right input a,u1_hat output u2_hat 166 tmp2=np.split(a,2) 167 g_half_a=tmp2[1]+(1-2*u1)*tmp2[0] 168 u2=self.SC_decoding(g_half_a) 169 170 #step3 up input u1,u2 output a_hat 171 res=np.concatenate([(u1+u2)%2,u2]) 172 return res 173 174class decoding(decoding): 175 def polar_decode(self,Lc): 176 #initialize class variable 177 decoding.n=0 178 decoding.EST_information=np.array([]) 179 self.SC_decoding(Lc) 180 res=decoding.EST_information 181 res=-1*np.sign(res) 182 EST_information=(res+1)/2 183 184 return EST_information 185 186class polar_code(encoding,decoding): 187 #def __init__(self,N): 188 #super().__init__(N) 189 190 def polar_code(self,EbNodB): 191 192 193 information,codeword=self.polar_encode() 194 195 Lc=-1*ch.generate_LLR(codeword,EbNodB)#デコーダが+、ー逆になってしまうので-1をかける 196 197 EST_information=self.polar_decode(Lc) 198 199 return information,EST_information 200 201if __name__=="__main__": 202 ray.init() 203 204 #N=512 205 #pc=polar_code() 206 207 #print(len(pc.info_bits)) 208 #information,EST_information=pc.polar_code(100) 209 #print(len(information)) 210 #print(len(EST_information)) 211 #print(np.sum(information!=EST_information)) 212 213 @ray.remote 214 def output(EbNodB): 215 count_err=0 216 count_all=0 217 count_berr=0 218 count_ball=0 219 MAX_ERR=8 220 221 while count_err<MAX_ERR: 222 223 pc=polar_code() 224 information,EST_information=pc.polar_code(EbNodB) 225 226 if np.any(information!=EST_information):#BLOCK error check 227 count_err+=1 228 229 count_all+=1 230 231 #calculate bit error rate 232 count_berr+=np.sum(information!=EST_information) 233 count_ball+=N 234 235 #print("\r","count_all=",count_all,",count_err=",count_err,"count_ball="\ 236 #,count_ball,"count_berr=",count_berr,end="") 237 238 #print("\n") 239 #print("BER=",count_berr/count_ball) 240 return count_err,count_all,count_berr,count_all 241 242 243 results=[output.remote(EbNodB) for EbNodB in range(-3,5)] 244 print(ray.get(results))
試したこと
- クラス内関数でなく、普通の関数を用いたらうまく行った。
- 引用するクラスに@ray.remoteデコレータをつけて動かすと、エラーは吐かなくなったが、一向に結果プログラムが進まなかった。(並列化がクラスの方と当関数の方の二重で行われてしまっていることが原因と考察している)
- この関数'output'をクラスに変更し、引用したいクラスを継承させても、同じエラーを吐いてしまう。
- multiprocessing module を用いたところ、うまく動いた
補足情報(FW/ツールのバージョンなど)
python 3.7.10
ray 0.9.0
VScode のjupyter notebookにて記述
ここにより詳細な情報を記載してください。
あなたの回答
tips
プレビュー