TensorFlowでカスタム訓練ループをfitに組み込むための便利な書き方
https://qiita.com/koshian2/items/4b4c1c10b36fa4b03b3e
こちらの書き方を参考にTensorflowの勉強をしています。
今別の問題でtrain_stepで入力データにdata augmentationを施したいと思っています。
Python
1 def train_step(self, data): 2 low_res_input, high_res_gt = data
例えばここの、low_res_inputに、rotationやaffine変換を適用するイメージです。
しかしそのままskimageで書いた画像処理を入れてもTensor形式の入力を求められて実行できないようです。
Python
1def img_processing(imgs): 2 x_ = np.zeros(imgs.shape) 3 print("img processing") 4 for i in range(imgs.shape[0]): 5 y = imgs[i].copy() 6 7 # Rotation 8 deg = 20 9 angle = np.random.randint(-deg, deg) 10 y = rotate(y, angle) 11 12 x_[i] = y 13 14 return x_
Python
1>x_ = np.zeros(imgs.shape) 2 3TypeError: 'NoneType' object cannot be interpreted as an integer
この画像処理の部分をtf.imageの関数で書くこともできるようなのですが、
skimageの関数がすべてあるようではないので、できればskimageのまま実行したいです。
いろいろと書きましたがお聞きしたいのは
・そもそもtrain_stepでのDataAugmentationの記載は正しいのでしょうか?
・上記記法でのEpochごとのDataAugmentationの実行方法
になります。
よろしくお願いします。
あなたの回答
tips
プレビュー