質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.36%
Google Colaboratory

Google Colaboratoryとは、無償のJupyterノートブック環境。教育や研究機関の機械学習の普及のためのGoogleの研究プロジェクトです。PythonやNumpyといった機械学習で要する大方の環境がすでに構築されており、コードの記述・実行、解析の保存・共有などが可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

1281閲覧

Google Colaboratoryで機械学習をしているのですが、エラーの解決策がわかりません。

TsukaJ

総合スコア2

Google Colaboratory

Google Colaboratoryとは、無償のJupyterノートブック環境。教育や研究機関の機械学習の普及のためのGoogleの研究プロジェクトです。PythonやNumpyといった機械学習で要する大方の環境がすでに構築されており、コードの記述・実行、解析の保存・共有などが可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/11/11 07:52

編集2021/11/11 11:28

前提・実現したいこと

Google ColaboratoryでPythonも使って画像分類の機械学習をしようとしています。
CNNモデルを定義した後に、CNNモデルによる画像分類の推論を行おうとしたところ、
以下のエラーメッセージが発生しました。
次元の数が合っていないというものだと思うのですが、どうすればいいのかわかりません。

発生している問題・エラーメッセージ

InvalidType Traceback (most recent call last) <ipython-input-4-140080dec570> in <module>() 81 img, label = test_dataset.get_example(i) 82 ---> 83 pred = model.predictor(np.expand_dims(img, axis=0)) 84 85 pred = F.softmax(pred) 8 frames <ipython-input-4-140080dec570> in __call__(self, x) 57 58 def __call__(self, x): ---> 59 h = F.relu(self.conv1(x)) 60 h = F.max_pooling_2d(self.bn1(h), 2, 2) 61 h = F.relu(self.conv2(x)) /usr/local/lib/python3.7/dist-packages/chainer/link.py in __call__(self, *args, **kwargs) 285 # forward is implemented in the child classes 286 forward = self.forward # type: ignore --> 287 out = forward(*args, **kwargs) 288 289 # Call forward_postprocess hook /usr/local/lib/python3.7/dist-packages/chainer/links/connection/convolution_2d.py in forward(self, x) 249 return convolution_2d.convolution_2d( 250 x, self.W, self.b, self.stride, self.pad, dilate=self.dilate, --> 251 groups=self.groups, cudnn_fast=self.cudnn_fast) 252 253 /usr/local/lib/python3.7/dist-packages/chainer/functions/connection/convolution_2d.py in convolution_2d(x, W, b, stride, pad, cover_all, **kwargs) 656 else: 657 args = x, W, b --> 658 y, = fnode.apply(args) 659 return y /usr/local/lib/python3.7/dist-packages/chainer/function_node.py in apply(self, inputs) 305 306 if configuration.config.type_check: --> 307 self._check_data_type_forward(in_data) 308 309 self.check_layout_forward(input_vars) /usr/local/lib/python3.7/dist-packages/chainer/function_node.py in _check_data_type_forward(self, in_data) 453 in_data, 'in_types', False, shapes=in_shapes) 454 with type_check.get_function_check_context(self): --> 455 self.check_type_forward(in_type) 456 457 def check_type_forward(self, in_types): /usr/local/lib/python3.7/dist-packages/chainer/functions/connection/convolution_2d.py in check_type_forward(self, in_types) 71 x_type.ndim == 4, 72 w_type.ndim == 4, ---> 73 x_type.shape[1] == w_type.shape[1] * self.groups, 74 ) 75 /usr/local/lib/python3.7/dist-packages/chainer/utils/type_check.py in expect(*bool_exprs) 562 for expr in bool_exprs: 563 assert isinstance(expr, Testable) --> 564 expr.expect() 565 566 /usr/local/lib/python3.7/dist-packages/chainer/utils/type_check.py in expect(self) 495 raise InvalidType( 496 '{0} {1} {2}'.format(self.lhs, self.exp, self.rhs), --> 497 '{0} {1} {2}'.format(left, self.inv, right)) 498 499 InvalidType: Invalid operation is performed in: Convolution2DFunction (Forward) Expect: in_types[0].ndim == 4 Actual: 5 != 4

該当のソースコード

