質問するログイン新規登録

Q&A

1回答

6087閲覧

PyTorchとKerasの推論速度について

takayoukey

総合スコア21

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

1グッド

1クリップ

投稿2020/04/06 06:38

編集2020/04/07 04:15

1

1

PyTorchとKerasの推論速度を比較してみたのですが、同じネットワークを比較するとVGG19などでは10倍近く速度差がありました。使用したのはどちらも公式にあるImagenetでのトレーニング済モデルです。

そこで質問なのですが、何が原因でこれほどの速度差が出たのでしょうか。Kerasがラッパーであるため遅くなることは理解できるのですが、もう少し内部動作に踏み込んだ洞察が欲しいと考えています。

また、今回は特殊な条件として、画像のロード->推論を1FPSとして測定しています。
若干メタ的な質問で恐縮ですがよろしくお願いします。

【補足】
TensorFlow 2.1.0
Keras 2.3.1
PyTorch 1.4.0

使用したコードは下記です。

Python

1import time 2 3def create_model(frame, arch): 4 if(frame == 'keras'): 5 import tensorflow.keras.applications as model 6 7 if(arch == 'densenet121'): 8 return model.DenseNet121() 9 if(arch == 'densenet169'): 10 return model.DenseNet169() 11 if(arch == 'densenet201'): 12 return model.DenseNet201() 13 14 elif(frame == 'pytorch'): 15 import torch 16 import torchvision.model as model 17 18 if(arch == 'densenet121'): 19 model = model.densenet121(pretrained=True).cuda() 20 if(arch == 'densenet169'): 21 model = model.densenet169(pretrained=True).cuda() 22 if(arch == 'densenet201'): 23 model = model.densenet201(pretrained=True).cuda() 24 25 model.eval() 26 model.to(torch.device('cuda')) 27 28 return model 29 30 31LOOP = 10 32img_path = 'xxx.jpg' 33frame = 'keras' 34name = 'densenet121' 35 36model = create_model(frame, name) 37 38# inference 39if(frame == 'keras'): 40 import numpy 41 from tensorflow.keras.preprocessing import image 42 43 start = time.perf_counter() 44 for i in range(LOOP): 45 img = image.load_img(img_path, target_size=(224, 224)) 46 img = image.img_to_array(img) 47 img = numpy.expand_dims(img, axis=0) 48 preds = model.predict(img) 49 elapsed_time = (time.perf_counter() - start) / LOOP_SIZE * 1000 50 fps = 1 / elapsed_time * 1000 51 52elif(frame == 'pytorch'): 53 import torchvision.transforms as transforms 54 from PIL import Image 55 import torch 56 from torch.autograd import Variable 57 device = torch.device('cuda') 58 transformation = transforms.Compose( 59 [ 60 transforms.Resize([224, 224]), 61 transforms.ToTensor(), 62 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 63 ] 64 ) 65 66 start = time.perf_counter() 67 for i in range(LOOP): 68 img = Image.open(img_path) 69 img = transformation(img).float() 70 img = img.unsqueeze_(0) 71 img = Variable(img) 72 img = img.to(device) 73 preds = model(img) 74 elapsed_time = (time.perf_counter() - start) / LOOP_SIZE * 1000 75 fps = 1 / elapsed_time * 1000 76 77line = '| %.3g | %.3g |' % (fps, elapsed_time) 78print('| FPS | Throughput |') 79print(line)
Gappoi-j👍を押しています

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

guest

回答の取得に失敗しました

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.29%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問