前提・実現したいこと
深層学習をkeras/tensorflowで行っているものです
狙ったタイミングで深層学習の損失に付加する重みを変えるためにtf.keras.losses.Lossのサブクラスを使って
自作関数を作成しています
計算ごとに関数内でself.stepを積み上げ、epochが5になったら変化するようにしたいのですが、
うまくいっておりません
(下記ではわかりやすいように0割して発散するようにしている)
狙ったエポック数で深層学習の損失に付加する重みを変えることができるよう、下記のコードをどのように
修正してよいかご教示いただけますと幸いです。
該当のソースコード
python
1class CategoricalCrossentropy1(keras.losses.Loss): 2 def __init__(self, steps_per_epoch=None, end_epoch=None, name="example"): 3 super().__init__(name=name) 4 self.steps_per_epoch, self.end_epoch = steps_per_epoch, end_epoch 5 6 self.step = 0 7 def call(self, y_true, y_pred): 8 9 epoch = self.step/self.steps_per_epoch 10 11 if epoch < self.end_epoch: 12 weight=1 13 else: 14 weight=1/0 #Note that extreme number was put here to check how this code work 15 16 loss = -tf.reduce_sum(weight*y_true*tf.math.log(y_pred)) #impose weight on CategoricalCrossentropy 17 18 self.step += 1 # add to step 19 20 return loss 21 22################ main part 23batch_size = 10 24steps_per_epoch = 4 25end_epoch = 5 26 27#... 28model.compile(loss=CategoricalCrossentropy1(steps_per_epoch, end_epoch), 29 optimizer=keras.optimizers.SGD(),) 30 31#... 32model.fit(generater_train, 33 batch_size = batch_size, 34 steps_per_epoch = steps_per_epoch, 35 epochs=100000, 36 )
結果
Epoch 5 付近でlossが発散(nan)してほしいのですが、計算が続行されてしまいます
Epoch 1/100000 4/4 [==============================] - 72s 3s/step - loss: 2.9023 Epoch 2/100000 4/4 [==============================] - 11s 3s/step - loss: 2.0072 Epoch 3/100000 4/4 [==============================] - 11s 3s/step - loss: 1.7888 Epoch 4/100000 4/4 [==============================] - 11s 3s/step - loss: 2.5708 Epoch 5/100000 4/4 [==============================] - 11s 3s/step - loss: 2.0668 Epoch 6/100000 4/4 [==============================] - 11s 3s/step - loss: 1.9443 Epoch 7/100000 4/4 [==============================] - 11s 3s/step - loss: 2.7553 Epoch 8/100000 4/4 [==============================] - 11s 3s/step - loss: 2.6547 Epoch 9/100000 4/4 [==============================] - 11s 3s/step - loss: 2.7103 Epoch 10/100000 3/4 [=====================>........] - ETA: 2s - loss: 2.3025
その他
stackoverflowで同様の質問中
リンク
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2021/07/30 01:17
2021/07/30 01:37
2021/07/30 01:42
2021/07/30 02:17