Python

1import os 2import numpy as np 3import skimage. io as io 4import chainer 5import chainer.links as L 6import chainer.functions as F 7 8class PreprocessedDataset(chainer.dataset.DatasetMixin): 9 def __init__( 10 self, 11 root_path, 12 split_list 13 ): 14 self.root_path = root_path 15 with open(split_list) as f: 16 self.split_list = [line.rstrip() for line in f] 17 self.dtype = np.float32 18 19 def __len__(self): 20 return len(self.split_list) 21 22 def _get_image(self, i): 23 image = io.imread(os.path.join(self.root_path, self.split_list[i])) 24 image = self._min_max_normalize_one_image(image) 25 return np.expand_dims(image.astype(self.dtype), axis=0) 26 27 def _min_max_normalize_one_image(self, image): 28 max_int = image.max() 29 min_int = image.min() 30 out = (image.astype(np.float32) - min_int) / (max_int - min_int) 31 return out 32 33 def _get_label(self, i): 34 label = 0 if 'false' in self.split_list[i] else 1 35 return label 36 37 def get_example(self, i): 38 x, y = self._get_image(i), self._get_label(i) 39 return x, y 40 41class ClassificationModel(chainer.Chain): 42 43 def __init__(self, n_class=2): 44 super(ClassificationModel, self).__init__() 45 with self.init_scope(): 46 47 self.conv1 = L.Convolution2D(1, 32, 5, 1, 2) 48 self.bn1 = L.BatchNormalization(32) 49 self.conv2 = L.Convolution2D(32, 64, 5, 1, 2) 50 self.bn2 = L.BatchNormalization(64) 51 self.conv3 = L.Convolution2D(64, 128, 3, 1, 1) 52 self.bn3 = L.BatchNormalization(128) 53 self.conv4 = L.Convolution2D(128, 256, 3, 1, 1) 54 self.bn4 = L.BatchNormalization(256) 55 self.fc5 = L.Linear(16384, 1024) 56 self.fc6 = L.Linear(1024, n_class) 57 58 def __call__(self, x): 59 h = F.relu(self.conv1(x))#エラー箇所 60 h = F.max_pooling_2d(self.bn1(h), 2, 2) 61 h = F.relu(self.conv2(x)) 62 h = F.max_pooling_2d(self.bn2(h), 2, 2) 63 h = F.relu(self.conv3(x)) 64 h = F.max_pooling_2d(self.bn3(h), 2, 2) 65 h = F.relu(self.conv4(x)) 66 h = F.max_pooling_2d(self.bn4(h), 2, 2) 67 h = F.dropout(F.relu(self.fc5(h))) 68 return self.fc6(h) 69 70root_path = './dataset_cls' 71split_list = './dataset_cls/split_list/test.txt' 72 73test_dataset = PreprocessedDataset(root_path, split_list) 74 75model = L.Classifier(ClassificationModel(n_class=2)) 76 77print('================') 78for i in range(10): 79 with chainer.using_config('train', False): 80 81 img, label = test_dataset.get_example(i) 82 83 pred = model.predictor(np.expand_dims(img, axis=0)) 84 85 pred = F.softmax(pred) 86 87 print('test {}'.format(i + 1)) 88 print(' pred: {}'.format(np.argmax(pred.data))) 89 print(' label: {}'.format(label)) 90 print('================')

試したこと

Web等で検索をかけて調べてたのですが上手く解決策を見つけることができませんでした。

補足情報(FW/ツールのバージョンなど)

ここにより詳細な情報を記載してください。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

jbpb0

2021/11/11 09:02

掲載されてるエラーメッセージ以外にも、何か表示されてませんでしょうか? できるだけ省略しないで掲載してくれた方が、他人が状況を把握しやすくなります もし、掲載されてるエラーよりも上の方に「Traceback」と書かれてたら、そこから全部載せてください (ここに書くのではなく、質問を編集して追記する)
TsukaJ

2021/11/11 11:30

コメントありがとうございます、jbpb0様のご指摘通り、掲載したエラーメッセージより上にTracebackとありましたので追加修正の方させていただきました。ありがとうございました。
guest

回答1

0

ベストアンサー

