Pythonを使って深層学習を勉強している初心者です。
resnet101を用いて、ファインチューニングを行いたいと考えています。(resnet101の方は、torchvision で提供されている学習済みのモデルを使用しております。)前処理として、手元にある画像サイズ640×480を、画像サイズ、224×224に変更する必要があります。画像は、音をスペクトログラムにしたものです。そのため、横軸に時間、縦軸に周波数、色は信号の強さを表しており、リサイズを行った場合、画像にある情報が潰れてしまいます。これを避けるために、前処理で画像サイズを変更しない方法や画像にある情報を潰さずに済む方法がありましたら、教えていただけると助かります。
参考させていただいたサイト: https://blog.brainpad.co.jp/entry/2018/04/17/143000
スペクトログラム
スペクトログラムを224×224にリサイズし、色情報の規格化を行った画像。
Python
1 2# 入力画像の前処理をするクラス 3# 訓練時と推論時で処理が異なる 4 5 6class ImageTransform(): 7 """ 8 画像の前処理クラス。訓練時、検証時で異なる動作をする。 9 画像のサイズをリサイズし、色を標準化する。 10 訓練時はRandomResizedCropとRandomHorizontalFlipでデータオーギュメンテーションする。 11 12 13 Attributes 14 ---------- 15 resize : int 16 リサイズ先の画像の大きさ。 17 mean : (R, G, B) 18 各色チャネルの平均値。 19 std : (R, G, B) 20 各色チャネルの標準偏差。 21 """ 22 23 def __init__(self, resize, mean, std): 24 self.data_transform = { 25 'train': transforms.Compose([ 26 transforms.Resize((224,224)), 27 #transforms.RandomHorizontalFlip(), # データオーギュメンテーション 28 transforms.ToTensor(), # テンソルに変換 29 transforms.Normalize(mean, std) # 標準化 30 ]), 31 'val': transforms.Compose([ 32 transforms.Resize((224,224)), 33 transforms.ToTensor(), # テンソルに変換 34 transforms.Normalize(mean, std) # 標準化 35 ]), 36 } 37 38 39 def __call__(self, img, phase='train'): 40 """ 41 Parameters 42 ---------- 43 phase : 'train' or 'val' 44 前処理のモードを指定。 45 """ 46 return self.data_transform[phase](img) 47 48# 訓練時の画像前処理の動作を確認 49# 実行するたびに処理結果の画像が変わる 50 51# 1. 画像読み込み 52image_file_path = './data/rockk.00000.jpg' 53img = Image.open(image_file_path) # [高さ][幅][色RGB] 54 55# 2. 元の画像の表示 56plt.imshow(img) 57plt.show() 58 59# 3. 画像の前処理と処理済み画像の表示 60size = 224 61mean = (0.485, 0.456, 0.406) 62std = (0.229, 0.224, 0.225) 63 64transform = ImageTransform(size, mean, std) 65img_transformed = transform(img, phase="train") # torch.Size([3, 224, 224]) 66 67# (色、高さ、幅)を (高さ、幅、色)に変換し、0-1に値を制限して表示 68img_transformed = img_transformed.numpy().transpose((1, 2, 0)) 69img_transformed = np.clip(img_transformed, 0, 1) 70plt.imshow(img_transformed) 71plt.show() 72 73 74 75 76 77#resnet101を用いたファインチューニング 78 79use_pretrained = True # 学習済みのパラメータを使用 80net = models.resnet101(pretrained=use_pretrained) 81 82 83net.fc = nn.Linear(in_features=2048, out_features=10) 84print(net) 85# 訓練モードに設定 86net.train() 87 88print('ネットワーク設定完了:学習済みの重みをロードし、訓練モードに設定しました') 89 90 91# 損失関数の設定 92criterion = nn.CrossEntropyLoss() 93 94 95# ファインチューニングで学習させるパラメータを、変数params_to_updateの1~3に格納する 96 97params_to_update_1 = [] 98params_to_update_2 = [] 99params_to_update_3 = [] 100 101# 学習させる層のパラメータ名を指定 102update_param_names_1 = ["layer"] 103update_param_names_2 = ["downsample"] 104update_param_names_3 = ["fc.weight", "fc.bias"] 105update_param_names_4 = ["bn"] 106update_param_names_5 = ["downsample.1.weight"] 107update_param_names_6 = ["conv1.weight"] 108update_param_names_7 = ["bn1.weight","bn1.bias"] 109 110# パラメータごとに各リストに格納する 111for name, param in net.named_parameters(): 112 print(name) 113 if update_param_names_1[0] in name:# layer:conv 114 if update_param_names_2[0] in name:# downsample:conv 115 param.requires_grad = False 116 print("勾配計算なし。学習しない:", name) 117 else: 118 param.requires_grad = True 119 params_to_update_1.append(param) 120 print("params_to_update_1に格納:", name) 121 122 elif name in update_param_names_6: 123 param.requires_grad = True 124 params_to_update_2.append(param) 125 print("params_to_update_2に格納:", name) 126 127 elif name in update_param_names_7: 128 param.requires_grad = True 129 params_to_update_2.append(param) 130 print("params_to_update_2に格納:", name) 131 132 133 elif name in update_param_names_3:# fc:conv 134 param.requires_grad = True 135 params_to_update_3.append(param) 136 print("params_to_update_3に格納:", name) 137 138 elif name in update_param_names_4: 139 param.requires_grad = False 140 print("bn: 勾配計算なし。学習しない:", name) 141 142 143 else: 144 param.requires_grad = False 145 print("勾配計算なし。学習しない:", name) 146 147 148 149# 最適化手法の設定 150# 1e-4 = 1*10^-4 , 5e-4 = 5*10^-4 151 152optimizer = optim.Adam([ 153 {'params': params_to_update_1, 'lr': 1e-4}, 154 {'params': params_to_update_2, 'lr': 1e-4}, 155 {'params': params_to_update_3, 'lr': 1e-3} 156]) 157
あなたの回答
tips
プレビュー