質問編集履歴
1
テンプレートに沿って質問を書き換えました。加えて、試したことも追加しました。
test
CHANGED
File without changes
|
test
CHANGED
@@ -1,4 +1,29 @@
|
|
1
|
+
|
2
|
+
### 前提
|
3
|
+
|
1
|
-
|
4
|
+
実行時間の短縮を目指して、遺伝的アルゴリズムで巡回セールスマン問題を解くプログラムに非同期処理を追加しました。このとき、経路を計算する関数calc_distanceをコルーチンにしたところ、sort(key=calc_distance)のところでエラーが出ました。最短経路をsortでリストの先頭に持っていき、先頭を別のリストに保存する方法をとっているのでsortが出来ないと困ってしまいます。
|
5
|
+
|
6
|
+
### 実現したいこと
|
7
|
+
|
8
|
+
・sort(key=)にコルーチンであるcalc_distanceを指定したい
|
9
|
+
|
10
|
+
### 発生している問題・エラーメッセージ
|
11
|
+
|
12
|
+
```
|
13
|
+
33356.142646205066
|
14
|
+
Traceback (most recent call last):
|
15
|
+
line 230, in <module>
|
16
|
+
asyncio.run(main())
|
17
|
+
line 44, in run
|
18
|
+
return loop.run_until_complete(main)
|
19
|
+
line 642, in run_until_complete
|
20
|
+
return future.result()
|
21
|
+
line 207, in main
|
22
|
+
population.sort(key= await Route.calc_distance)
|
23
|
+
TypeError: object function can't be used in 'await' expression
|
24
|
+
```
|
25
|
+
|
26
|
+
### ソースコード
|
2
27
|
|
3
28
|
```python
|
4
29
|
import asyncio
|
@@ -7,43 +32,6 @@
|
|
7
32
|
import csv
|
8
33
|
import copy
|
9
34
|
import time
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
def read_tspfile():
|
14
|
-
"""
|
15
|
-
tspファイルを読み込み、都市の座標(float型)を
|
16
|
-
[[都市番号,X,Y],[...],...] の形で返す
|
17
|
-
"""
|
18
|
-
def str2float(cities):
|
19
|
-
data = [[0]]*len(cities)
|
20
|
-
for i in range(len(cities)):
|
21
|
-
city = [0]*len(cities[i])
|
22
|
-
data[i] = city
|
23
|
-
try:
|
24
|
-
for j in range(len(cities[i])):
|
25
|
-
data[i][j] = float(cities[i][j])
|
26
|
-
except:
|
27
|
-
data[i]*=0
|
28
|
-
continue
|
29
|
-
data2 = list(filter(None,data))
|
30
|
-
return data2
|
31
|
-
|
32
|
-
def remove_blank(cities):
|
33
|
-
for i in range(len(cities)):
|
34
|
-
for j in range(len(cities[i])):
|
35
|
-
try:
|
36
|
-
cities[i].remove('')
|
37
|
-
except:
|
38
|
-
continue
|
39
|
-
|
40
|
-
with open("a280.tsp","r") as fin:
|
41
|
-
data = [city.split(' ') for city in fin.read().splitlines()]
|
42
|
-
remove_blank(data)
|
43
|
-
cities_data = str2float(data)
|
44
|
-
return cities_data
|
45
|
-
|
46
|
-
|
47
35
|
|
48
36
|
cities_data = read_tspfile()
|
49
37
|
population = [] # [[経路],[経路],[経路]...[経路]]
|
@@ -207,7 +195,7 @@
|
|
207
195
|
|
208
196
|
for i in range(generation):
|
209
197
|
print(record_distance)
|
210
|
-
population.sort(key=await Route.calc_distance)
|
198
|
+
population.sort(key= await Route.calc_distance)
|
211
199
|
distance1 = await population[0].calc_distance() # 最短経路
|
212
200
|
|
213
201
|
if distance1 < record_distance:
|
@@ -233,18 +221,110 @@
|
|
233
221
|
asyncio.run(main())
|
234
222
|
end = time.time()
|
235
223
|
print(end-start)
|
236
|
-
|
237
|
-
```
|
224
|
+
```
|
225
|
+
|
226
|
+
### 試したこと
|
227
|
+
|
228
|
+
今のところ、sortのkeyにコルーチンを指定する方法が分からないので、ここに記載するのは変かもしれませんが、経路を計算するcalc_distance自体はコルーチンにしないで、新たに以下のような関数を作りました。
|
238
|
-
```
|
229
|
+
```python
|
239
|
-
35009.37540505089
|
240
|
-
|
230
|
+
async def give2calc_dist(Route_obj):
|
241
|
-
line 230, in <module>
|
242
|
-
asyncio.run(main())
|
243
|
-
line 44, in run
|
244
|
-
return loop.run_until_complete(main)
|
245
|
-
line 642, in run_until_complete
|
246
|
-
return future.result()
|
247
|
-
line 207, in main
|
248
|
-
|
231
|
+
return Route_obj.calc_distance()
|
249
|
-
TypeError: object function can't be used in 'await' expression
|
250
|
-
```
|
232
|
+
```
|
233
|
+
calc_distanceの部分を全てこれに置き換えると、以下のコードになります。
|
234
|
+
```python
|
235
|
+
async def pfga():
|
236
|
+
|
237
|
+
# 2未満なら追加。これだけだとランダムに2こ取り出す動作でエラー吐く。別途初期集団は作っておく
|
238
|
+
if len(population) < 2:
|
239
|
+
population.append(Route())
|
240
|
+
|
241
|
+
# ランダムに2個取り出す
|
242
|
+
p1 = population.pop(random.randint(0, len(population)-1))
|
243
|
+
p2 = population.pop(random.randint(0, len(population)-1))
|
244
|
+
|
245
|
+
# 子を作成
|
246
|
+
c1, c2 = await crossover(p1,p2)
|
247
|
+
|
248
|
+
if await give2calc_dist(p1) < await give2calc_dist(p2):
|
249
|
+
p_good = p1 # 短い経路(優秀)
|
250
|
+
p_bad = p2 # 長い経路(淘汰される)
|
251
|
+
else:
|
252
|
+
p_good = p2
|
253
|
+
p_bad = p1
|
254
|
+
if await give2calc_dist(c1) < await give2calc_dist(c2):
|
255
|
+
c_good = c1
|
256
|
+
c_bad = c2
|
257
|
+
else:
|
258
|
+
c_good = c2
|
259
|
+
c_bad = c1
|
260
|
+
|
261
|
+
if await give2calc_dist(c_bad) <= await give2calc_dist(p_good):
|
262
|
+
# 子2個体がともに親の2個体より良かった場合
|
263
|
+
# 子2個体及び適応度の良かった方の親個体計3個体が局所集団に戻り、局所集団数は1増加する。
|
264
|
+
population.append(c1)
|
265
|
+
population.append(c2)
|
266
|
+
population.append(p_good)
|
267
|
+
elif await give2calc_dist(p_bad) <= await give2calc_dist(c_good):
|
268
|
+
# 子2個体がともに親の2個体より悪かった場合
|
269
|
+
# 親2個体のうち良かった方のみが局所集団に戻り、局所集団数は1減少する。
|
270
|
+
population.append(p_good)
|
271
|
+
elif await give2calc_dist(p_good) <= await give2calc_dist(c_good) and await give2calc_dist(p_bad) >= await give2calc_dist(c_good):
|
272
|
+
# 親2個体のうちどちらか一方のみが子2個体より良かった場合
|
273
|
+
# 親2個体のうち良かった方と子2個体のうち良かった方が局所集団に戻り、局所集団数は変化しない。
|
274
|
+
population.append(c_good)
|
275
|
+
population.append(p_good)
|
276
|
+
elif await give2calc_dist(c_good) <= await give2calc_dist(p_good) and await give2calc_dist(c_bad) >= await give2calc_dist(p_good):
|
277
|
+
# 子2個体のうちどちらか一方のみが親2個体より良かった場合
|
278
|
+
# 子2個体のうち良かった方のみが局所集団に戻り、全探索空間からランダムに1個体選んで局所集団に追加する。局所集団数は変化しない。
|
279
|
+
population.append(c_good)
|
280
|
+
population.append(Route())
|
281
|
+
else:
|
282
|
+
raise ValueError("not comming")
|
283
|
+
|
284
|
+
|
285
|
+
async def main(generation=500):
|
286
|
+
# citiesに読み込んだ座標を持つCityオブジェクトを入れる
|
287
|
+
for i in range(CITIES_N):
|
288
|
+
cities.append(City(cities_data[i][0],
|
289
|
+
cities_data[i][1],
|
290
|
+
cities_data[i][2])) # num,X,Yの順
|
291
|
+
|
292
|
+
# populationに個体を追加
|
293
|
+
for i in range(2):
|
294
|
+
population.append(Route())
|
295
|
+
|
296
|
+
best = random.choice(population) # 個体(経路)
|
297
|
+
record_distance = await give2calc_dist(best) # 距離
|
298
|
+
|
299
|
+
with open('asyncio_PfGA_result.csv','w') as fout:
|
300
|
+
|
301
|
+
csvout = csv.writer(fout)
|
302
|
+
result = []
|
303
|
+
|
304
|
+
for i in range(generation):
|
305
|
+
print(record_distance)
|
306
|
+
population.sort(key=Route.calc_distance)
|
307
|
+
distance1 = await give2calc_dist(population[0]) # 最短経路
|
308
|
+
|
309
|
+
if distance1 < record_distance:
|
310
|
+
record_distance = distance1
|
311
|
+
best = population[0] # 最短経路を更新
|
312
|
+
|
313
|
+
task1 = asyncio.create_task(pfga())
|
314
|
+
#task2 = asyncio.create_task(pfga())
|
315
|
+
|
316
|
+
await task1
|
317
|
+
#await task2
|
318
|
+
|
319
|
+
if generation == 1 or generation%100 == 0:
|
320
|
+
data = []
|
321
|
+
data.extend([record_distance])
|
322
|
+
result.append(data)
|
323
|
+
if i == generation:
|
324
|
+
csvout.writerows(result)
|
325
|
+
print(best.citynums)
|
326
|
+
```
|
327
|
+
こうすると一応正常に動きました。しかし、計算処理自体はおそらく非同期になってないです。
|
328
|
+
|
329
|
+
|
330
|
+
|