🎄teratailクリスマスプレゼントキャンペーン2024🎄』開催中!

\teratail特別グッズやAmazonギフトカード最大2,000円分が当たる!/

詳細はこちら
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

解決済

2回答

1906閲覧

CNNの学習がうまくいかない

taichi1602

総合スコア26

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

NumPy

NumPyはPythonのプログラミング言語の科学的と数学的なコンピューティングに関する拡張モジュールです。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

0クリップ

投稿2021/01/03 13:09

編集2021/01/04 04:32

ゼロから作るdeep learningのim2colとcol2im関数のみを用いて, かんたんな畳込みネットワークを作成しようとしているのですが, うまく学習が行えていません。
下のプログラムで決定的な誤りがある箇所はありますか?
学習率と活性化関数とバイアスは考えないでプログラムしています。

CNNの構成は、入力→畳み込み層→畳み込み層2→全結合層を想定しています。
全結合層から2乗誤差を逆伝搬して誤差逆伝搬を実装しています。

下のプログラムでは、簡略化のために、入力はnp.arrangeで実装しています。
入力はmnistを想定してのこの形となっています。

学習した結果、テストしたところ、出力の最大の要素が全て同じになってしまい、正しく学習がされていないようです。
col2imや逆伝搬の使い方に誤りがあるのでしょうか?

python

1 2import numpy as np 3 4def im2col(input_data, filter_h, filter_w, stride_h=1, stride_w=1, pad_h=0, pad_w=0): 5 6 N, C, H, W = input_data.shape 7 out_h = (H + 2*pad_h - filter_h)//stride_h + 1 8 out_w = (W + 2*pad_w - filter_w)//stride_w + 1 9 10 img = np.pad(input_data, [(0,0), (0,0), (pad_h, pad_h), (pad_w, pad_w)], 'constant') 11 col = np.zeros((N, C, filter_h, filter_w, out_h, out_w)) 12 13 for y in range(filter_h): 14 y_max = y + stride_h*out_h 15 for x in range(filter_w): 16 x_max = x + stride_w*out_w 17 col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_w] 18 19 col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1) 20 return col, out_h, out_w 21 22def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0): 23 N, C, H, W = input_shape 24 out_h = (H + 2*pad - filter_h)//stride + 1 25 out_w = (W + 2*pad - filter_w)//stride + 1 26 col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2) 27 28 img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1)) 29 for y in range(filter_h): 30 y_max = y + stride*out_h 31 for x in range(filter_w): 32 x_max = x + stride*out_w 33 img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :] 34 35 return img[:, :, pad:H + pad, pad:W + pad] 36 37def f(x): 38 y = 1 / (1 + np.exp(-0.01*x)) 39 return y 40 41#フィルタ 42#w = np.arange(5*1*9*9).reshape(5,1,9,9) 43w = np.random.rand(5,1,9,9)/100 44 45#2層目フィルタ 46#w2 = np.arange(5*5*9*9).reshape(5,5,9,9) 47w2 = np.random.rand(5,5,9,9)/100 48 49#全結合層重み 50w_affine = np.random.rand(6*6*5,10)/100 51 52for i in range(1000):#学習回数 53 #入力 54 x = np.arange(6*1*28*28).reshape(6,1,28,28) 55 x = 0.0001*x 56 57 #入力2次元化 58 col_x,out_h,out_w = im2col(x,9,9,1,1,0,0)#入力, フィルタサイズ9, ストライド1 59 60 #1層目フィルタ2次元化 61 col_w = w.reshape(5,-1).T 62 #1層目畳み込み計算 63 conv1_col = np.dot(col_x, col_w) 64 65 #1層目畳み込み層変換 66 conv1 = conv1_col.reshape(x.shape[0], out_h, out_w, -1).transpose(0, 3, 2, 1) 67 68 #2層目フィルタ2次元化 69 col_w2 = w2.reshape(5,-1).T 70 71 #2層目畳み込み計算(2次元) 72 input_col,out_h,out_w = im2col(conv1,9,9,2,2,0,0)#cov1, フィルタサイズ9,ストライド2 73 conv2_col = np.dot(input_col, col_w2) 74 conv2 = conv2_col.reshape(conv1.shape[0], out_h, out_w, -1).transpose(0, 3, 2, 1) 75 76 #全結合層出力 77 z = conv2.reshape(conv2.shape[0],-1) 78 y = f(np.dot(z,w_affine)) 79 80 t = np.zeros((6,10)) 81 #教師信号 82 for i in range(6): 83 t[i,i]=1 84 85 #全結合層誤差 86 Error2 = (t - y) * (1 - y) * y 87 88 a = 1e-5 89 #affine層更新 90 w_affine += a * np.dot(z.T, Error2) 91 92 #畳み込み層2誤差(1次元) 93 Error1 = np.dot(Error2, w_affine.T) 94 95 #畳み込み層誤差変換 96 dout2 = Error1.reshape(6,5,6,6) 97 dout2 = dout2.transpose(0,2,3,1).reshape(-1,5) 98 99 #w2更新 100 dw2 = np.dot(input_col.T,dout2) 101 dw2 = dw2.transpose(1,0).reshape(5,5,9,9) 102 w2 += a * dw2 103 104 #畳み込み層1誤差 105 dout1_col = np.dot(dout2,col_w2.T) 106 dout1 = col2im(dout1_col,conv1.shape,9,9,2,0) 107 dout1 = dout1.transpose(0,2,3,1).reshape(-1,5) 108 109 #w1更新 110 dw1 = np.dot(col_x.T,dout1) 111 dw1 = dw1.transpose(1,0).reshape(5,1,9,9) 112 w += a * dw1 113 114#学習結果をテスト 115#入力 116x = np.arange(6*1*28*28).reshape(6,1,28,28) 117x = 0.0001*x 118 119#フィルタ 120#w = np.arange(5*1*9*9).reshape(5,1,9,9) 121#w = 0.01*w 122 123#入力2次元化 124col_x,out_h,out_w = im2col(x,9,9,1,1,0,0)#入力, フィルタサイズ9, ストライド1 125 126#1層目フィルタ2次元化 127col_w = w.reshape(5,-1).T 128#1層目畳み込み計算 129conv1_col = np.dot(col_x, col_w) 130 131#1層目畳み込み層変換 132conv1 = conv1_col.reshape(x.shape[0], out_h, out_w, -1).transpose(0, 3, 2, 1) 133 134#2層目フィルタ 135#w2 = np.arange(5*5*9*9).reshape(5,5,9,9) 136#w2 = 0.01*w2 137 138#2層目フィルタ2次元化 139col_w2 = w2.reshape(5,-1).T 140 141#2層目畳み込み計算(2次元) 142input_col,out_h,out_w = im2col(conv1,9,9,2,2,0,0)#cov1, フィルタサイズ9,ストライド2 143conv2_col = np.dot(input_col, col_w2) 144conv2 = conv2_col.reshape(conv1.shape[0], out_h, out_w, -1).transpose(0, 3, 2, 1) 145 146#全結合層重み 147#w_affine = np.random.rand(6*6*5,10) 148 149#全結合層出力 150z = conv2.reshape(conv2.shape[0],-1) 151y = f(np.dot(z,w_affine)) 152y_max = np.argmax(y,axis=1) 153 154print(y) 155print(y_max) 156print(t)

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

