実現したいこと
弱くラベル付けされたデータセットを用いたSelf-Attentionの学習について理解したい。
前提
以下URLの論文について、弱くラベル付けされたデータセット(監視カメラ映像)を使用しているようなのですが、疑問に思っているのが異常データ(正常フレームと異常フレームが混在)の学習について、何分割かに分けてSelf-Attentionに学習させる場合、正常フレームを誤って異常として学習させている可能性はあるのでしょうか?
その場合の解決策はありますでしょうか?
また、損失の算出方法も理解できず詰まっています。
どこにも参考になるサイト見つからなかったので、有識者の方ご教授お願いします。
論文URL:http://makotookabe.com/publication/watanabe_miru2022.pdf
コードURL:https://github.com/yaegasikk/attention_anomaly_detector/
該当のソースコード
train.py
1import torch 2from torch.utils.data import DataLoader 3import torch.optim as optim 4from torch import nn 5from tqdm import tqdm 6import option 7from dataset import Dataset 8from sklearn.metrics import roc_curve,roc_auc_score,auc 9import numpy as np 10from radam import RAdam 11import torch.nn.functional as F 12from network.video_classifier import VideoClassifier 13 14class Sampling_replace_iter: 15 def __init__(self,features,n_sample,is_normal=True,replace=True,T_shuffle=False): 16 17 self.features = np.array(features) 18 self.is_normal = is_normal 19 self.n_sample = n_sample 20 self.indexs = np.arange(0,len(features)) 21 np.random.shuffle(self.indexs) 22 self.count = 0 23 self.replace = replace 24 self.T_shuffle = T_shuffle 25 26 def __iter__(self): 27 28 if self.is_normal: 29 labels = torch.zeros(self.n_sample).long() 30 else: 31 labels = torch.ones(self.n_sample).long() 32 self.count = 0 33 34 # print(self.features.shape) [8000,T,2048] 35 if self.T_shuffle: 36 bs,t,f = self.features.shape 37 self.features = self.features[:][np.random.shuffle(np.arange(t))] 38 self.features = self.features.reshape(bs,t,f) 39 40 if not self.replace: 41 np.random.shuffle(self.indexs) 42 43 while self.count+self.n_sample < len(self.indexs): 44 if self.replace: 45 choice_index = np.random.choice(self.indexs,self.n_sample,replace=True).tolist() 46 yield torch.tensor(self.features[choice_index]) ,labels 47 self.count += self.n_sample 48 else: 49 choice_index = self.indexs[self.count:self.count+self.n_sample].tolist() 50 yield torch.tensor(self.features[choice_index]),labels 51 self.count += self.n_sample 52 53 54def split_cal_auc_mil(model,dataloader,device,split_size=32,gt_path="list/gt-ucf.npy"): 55 model.to(device) 56 model.eval() 57 bag_pred_scores = [] 58 59 with torch.no_grad(): 60 for feature in tqdm(dataloader): 61 #print('feature.shape = ',feature.shape) #[1,58,10,2048] 62 feature = feature.permute(0,2,1,3) 63 feature = feature.squeeze(0) 64 split_feature = torch.split(feature,split_size,dim=1) 65 for feature_i in split_feature: 66 feature_i = feature_i.to(device) 67 out = model(feature_i) 68 out = out.mean(0).reshape(-1).detach().cpu() 69 #out = out.repeat(feature_i.size(1)) 70 bag_pred_scores.append(out) 71 72 gt = np.load(gt_path) 73 bag_pred_scores = torch.cat(bag_pred_scores).numpy() 74 pred_score = np.repeat(bag_pred_scores,16) 75 #print('gt.shape = ',gt.shape) #[1114144] for UCF-Crime 76 #print('pred_score.shape = ',pred_score.shape) #[1114144] for UCF-Crime 77 fpr, tpr, thresholds = roc_curve(gt,pred_score) 78 auc_score = auc(fpr,tpr) 79 80 return auc_score,fpr,tpr 81 82def split_cal_auc_videoclassifier(model,dataloader,device,split_size=32,gt_path="list/gt-ucf.npy"): 83 model.to(device) 84 model.eval() 85 bag_pred_scores = [] 86 87 with torch.no_grad(): 88 for feature in tqdm(dataloader): 89 #print('feature.shape = ',feature.shape) #[1,58,10,2048] 90 feature = feature.permute(0,2,1,3) 91 feature = feature.squeeze(0) 92 split_feature = torch.split(feature,split_size,dim=1) 93 for feature_i in split_feature: 94 feature_i = feature_i.to(device) 95 out = model(feature_i) 96 out = out.mean(0).reshape(-1).detach().cpu() 97 out = out.repeat(feature_i.size(1)) 98 bag_pred_scores.append(out) 99 100 gt = np.load(gt_path) 101 bag_pred_scores = torch.cat(bag_pred_scores).numpy() 102 pred_score = np.repeat(bag_pred_scores,16) 103 #print('gt.shape = ',gt.shape) #[1114144] for UCF-Crime 104 #print('pred_score.shape = ',pred_score.shape) #[1114144] for UCF-Crime 105 fpr, tpr, thresholds = roc_curve(gt,pred_score) 106 auc_score = auc(fpr,tpr) 107 108 return auc_score,fpr,tpr 109 110if __name__=='__main__': 111 112 save_weight = True 113 args = option.parser.parse_args() 114 seed = args.seed 115 torch.manual_seed(seed) 116 np.random.seed(seed) 117 118 da = args.da 119 r = args.r 120 121 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 122 torch.backends.cudnn.benchmark = True 123 print("Use : {}".format(device)) 124 model_name = 'SelfAttention-da{}-r{}'.format(da,r) 125 model = VideoClassifier(args,r=r,da=da) 126 model = model.to(device) 127 print(model) 128 129 criterion = nn.BCELoss() 130 optimizer = RAdam(model.parameters()) 131 test_data = Dataset(args,test_mode=True,is_normal=False,is_onmemory=True) 132 test_dataloader = DataLoader(test_data,batch_size=1,shuffle=False) 133 max_auc_score,_,_ =split_cal_auc_videoclassifier(model,test_dataloader,device,split_size=args.test_split_size) 134 print('first_auc_score : ',max_auc_score) 135 136 feature_list = list(open(args.pre_rgb_list)) 137 normal_list = feature_list[810:] 138 abnormal_list = feature_list[:810] 139 normal_features = [] 140 abnormal_features = [] 141 142 for normal_i in tqdm(normal_list): 143 feature = np.array(np.load(normal_i.strip('\n'), allow_pickle=True),dtype=np.float32) 144 #print(feature.shape) 145 feature_split = np.split(feature,10,axis=0) 146 feature_split = [split_i.reshape(args.T,2048) for split_i in feature_split] 147 for split_i in feature_split: 148 normal_features.append(split_i) 149 print(len(normal_features)) 150 151 for abnormal_i in tqdm(abnormal_list): 152 feature = np.array(np.load(abnormal_i.strip('\n'), allow_pickle=True),dtype=np.float32) 153 feature_split = np.split(feature,10,axis=0) 154 feature_split = [split_i.reshape(args.T,2048) for split_i in feature_split] 155 for split_i in feature_split: 156 abnormal_features.append(split_i) 157 158 normal_iter = Sampling_replace_iter(features=normal_features,n_sample=args.batch_size,is_normal=True,replace=False,T_shuffle=False) 159 abnormal_iter = Sampling_replace_iter(features=abnormal_features,n_sample=args.batch_size,is_normal=False,replace=False,T_shuffle=False) 160 161 auc_list = [max_auc_score] 162 163 for epoch_i in range(args.max_epoch): 164 print('epoch {} '.format(epoch_i)) 165 model.train() 166 167 for (normal_feature,normal_label),(anomaly_feature,anomaly_label) in zip(normal_iter,abnormal_iter): 168 input_datas = torch.cat((normal_feature,anomaly_feature),0) 169 input_labels = torch.cat((normal_label,anomaly_label),0).float() 170 #print(input_labels.shape) #[batch_size*2] 171 #print(input_datas.shape) #[bathc_size,10,32,2048] 172 input_datas = input_datas.to(device) 173 input_labels = input_labels.to(device) 174 output = model(input_datas) 175 loss = criterion(output,input_labels) 176 optimizer.zero_grad() 177 loss.backward() 178 optimizer.step() 179 auc_score,_,_ = split_cal_auc_videoclassifier(model,test_dataloader,device,split_size=args.test_split_size) 180 if max_auc_score < auc_score: 181 max_auc_score = auc_score 182 if save_weight: 183 torch.save(model.to('cpu').state_dict(),'save_weight/{}-T{}-seed{}.pth'.format(model_name,args.T,args.seed)) 184 model = model.to(device) 185 print('{}-seed{} max_auc {} auc {}'.format(model_name,args.seed,max_auc_score,auc_score)) 186 auc_list.append(auc_score) 187 model.train() 188 189 #auc_list = np.array(auc_list) 190 #np.save('list/{}-seed{}.npy'.format(model_name,args.seed),auc_list)
試したこと
Self-Attention機構は理解したつもりです。
恐らくですがフレーム画像に0, 1の値を持たせて0に近ければ正常、1に近ければ異常としていると思います。(違ってたらすみません)
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。