-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* release better dino result * refine * fix deta config * add dino better hyper * add DINO pretrained weights links * bump to v0.4.0
- Loading branch information
Showing
7 changed files
with
414 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
""" | ||
Training script using the new "LazyConfig" python config files. | ||
This scripts reads a given python config file and runs the training or evaluation. | ||
It can be used to train any models or dataset as long as they can be | ||
instantiated by the recursive construction defined in the given config file. | ||
Besides lazy construction of models, dataloader, etc., this scripts expects a | ||
few common configuration parameters currently defined in "configs/common/train.py". | ||
To add more complicated training logic, you can easily add other configs | ||
in the config file and implement a new train_net.py to handle them. | ||
""" | ||
import logging | ||
import os | ||
import sys | ||
import time | ||
import torch | ||
from torch.nn.parallel import DataParallel, DistributedDataParallel | ||
|
||
from detectron2.checkpoint import DetectionCheckpointer | ||
from detectron2.config import LazyConfig, instantiate | ||
from detectron2.engine import ( | ||
SimpleTrainer, | ||
default_argument_parser, | ||
default_setup, | ||
default_writers, | ||
hooks, | ||
launch, | ||
) | ||
from detectron2.engine.defaults import create_ddp_model | ||
from detectron2.evaluation import inference_on_dataset, print_csv_format | ||
from detectron2.utils import comm | ||
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) | ||
|
||
logger = logging.getLogger("detrex") | ||
|
||
|
||
def match_name_keywords(n, name_keywords): | ||
out = False | ||
for b in name_keywords: | ||
if b in n: | ||
out = True | ||
break | ||
return out | ||
|
||
|
||
class Trainer(SimpleTrainer): | ||
""" | ||
We've combine Simple and AMP Trainer together. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model, | ||
dataloader, | ||
optimizer, | ||
amp=False, | ||
clip_grad_params=None, | ||
grad_scaler=None, | ||
): | ||
super().__init__(model=model, data_loader=dataloader, optimizer=optimizer) | ||
|
||
unsupported = "AMPTrainer does not support single-process multi-device training!" | ||
if isinstance(model, DistributedDataParallel): | ||
assert not (model.device_ids and len(model.device_ids) > 1), unsupported | ||
assert not isinstance(model, DataParallel), unsupported | ||
|
||
if amp: | ||
if grad_scaler is None: | ||
from torch.cuda.amp import GradScaler | ||
|
||
grad_scaler = GradScaler() | ||
self.grad_scaler = grad_scaler | ||
|
||
# set True to use amp training | ||
self.amp = amp | ||
|
||
# gradient clip hyper-params | ||
self.clip_grad_params = clip_grad_params | ||
|
||
def run_step(self): | ||
""" | ||
Implement the standard training logic described above. | ||
""" | ||
assert self.model.training, "[Trainer] model was changed to eval mode!" | ||
assert torch.cuda.is_available(), "[Trainer] CUDA is required for AMP training!" | ||
from torch.cuda.amp import autocast | ||
|
||
start = time.perf_counter() | ||
""" | ||
If you want to do something with the data, you can wrap the dataloader. | ||
""" | ||
data = next(self._data_loader_iter) | ||
data_time = time.perf_counter() - start | ||
|
||
""" | ||
If you want to do something with the losses, you can wrap the model. | ||
""" | ||
loss_dict = self.model(data) | ||
with autocast(enabled=self.amp): | ||
if isinstance(loss_dict, torch.Tensor): | ||
losses = loss_dict | ||
loss_dict = {"total_loss": loss_dict} | ||
else: | ||
losses = sum(loss_dict.values()) | ||
|
||
""" | ||
If you need to accumulate gradients or do something similar, you can | ||
wrap the optimizer with your custom `zero_grad()` method. | ||
""" | ||
self.optimizer.zero_grad() | ||
|
||
if self.amp: | ||
self.grad_scaler.scale(losses).backward() | ||
if self.clip_grad_params is not None: | ||
self.grad_scaler.unscale_(self.optimizer) | ||
self.clip_grads(self.model.parameters()) | ||
self.grad_scaler.step(self.optimizer) | ||
self.grad_scaler.update() | ||
else: | ||
losses.backward() | ||
if self.clip_grad_params is not None: | ||
self.clip_grads(self.model.parameters()) | ||
self.optimizer.step() | ||
|
||
self._write_metrics(loss_dict, data_time) | ||
|
||
def clip_grads(self, params): | ||
params = list(filter(lambda p: p.requires_grad and p.grad is not None, params)) | ||
if len(params) > 0: | ||
return torch.nn.utils.clip_grad_norm_( | ||
parameters=params, | ||
**self.clip_grad_params, | ||
) | ||
|
||
|
||
def do_test(cfg, model): | ||
if "evaluator" in cfg.dataloader: | ||
ret = inference_on_dataset( | ||
model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) | ||
) | ||
print_csv_format(ret) | ||
return ret | ||
|
||
|
||
def do_train(args, cfg): | ||
""" | ||
Args: | ||
cfg: an object with the following attributes: | ||
model: instantiate to a module | ||
dataloader.{train,test}: instantiate to dataloaders | ||
dataloader.evaluator: instantiate to evaluator for test set | ||
optimizer: instantaite to an optimizer | ||
lr_multiplier: instantiate to a fvcore scheduler | ||
train: other misc config defined in `configs/common/train.py`, including: | ||
output_dir (str) | ||
init_checkpoint (str) | ||
amp.enabled (bool) | ||
max_iter (int) | ||
eval_period, log_period (int) | ||
device (str) | ||
checkpointer (dict) | ||
ddp (dict) | ||
""" | ||
model = instantiate(cfg.model) | ||
logger = logging.getLogger("detectron2") | ||
logger.info("Model:\n{}".format(model)) | ||
model.to(cfg.train.device) | ||
|
||
# this is an hack of train_net | ||
param_dicts = [ | ||
{ | ||
"params": [ | ||
p | ||
for n, p in model.named_parameters() | ||
if not match_name_keywords(n, ["backbone"]) | ||
and not match_name_keywords(n, ["reference_points", "sampling_offsets"]) | ||
and p.requires_grad | ||
], | ||
"lr": 2e-4, | ||
}, | ||
{ | ||
"params": [ | ||
p | ||
for n, p in model.named_parameters() | ||
if match_name_keywords(n, ["backbone"]) and p.requires_grad | ||
], | ||
"lr": 2e-5, | ||
}, | ||
{ | ||
"params": [ | ||
p | ||
for n, p in model.named_parameters() | ||
if match_name_keywords(n, ["reference_points", "sampling_offsets"]) | ||
and p.requires_grad | ||
], | ||
"lr": 2e-5, | ||
}, | ||
] | ||
optim = torch.optim.AdamW(param_dicts, 2e-4, weight_decay=1e-4) | ||
|
||
train_loader = instantiate(cfg.dataloader.train) | ||
|
||
model = create_ddp_model(model, **cfg.train.ddp) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
dataloader=train_loader, | ||
optimizer=optim, | ||
amp=cfg.train.amp.enabled, | ||
clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None, | ||
) | ||
|
||
checkpointer = DetectionCheckpointer( | ||
model, | ||
cfg.train.output_dir, | ||
trainer=trainer, | ||
) | ||
|
||
trainer.register_hooks( | ||
[ | ||
hooks.IterationTimer(), | ||
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), | ||
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) | ||
if comm.is_main_process() | ||
else None, | ||
hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), | ||
hooks.PeriodicWriter( | ||
default_writers(cfg.train.output_dir, cfg.train.max_iter), | ||
period=cfg.train.log_period, | ||
) | ||
if comm.is_main_process() | ||
else None, | ||
] | ||
) | ||
|
||
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) | ||
if args.resume and checkpointer.has_checkpoint(): | ||
# The checkpoint stores the training iteration that just finished, thus we start | ||
# at the next iteration | ||
start_iter = trainer.iter + 1 | ||
else: | ||
start_iter = 0 | ||
trainer.train(start_iter, cfg.train.max_iter) | ||
|
||
|
||
def main(args): | ||
cfg = LazyConfig.load(args.config_file) | ||
cfg = LazyConfig.apply_overrides(cfg, args.opts) | ||
default_setup(cfg, args) | ||
|
||
if args.eval_only: | ||
model = instantiate(cfg.model) | ||
model.to(cfg.train.device) | ||
model = create_ddp_model(model) | ||
DetectionCheckpointer(model).load(cfg.train.init_checkpoint) | ||
print(do_test(cfg, model)) | ||
else: | ||
do_train(args, cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = default_argument_parser().parse_args() | ||
launch( | ||
main, | ||
args.num_gpus, | ||
num_machines=args.num_machines, | ||
machine_rank=args.machine_rank, | ||
dist_url=args.dist_url, | ||
args=(args,), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.