質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

1119閲覧

cnnにおけるinputの形状について

hoshihiro

総合スコア2

CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2020/07/13 10:51

編集2020/07/13 10:55

前提・実現したいこと

行列で表されるcsvファイルを画像同様に扱って学習を行う

実装中に以下のエラーメッセージが発生しました。
このエラーを改善したいです。

発生している問題・エラーメッセージ

Error when checking input: expected input_27 to have 4 dimensions, but got array with shape (32, 65, 80)

該当のソースコード

Python

1import numpy as np 2import glob 3import pandas as pd 4from sklearn.model_selection import train_test_split 5 6#ファイルの読み込み 7labelfiles = glob.glob("c:/keraslab/dataset/label/*.csv") 8label110_csv=[] 9for labelfile in labelfiles: 10 label110_csv.append(np.loadtxt(labelfile,delimiter=",")) 11label110_csv = np.array(label110_csv) 12 13inputfiles = glob.glob("c:/keraslab/dataset/input/*.csv") 14input18_csv=[] 15for inputfile in inputfiles: 16 input18_csv.append(np.loadtxt(inputfile,delimiter=",")) 17input18_csv = np.array(input18_csv) 18 19#正規化 20input18_csv=(input18_csv-np.min(input18_csv))/np.max(input18_csv) 21label110_csv=(label110_csv-np.min(label110_csv))/np.max(label110_csv) 22 23#train,testの作成 24i_train, i_test = train_test_split(input18_csv) 25l_train, l_test = train_test_split(label110_csv) 26 27from keras.layers import Input,Conv2D,BatchNormalization,Activation 28from keras.models import Model 29 30 31#DnCNN 32def network_dncnn(): 33 input_img= Input(shape=(65,80,1)) 34 35 x= Conv2D(64,kernel_size=3,activation='relu',padding='same')(input_img) 36 37 for i in range(15): 38 x= Conv2D(64,kernel_size=3,padding='same')(x) 39 x= BatchNormalization()(x) 40 x= Activation('relu')(x) 41 42 x =Conv2D(1,kernel_size=3,activation='tanh',padding='same')(x) 43 44 model=Model(input_img,x) 45 46 return model 47 48model=network_dncnn() 49 50#モデルの表示 51print(model.summary()) 52 53import keras.optimizers as optimizers 54from tensorflow.python.keras import backend as K 55 56#training 57adam=optimizers.Adam(lr=1e-3) 58model.compile(loss='mean_squared_error',optimizer='adam') 59 60#trainingパラメータ- 61training =model.fit(i_train,l_train,epochs=100,batch_size=128,shuffle=True,validation_data=(i_test,l_test),verbose=1) 62 63import matplotlib.pyplot as plt 64 65#学習履歴の表示 66def plot_history(history): 67 plt.plot(history.history['loss']) 68 plt.plot(history.history['val_loss']) 69 plt.title('model loss') 70 plt.xlabel('epoch') 71 plt.ylabel('loss') 72 plt.legend(['loss','val_loss'],loc='lower right') 73 plt.show() 74 75plot_history(training) 76

試したこと

input_img=Input(shape(32,65,80,1))

に変更させると

Input 0 is incompatible with layer conv2d_211: expected ndim=4, found ndim=5

と表示されます。

補足情報(FW/ツールのバージョンなど)

input18_csvとlabel110_csvは(43,65,80)で、trainとtestで32-11に分割されています。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

nput18_csvとlabel110_csvは(43,65,80)で、trainとtestで32-11に分割されています。

Conv2D は (B, H, W, C) の4次元配列を要求するので、
nput18_csv、label110_csv が1チャンネルで形状が (B, H, W) だとしたら、

python

1nput18_csv = np.expand_dims(nput18_csv, axis=-1)

でモデルに投入する前に (B, H, W, 1) に形状を変更しておく必要があります。

Python - expand_dimsの使い方|teratail

投稿2020/07/13 12:17

編集2020/07/15 04:25
tiitoi

総合スコア21956

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

hoshihiro

2020/07/15 04:25

ありがとうございます。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問