前提
pytorchで下記のページを参考に、VAEの潜在空間の可視化を行っています。
参考サイト:徹底解説:VAEをはじめから丁寧に
こちらのサイトでは、潜在空間の可視化をバッチごとに行っていました。バッチサイズは1000です。
python
1z_dim = 2 2model = VAE(z_dim) 3cm = plt.get_cmap("tab10") # カラーマップの用意 4# 可視化開始 5for num_batch, data in enumerate(dataloader_test): 6 fig_plot, ax_plot = plt.subplots(figsize=(9, 9)) 7 fig_scatter, ax_scatter = plt.subplots(figsize=(9, 9)) 8 # 学習済みVAEに入力を与えたときの潜在変数を抽出 9 _, z, _ = model(data[0], device) 10 z = z.detach().numpy() 11 # 各クラスごとに可視化する 12 for k in range(10): 13 cluster_indexes = np.where(data[1].detach().numpy() == k)[0] 14 ax_plot.plot(z[cluster_indexes,0], z[cluster_indexes,1], "o", ms=4, color=cm(k)) 15 fig_plot.savefig(f"./latent_space_z_{z_dim}_{num_batch}_plot.png") 16 fig_scatter.savefig(f"./latent_space_z_{z_dim}_{num_batch}_scatter.png") 17 plt.close(fig_plot) 18 plt.close(fig_scatter)
これを、全データを一つのプロットとして出力したいです。
やってみたこと
①バッチごとにキャンバスを変えずに、同じプロットに描画
→ tSNE処理(多次元を2次元に圧縮する手法)を行っていたため不具合が起きて以下のような画像が生成される
全データを描画に一括で与える必要があるとわかったのですが、そのやり方がわからず詰まっています。
発生している問題・エラーメッ
python
1 def latent_space_tSNE_alldata(self): 2 """Visualization of latent space using all data 3 4 Args: 5 None. (zs, labels) 6 7 Returns: 8 None. 9 """ 10 cm = plt.get_cmap("tab10") 11 fig_plot, ax_plot = plt.subplots(figsize=(9, 9)) 12 fig_scatter, ax_scatter = plt.subplots(figsize=(9, 9)) 13 14 15#### 16 for num_batch, data in enumerate(self.dataloader_test): 17 _, z, _ = self.model(data[0], self.device) 18#### 19 20 z = z.cpu().detach().numpy() 21 points = TSNE(n_components=2, random_state=0).fit_transform(z) 22 for k in range(10): 23 cluster_indexes = np.where(data[1].cpu().detach().numpy() == k)[0] 24 ax_plot.plot(points[cluster_indexes,0], points[cluster_indexes,1], "o", ms=4, color=cm(k)) 25 ax_scatter.scatter(points[cluster_indexes,0], points[cluster_indexes,1], marker=f"${k}$", color=cm(k),label=k) 26 27 fig_plot.savefig(f"./images/latent_space_tSNE_alldata/z_{self.z_dim}_alldata_plot.png") 28 fig_scatter.savefig(f"./images/latent_space_tSNE_alldata/z_{self.z_dim}_alldata_scatter.png") 29 plt.legend(loc="upper left", fontsize=10) 30 plt.close(fig_plot) 31 plt.close(fig_scatter) 32 33
"####" で囲んだfor文のところで、zにappendしたところエラーで動きませんでした。
描画部以外のコードは参考サイトのものを使用しています。
さまざまやり方はあると思うのですが、初心者のためどのような形がスマートなのいまいちわからない状態です。全くお門違いの質問だったら申し訳ありませんが、ご助言いただければ幸いです。よろしくお願いいたします。