質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

2回答

1947閲覧

python sorted()が遅い もっと早いアルゴリズムにしたい

Lizard_knight

総合スコア18

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

1クリップ

投稿2019/01/15 16:31

pythonのsorted()関数から高速なソート関数に変えたい

28行目のpre_dataにzip関数を適用したイテレータになっていて、key=pre_data[0]をソートした配列を生成したいがpythonのsorted()しか思いつきません。
np.sort()もエラーが出てしまいます。
助けてください(;´Д`)

python3

1 18 def predict_image(model, labels, x_data): 2 19 predicts = [] 3 20 top5s = [] 4 21 times = [] 5 22 6 23 pre = model.predict(x_data) 7 24 8 25 """ output top5 """ 9 26 for i in range(len(pre)): 10 27 print('i:', i) 11 28 pre_data = zip(pre[i].data, labels) 12 29 import pdb;pdb.set_trace() 13 30 top5_unit = sorted(pre_data, key=lambda x: -x[0])[:5] 14 31 #print(np.sort(pre[i].data)) 15 32 top5_unit = np.sort(pre_data, 0)[:5] 16 33 top5s.append(top5_unit) 17 34 18 35 19 36 """ output result """ 20 37 flag = 0 21 38 tlabels = ['tusker', 'African elephant', 'Indian elephant'] 22 39 top1 = 0 23 40 top5 = 0 24 41 for x in range(len(top5s)): 25 42 top1_acc = 0 26 43 top5_acc = 0 27 44 for i, data in enumerate(top5s[x]): 28 45 str_split = data[1].split(",") 29 46 for tlabel in tlabels: 30 47 if tlabel == str_split[0]: 31 48 if i == 0: 32 49 top1_acc = data[0]*100 33 50 top5_acc += data[0]*100 34 51 #print("{0}-label:{2} / accuracy:{1}%".format(i + 1, data[0] * 100, data[1])) 35 52 top1 += top1_acc 36 53 top5 += top5_acc 37 54 38 55 path_w = './result' 39 56 40 57 #with open(path_w, mode='w') as f: 41 58 # f.write('accuracy.txt') 42 59 print("top1:{0} | top5:{1}".format(top1 / len(top5s), top5 / len(top5s))) 43 60 #print("execution time:{:.3f}s".format(mean(times))) 44 61 45 62 return top1, top5

試したこと

30行目をnp.sort()に単純に変えたが
以下のエラーコードが出てしまう。

ValueError: object __array__ method not producing an array

補足情報(FW/ツールのバージョンなど)

ここにより詳細な情報を記載してください。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

hayataka2049

2019/01/15 16:41

zipは一回読むと空になっちゃいますが、そのコードでまともに動きますか?
quickquip

2019/01/16 03:41

sortedはIterableを受け取るので動くかと思います。
quickquip

2019/01/16 03:44

このmodelはなんですか? model.predictが返すものの型、pre[i].dataの型はなにか分かってますか?
Lizard_knight

2019/01/16 05:23

mdoelはVGG16Layers()というクラスで要するに学習済みのネットワークモデルです。 model.predictの型は<class 'chainer.variable.Variable'>で、 pre[i].dataの型は<class 'cupy.core.core.ndarray'>です。
hayataka2049

2019/01/16 05:31

top5_unit = sorted(pre_data, key=lambda x: -x[0])[:5]の行はいけるでしょうけど下にtop5_unit = np.sort(pre_data, 0)[:5]を書いたら駄目だろうという話です>quiquiさん
quickquip

2019/01/16 05:39

> hayataka2049さん なるほど。了解しました。 質問の中ですでに"そこは動かないと判明している"という頭で私が読んでいて、その指摘だと思いませんでした。失礼しました。
Lizard_knight

2019/01/16 06:31 編集

>hayataka2049さん すみません、top5_unit = np.sort(pre_data, 0)[:5]の部分はコメントアウトし忘れていました。 ないものだと思ってください。
hayataka2049

2019/01/16 06:33

念の為確認したいのですが、 top5_unit = sorted(pre_data, key=lambda x: -x[0])[:5]で書くと遅いが一応動く top5_unit = np.sort(pre_data, 0)[:5]だと質問文の通りのエラー という理解で良いですか?
Lizard_knight

2019/01/16 07:43

>hayataka2049さん はい、そうです!
guest

回答2

0

ベストアンサー

cupyの配列で返ってくると扱いに困るみたいなので、numpyに変換してみます(モデル側をいじれればそちらで対応するべきかもしれませんが)。

cupyは扱ったことがないのですが、cupy.asnumpyでnumpy配列に変換できるらしいです。

また、以下の点を変更した方が良いでしょう。

  • zipをやめる
  • そのnp.sortの使い方で0列目を基準にソートされたりはしないので、ソート対象でargsortしてから並び替える
  • 降順ソートになるようにスライスを直す

動作未検証ですが以下のようにしてできませんか。

python

1 for i in range(len(pre)): 2 data = cupy.asnumpy(pre[i].data) 3 idx = data.argsort(data)[:-6:-1] 4 top5_unit = np.stack([data[idx], labels[idx]], axis=1) # labelsもnumpy配列を想定しているので、違うのなら適宜変換してください 5 top5s.append(top5_unit)

投稿2019/01/16 12:29

hayataka2049

総合スコア30933

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

Lizard_knight

2019/01/16 13:50

回答ありがとうございます! sorted()とは比べ物にならないくらい早くなりました! ベストアンサーに選ばさせていただきます。
guest

0

hayataka2049さんのソースを基に書き換えた結果です。

python3

1 19 def predict_image(model, labels, x_data, num): 2 20 predicts = [] 3 21 top5s = [] 4 22 times = [] 5 23 6 24 pre = model.predict(x_data) 7 25 """ output top5 """ 8 26 for i in range(len(pre)): 9 27 data = cupy.asnumpy(pre[i].data) 10 28 idx = np.argsort(data)[:-6:-1] 11 29 top5_unit = np.stack([data[idx], idx], axis=1) 12 30 top5s.append(top5_unit) 13 31 14 32 #pre_data = zip(pre[i].data, labels) 15 33 #top5_unit = sorted(pre_data, key=lambda x: -x[0])[:5] 16 34 #print(cupy.sort(pre[i].data)) 17 35 #print(pre_data[0]) 18 36 #top5_unit = cupy.sort(pre_data, axis=0)[:5] 19 37 20 38 21 39 """ output result """ 22 40 flag = 0 23 41 #tlabels = ['tusker', 'African elephant', 'Indian elephant'] 24 42 tlabels = [101, 385, 386] 25 43 top1 = 0 26 44 top5 = 0 27 45 for x in range(len(top5s)): 28 46 top1_acc = 0 29 47 top5_acc = 0 30 48 for i in range(len(top5s[x])): 31 49 for tlabel in tlabels: 32 50 if tlabel == top5s[x][i][1]: 33 51 if i == 0: 34 52 top1_acc = top5s[x][i][0]*100 35 53 top5_acc += top5s[x][i][0]*100 36 54 37 55 top1 += top1_acc 38 56 top5 += top5_acc 39 57 40 58 41 59 path_w = './result' 42 60 43 61 #with open(path_w, mode='w') as f: 44 62 # f.write('accuracy.txt') 45 63 print("top1:{0:2.5f} | top5:{1:2.5f} | batch {2}/{3}". format(top1 / len(top5s), top5 / len(top5s), num+1, 13)) 46 64 #print("execution time:{:.3f}s".format(mean(times))) 47 65 48 66 return top1, top5

投稿2019/01/16 13:52

Lizard_knight

総合スコア18

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問