質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
87.20%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

解決済

[python]ディープラーニングで、学習モデル・重みデータを保存したい

nariho
nariho

総合スコア0

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

1回答

0評価

0クリップ

1976閲覧

投稿2018/11/13 04:57

編集2022/01/12 10:58

[python]ディープラーニングで、学習モデル・重みデータを保存したい

私は高専の機械科に所属している学生で、卒業研究でpythonを使用しているのですが、数ヶ月あたりに始めたのでよくわかりません。
研究内容は「ディープラーニングを用いて加工形状を分類する」というもので、既にExtraTreesClassifierという学習手法を使って学習は出来ているのですが、学習モデルと重みデータを保存しようとするとエラーが出てしまいます。

#エラーメッセージ

run trial_senban_ETC2.py ./data/learn900 ./data/test100
0.9
Traceback (most recent call last):

File "C:\Users\kousaku25\Desktop\narita\senban\trial_senban_ETC2.py", line 113, in <module>
model.save('senban_model.pkl')

AttributeError: 'ExtraTreesClassifier' object has no attribute 'save'

該当のソースコード

import os
import glob
import sys
import keras
import numpy as np
from skimage import io
from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.ensemble import ExtraTreesClassifier
from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Dense, Dropout, Flatten
from keras.optimizers import RMSprop

IMAGE_SIZE1 = 84
IMAGE_SIZE2 = 124
COLOR_BYTE = 3
CATEGORY_NUM = 6
nb_classes = 2

def load_senbanimage(path):

files = glob.glob(os.path.join(path, '*/*.bmp')) images = np.ndarray((len(files), IMAGE_SIZE1, IMAGE_SIZE2, COLOR_BYTE), dtype = np.uint8) labels = np.ndarray(len(files), dtype=np.int) for idx, file in enumerate(files): image = io.imread(file) images[idx] = image label = os.path.split(os.path.dirname(file))[-1] labels[idx] = int(label) flat_data = images.reshape((-1, IMAGE_SIZE1 * IMAGE_SIZE2 * COLOR_BYTE)) images = flat_data.view() return datasets.base.Bunch(data=flat_data, target=labels.astype(np.int), target_names=np.arange(CATEGORY_NUM), images=images, DESCR=None) labels = keras.utils.np_utils.to_categorical(labels, nb_classes) in_size = IMAGE_SIZE1 * IMAGE_SIZE2 model = Sequential() model.add(Convolution2D(32, 3, 3, input_shape=(in_size))) model.add(Activation=('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Convolution2D(32, 3, 3)) model.add(Activation=('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Convolution2D(64, 3, 3,)) model.add(Activation=('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Flatten()) model.add(Dense(2)) model.add(Activation=('relu')) model.add(Dropout(0.5)) model.add(Dense(1)) model.add(Activation=('sigmoid')) model.compile( loss='binary_crossentropy', optimizer=RMSprop(), metrics=['accuracy'])

if name == 'main':
argvs = sys.argv
train_path = argvs[1]
test_path = argvs[2]

train = load_senbanimage(train_path) model = ExtraTreesClassifier(n_estimators=20, random_state=42) model.fit(train.data, train.target) test = load_senbanimage(test_path) predicted = model.predict(test.data) print (accuracy_score(test.target, predicted)) model.save('senban_model.pkl') model.save_weights('senban_weight.h5')

補足

最適な学習手法が明確ではなかったためall_estimatorsメソッドを用いて最も正答率の高かったExtraTreesClassifierを採用しましたが、エラー文を見た感じそこに問題があるのでしょうか?

知識が浅く、参考書やネットで見たものをツギハギしたため、おかしい点があると思いますが、アドバイスいただけると幸いです。
よろしくお願いいたします。

良い質問の評価を上げる

以下のような質問は評価を上げましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

  • プログラミングに関係のない質問
  • やってほしいことだけを記載した丸投げの質問
  • 問題・課題が含まれていない質問
  • 意図的に内容が抹消された質問
  • 過去に投稿した質問と同じ内容の質問
  • 広告と受け取られるような投稿

評価を下げると、トップページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

ktaro99
ktaro99

2018/11/13 05:01

コードブロックを使用すると見やすくなりますよ。「```」のあとに改行をしてコードを貼り付けて「```」で閉じることで使えます。「```」のあとに「Python」などそのコードに使用されている言語を記入した後に改行するとコードブロックに「Python」などのラベルがつきます。
nariho
nariho

2018/11/13 05:06

ご指摘いただきありがとうございます。参考になりました。

まだ回答がついていません

会員登録して回答してみよう

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
87.20%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。