質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

0回答

1594閲覧

pythonにおけるnnablaでのCNNでエラーが出てしまう件

tomoki_takaba

総合スコア62

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2020/01/14 15:09

この記事を参考に以下のようなCNNのプログラムを書きました。
環境はmacOS X 10.15、python3.7.1です

python

1from __future__ import absolute_import 2from six.moves import range 3 4import os 5 6import nnabla as nn 7# ① NNabla関連モジュールのインポート 8import nnabla as nn 9import nnabla.functions as F 10import nnabla.parametric_functions as PF 11import nnabla.solvers as S 12from nnabla.utils.data_iterator import data_iterator_simple 13import nnabla.initializer as I 14 15# ② NNabla関連以外のモジュールのインポート 16import numpy as np 17from sklearn.datasets import load_digits 18from PIL import Image 19 20 21# ③ 学習データを読み込み使えるようにする 22 23 24def data_iterator_tiny_digits(digits, batch_size=64, shuffle=False, rng=None): 25 def load_func(index): 26 data_array = np.empty((0,8),float) 27 target_arrry = np.array([]) 28 main_folder = os.listdir("./data/") 29 for sub_folder in main_folder: 30 data_list = os.listdir("./data/" + sub_folder) 31 for data in data_list: 32 img = Image.open("./data/" + sub_folder + "/" + data) 33 fix_img = img.convert('L') 34 data_grey = np.asarray(fix_img) 35 target_arrry = np.append(target_arrry,sub_folder) 36 data_array = np.append(data_array,data_grey,axis=0) 37 return data_array,np.array([target_arrry]).astype(np.int32) 38 return data_iterator_simple(load_func, 200, batch_size, shuffle, rng, with_file_cache=False) 39 40# ④ 損失グラフを構築する関数を定義する 41 42 43def logreg_loss(y, t): 44 loss_f = F.mean(F.softmax_cross_entropy(y, t)) 45 return loss_f 46 47# ⑤ トレーニング関数を定義する 48 49 50def training(xt, tt, data_t, loss_t, steps, learning_rate): 51 solver = S.Sgd(learning_rate) 52 # Set parameter variables to be updatd. 53 solver.set_parameters(nn.get_parameters()) 54 for i in range(steps): 55 xt.d, tt.d = data_t.next() 56 loss_t.forward() 57 solver.zero_grad() # Initialize gradients of all parameters to zero. 58 loss_t.backward() 59 # Applying weight decay as an regularization 60 solver.weight_decay(1e-5) 61 solver.update() 62 if i % 100 == 0: # Print for each 10 iterations 63 print(str(i) + ":" + str(loss.d)) 64 65# ⑥ ニューラルネットを構築する関数を定義する 66 67 68def network(x): 69 initializer = I.UniformInitializer((-0.1, 0.1)) 70 with nn.parameter_scope("cnn"): 71 with nn.parameter_scope("conv1"): 72 h = F.relu(PF.batch_normalization( 73 PF.convolution(x, 4, (3, 3), pad=(1, 1), stride=(2, 2),w_init=initializer, with_bias=False))) 74 with nn.parameter_scope("conv2"): 75 h = F.relu(PF.batch_normalization( 76 PF.convolution(h, 8, (3, 3), pad=(1, 1)))) 77 h = F.average_pooling(h, (2, 2)) 78 with nn.parameter_scope("fc3"): 79 h = F.relu(PF.affine(h, 32)) 80 with nn.parameter_scope("classifier"): 81 h = PF.affine(h, 10) 82 return h 83 84 85# ⑦ 実行開始:scikit_learnでdigits(8✕8サイズ)データを取得し、NNablaで処理可能に整形する 86np.random.seed(0) 87digits = 0 88data = data_iterator_tiny_digits(digits,batch_size=64, shuffle=True) 89 90# ⑧ ニューラルネットワークを構築する 91nn.clear_parameters() 92img, label = data.next() 93x = nn.Variable(img.shape) 94y = network(x) 95t = nn.Variable(label.shape) 96loss = logreg_loss(y, t) 97 98# ⑨ 学習する 99learning_rate = 1e-1 100training(x, t, data, loss, 1000, learning_rate) 101 102# ⑩ 推論し、最後に正確さを求めて表示する 103x.d, t.d = data.next() 104y.forward() 105mch = 0 106for p in range(len(t.d)): 107 if t.d[p] == y.d.argmax(axis=1)[p]: 108 mch += 1 109 110print("Accuracy:{}".format(mch / len(t.d))) 111

しかし、実行すると以下のようなエラーが出てしまいます

bash

12020-01-14 21:13:45,163 [nnabla][INFO]: Initializing CPU extension... 22020-01-14 21:13:48,083 [nnabla][INFO]: DataSource with shuffle(True) 32020-01-14 21:13:48,524 [nnabla][INFO]: Using DataSourceWithMemoryCache 42020-01-14 21:13:48,524 [nnabla][INFO]: DataSource with shuffle(True) 52020-01-14 21:13:48,612 [nnabla][INFO]: On-memory 62020-01-14 21:13:48,613 [nnabla][INFO]: Using DataIterator 7Traceback (most recent call last): 8 File "file_losd_test.py", line 94, in <module> 9 y = network(x) 10 File "file_losd_test.py", line 73, in network 11 PF.convolution(x, 4, (3, 3), pad=(1, 1), stride=(2, 2),w_init=initializer, with_bias=False))) 12 File "<string>", line 5, in convolution 13 File "/Users/takabatomoki/.pyenv/versions/3.7.1/lib/python3.7/site-packages/nnabla/parametric_functions.py", line 639, in convolution 14 return F.convolution(inp, w, b, base_axis, pad, stride, dilation, group, channel_last) 15 File "<convolution>", line 3, in convolution 16 File "/Users/takabatomoki/.pyenv/versions/3.7.1/lib/python3.7/site-packages/nnabla/function_bases.py", line 352, in convolution 17 return F.Convolution(ctx, base_axis, pad, stride, dilation, group, channel_last)(*inputs, n_outputs=n_outputs, auto_forward=get_auto_forward(), outputs=outputs) 18 File "function.pyx", line 292, in nnabla.function.Function.__call__ 19 File "function.pyx", line 271, in nnabla.function.Function._cg_call 20RuntimeError: value error in setup_impl 21/Users/gitlab-runner/builds/9703d983/0/nnabla/builders/all/nnabla/src/nbla/function/./generic/convolution.cpp:51 22Failed `shape_weights.size() == 2 + spatial_dims_`: Weights must be a tensor more than 3D.

初心者なので疎いのですが、どうにも重みが規定値を超えなければならないそうなのですが、どうすればいいのでしょうか?
ご教授いただけると幸いです。
よろしくお願い申し上げます。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問