質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
86.12%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

受付中

pytorchでデータローダーの全データを取り出したい

tmc5
tmc5

総合スコア26

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

1回答

0グッド

0クリップ

167閲覧

投稿2022/11/23 09:42

前提

pytorchで下記のページを参考に、VAEの潜在空間の可視化を行っています。
参考サイト:徹底解説:VAEをはじめから丁寧に

こちらのサイトでは、潜在空間の可視化をバッチごとに行っていました。バッチサイズは1000です。

潜在空間の可視化

python

1z_dim = 2 2model = VAE(z_dim) 3cm = plt.get_cmap("tab10") # カラーマップの用意 4# 可視化開始 5for num_batch, data in enumerate(dataloader_test): 6 fig_plot, ax_plot = plt.subplots(figsize=(9, 9)) 7 fig_scatter, ax_scatter = plt.subplots(figsize=(9, 9)) 8 # 学習済みVAEに入力を与えたときの潜在変数を抽出 9 _, z, _ = model(data[0], device) 10 z = z.detach().numpy() 11 # 各クラスごとに可視化する 12 for k in range(10): 13 cluster_indexes = np.where(data[1].detach().numpy() == k)[0] 14 ax_plot.plot(z[cluster_indexes,0], z[cluster_indexes,1], "o", ms=4, color=cm(k)) 15 fig_plot.savefig(f"./latent_space_z_{z_dim}_{num_batch}_plot.png") 16 fig_scatter.savefig(f"./latent_space_z_{z_dim}_{num_batch}_scatter.png") 17 plt.close(fig_plot) 18 plt.close(fig_scatter)

これを、全データを一つのプロットとして出力したいです。

やってみたこと

①バッチごとにキャンバスを変えずに、同じプロットに描画
→ tSNE処理(多次元を2次元に圧縮する手法)を行っていたため不具合が起きて以下のような画像が生成される
イメージ説明

全データを描画に一括で与える必要があるとわかったのですが、そのやり方がわからず詰まっています。

発生している問題・エラーメッ

python

1 def latent_space_tSNE_alldata(self): 2 """Visualization of latent space using all data 3 4 Args: 5 None. (zs, labels) 6 7 Returns: 8 None. 9 """ 10 cm = plt.get_cmap("tab10") 11 fig_plot, ax_plot = plt.subplots(figsize=(9, 9)) 12 fig_scatter, ax_scatter = plt.subplots(figsize=(9, 9)) 13 14 15#### 16 for num_batch, data in enumerate(self.dataloader_test): 17 _, z, _ = self.model(data[0], self.device) 18#### 19 20 z = z.cpu().detach().numpy() 21 points = TSNE(n_components=2, random_state=0).fit_transform(z) 22 for k in range(10): 23 cluster_indexes = np.where(data[1].cpu().detach().numpy() == k)[0] 24 ax_plot.plot(points[cluster_indexes,0], points[cluster_indexes,1], "o", ms=4, color=cm(k)) 25 ax_scatter.scatter(points[cluster_indexes,0], points[cluster_indexes,1], marker=f"${k}$", color=cm(k),label=k) 26 27 fig_plot.savefig(f"./images/latent_space_tSNE_alldata/z_{self.z_dim}_alldata_plot.png") 28 fig_scatter.savefig(f"./images/latent_space_tSNE_alldata/z_{self.z_dim}_alldata_scatter.png") 29 plt.legend(loc="upper left", fontsize=10) 30 plt.close(fig_plot) 31 plt.close(fig_scatter) 32 33

"####" で囲んだfor文のところで、zにappendしたところエラーで動きませんでした。
描画部以外のコードは参考サイトのものを使用しています。

さまざまやり方はあると思うのですが、初心者のためどのような形がスマートなのいまいちわからない状態です。全くお門違いの質問だったら申し訳ありませんが、ご助言いただければ幸いです。よろしくお願いいたします。

以下のような質問にはグッドを送りましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

グッドが多くついた質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

下記のような質問は推奨されていません。

  • 間違っている
  • 質問になっていない投稿
  • スパムや攻撃的な表現を用いた投稿

適切な質問に修正を依頼しましょう。

jbpb0

2022/11/24 06:47

「dataloader_test」じゃなくて、全データが入ってる「dataset_test」を使ったら、いかがでしょうか スマートではなくて少々強引ですが、 img = [] label = [] for num_data in range(len(dataset_test)): # この下の二行はインデント有り img.append(np.array(dataset_test[num_data][0])) label.append(np.array(dataset_test[num_data][1])) imgt = torch.tensor(np.array(img)) labelt = torch.tensor(np.array(label)) を実行して作成した「imgt」と「labelt」は、どちらも全データが入ってる「torch.Tensor」なので、 > for num_batch, data in enumerate(dataloader_test): で「dataloader_test」から取り出した「data[0]」と「data[1]」の代わりに使えると思います
jbpb0

2022/11/24 06:52

または、こちらも強引ですが、 dataloader_test2 = torch.utils.data.DataLoader(dataset_test, batch_size=len(dataset_test), shuffle=False) と、バッチサイズが全データ数と同じものを別に作って、 for num_batch, data in enumerate(dataloader_test2): として使えば、「data」には全データが入ります

回答1

0

それは辛いですね、、、🥺

投稿2022/11/23 10:16

user_x_xxx

総合スコア22

良いと思った回答にはグッドを送りましょう。
グッドが多くついた回答ほどページの上位に表示されるので、他の人が素晴らしい回答を見つけやすくなります。

下記のような回答は推奨されていません。

  • 間違っている回答
  • 質問の回答になっていない投稿
  • スパムや攻撃的な表現を用いた投稿

このような回答には修正を依頼しましょう。

2022/11/23 11:30

こちらの回答が他のユーザーから「質問に対する回答となっていない投稿」という指摘を受けました。

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
86.12%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。