teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

1

質問の具体化

2021/03/20 10:15

投稿

rest
rest

スコア18

title CHANGED
File without changes
body CHANGED
@@ -3,75 +3,4 @@
3
3
  可能性としては
4
4
  ① rgbaのrの要素が使われる。
5
5
  ② rgbaの画像がグレースケールに変換されて学習される。
6
- の二つを考えています。どちらだと思いますか?
6
+ の二つを考えています。どちらだと思いますか?
7
-
8
- ・参考にしたサイト
9
- [VGG16を転移学習させて「まどか☆マギカ」のキャラを見分ける](https://qiita.com/God_KonaBanana/items/2cf829172087d2423f58)
10
-
11
- ・全文
12
- ```ここに言語を入力
13
- #model&train
14
- from keras.models import Model
15
- from keras.layers import Dense, GlobalAveragePooling2D,Input
16
- from keras.applications.vgg16 import VGG16
17
- from keras.preprocessing.image import ImageDataGenerator
18
- from keras.optimizers import SGD
19
- from keras.callbacks import CSVLogger
20
- import matplotlib.pyplot as plt
21
- import os
22
-
23
- classes = ['hituji','buta','usi']
24
- label=['hituji','buta','usi']
25
- img_height=256
26
- img_width=256
27
- batch_size=16
28
- num_epochs=50
29
- n_categories=3
30
- seed=1
31
- file_name = 'doubutu_bunrui'
32
- print(file_name)
33
-
34
- train_dir==os.path.join('D:','train')
35
-
36
- #model作成(グレースケール)
37
- base_model=VGG16(weights=None,include_top=False,
38
- input_shape=(img_width,img_height,1),
39
- input_tensor=Input(shape=(img_width,img_height,1)),
40
- )
41
-
42
- #add new layers instead of FC networks
43
- x=base_model.output
44
- x=GlobalAveragePooling2D()(x)
45
- x=Dense(1024,activation='relu')(x)
46
- prediction=Dense(n_categories,activation='softmax')(x)
47
- model=Model(inputs=base_model.input,outputs=prediction)
48
- #fix weights
49
- for layer in base_model.layers[:0]:
50
- layer.trainable=False
51
- model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),
52
- loss='categorical_crossentropy',
53
- metrics=['accuracy'])
54
- #save model
55
- json_string=model.to_json()
56
- open(file_name+'.json','w').write(json_string)
57
-
58
- #学習(train)
59
- train_datagen=ImageDataGenerator()
60
- train_generator=train_datagen.flow_from_directory(
61
- train_dir,
62
- target_size=(img_width,img_height),
63
- batch_size=batch_size,
64
- classes=classes,
65
- class_mode='categorical',
66
- color_mode='grayscale'
67
- shuffle=True,
68
- seed=seed
69
- )
70
- #history
71
- history=model.fit_generator(train_generator,
72
- epochs=num_epochs,
73
- verbose=0,
74
- callbacks=[CSVLogger(file_name+'.csv')])
75
- #save weights
76
- model.save(file_name+'.h5')
77
- ```