質問編集履歴
1
損失関数を追記しました。
test
CHANGED
File without changes
|
test
CHANGED
@@ -86,7 +86,7 @@
|
|
86
86
|
|
87
87
|
```
|
88
88
|
|
89
|
-
|
89
|
+
該当の関数全体(プログラム全体は文字数以内に入り切りませんでした(>人<)):
|
90
90
|
|
91
91
|
```Python
|
92
92
|
|
@@ -312,6 +312,50 @@
|
|
312
312
|
|
313
313
|
```
|
314
314
|
|
315
|
+
損失関数:
|
316
|
+
|
317
|
+
```Python
|
318
|
+
|
319
|
+
def get_loss(y, y_):
|
320
|
+
|
321
|
+
# Calculate the loss from digits being incorrect. Don't count loss from
|
322
|
+
|
323
|
+
# digits that are in non-present plates.
|
324
|
+
|
325
|
+
digits_loss = tf.nn.softmax_cross_entropy_with_logits(
|
326
|
+
|
327
|
+
tf.reshape(y[:, 1:],
|
328
|
+
|
329
|
+
[-1, len(common.CHARS)]),
|
330
|
+
|
331
|
+
tf.reshape(y_[:, 1:],
|
332
|
+
|
333
|
+
[-1, len(common.CHARS)]))
|
334
|
+
|
335
|
+
digits_loss = tf.reshape(digits_loss, [-1, 7])
|
336
|
+
|
337
|
+
digits_loss = tf.reduce_sum(digits_loss, 1)
|
338
|
+
|
339
|
+
digits_loss *= (y_[:, 0] != 0)
|
340
|
+
|
341
|
+
digits_loss = tf.reduce_sum(digits_loss)
|
342
|
+
|
343
|
+
|
344
|
+
|
345
|
+
# Calculate the loss from presence indicator being wrong.
|
346
|
+
|
347
|
+
presence_loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
348
|
+
|
349
|
+
y[:, :1], y_[:, :1])
|
350
|
+
|
351
|
+
presence_loss = 7 * tf.reduce_sum(presence_loss)
|
352
|
+
|
353
|
+
|
354
|
+
|
355
|
+
return digits_loss, presence_loss, digits_loss + presence_loss
|
356
|
+
|
357
|
+
```
|
358
|
+
|
315
359
|
|
316
360
|
|
317
361
|
よろしくお願い致します。
|