Kerasの損失関数を自作したいがエラーが生じる。
Kerasを用いて自作の損失関数を用いたいと考えています。
y_predとy_trueをtensor型で取り出し、y_trueが0から1までの連続値を5つ持っているとします。
y_true = [[0, 0.2, 0.5, 0.7, 1], [0.1, 0.6, 0.7, 1, 1], ・・・]のような形です。
ここでy_trueの要素の値が0.5以上の場合のみだけ損失関数の計算の対象として、y_trueが0.5未満の場合は損失関数の計算からは除外するような損失関数を作成したいと考えています。
そのためy_trueの要素1つ1つをif文で0.5以上かどうが判断して計算しようと思ったのですが、y_pred, y_trueの要素を1つずつ取り出すことができません。
(上の例でいくと、0, 0.2, 0.5...という風に1つずつ順に取り出したいです。)
y_true, y_predはともにshapeは(None,5)であり、flattenを用いて1次元に落としこみy_true[i]のような形で参照しようと試みましたが不可でした。Noneなので[i]で参照できないのは想像容易いですが、他に方法が思いつかない状態です。
損失関数であるため、listやnumpyに変更することもできないようです...。
解決法を探しましたがどれもy_true, y_predを要素ごとに参照している例がなく困っています。
解決法のわかる方おられましたらお願いします。
発生している問題・エラーメッセージ
TypeError: 'TensorShape' object cannot be interpreted as an integer
該当のソースコード
python
1def loss_function(y_true, y_pred, total=0, num_over_half=0): 2 y_true = torch.flatten(y_true) 3 y_pred = torch.flatten(y_pred) 4 for i in range(y_true.shape): 5 if y_true[i] < 0.5: 6 pass 7 else: 8 error = y_true[i] - y_pred[i] 9 total += error**2 10 num_over_half += 1 11 MSE_custom = total / num_over_half 12 13 return MSE_custom
補足情報(FW/ツールのバージョンなど)
python 3.8.5
tensorflow 2.3.0
keras 2.4.3
あなたの回答
tips
プレビュー