質問編集履歴

4

コードをすべて載せた

2021/08/19 09:41

投稿

kaneko_
kaneko_

スコア10

test CHANGED
File without changes
test CHANGED
@@ -4,13 +4,25 @@
4
4
 
5
5
  python module Rayを用いた関数の並列化処理を行いたいのですが、”TypeError: can't pickle _thread.lock objects”とエラーが出てしまいます。
6
6
 
7
+ エラーの原因はメソッドが直列化できないからであることはわかりました。
8
+
7
- エラーの原因は関内で別クラスのクラス内関数引用してるからということはわかのですが、なぜクス内関数を引用するとエラーが出るのかがわかりません。
9
+ インスタンス変を総て外に出して引数にしたクラスメソッドを作成し、そのクラスメソッド並列化処理すればうまくことはわかったのですが、プログムが非常に膨大であるため、関数構造変えられません。
8
-
10
+
9
- みに試しに(クラス内関数でな)単る関数を引用してコードを書いたところ、うまく動作しました
11
+ また、cloudpickleやpickleを使うとうまくいく?よう記述をwebで見つけたのですがそもそもpickleがよわからないた、うまくコードに起こせない状態です
10
-
12
+
11
- また、multiprocessing moduleて記述しところうまくプログラムが動きました。
13
+ multiprocessingモジュール使えば動いたので並列処理できないプログラムでないことは確かなのです、、、
12
-
14
+
15
+
16
+
13
- どこをどのように直せばよいのかわからないため質問いたしました。よろしくおねがいします。
17
+ もし何か解決策等あれ、ご教授よろしくおねがいします。
18
+
19
+
20
+
21
+ 参考にしたwebページ
22
+
23
+ https://docs.ray.io/en/releases-0.8.5/serialization.html
24
+
25
+ https://docs.ray.io/en/master/serialization.html
14
26
 
15
27
 
16
28
 
@@ -20,7 +32,97 @@
20
32
 
21
33
  ```
22
34
 
23
- TypeError: can't pickle _thread.lock objects
35
+ TypeError: can't pickle _thread.lock objects
36
+
37
+ ---------------------------------------------------------------------------
38
+
39
+ TypeError Traceback (most recent call last)
40
+
41
+ <ipython-input-13-9de0171f325e> in <module>
42
+
43
+ 41
44
+
45
+ 42
46
+
47
+ ---> 43 results=[output.remote(EbNodB) for EbNodB in range(-3,5)]
48
+
49
+ 44 print(ray.get(results))
50
+
51
+ 45
52
+
53
+
54
+
55
+ <ipython-input-13-9de0171f325e> in <listcomp>(.0)
56
+
57
+ 41
58
+
59
+ 42
60
+
61
+ ---> 43 results=[output.remote(EbNodB) for EbNodB in range(-3,5)]
62
+
63
+ 44 print(ray.get(results))
64
+
65
+ 45
66
+
67
+
68
+
69
+ ~/.pyenv/versions/3.7.10/lib/python3.7/site-packages/ray/remote_function.py in _remote_proxy(*args, **kwargs)
70
+
71
+ 97 @wraps(function)
72
+
73
+ 98 def _remote_proxy(*args, **kwargs):
74
+
75
+ ---> 99 return self._remote(args=args, kwargs=kwargs)
76
+
77
+ 100
78
+
79
+ 101 self.remote = _remote_proxy
80
+
81
+
82
+
83
+ ~/.pyenv/versions/3.7.10/lib/python3.7/site-packages/ray/remote_function.py in _remote(self, args, kwargs, num_returns, num_cpus, num_gpus, memory, object_store_memory, accelerator_type, resources, max_retries, placement_group, placement_group_bundle_index, name)
84
+
85
+ 175 # first driver. This is an argument for repickling the function,
86
+
87
+ 176 # which we do here.
88
+
89
+ --> 177 self._pickled_function = pickle.dumps(self._function)
90
+
91
+ 178
92
+
93
+ 179 self._function_descriptor = PythonFunctionDescriptor.from_function(
94
+
95
+
96
+
97
+ ~/.pyenv/versions/3.7.10/lib/python3.7/site-packages/ray/cloudpickle/cloudpickle_fast.py in dumps(obj, protocol, buffer_callback)
98
+
99
+ 68 with io.BytesIO() as file:
100
+
101
+ 69 cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
102
+
103
+ ---> 70 cp.dump(obj)
104
+
105
+ 71 return file.getvalue()
106
+
107
+ 72
108
+
109
+
110
+
111
+ ~/.pyenv/versions/3.7.10/lib/python3.7/site-packages/ray/cloudpickle/cloudpickle_fast.py in dump(self, obj)
112
+
113
+ 654 def dump(self, obj):
114
+
115
+ 655 try:
116
+
117
+ --> 656 return Pickler.dump(self, obj)
118
+
119
+ 657 except RuntimeError as e:
120
+
121
+ 658 if "recursion" in e.args[0]:
122
+
123
+
124
+
125
+ TypeError: can't pickle _thread.lock objects
24
126
 
25
127
  ```
