質問編集履歴
1
損失関数を追記しました。
title
CHANGED
File without changes
|
body
CHANGED
@@ -42,7 +42,7 @@
|
|
42
42
|
init = tf.initialize_all_variables()
|
43
43
|
|
44
44
|
```
|
45
|
-
|
45
|
+
該当の関数全体(プログラム全体は文字数以内に入り切りませんでした(>人<)):
|
46
46
|
```Python
|
47
47
|
def train(learn_rate, report_steps, batch_size, initial_weights=None):
|
48
48
|
"""
|
@@ -155,5 +155,27 @@
|
|
155
155
|
numpy.savez("weights.npz", *last_weights)
|
156
156
|
return last_weights
|
157
157
|
```
|
158
|
+
損失関数:
|
159
|
+
```Python
|
160
|
+
def get_loss(y, y_):
|
161
|
+
# Calculate the loss from digits being incorrect. Don't count loss from
|
162
|
+
# digits that are in non-present plates.
|
163
|
+
digits_loss = tf.nn.softmax_cross_entropy_with_logits(
|
164
|
+
tf.reshape(y[:, 1:],
|
165
|
+
[-1, len(common.CHARS)]),
|
166
|
+
tf.reshape(y_[:, 1:],
|
167
|
+
[-1, len(common.CHARS)]))
|
168
|
+
digits_loss = tf.reshape(digits_loss, [-1, 7])
|
169
|
+
digits_loss = tf.reduce_sum(digits_loss, 1)
|
170
|
+
digits_loss *= (y_[:, 0] != 0)
|
171
|
+
digits_loss = tf.reduce_sum(digits_loss)
|
158
172
|
|
173
|
+
# Calculate the loss from presence indicator being wrong.
|
174
|
+
presence_loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
175
|
+
y[:, :1], y_[:, :1])
|
176
|
+
presence_loss = 7 * tf.reduce_sum(presence_loss)
|
177
|
+
|
178
|
+
return digits_loss, presence_loss, digits_loss + presence_loss
|
179
|
+
```
|
180
|
+
|
159
181
|
よろしくお願い致します。
|