質問編集履歴

1

質問の具体化

2021/03/20 10:15

投稿

rest
rest

スコア18

test CHANGED
File without changes
test CHANGED
@@ -12,138 +12,6 @@
12
12
 
13
13
 
14
14
 
15
- ・全文
16
-
17
- ```ここに言語を入力
18
-
19
- #model&train
20
-
21
- from keras.models import Model
22
-
23
- from keras.layers import Dense, GlobalAveragePooling2D,Input
24
-
25
- from keras.applications.vgg16 import VGG16
26
-
27
- from keras.preprocessing.image import ImageDataGenerator
28
-
29
- from keras.optimizers import SGD
30
-
31
- from keras.callbacks import CSVLogger
32
-
33
- import matplotlib.pyplot as plt
34
-
35
- import os
36
-
37
-
38
-
39
- classes = ['hituji','buta','usi']
40
-
41
- label=['hituji','buta','usi']
42
-
43
- img_height=256
44
-
45
- img_width=256
46
-
47
- batch_size=16
48
-
49
- num_epochs=50
50
-
51
- n_categories=3
52
-
53
- seed=1
54
-
55
- file_name = 'doubutu_bunrui'
56
-
57
- print(file_name)
58
-
59
-
60
-
61
- train_dir==os.path.join('D:','train')
62
-
63
-
64
-
65
- #model作成(グレースケール)
66
-
67
- base_model=VGG16(weights=None,include_top=False,
68
-
69
- input_shape=(img_width,img_height,1),
70
-
71
- input_tensor=Input(shape=(img_width,img_height,1)),
72
-
73
- )
74
-
75
-
76
-
77
- #add new layers instead of FC networks
78
-
79
- x=base_model.output
80
-
81
- x=GlobalAveragePooling2D()(x)
82
-
83
- x=Dense(1024,activation='relu')(x)
84
-
85
- prediction=Dense(n_categories,activation='softmax')(x)
86
-
87
- model=Model(inputs=base_model.input,outputs=prediction)
88
-
89
- #fix weights
90
-
91
- for layer in base_model.layers[:0]:
92
-
93
- layer.trainable=False
94
-
95
- model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),
96
-
97
- loss='categorical_crossentropy',
98
-
99
- metrics=['accuracy'])
100
-
101
- #save model
102
-
103
- json_string=model.to_json()
104
-
105
- open(file_name+'.json','w').write(json_string)
106
-
107
-
108
-
109
- #学習(train)
110
-
111
- train_datagen=ImageDataGenerator()
112
-
113
- train_generator=train_datagen.flow_from_directory(
114
-
115
- train_dir,
116
-
117
- target_size=(img_width,img_height),
118
-
119
- batch_size=batch_size,
120
-
121
- classes=classes,
122
-
123
- class_mode='categorical',
124
-
125
- shuffle=True,
126
-
127
- seed=seed
128
-
129
- )
130
-
131
- #history
132
-
133
- history=model.fit_generator(train_generator,
134
-
135
- epochs=num_epochs,
136
-
137
- verbose=0,
138
-
139
- callbacks=[CSVLogger(file_name+'.csv')])
140
-
141
- #save weights
142
-
143
- model.save(file_name+'.h5')
144
-
145
- ```
146
-
147
15
  ・エラー文
148
16
 
149
17
  ```ここに言語を入力