26
128
 
@@ -32,93 +134,493 @@
32
134
 
33
135
  ```python
34
136
 
35
- @ray.remote
36
-
37
- def output(main_func,EbNodB):#main_funcはクラス内関数です。
38
-
39
- '''
40
-
41
- あるSNRで計算結果を出力する関数を作成
42
-
43
- main_func must input 'EbNodB' and output 1D 'codeword' and 'EST_codeword'
44
-
45
- '''
46
-
47
-
48
-
49
- #seed値の設定
50
-
51
- np.random.seed()
52
-
53
-
54
-
55
- #prepare some constants
56
-
57
- MAX_ERR=10
58
-
59
- count_bitall=0
60
-
61
- count_biterr=0
62
-
63
- count_all=0
64
-
65
- count_err=0
66
-
67
-
68
-
69
-
70
-
71
- while count_err<MAX_ERR:
72
-
73
- #print("ok")
74
-
75
- K=100
76
-
77
- information,EST_information=main_func(EbNodB) #ここでクラス内関数を用いています。
78
-
79
- #calculate block error rate
80
-
81
- if np.any(information!=EST_information):
82
-
83
- count_err+=1
84
-
85
- count_all+=1
86
-
87
-
88
-
89
- #calculate bit error rate
90
-
91
- count_biterr+=np.sum(information!=EST_information)
92
-
93
- count_bitall+=len(information)
94
-
95
-
137
+ #必要なライブラリ、定数
138
+
139
+ import sys
140
+
141
+ sys.path.append("../channel")
142
+
143
+ from AWGN import _AWGN
144
+
145
+ import numpy as np
146
+
147
+ import random
148
+
149
+ import time
150
+
151
+ import math
152
+
153
+ from decimal import *
154
+
155
+ import ray
156
+
157
+ random.seed(time.time())
158
+
159
+
160
+
161
+ ch=_AWGN()
162
+
163
+
164
+
165
+ class coding():
166
+
167
+ def __init__(self):
168
+
169
+ #super().__init__()
170
+
171
+
172
+
173
+ self.N=1024
174
+
175
+ self.R=0.5
176
+
177
+ self.K=math.floor(self.R*self.N)
178
+
179
+ self.design_SNR=4
180
+
181
+
182
+
183
+ #prepere constants
184
+
185
+ tmp2=np.log2(self.N)
186
+
187
+ self.itr_num=tmp2.astype(int)
188
+
189
+ self.frozen_bits,self.info_bits=\
190
+
191
+ self.Bhattacharyya_bounds()
192
+
193
+
194
+
195
+ self.Gres=self.make_H()
196
+
197
+
198
+
199
+ self.filename="polar_code_{}_{}".format(self.N,self.K)
200
+
201
+
202
+
203
+ class coding(coding):
204
+
205
+ #frozen_bitの選択
206
+
207
+ def Bhattacharyya_bounds(self):
208
+
209
+ E=np.zeros(1,dtype=np.float128)
210
+
211
+ E =Decimal('10') ** (Decimal(str(self.design_SNR)) / Decimal('10'))
212
+
213
+ itr_num=np.log2(N)
214
+
215
+ itr_num=itr_num.astype(int)
216
+
217
+ z=np.zeros(self.N,dtype=np.float128)
218
+
219
+
220
+
221
+ #10^10かけて計算する
222
+
223
+
224
+
225
+ z[0]=math.exp(Decimal('-1')*Decimal(str(E)))
226
+
227
+
228
+
229
+ #print("E=",np.exp(-E))
230
+
231
+
232
+
233
+ for j in range(1,itr_num+1):
234
+
235
+ tmp=2**(j)//2
236
+
237
+
238
+
239
+ for t in range(tmp):
240
+
241
+ T=z[t]
242
+
243
+ z[t]=Decimal('2')*Decimal(str(T))-Decimal(str(T))**Decimal('2')
244
+
245
+ z[tmp+t]=Decimal(str(T))**Decimal('2')
246
+
247
+ #print(z)
248
+
249
+ #np.savetxt("z",z)
250
+
251
+ tmp=self.indices_of_elements(z,N)
252
+
253
+ frozen_bits=tmp[:self.N-self.K]
254
+
255
+ info_bits=tmp[self.N-self.K:]
256
+
257
+ return np.sort(frozen_bits),np.sort(info_bits)
258
+
259
+
260
+
261
+ @staticmethod
262
+
263
+ def indices_of_elements(v,l):
264
+
265
+ tmp=np.argsort(v)[::-1]
266
+
267
+ #print(tmp)
268
+
269
+ res=tmp[0:l]
270
+
271
+ return res
272
+
273
+
274
+
275
+ class coding(coding):
276
+
277
+ @staticmethod
278
+
279
+ def tensordot(A):
280
+
281
+ tmp0=np.zeros((A.shape[0],A.shape[1]),dtype=np.int)
282
+
283
+ tmp1=np.append(A,tmp0,axis=1)
284
+
285
+ #print(tmp1)
286
+
287
+ tmp2=np.append(A,A,axis=1)
288
+
289
+ #print(tmp2)
290
+
291
+ tmp3=np.append(tmp1,tmp2,axis=0)
292
+
293
+ #print(tmp3)
294
+
295
+ return tmp3
296
+
297
+
298
+
299
+ def make_H(self):
300
+
301
+ G2=np.array([[1,0],[1,1]],dtype=np.int)
302
+
303
+ Gres=G2
304
+
305
+ for _ in range(self.itr_num-1):
306
+
307
+ #print(i)
308
+
309
+ Gres=self.tensordot(Gres)
310
+
311
+ return Gres
312
+
313
+
314
+
315
+ class encoding(coding):
316
+
317
+ #def __init__(self,N):
318
+
319
+ #super().__init__(N)
320
+
321
+
322
+
323
+ def generate_information(self):
324
+
325
+ #generate information
326
+
327
+ information=np.random.randint(0,2,self.K)
328
+
329
+ return information
330
+
331
+
332
+
333
+ class encoding(encoding):
334
+
335
+ def generate_U(self,information):
336
+
337
+ u_message=np.zeros(self.N)
338
+
339
+ u_message[self.info_bits]=information
340
+
341
+ return u_message
342
+
343
+
344
+
345
+ lass encoding(encoding):
346
+
347
+ def polar_encode(self):
348
+
349
+ information=self.generate_information()
350
+
351
+ u_message=self.generate_U(information)
352
+
353
+ codeword=(u_message@self.Gres)%2
354
+
355
+ return information,codeword
356
+
357
+
358
+
359
+ class decoding(coding):
360
+
361
+ n=0
362
+
363
+ EST_information=np.array([])
364
+
365
+
366
+
367
+ #def __init__(self,N):
368
+
369
+ #super().__init__(N)
370
+
371
+
372
+
373
+ @staticmethod
374
+
375
+ def chk(llr_1,llr_2):
376
+
377
+ CHECK_NODE_TANH_THRES=30
378
+
379
+ res=np.zeros(len(llr_1))
380
+
381
+ for i in range(len(res)):
382
+
383
+
384
+
385
+ if abs(llr_1[i]) > CHECK_NODE_TANH_THRES and abs(llr_2[i]) > CHECK_NODE_TANH_THRES:
386
+
387
+ if llr_1[i] * llr_2[i] > 0:
388
+
389
+ # If both LLRs are of one sign, we return the minimum of their absolute values.
390
+
391
+ res[i]=min(abs(llr_1[i]), abs(llr_2[i]))
392
+
393
+ else:
394
+
395
+ # Otherwise, we return an opposite to the minimum of their absolute values.
396
+
397
+ res[i]=-1 * min(abs(llr_1[i]), abs(llr_2[i]))
398
+
399
+ else:
400
+
401
+ res[i]= 2 * np.arctanh(np.tanh(llr_1[i] / 2, ) * np.tanh(llr_2[i] / 2))
402
+
403
+ return res
404
+
405
+
406
+
407
+ def SC_decoding(self,a):
408
+
409
+ #interior node operation
410
+
411
+ if a.shape[0]==1:
412
+
413
+ #frozen_bit or not
414
+
415
+ if np.any(self.frozen_bits==decoding.n):
416
+
417
+ tmp0=np.zeros(1)
418
+
419
+ elif a>=0:
420
+
421
+ tmp0=np.zeros(1)
422
+
423
+ elif a<0:
424
+
425
+ tmp0=np.ones(1)
426
+
427
+ else:
428
+
429
+ print("err!")
430
+
431
+ exit()
432
+
433
+
434
+
435
+ if np.any(self.info_bits==decoding.n):
436
+
437
+ decoding.EST_information=np.append(decoding.EST_information,a)
438
+
439
+ #print(decoding.n)
440
+
441
+ #print(t)
442
+
443
+ decoding.n+=1
444
+
445
+ #if t>=N:
446
+
447
+ #exit()
448
+
449
+ return tmp0
450
+
451
+
452
+
453
+ #step1 left input a output u1_hat
454
+
455
+
456
+
457
+ tmp1=np.split(a,2)
458
+
459
+ f_half_a=self.chk(tmp1[0],tmp1[1])
460
+
461
+ u1=self.SC_decoding(f_half_a)
462
+
463
+
464
+
465
+ #step2 right input a,u1_hat output u2_hat
466
+
467
+ tmp2=np.split(a,2)
468
+
469
+ g_half_a=tmp2[1]+(1-2*u1)*tmp2[0]
470
+
471
+ u2=self.SC_decoding(g_half_a)
96
472
 
97
473
 
98
474
 
99
-
100
-
101
- return count_err,count_all,count_biterr,count_bitall
102
-
103
-
104
-
105
- #check
106
-
107
- tc=turbo_code() #クラスの呼び出し
108
-
109
-
110
-
111
- output_ids=[]
112
-
113
-
114
-
115
- output_ids=output.remote(tc.turbo_code,i)
116
-
117
-
118
-
119
- outputs=ray.get(output_ids)
120
-
121
- print(outputs)
475
+ #step3 up input u1,u2 output a_hat
476
+
477
+ res=np.concatenate([(u1+u2)%2,u2])
478
+
479
+ return res
480
+
481
+
482
+
483
+ class decoding(decoding):
484
+
485
+ def polar_decode(self,Lc):
486
+
487
+ #initialize class variable
488
+
489
+ decoding.n=0
490
+
491
+ decoding.EST_information=np.array([])
492
+
493
+ self.SC_decoding(Lc)
494
+
495
+ res=decoding.EST_information
496
+
497
+ res=-1*np.sign(res)
498
+
499
+ EST_information=(res+1)/2
500
+
501
+
502
+
503
+ return EST_information
504
+
505
+
506
+
507
+ class polar_code(encoding,decoding):
508
+
509
+ #def __init__(self,N):
510
+
511
+ #super().__init__(N)
512
+
513
+
514
+
515
+ def polar_code(self,EbNodB):
516
+
517
+
518
+
519
+
520
+
521
+ information,codeword=self.polar_encode()
522
+
523
+
524
+
525
+ Lc=-1*ch.generate_LLR(codeword,EbNodB)#デコーダが+、ー逆になってしまうので-1をかける
526
+
527
+
528
+
529
+ EST_information=self.polar_decode(Lc)
530
+
531
+
532
+
533
+ return information,EST_information
534
+
535
+
536
+
537
+ if __name__=="__main__":
538
+
539
+ ray.init()
540
+
541
+
542
+
543
+ #N=512
544
+
545
+ #pc=polar_code()
546
+
547
+
548
+
549
+ #print(len(pc.info_bits))
550
+
551
+ #information,EST_information=pc.polar_code(100)
552
+
553
+ #print(len(information))
554
+
555
+ #print(len(EST_information))
556
+
557
+ #print(np.sum(information!=EST_information))
558
+
559
+
560
+
561
+ @ray.remote
562
+
563
+ def output(EbNodB):
564
+
565
+ count_err=0
566
+
567
+ count_all=0
568
+
569
+ count_berr=0
570
+
571
+ count_ball=0
572
+
573
+ MAX_ERR=8
574
+
575
+
576
+
577
+ while count_err<MAX_ERR:
578
+
579
+
580
+
581
+ pc=polar_code()
582
+
583
+ information,EST_information=pc.polar_code(EbNodB)
584
+
585
+
586
+
587
+ if np.any(information!=EST_information):#BLOCK error check
588
+
589
+ count_err+=1
590
+
591
+
592
+
593
+ count_all+=1
594
+
595
+
596
+
597
+ #calculate bit error rate
598
+
599
+ count_berr+=np.sum(information!=EST_information)
600
+
601
+ count_ball+=N
602
+
603
+
604
+
605
+ #print("\r","count_all=",count_all,",count_err=",count_err,"count_ball="\
606
+
607
+ #,count_ball,"count_berr=",count_berr,end="")
608
+
609
+
610
+
611
+ #print("\n")
612
+
613
+ #print("BER=",count_berr/count_ball)
614
+
615
+ return count_err,count_all,count_berr,count_all
616
+
617
+
618
+
619
+
620
+
621
+ results=[output.remote(EbNodB) for EbNodB in range(-3,5)]
622
+
623
+ print(ray.get(results))
122
624
 
123
625
  ```
