teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

1

損失関数を追記しました。

2020/10/08 06:56

投稿

r.kanke
r.kanke

スコア1

title CHANGED
File without changes
body CHANGED
@@ -42,7 +42,7 @@
42
42
  init = tf.initialize_all_variables()
43
43
 
44
44
  ```
45
- 問題の関数全体(プログラム全体は文字数以内に入り切りませんでした(>人<)):
45
+ 該当の関数全体(プログラム全体は文字数以内に入り切りませんでした(>人<)):
46
46
  ```Python
47
47
  def train(learn_rate, report_steps, batch_size, initial_weights=None):
48
48
  """
@@ -155,5 +155,27 @@
155
155
  numpy.savez("weights.npz", *last_weights)
156
156
  return last_weights
157
157
  ```
158
+ 損失関数:
159
+ ```Python
160
+ def get_loss(y, y_):
161
+ # Calculate the loss from digits being incorrect. Don't count loss from
162
+ # digits that are in non-present plates.
163
+ digits_loss = tf.nn.softmax_cross_entropy_with_logits(
164
+ tf.reshape(y[:, 1:],
165
+ [-1, len(common.CHARS)]),
166
+ tf.reshape(y_[:, 1:],
167
+ [-1, len(common.CHARS)]))
168
+ digits_loss = tf.reshape(digits_loss, [-1, 7])
169
+ digits_loss = tf.reduce_sum(digits_loss, 1)
170
+ digits_loss *= (y_[:, 0] != 0)
171
+ digits_loss = tf.reduce_sum(digits_loss)
158
172
 
173
+ # Calculate the loss from presence indicator being wrong.
174
+ presence_loss = tf.nn.sigmoid_cross_entropy_with_logits(
175
+ y[:, :1], y_[:, :1])
176
+ presence_loss = 7 * tf.reduce_sum(presence_loss)
177
+
178
+ return digits_loss, presence_loss, digits_loss + presence_loss
179
+ ```
180
+
159
181
  よろしくお願い致します。