質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

1020閲覧

混合行列を使ってヒートマップ表示したい

shimauma111

総合スコア6

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2019/11/16 14:34

前提・実現したいこと

深層学習をやっていて、学習後の精度確認で、混合行列を使ってヒートマップ形式で精度を確認したいのですが、どのようにコードを作成したらよいのか分からなくて質問させていただきました。
■■な機能を実装中に以下のエラーメッセージが発生しました。

該当のソースコード

# 分類クラス classes = ['apple', 'ball','banana','onigiri'] nb_classes = len(classes) batch_size = 32 nb_epoch = 50 # 画像のサイズ img_rows, img_cols = 224, 224 # モデルの構築 def build_model() : # 画像の読み込み.今回はカラー画像の為,shapeの3番目の引数が3(ch) input_tensor = Input(shape=(img_rows, img_cols, 3)) # VGG16はモデルの名前.引数weightsでpre-trainingしている vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor) # Sequentialは単純に,前の層の全ノードから矢印を引っ張ってくるモデルを意味している # 矢印のつなぎ方を複雑にするには,Functional APIを使う # https://qiita.com/Ishotihadus/items/e28dd461a8ba27a2676e _model = Sequential() _model.add(Flatten(input_shape=vgg16.output_shape[1:])) _model.add(Dense(256, activation='relu')) _model.add(Dropout(0.5)) _model.add(Dense(nb_classes, activation='softmax')) model = Model(inputs=vgg16.input, outputs=_model(vgg16.output)) # modelの14層目までのモデル重み for layer in model.layers[:15]: layer.trainable = False # 損失関数と評価関数を指定 model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=1e-4, momentum=0.9), metrics=['accuracy']) return model if __name__ == "__main__": # ImageDataGeneratorはリアルタイムにデータ拡張しながら,テンソル画像データのバッチを生成する # 要はデータの水増し(Data Augumentation)に関するオプションを指定している # https://keras.io/ja/preprocessing/image/#imagedatagenerator_1 train_datagen = ImageDataGenerator( rescale=1.0 / 255 ) # train_generator: 指定したディレクトリから画像を読み込むときに使用する関数 train_generator = train_datagen.flow_from_directory( directory= 'dataset/train', target_size=(img_rows, img_cols), color_mode='rgb', classes=classes, class_mode='categorical', batch_size=batch_size, shuffle=True) # 評価用画像の用意 test_datagen = ImageDataGenerator(rescale=1.0 / 255) test_generator = test_datagen.flow_from_directory( directory= 'dataset/test', target_size=(img_rows, img_cols), color_mode='rgb', classes=classes, class_mode='categorical', batch_size=batch_size, shuffle=True) # インスタンスの呼び出し model = build_model() # 過学習の抑制 mc = ModelCheckpoint('weights.{epoch:02d}-{loss:.2f}-{acc:.2f}-{val_loss:.2f}-{val_acc:.2f}.h5',monitor="val_loss", verbose=1, save_best_only=True) es = EarlyStopping(monitor='val_loss', patience=10, verbose=1) # Fine-tuning history = model.fit_generator( train_generator, samples_per_epoch=2925, nb_epoch=nb_epoch, validation_data=test_generator, nb_val_samples=975, callbacks=[mc,es] ) #acc, val_accのプロット plt.plot(history.history["acc"], label="acc", ls="-", marker="o") plt.plot(history.history["val_acc"], label="val_acc", ls="-", marker="x") plt.ylabel("accuracy") plt.xlabel("epoch") plt.legend(loc="best") #Final.pngという名前で、結果を保存 plt.savefig('acc.png') plt.show() model.save('vgg16_transfer.h5')

試したこと

import pandas as pd import seaborn as sns from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt def print_cmx(y_true, y_pred): labels = sorted(list(set(y_true))) cmx_data = confusion_matrix(y_true, y_pred, labels=labels) labels= ["tantan","iekei","jirou","shoyu","sio","udon"]#ラベルを付け加える df_cmx = pd.DataFrame(cmx_data, index=labels, columns=labels) plt.figure(figsize = (10,7)) sns.heatmap(df_cmx, annot=True) plt.xlabel("Predict-labels") plt.ylabel("True-labels") plt.show() predict_classes = model.predict_classes(x_test, batch_size=32) true_classes = np.argmax(y_test,1) print_cmx(true_classes,predict_classes)

上記のコードを基に作成したかったのですがうまくいかなかったです。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

投稿2019/11/16 14:40

WathMorks

総合スコア1582

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

shimauma111

2019/11/17 07:19

y_trueとか設定していなんですが、実装できますか? 参考にして試したのですができなかったので、教えていただけると助かります
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問