pythonのsorted()関数から高速なソート関数に変えたい
28行目のpre_dataにzip関数を適用したイテレータになっていて、key=pre_data[0]をソートした配列を生成したいがpythonのsorted()しか思いつきません。
np.sort()もエラーが出てしまいます。
助けてください(;´Д`)
python3
1 18 def predict_image(model, labels, x_data): 2 19 predicts = [] 3 20 top5s = [] 4 21 times = [] 5 22 6 23 pre = model.predict(x_data) 7 24 8 25 """ output top5 """ 9 26 for i in range(len(pre)): 10 27 print('i:', i) 11 28 pre_data = zip(pre[i].data, labels) 12 29 import pdb;pdb.set_trace() 13 30 top5_unit = sorted(pre_data, key=lambda x: -x[0])[:5] 14 31 #print(np.sort(pre[i].data)) 15 32 top5_unit = np.sort(pre_data, 0)[:5] 16 33 top5s.append(top5_unit) 17 34 18 35 19 36 """ output result """ 20 37 flag = 0 21 38 tlabels = ['tusker', 'African elephant', 'Indian elephant'] 22 39 top1 = 0 23 40 top5 = 0 24 41 for x in range(len(top5s)): 25 42 top1_acc = 0 26 43 top5_acc = 0 27 44 for i, data in enumerate(top5s[x]): 28 45 str_split = data[1].split(",") 29 46 for tlabel in tlabels: 30 47 if tlabel == str_split[0]: 31 48 if i == 0: 32 49 top1_acc = data[0]*100 33 50 top5_acc += data[0]*100 34 51 #print("{0}-label:{2} / accuracy:{1}%".format(i + 1, data[0] * 100, data[1])) 35 52 top1 += top1_acc 36 53 top5 += top5_acc 37 54 38 55 path_w = './result' 39 56 40 57 #with open(path_w, mode='w') as f: 41 58 # f.write('accuracy.txt') 42 59 print("top1:{0} | top5:{1}".format(top1 / len(top5s), top5 / len(top5s))) 43 60 #print("execution time:{:.3f}s".format(mean(times))) 44 61 45 62 return top1, top5
試したこと
30行目をnp.sort()に単純に変えたが
以下のエラーコードが出てしまう。
ValueError: object __array__ method not producing an array
補足情報(FW/ツールのバージョンなど)
ここにより詳細な情報を記載してください。
zipは一回読むと空になっちゃいますが、そのコードでまともに動きますか?
sortedはIterableを受け取るので動くかと思います。
このmodelはなんですか? model.predictが返すものの型、pre[i].dataの型はなにか分かってますか?
mdoelはVGG16Layers()というクラスで要するに学習済みのネットワークモデルです。
model.predictの型は<class 'chainer.variable.Variable'>で、
pre[i].dataの型は<class 'cupy.core.core.ndarray'>です。
top5_unit = sorted(pre_data, key=lambda x: -x[0])[:5]の行はいけるでしょうけど下にtop5_unit = np.sort(pre_data, 0)[:5]を書いたら駄目だろうという話です>quiquiさん
> hayataka2049さん
なるほど。了解しました。
質問の中ですでに"そこは動かないと判明している"という頭で私が読んでいて、その指摘だと思いませんでした。失礼しました。
>hayataka2049さん
すみません、top5_unit = np.sort(pre_data, 0)[:5]の部分はコメントアウトし忘れていました。
ないものだと思ってください。
念の為確認したいのですが、
top5_unit = sorted(pre_data, key=lambda x: -x[0])[:5]で書くと遅いが一応動く
top5_unit = np.sort(pre_data, 0)[:5]だと質問文の通りのエラー
という理解で良いですか?
>hayataka2049さん
はい、そうです!
回答2件
あなたの回答
tips
プレビュー