発生している問題・エラーメッセージ
----> 5 model = classifier(flags.aug_kwargs) 6 model.load_state_dict(torch.load(model_path)) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), TypeError: forward() missing 1 required positional argument: 'targets'
該当のソースコード
Python
1@dataclass 2class Flags: 3 # General 4 debug: bool = True 5 outdir: str = "results/det" 6 device: str = "cuda:0" 7 8 # Data config 9 imgdir_name: str = "vinbigdata-chest-xray-resized-png-256x256" 10 # split_mode: str = "all_train" 11 seed: int = 111 12 target_fold: int = 0 # 0~4 13 label_smoothing: float = 0.0 14 # Model config 15 model_name: str = "tf_inception_v3" 16 model_mode: str = "normal" # normal, cnn_fixed supported 17 # Training config 18 epoch: int = 20 19 batchsize: int = 8 20 valid_batchsize: int = 16 21 num_workers: int = 4 22 snapshot_freq: int = 5 23 ema_decay: float = 0.999 # negative value is to inactivate ema. 24 scheduler_type: str = "" 25 scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {}) 26 scheduler_trigger: List[Union[int, str]] = field(default_factory=lambda: [1, "iteration"]) 27 aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {}) 28 mixup_prob: float = -1.0 # Apply mixup augmentation when positive value is set. 29 30 def update(self, param_dict: Dict) -> "Flags": 31 # Overwrite by `param_dict` 32 for key, value in param_dict.items(): 33 if not hasattr(self, key): 34 raise ValueError(f"[ERROR] Unexpected key for flag = {key}") 35 setattr(self, key, value) 36 return self 37 38flags_dict = { 39 "debug": False, # Change to True for fast debug run! 40 "outdir": "results/tmp_debug", 41 # Data 42 "imgdir_name": "vinbigdata-chest-xray-resized-png-256x256", 43 # Model 44 "model_name": "tf_inception_v3",#"resnet18", 45 # Training 46 "num_workers": 4, 47 "epoch": 15, 48 "batchsize": 8, 49 "scheduler_type": "CosineAnnealingWarmRestarts", 50 "scheduler_kwargs": {"T_0": 28125}, # 15000 * 15 epoch // (batchsize=8) 51 "scheduler_trigger": [1, "iteration"], 52 "aug_kwargs": { 53 "HorizontalFlip": {"p": 0.5}, 54 "ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5}, 55 "RandomBrightnessContrast": {"p": 0.5}, 56 "CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5}, 57 "Blur": {"blur_limit": [3, 7], "p": 0.5}, 58 "Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3}, 59 "RandomGamma": {"gamma_limit": [80, 120], "p": 0.6}, 60 } 61} 62 63# args = parse() 64print("torch", torch.__version__) 65flags = Flags().update(flags_dict) 66print("flags", flags) 67debug = flags.debug 68outdir = Path(flags.outdir) 69os.makedirs(str(outdir), exist_ok=True) 70flags_dict = dataclasses.asdict(flags) 71save_yaml(str(outdir / "flags.yaml"), flags_dict) 72 73# --- Read data --- 74inputdir = Path("/kaggle/input") 75datadir = inputdir / "vinbigdata-chest-xray-abnormalities-detection" 76imgdir = inputdir / flags.imgdir_name 77 78# Read in the data CSV files 79train = pd.read_csv(datadir / "train.csv") 80 81class CNNFixedPredictor(nn.Module): 82 def __init__(self, cnn: nn.Module, num_classes: int = 2): 83 super(CNNFixedPredictor, self).__init__() 84 self.cnn = cnn 85 self.lin = Linear(cnn.num_features, num_classes) 86 print("cnn.num_features", cnn.num_features) 87 88 # We do not learn CNN parameters. 89 # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html 90 for param in self.cnn.parameters(): 91 param.requires_grad = False 92 93 def forward(self, x): 94 feat = self.cnn(x) 95 return self.lin(feat) 96 97def build_predictor(model_name: str, model_mode: str = "normal"): 98 if model_mode == "normal": 99 # normal configuration. train all parameters. 100 return timm.create_model(model_name, pretrained=True, num_classes=2, in_chans=3) 101 elif model_mode == "cnn_fixed": 102 # normal configuration. train all parameters. 103 # https://rwightman.github.io/pytorch-image-models/feature_extraction/ 104 timm_model = timm.create_model(model_name, pretrained=True, num_classes=0, in_chans=3) 105 return CNNFixedPredictor(timm_model, num_classes=2) 106 else: 107 raise ValueError(f"[ERROR] Unexpected value model_mode={model_mode}") 108 109class Classifier(nn.Module): 110 """two class classfication""" 111 112 def __init__(self, predictor, lossfun=cross_entropy_with_logits): 113 super().__init__() 114 self.predictor = predictor 115 self.lossfun = lossfun 116 self.prefix = "" 117 118 def forward(self, image, targets): 119 outputs = self.predictor(image) 120 loss = self.lossfun(outputs, targets) 121 metrics = { 122 f"{self.prefix}loss": loss.item(), 123 f"{self.prefix}acc": accuracy_with_logits(outputs, targets).item() 124 } 125 ppe.reporting.report(metrics, self) 126 return loss, metrics 127 128 def predict(self, data_loader): 129 pred = self.predict_proba(data_loader) 130 label = torch.argmax(pred, dim=1) 131 return label 132 133 def predict_proba(self, data_loader): 134 device: torch.device = next(self.parameters()).device 135 y_list = [] 136 self.eval() 137 with torch.no_grad(): 138 for batch in data_loader: 139 if isinstance(batch, (tuple, list)): 140 # Assumes first argument is "image" 141 batch = batch[0].to(device) 142 else: 143 batch = batch.to(device) 144 y = self.predictor(batch) 145 y = torch.softmax(y, dim=-1) 146 y_list.append(y) 147 pred = torch.cat(y_list) 148 return pred 149 150def create_trainer(model, optimizer, device) -> Engine: 151 model.to(device) 152 153 def update_fn(engine, batch): 154 model.train() 155 optimizer.zero_grad() 156 loss, metrics = model(*[elem.to(device) for elem in batch]) 157 loss.backward() 158 optimizer.step() 159 return metrics 160 trainer = Engine(update_fn) 161 return trainer 162 163skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=flags.seed) 164# skf.get_n_splits(None, None) 165y = np.array([int(len(d["annotations"]) > 0) for d in dataset_dicts]) 166split_inds = list(skf.split(dataset_dicts, y)) 167train_inds, valid_inds = split_inds[flags.target_fold] # 0th fold 168train_dataset = VinbigdataTwoClassDataset( 169 [dataset_dicts[i] for i in train_inds], 170 image_transform=Transform(flags.aug_kwargs), 171 mixup_prob=flags.mixup_prob, 172 label_smoothing=flags.label_smoothing, 173) 174valid_dataset = VinbigdataTwoClassDataset([dataset_dicts[i] for i in valid_inds]) 175 176train_loader = DataLoader( 177 train_dataset, 178 batch_size=flags.batchsize, 179 num_workers=flags.num_workers, 180 shuffle=True, 181 pin_memory=True, 182) 183valid_loader = DataLoader( 184 valid_dataset, 185 batch_size=flags.valid_batchsize, 186 num_workers=flags.num_workers, 187 shuffle=False, 188 pin_memory=True, 189) 190 191device = torch.device(flags.device) 192 193predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode) 194classifier = Classifier(predictor) 195model = classifier 196optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3) 197 198# Train setup 199trainer = create_trainer(model, optimizer, device) 200 201ema = EMA(predictor, decay=flags.ema_decay) 202 203def eval_func(*batch): 204 loss, metrics = model(*[elem.to(device) for elem in batch]) 205 # HACKING: report ema value with prefix. 206 if flags.ema_decay > 0: 207 classifier.prefix = "ema_" 208 ema.assign() 209 loss, metrics = model(*[elem.to(device) for elem in batch]) 210 ema.resume() 211 classifier.prefix = "" 212 213valid_evaluator = E.Evaluator( 214 valid_loader, model, progress_bar=False, eval_func=eval_func, device=device 215) 216 217# log_trigger = (10 if debug else 1000, "iteration") 218log_trigger = (1, "epoch") 219log_report = E.LogReport(trigger=log_trigger) 220extensions = [ 221 log_report, 222 E.ProgressBarNotebook(update_interval=10 if debug else 100), # Show progress bar during training 223 E.PrintReportNotebook(), 224 E.FailOnNonNumber(), 225] 226epoch = flags.epoch 227models = {"main": model} 228optimizers = {"main": optimizer} 229manager = IgniteExtensionsManager( 230 trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir), 231) 232# Run evaluation for valid dataset in each epoch. 233manager.extend(valid_evaluator) 234 235# Save predictor.pt every epoch 236manager.extend( 237 E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch") 238) 239 240manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger) 241 242if flags.ema_decay > 0: 243 # Exponential moving average 244 manager.extend(lambda manager: ema(), trigger=(1, "iteration")) 245 246 def save_ema_model(manager): 247 ema.assign() 248 #torch.save(predictor.state_dict(), outdir / "predictor_ema.pt") 249 torch.save(predictor.state_dict(), "/kaggle/working/predictor_ema.pt") 250 ema.resume() 251 252 manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch")) 253 254_ = trainer.run(train_loader, max_epochs=epoch) 255 256torch.save(predictor.state_dict(), "/kaggle/working/predictor_tf_inception_v3.pt") 257df = log_report.to_dataframe() 258df.to_csv("/kaggle/working/log.csv", index=False) 259 260####################################追記終わり##################################################### 261 262#モデルの読み込み 263model_path = './predictor_last.pt' 264model = classifier(flags.aug_kwargs) 265model.load_state_dict(torch.load(model_path))
回答2件
あなたの回答
tips
プレビュー