カスタムする学習方法を作成するのは初めてで、こちらの方の記事も参考にしたのですが、(「TensorFlow2.0でDistribute Trainingしたときにfitと訓練ループで精度が違ってハマった話」https://blog.shikoan.com/tf20-distribute-error/)
このコードで間違っている部分と、改善策などを教えていただきたいです。(keras + tensorflowを用いております。)
1import tensorflow as tf 2import os 3from tensorflow.keras import layers 4import pickle 5strategy = tf.distribute.get_strategy() 6 7x = np.load(---) 8x_ = np.load(---) 9x_test = np.load(---) 10x_test_ = np.load(---) 11 12class Model(layers.Layer): 13 14 def create_model(self): 15 inputs = layers.Input((100,100,3)) 16 17 x = layers.Conv2D(32, 3, padding="same")(inputs) 18 x = layers.BatchNormalization()(x) 19 x = layers.Activation("relu")(x) 20 x = layers.MaxPooling2D(2)(x) 21 22 x = layers.Conv2D(64, 3, padding="same")(x) 23 x = layers.BatchNormalization()(x) 24 x = layers.Activation("relu")(x) 25 x = layers.MaxPooling2D(2)(x) 26 27 x = layers.Conv2D(128, 3, padding="same")(x) 28 x = layers.BatchNormalization()(x) 29 x = layers.Activation("relu")(x) 30 x = layers.MaxPooling2D(2)(x) 31 32 x = layers.Conv2D(256, 3, padding="same")(x) 33 x = layers.BatchNormalization()(x) 34 x = layers.Activation("relu")(x) 35 conv_out = layers.GlobalAveragePooling2D()(x) 36 37 x = layers.Dense(128, activation="relu")(conv_out) 38 x = layers.Dense(64, activation="relu")(x) 39 Z = layers.Dense(10, activation="softmax")(x) 40 41 x = layers.Dense(128, activation="relu")(conv_out) 42 x = layers.Dense(64, activation="relu")(x) 43 overclustering = layers.Dense(50, activation="softmax")(x) 44 45 return Model(inputs, [Z, overclustering]) 46 47 48def load_dataset(batch_size): 49 (X_train, y_train), (X_test, y_test) = (x,x_),(x_test,x_test_) 50 trainset = (X_train, y_train) 51 testset = (X_test, y_test) 52 return trainset, testset 53 54def main(): 55 batch_size = 10 56 trainset, valset = load_dataset(batch_size) 57 result = {"val_acc": [], "lr": []} 58 with strategy.scope(): 59 model = Model() 60 optim = tf.keras.optimizers.Adam(0.0001) 61 62 # Distributed用のデータセットの変換を行う 63 trainset = strategy.experimental_distribute_dataset(trainset) 64 valset = strategy.experimental_distribute_dataset(valset) 65 66 def IIC(self, z, z_, c=10): 67 z = tf.reshape(z, [-1, c, 1]) 68 z_ = tf.reshape(z_, [-1, 1, c]) 69 P = tf.math.reduce_sum(z * z_, axis=0) # 同時確率 70 P = (P + tf.transpose(P)) / 2 # 対称化 71 P = tf.clip_by_value(P, 1e-7, tf.float32.max) # logが発散しないようにバイアス 72 P = P / tf.math.reduce_sum(P) # 規格化 73 74 # 周辺確率 75 Pi = tf.math.reduce_sum(P, axis=0) 76 Pi = tf.reshape(Pi, [c, 1]) 77 Pi = tf.tile(Pi, [1,c]) 78 Pj = tf.math.reduce_sum(P, axis=1) 79 Pj = tf.reshape(Pj, [1, c]) 80 Pj = tf.tile(Pj, [c,1]) 81 82 loss = tf.math.reduce_sum(P * (tf.math.log(Pi) + tf.math.log(Pj) - tf.math.log(P))) 83 84 return loss 85 acc = tf.keras.metrics.SparseCategoricalAccuracy() 86 87 88 @tf.function 89 def train_on_batch( X, X_): 90 with tf.GradientTape() as tape: 91 z, overclustering = model(X, training=True) 92 z_, overclustering_ = model(X_, training=True) 93 loss_cluster = IIC(z, z_) 94 loss_overclustering = IIC(overclustering, overclustering_, c=50) 95 96 loss = (loss_cluster + loss_overclustering) / 2 97 98 graidents = tape.gradient(loss, self.model.trainable_weights) 99 self.optim.apply_gradients(zip(graidents, self.model.trainable_weights)) 100 return loss_cluster, loss_overclustering 101 102 103 def validation_on_batch(X, y_true): 104 y_pred = model(X, training=False) 105 acc.update_state(y_true, y_pred) 106 107 108 for i in range(100): 109 acc.reset_states() 110 print("Epoch = ", i) 111 for X, y in trainset: 112 train_on_batch(X, y) 113 train_acc = acc.result().numpy() 114 115 acc.reset_states() 116 for X, y in valset: 117 validation_on_batch(X, y) 118 print(f"Train acc = {train_acc}, Validation acc = {acc.result().numpy()}") 119 120 if i == 60: 121 optim.lr = 0.01 122 elif i == 85: 123 optim.lr = 0.001 124 125 result["val_acc"].append(acc.result().numpy()) 126 result["lr"].append(optim.lr.numpy()) # numpyにしないとWeaker objectが云々言うから 127 128 with open("history_correct.dat", "wb") as fp: 129 pickle.dump(result, fp) 130 131if __name__ == "__main__": 132 main()
1--------------------------------------------------------------------------- 2ValueError Traceback (most recent call last) 3<ipython-input-6-06b2ec27c006> in <module> 4 121 5 122 if __name__ == "__main__": 6--> 123 main() 7 8<ipython-input-6-06b2ec27c006> in main() 9 100 acc.reset_states() 10 101 print("Epoch = ", i) 11--> 102 for X, y in trainset: 12 103 train_on_batch(X, y) 13 104 train_acc = acc.result().numpy() 14 15ValueError: too many values to unpack (expected 2)