taichi1602

2021/01/03 14:57

同じプログラムに関しての質問です。 上記の質問に関するエラーは無くなったのですが、CNNの動作がこれで正しいかどうかの質問です。 特に逆伝搬はこのコードで正しいのかどうかを教えていただきたいです。
guest

回答2

0

こんにちは、これはゼロから作るディープラーニングの本の内容を見ながらご自身でこちらのオリジナルのコードを作成されている、ということで正しいですか?このリポジトリと内容が違ったのでそう解釈しました。

python

1import numpy as np 2#2層目フィルタ 3w2 = np.arange(5*5*9*9).reshape(5,5,9,9)

例えば、このw2だと、0,1,2,3, ...という値が初期値として入って、
たとえば平均0、分散1の初期値が入ると思うのですが、その辺もうまくいかない要因かなあと思いました。im2colはうまく動作している気がしたのですが、他のところは見れていません。
もろもろこちらの解釈ちがいであればすいません。うまくいかない、というのも、どううまくいかないのか(xxというエラー文がでる、△の値が想定よりずっと大きいなど)が書かれていると他の方からよりよい回答が得られそうだと思いました。役に立てば幸いです。

投稿2021/01/03 15:30

Kenta_py

総合スコア132

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

taichi1602

2021/01/03 15:52

回答ありがとうございます。 >これはゼロから作るディープラーニングの本の内容を見ながらご自身でこちらのオリジナルのコードを作成されている はいその通りです。 本では1層の畳み込み層での実装でしたが、作成するプログラムは, 2層の畳み込み層で実装を考えています。 また、学習アルゴリズムの理解のために、im2colとcol2imのみを本から借りてきての実装を試みています。 質問内容を簡略にするために、質問では入力と重みはnp.arraneにしていますが、実際は入力はmnistで、重みは乱数で初期化して学習させています。 うまくいかない点として、mnistの学習結果が未学習の時と学習後で変化が見られないため質問をしました。(重みの値自体は更新されている様です) 入力と重み以外は質問のプログラムと同じのために、こうのような形で質問させていただきました。
Kenta_py