124
626
 
@@ -144,4 +646,6 @@
144
646
 
145
647
  ray 0.9.0
146
648
 
649
+ VScode のjupyter notebookにて記述
650
+
147
651
  ここにより詳細な情報を記載してください。

3

誤字の修正

2021/08/19 09:41

投稿

kaneko_
kaneko_

スコア10

test CHANGED
File without changes
test CHANGED
@@ -10,7 +10,7 @@
10
10
 
11
11
  また、multiprocessing moduleを用いて記述したところ、うまくプログラムが動きました。
12
12
 
13
- そもそものクラス構造に問題がある気がするのですが、どのように直せばよいのかわからないため質問いたしました。よろしくおねがいします。
13
+ こをどのように直せばよいのかわからないため質問いたしました。よろしくおねがいします。
14
14
 
15
15
 
16
16
 

2

追加の情報を入力

2021/08/04 09:03

投稿

kaneko_
kaneko_

スコア10

test CHANGED
File without changes
test CHANGED
@@ -7,6 +7,8 @@
7
7
  エラーの原因は関数内で別クラスのクラス内関数を引用しているからということはわかるのですが、なぜクラス内関数を引用するとエラーが出るのかがわかりません。
8
8
 
9
9
  ちなみに、試しに(クラス内関数でなく)単なる関数を引用してコードを書いたところ、うまく動作しました。
10
+
11
+ また、multiprocessing moduleを用いて記述したところ、うまくプログラムが動きました。
10
12
 
11
13
  そもそものクラス構造に問題がある気がするのですが、どのように直せばよいのかわからないため質問いたしました。よろしくおねがいします。
12
14
 
@@ -132,6 +134,8 @@
132
134
 
133
135
  - この関数'output'をクラスに変更し、引用したいクラスを継承させても、同じエラーを吐いてしまう。
134
136
 
137
+ - multiprocessing module を用いたところ、うまく動いた
138
+
135
139
 
136
140
 
137
141
  ### 補足情報(FW/ツールのバージョンなど)

1

誤字の修正

2021/08/04 09:03

投稿

kaneko_
kaneko_

スコア10

test CHANGED
File without changes
test CHANGED
@@ -102,7 +102,7 @@
102
102
 
103
103
  #check
104
104
 
105
- tc=turbo_code #クラスの呼び出し
105
+ tc=turbo_code() #クラスの呼び出し
106
106
 
107
107
 
108
108