anacondaで仮想環境を作り、そこでtensorflowとkerasを使って機械学習をしてみたかったので、ネットに転がっていた「自分で用意した画像を使った機械学習のコード」を少し手を加えて(ニューラルネットワーク部分を自分で考えたものにした)実行してみました。すると、実行結果の欄には自分が用意した画像のファイル名がずらりと並んでいき、うまくいくと思いましたが、次のようなエラー文が出ました。
#エラー文
AttributeError Traceback (most recent call last) <ipython-input-9-ddcfcffbcd49> in <module> 59 # Squeeze 60 model = Sequential() ---> 61 model.add(GlobalAveragePooling2D()(input)) 62 model.add(Dense(200,activation="relu")) 63 model.add(Dense(200,activation="sigmoid")) ~\anaconda3\envs\ban\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in __call__(self, inputs, *args, **kwargs) 892 # Eager execution on data tensors. 893 with backend.name_scope(self._name_scope()): --> 894 self._maybe_build(inputs) 895 cast_inputs = self._maybe_cast_inputs(inputs) 896 with base_layer_utils.autocast_context_manager( ~\anaconda3\envs\ban\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in _maybe_build(self, inputs) 2125 if not self.built: 2126 input_spec.assert_input_compatibility( -> 2127 self.input_spec, inputs, self.name) 2128 input_list = nest.flatten(inputs) 2129 if input_list and self._dtype_policy.compute_dtype is None: ~\anaconda3\envs\ban\lib\site-packages\tensorflow_core\python\keras\engine\input_spec.py in assert_input_compatibility(input_spec, inputs, layer_name) 161 spec.min_ndim is not None or 162 spec.max_ndim is not None): --> 163 if x.shape.ndims is None: 164 raise ValueError('Input ' + str(input_index) + ' of layer ' + 165 layer_name + ' is incompatible with the layer: ' AttributeError: 'function' object has no attribute 'shape'
エラーがおきたコードの全文を下に載せます
#ソースコード
import
1from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, AveragePooling2D, GlobalAveragePooling2D, Dense, Multiply, Input 2 3from keras.activations import linear 4from keras.utils import to_categorical 5import numpy as np 6import os, pickle, zipfile, glob 7from keras.models import Sequential 8from keras.layers import Activation, Dense, Dropout 9from keras.utils.np_utils import to_categorical 10from keras.optimizers import Adagrad 11from keras.optimizers import Adam 12import numpy as np 13from PIL import Image 14import os 15 16image_list = [] 17label_list = [] 18 19for dir in os.listdir("data/training"): 20 if dir == ".DS_Store": 21 continue 22 dir1 = "data/training/" + dir 23 label = 0 24 if dir == "no": 25 label = 0 26 elif dir == "ok": 27 label = 1 28 for file in os.listdir(dir1): 29 if file != ".DS_Store": 30 label_list.append(label) 31 filepath = dir1 + "/" + file 32 image = np.array(Image.open(filepath).resize((25, 25))) 33 print(filepath) 34 image = image.transpose(2, 0, 1) 35 image = image.reshape(1, image.shape[0] * image.shape[1] * image.shape[2]).astype("float32")[0] 36 image_list.append(image / 255.) 37image_list = np.array(image_list) 38Y = to_categorical(label_list) 39model = Sequential() 40model.add(GlobalAveragePooling2D()(input)) 41model.add(Dense(200,activation="relu")) 42model.add(Dense(200,activation="sigmoid")) 43 44opt = Adam(lr=0.001) 45model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"]) 46model.fit(image_list, Y, nb_epoch=1500, batch_size=100, validation_split=0.1) 47 48total = 0. 49ok_count = 0. 50 51for dir in os.listdir("data/training"): 52 if dir == ".DS_Store": 53 continue 54 dir1 = "data/validation/" + dir 55 label = 0 56 if dir == "no": 57 label = 0 58 elif dir == "ok": 59 label = 1 60 for file in os.listdir(dir1): 61 if file != ".DS_Store": 62 label_list.append(label) 63 filepath = dir1 + "/" + file 64 image = np.array(Image.open(filepath).resize((25, 25))) 65 print(filepath) 66 image = image.transpose(2, 0, 1) 67 image = image.reshape(1, image.shape[0] * image.shape[1] * image.shape[2]).astype("float32")[0] 68 result = model.predict_classes(np.array([image / 255.])) 69 print("label:", label, "result:", result[0]) 70 total += 1. 71 if label == result[0]: 72 ok_count += 1. 73 74print("seikai: ", ok_count / total * 100, "%") 75 76from IPython.display import SVG 77from keras.utils.vis_utils import model_to_dot 78SVG(model_to_dot(model).create(prog='dot', format='svg'))ここに言語を入力
コードの中にはfunctionという部分は書いてありませんし、一切手を加えていない元のコードを実行した際にはこのようなエラーは発生しませんでした。手詰まりとなり、わからないのは以下の点です
1:何が原因で起きたエラーなのか?
2:何をどうすれば解消できるのか?
3:functionという単語はどこから来たのか?
ソースコードで何をしているか、全然わかっていない初心者です。よろしくお願いします
回答1件
あなたの回答
tips
プレビュー