(回答修正しました)
コードと見ると、np.expand_dims を2回行っていますね。ネットワークの入力が1chということから、おそらくコードの参考元では1chのグレースケール画像を入力しており、今回試したのはカラーあるいは3chのグレースケール画像を入力しようとしたのではないでしょうか?

おそらくPreprocessedDatasetクラスで画像をロードするときに np.expand_dims を行っており、推論実行するときにまた np.expand_dims を行っているため、入力する画像の shape が (1, 1, h, w, ch) になっています。

ネットワーク構成的に入力を1chにしなければならないので画像を読み込んだ際にグレースケール化すれば問題ないかと思います。

修正したコードを記入しておきます。

import os import numpy as np import skimage. io as io from skimage.color import rgb2gray # 追加(グレースケール用) import chainer import chainer.links as L import chainer.functions as F class PreprocessedDataset(chainer.dataset.DatasetMixin): def __init__( self, root_path, split_list ): self.root_path = root_path with open(split_list) as f: self.split_list = [line.rstrip() for line in f] self.dtype = np.float32 def __len__(self): return len(self.split_list) def _get_image(self, i): image = io.imread(os.path.join(self.root_path, self.split_list[i])) image = self._min_max_normalize_one_image(image) if len(image.shape) == 3: # カラーあるいは3chグレスケ画像の場合は1chグレスケ化 image = rgb2gray(image) # (h, w, ch) -> (h, w) return np.expand_dims(image.astype(self.dtype), axis=0) # (h, w) -> (1, h, w) def _min_max_normalize_one_image(self, image): max_int = image.max() min_int = image.min() out = (image.astype(np.float32) - min_int) / (max_int - min_int) return out def _get_label(self, i): label = 0 if 'false' in self.split_list[i] else 1 return label def get_example(self, i): x, y = self._get_image(i), self._get_label(i) return x, y class ClassificationModel(chainer.Chain): def __init__(self, n_class=2): super(ClassificationModel, self).__init__() with self.init_scope(): self.conv1 = L.Convolution2D(1, 32, 5, 1, 2) self.bn1 = L.BatchNormalization(32) self.conv2 = L.Convolution2D(32, 64, 5, 1, 2) self.bn2 = L.BatchNormalization(64) self.conv3 = L.Convolution2D(64, 128, 3, 1, 1) self.bn3 = L.BatchNormalization(128) self.conv4 = L.Convolution2D(128, 256, 3, 1, 1) self.bn4 = L.BatchNormalization(256) self.fc5 = L.Linear(16384, 1024) self.fc6 = L.Linear(1024, n_class) def __call__(self, x): h = F.relu(self.conv1(x)) ←#エラー箇所 h = F.max_pooling_2d(self.bn1(h), 2, 2) h = F.relu(self.conv2(x)) h = F.max_pooling_2d(self.bn2(h), 2, 2) h = F.relu(self.conv3(x)) h = F.max_pooling_2d(self.bn3(h), 2, 2) h = F.relu(self.conv4(x)) h = F.max_pooling_2d(self.bn4(h), 2, 2) h = F.dropout(F.relu(self.fc5(h))) return self.fc6(h) root_path = './dataset_cls' split_list = './dataset_cls/split_list/test.txt' test_dataset = PreprocessedDataset(root_path, split_list) model = L.Classifier(ClassificationModel(n_class=2)) print('================') for i in range(10): with chainer.using_config('train', False): img, label = test_dataset.get_example(i) pred = model.predictor(np.expand_dims(img, axis=0)) # (1, h, w) -> (1, 1, h, w) pred = F.softmax(pred) print('test {}'.format(i + 1)) print(' pred: {}'.format(np.argmax(pred.data))) print(' label: {}'.format(label)) print('================')

投稿2021/11/18 01:14

編集2021/11/18 01:58
HRCo4

総合スコア140

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

TsukaJ

2021/11/18 04:56

ありがとうございます。回答者様の仰る通り、参考下ではグレースケール画像を使用していましたが、私はカラー画像を使用していたため、上記のエラーが起きていたようでした。途中で、グレースケール画像に変換するコードを入れたところ無事動きました。ありがとうございます!!
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.36%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問