2021/01/03 16:43

なるほど、詳しくありがとうございます。質問を簡略化するためにいろいろ実際のものと入れ替えている旨も理解しました。確かに、簡略化することも大事だと思うのですが、そのあたりも詳しく書かないと初見で見る人にとっては意味を汲み取るのも難しいと思うので次からはそのような条件も質問文に書くと良いと思いました。 自作してる件も承知いたしました。これはコードのごく一部だと思うので他の部分も必要だと思うのですが、他の部分はもとのコードと全く同じなのでしょうか? 残念ながらこの本の実装は私は試したことがなくて、いますぐこの場所がちがう、、、とは言えそうにないです、、、すいません。 言えそうなこととしては、2層目を追加した、とのことですが、正しく実装自体はできたとして、その2層目の構造でMNISTの学習がうまくいく保証はありそうですか?何かしらのソフトやライブラリでうまくいく2層のCNNを探して、それをもとに2層目を定義⇒誤差の計算を実装とすればよいかなと思いました。 また、逆にこちらから質問させてほしいのですが、全結合⇒2層目の逆誤差伝搬と2層目⇒1層目の逆誤差伝搬の式は結構異なると思うのですが、そのあたりはどう実装しましたか?プーリング層のデルタとかいろいろ考えるところがあって、実際自分で書くのは非常に難しそうだなあと思っていました。もしそのあたりを区別せず、そのまま2層分をうごくように書き下しただけだとすんなりとは行かなさそうです。 mnist cnn from scratch 2 layers 等で調べるといろいろ出てきたので(例:https://mlfromscratch.com/neural-network-tutorial/#/)、まずはそれらを見て勉強するのもよさそうです。
taichi1602

2021/01/04 02:41

回答ありがとうございます。簡略化の件も確かにそうですね。 質問のプログラムは、mnistの実装以外は全て、同じものです。 質問プログラムで順伝搬と逆伝搬を行います。 2層目の追加は一般的なCNNは畳み込み層を複数入れて実装しているために勉強のために追加しました。 誤差逆伝搬の式も、本にあるcol2im関数で逆伝搬を行っています。 質問のCNNでは、プーリング層はなしで、入力→畳み込み層→畳み込み層2→全結合層(出力)の構成になっています。 逆伝搬は、2乗誤差を誤差逆伝搬法で更新しています。 質問では、CNNの構造というよりも、正しい更新式になっているかという質問です。 正しく学習はしているが、バッチや活性化関数、学習率が適切でなく学習ができていなければ、そこは次の勉強にしたいと思います。 参考URLありがとうございます。確認してみます。
Kenta_py

2021/01/04 07:59

なるほど、もろもろ理解しました。質問の更新もありがとうございます。他の良い回答が得られるとよいですね。 >>質問では、CNNの構造というよりも、正しい更新式になっているかという質問です。 ありがとうございます。もちろんそこは理解しているのですが、そもそもその構造やパラメータでうまくいくかわかっていれば、うまく学習できないのはパラメータのせいなのか、コーディングのせいかの切り分けができるのでより検証がしやすいのではと思った次第です。 また時間があればコードのほうも見ておきますね。また改善箇所がわかればここで返信させていただきますね。
taichi1602

2021/01/04 10:08

ありがとうございます。 よろしくお願いします
guest

0

ベストアンサー

教師信号がOne-hotベクトルなので、活性化関数はSoftmaxなどを使って0〜1の範囲にしないと適切な誤差が算出できないかと思います。正しい誤差が求められなければ、正しい学習は行えません。

また、学習が進んでいるかどうかは、交差エントロピー誤差が求められれば分かります。誤差が小さくなることがわかれば、学習が進んでいると判断できます。

そのため、活性化関数と交差エントロピー誤差を勉強していただいて、それをプログラムに組み込んでいただいてから、学習ができているかどうかを判断されるのが良いかと思います。

投稿2021/01/06 12:36

segavvy

総合スコア1038

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

taichi1602

2021/01/07 00:48

回答ありがとうございます。 softmaxの誤りがありました。 segavvyさんのQiita見させていただきました。 無事、誤差が減少していき、学習できていることを確認できました。 回答ありがとうございました。
segavvy

2021/01/07 03:39

お役に立てたようでよかったです! Qiitaもご参照いただき、ありがとうございました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.36%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問