質問編集履歴

1

損失関数を追記しました。

2020/10/08 06:56

投稿

r.kanke
r.kanke

スコア1

test CHANGED
File without changes
test CHANGED
@@ -86,7 +86,7 @@
86
86
 
87
87
  ```
88
88
 
89
- 問題の関数全体(プログラム全体は文字数以内に入り切りませんでした(>人<)):
89
+ 該当の関数全体(プログラム全体は文字数以内に入り切りませんでした(>人<)):
90
90
 
91
91
  ```Python
92
92
 
@@ -312,6 +312,50 @@
312
312
 
313
313
  ```
314
314
 
315
+ 損失関数:
316
+
317
+ ```Python
318
+
319
+ def get_loss(y, y_):
320
+
321
+ # Calculate the loss from digits being incorrect. Don't count loss from
322
+
323
+ # digits that are in non-present plates.
324
+
325
+ digits_loss = tf.nn.softmax_cross_entropy_with_logits(
326
+
327
+ tf.reshape(y[:, 1:],
328
+
329
+ [-1, len(common.CHARS)]),
330
+
331
+ tf.reshape(y_[:, 1:],
332
+
333
+ [-1, len(common.CHARS)]))
334
+
335
+ digits_loss = tf.reshape(digits_loss, [-1, 7])
336
+
337
+ digits_loss = tf.reduce_sum(digits_loss, 1)
338
+
339
+ digits_loss *= (y_[:, 0] != 0)
340
+
341
+ digits_loss = tf.reduce_sum(digits_loss)
342
+
343
+
344
+
345
+ # Calculate the loss from presence indicator being wrong.
346
+
347
+ presence_loss = tf.nn.sigmoid_cross_entropy_with_logits(
348
+
349
+ y[:, :1], y_[:, :1])
350
+
351
+ presence_loss = 7 * tf.reduce_sum(presence_loss)
352
+
353
+
354
+
355
+ return digits_loss, presence_loss, digits_loss + presence_loss
356
+
357
+ ```
358
+
315
359
 
316
360
 
317
361
  よろしくお願い致します。