前提
pytorchのtensorの取り扱いについて
GPUを用いたpytorch tensorの計算(下記スクリプト)について高速に動作させたいです。
実現したいこと
下記スクリプトの高速化を行いたいです。
任意のインデックスiについて a[i],b[i]のペアが、ref_listの0行目と1行目にあるか探索し、ref_listの2行目の値をcに格納する
import torch
import numpy as np
import time
t0=time.time()
a = torch.tensor([1,1,2,1,1,2], device=torch.device('cuda:0'))
b = torch.tensor([2,2,1,1,1,2], device=torch.device('cuda:0'))
ref_list=torch.tensor([[1,1,0.1],[1,2,0.2],[2,2,0.3]],device=torch.device('cuda:0'))
c=torch.empty(6,device=torch.device('cuda:0'))
for pair1,pair2,val in ref_list:
idx=[i for i,(a_,b_) in enumerate(zip(a,b)) if ((a_==pair1) and (b_==pair2)) or ((a_==pair2) and (b_==pair1)) ]
c[idx]=val
print(time.time()-t0)
発生している問題・エラーメッセージ
ネットで調べながら上記スクリプトを自作しました。さらに高速なアルゴリズムがあればお伺いさせてください。
実際の実装はaとbの長さが数十万に及びますので、できるだけ早く動作させたいと考えております。
あなたの回答
tips
プレビュー