やりたいこと
はじめてGPyを使います。12次元のoil flowデータをGPLVMを使って潜在空間に埋め込んで、潜在空間での各データの座標を取得していです。
Datasetに関して
oil flowデータはここのgzipped MATLAB workspaceからダウンロードして、解凍することで3Class.matファイルとして使えます。
oil flowデータには12次元のベクトルと3次元のラベルが1000データ分保管されています。
3次元のラベルというのは、
Class A なら [0,0,1]
Class B なら [0,1,0]
Class C なら [1,0,0]
といった形です。
python
1import scipy.io 2oil_flow_dataset = scipy.io.loadmat("3Class.mat") 3oil_flow_dataset.keys() 4#dict_keys(['DataTrnLbls', 'DataVdnLbls', 'DataTrn', 'DataTstFrctns', 'DataTrnFrctns', 'DataTst', 'DataVdn', 'DataVdnFrctns', 'DataTstLbls']) 5oil_flow_dataset["DataTrnLbls"].shape 6#(1000,3) 7oil_flow_dataset["DataTrn"].shape 8#(1000,12)
TrnはTrain, VdnはValid, TstはTestを表してます。
やったこと
次のコードを実行することで、oil_flow_dataのGPLVMによる2次元の潜在空間での埋め込みの様子を知ることができます。
python
1import numpy as np 2import GPy 3import scipy.io 4import matplotlib.pyplot as plt 5 6oil_flow_dataset = scipy.io.loadmat("3Class.mat") 7observed_data = oil_flow_dataset["DataTrn"] 8normalized_observed_data = (observed_data - observed_data.mean(axis=0)) / observed_data.var(axis=0) 9GT = oil_flow_dataset["DataTrnLbls"].nonzero()[1] 10 11model = GPy.models.GPLVM(normalized_observed_data, input_dim=2) 12model.optimize(messages=True, max_iters=1e3) 13model.plot_latent(labels=GT) 14plt.savefig("gplvm.png")
私がやりたいのは、この2次元の潜在空間における各データ(上の図における各△)の座標を取得したいのです。要はplot_latentなんてものを使わずに普通にmatplotlibを使って潜在空間をプロットしたいのです。
よろしくお願いします。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2019/12/05 11:18