質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.31%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

Q&A

解決済

1回答

544閲覧

CNNの画像を二次元配列から三次元配列へ変換

karintage

総合スコア5

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

0グッド

0クリップ

投稿2022/02/18 09:52

[Python CNN Grad-CAM]
CNNの画像にノイズを加えたものの特徴マップをGrad-CAMをを用いて可視化することを目標としています。
ノイズを加える前の画像は可視化することができましたが、ノイズを加えた後の画像は、同じプログラムでは実行できませんでした。

ノイズなしの場合は、[1,28,28]の三次元、ありの場合は[28,28]の二次元配列です。
以下がエラー箇所のプログラムです。
画像データの引数はdataです。

data.shape[2]

でエラーが発生します。

二次元配列のデータはそのままで、軸を増やせば解決できると思うのですが、どうすれば次元を増やせるでしょうか。
ご教授よろしくおねがいします。

以下の二行目がエラー部です。

python

1data, label = trainset[0] 2input_data = data.view(1, data.shape[0], data.shape[1], data.shape[2]).to(device) 3 4# モデルで予測 5# confidence: 信頼度、 predicted: 予測ラベル 6# (今回は、torch.max(, 1)でTop1の情報のみを取得 7output = nn.Softmax(dim=1)(model(input_data)) 8confidence, predicted = torch.max(output.data, 1) 9 10# GradCAMオブジェクト生成 11gcam = GradCAM(model=model) 12 13# 画像を順伝搬 14_ = gcam.forward(input_data) 15 16# 予測ラベルを元に逆伝搬し、勾配を計算 17single_predicted = predicted.view(1, predicted.shape[0]).to(device) 18gcam.backward(ids=single_predicted) 19 20# 指定の層の勾配マップを取得 21#(各層の名前は、print(model)で参照可能) 22regions = gcam.generate("conv1") 23 24# tensorを、numpy に変換 25raw_image = input_data[0].to('cpu').detach().numpy().copy() 26raw_image = raw_image.transpose((1, 2, 0)) 27raw_image = ((raw_image * 0.5) + 0.5) * 255.0 28raw_image = raw_image.astype(np.uint8) 29 30# 各画像 出力 31output = out_gradcam( 32 gcam=regions[0, 0], 33 raw_image=raw_image 34) 35 36## GradCAM 37plt.imshow(output, cmap="jet") 38plt.colorbar()

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

python

1>>> data = data.reshape(1,28,28) 2>>> data.shape 3(1, 28, 28)

です。

投稿2022/02/18 10:04

ppaul

総合スコア24672

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.31%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問