From 3673ffe2551846d52c7d85f8773a604d6fa764f9 Mon Sep 17 00:00:00 2001 From: alcinos Date: Sun, 28 Jun 2020 15:59:10 +0200 Subject: [PATCH] Add Detectron2 wrapper (#103) --- README.md | 2 + d2/README.md | 34 ++++ d2/configs/detr_256_6_6_torchvision.yaml | 45 +++++ d2/converter.py | 69 +++++++ d2/detr/__init__.py | 4 + d2/detr/config.py | 32 ++++ d2/detr/dataset_mapper.py | 122 ++++++++++++ d2/detr/detr.py | 230 +++++++++++++++++++++++ d2/train_net.py | 154 +++++++++++++++ 9 files changed, 692 insertions(+) create mode 100644 d2/README.md create mode 100644 d2/configs/detr_256_6_6_torchvision.yaml create mode 100644 d2/converter.py create mode 100644 d2/detr/__init__.py create mode 100644 d2/detr/config.py create mode 100644 d2/detr/dataset_mapper.py create mode 100644 d2/detr/detr.py create mode 100644 d2/train_net.py diff --git a/README.md b/README.md index 06a9fe2c4..269edc44c 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ Training code follows this idea - it is not a library, but simply a [main.py](main.py) importing model and criterion definitions with standard training loops. +Additionnally, we provide a Detectron2 wrapper in the d2/ folder. See the readme there for more information. + For details see [End-to-End Object Detection with Transformers](https://ai.facebook.com/research/publications/end-to-end-object-detection-with-transformers) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, and Sergey Zagoruyko. # Model Zoo diff --git a/d2/README.md b/d2/README.md new file mode 100644 index 000000000..81ef70a3b --- /dev/null +++ b/d2/README.md @@ -0,0 +1,34 @@ +Detectron2 wrapper for DETR +======= + +We provide a Detectron2 wrapper for DETR, thus providing a way to better integrate it in the existing detection ecosystem. It can be used for example to easily leverage datasets or backbones provided in Detectron2. + +This wrapper currently supports only box detection, and is intended to be as close as possible to the original implementation, and we checked that it indeed match the results. Some notable facts and caveats: +- The data augmentation matches DETR's original data augmentation. This required patching the RandomCrop augmentation from Detectron2, so you'll need a version from the master branch from June 24th 2020 or more recent. +- To match DETR's original backbone initialization, we use the weights of a ResNet50 trained on imagenet using torchvision. This network uses a different pixel mean and std than most of the backbones available in Detectron2 by default, so extra care must be taken when switching to another one. Note that no other torchvision models are available in Detectron2 as of now, though it may change in the future. +- The gradient clipping mode is "full_model", which is not the default in Detectron2. + +# Usage + +To install Detectron2, please follow the [official installation instructions](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). + +## Evaluating a model + +For convenience, we provide a conversion script to convert models trained by the main DETR training loop into the format of this wrapper. To download and convert the main Resnet50 model, simply do: + +``` +python converter.py --source_model https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth --output_model converted_model.pth +``` + +You can then evaluate it using: +``` +python train_net.py --eval-only --config configs/detr_256_6_6_torchvision.yaml MODEL.WEIGHTS "converted_model.pth" +``` + + +## Training + +To train DETR on a single node with 8 gpus, simply use: +``` +python train_net.py --config configs/detr_256_6_6_torchvision.yaml --num-gpus 8 +``` diff --git a/d2/configs/detr_256_6_6_torchvision.yaml b/d2/configs/detr_256_6_6_torchvision.yaml new file mode 100644 index 000000000..25d641845 --- /dev/null +++ b/d2/configs/detr_256_6_6_torchvision.yaml @@ -0,0 +1,45 @@ +MODEL: + META_ARCHITECTURE: "Detr" + WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + MASK_ON: False + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + DETR: + GIOU_WEIGHT: 2.0 + L1_WEIGHT: 5.0 + NUM_OBJECT_QUERIES: 100 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 64 + BASE_LR: 0.0001 + STEPS: (369600,) + MAX_ITER: 554400 + WARMUP_FACTOR: 1.0 + WARMUP_ITERS: 10 + WEIGHT_DECAY: 0.0001 + OPTIMIZER: "ADAMW" + BACKBONE_MULTIPLIER: 0.1 + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "full_model" + CLIP_VALUE: 0.01 + NORM_TYPE: 2.0 +INPUT: + MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) + CROP: + ENABLED: True + TYPE: "absolute_range" + SIZE: (384, 600) + FORMAT: "RGB" +TEST: + EVAL_PERIOD: 4000 +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: False + NUM_WORKERS: 4 +VERSION: 2 diff --git a/d2/converter.py b/d2/converter.py new file mode 100644 index 000000000..6fa5ff4c0 --- /dev/null +++ b/d2/converter.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Helper script to convert models trained with the main version of DETR to be used with the Detectron2 version. +""" +import json +import argparse + +import numpy as np +import torch + + +def parse_args(): + parser = argparse.ArgumentParser("D2 model converter") + + parser.add_argument("--source_model", default="", type=str, help="Path or url to the DETR model to convert") + parser.add_argument("--output_model", default="", type=str, help="Path where to save the converted model") + return parser.parse_args() + + +def main(): + args = parse_args() + + # D2 expects contiguous classes, so we need to remap the 92 classes from DETR + # fmt: off + coco_idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, + 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91] + # fmt: on + + coco_idx = np.array(coco_idx) + + if args.source_model.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url(args.source_model, map_location="cpu", check_hash=True) + else: + checkpoint = torch.load(args.source_model, map_location="cpu") + model_to_convert = checkpoint["model"] + + model_converted = {} + for k in model_to_convert.keys(): + old_k = k + if "backbone" in k: + k = k.replace("backbone.0.body.", "") + if "layer" not in k: + k = "stem." + k + for t in [1, 2, 3, 4]: + k = k.replace(f"layer{t}", f"res{t + 1}") + for t in [1, 2, 3]: + k = k.replace(f"bn{t}", f"conv{t}.norm") + k = k.replace("downsample.0", "shortcut") + k = k.replace("downsample.1", "shortcut.norm") + k = "backbone.0.backbone." + k + k = "detr." + k + print(old_k, "->", k) + if "class_embed" in old_k: + v = model_to_convert[old_k].detach() + if v.shape[0] == 92: + shape_old = v.shape + model_converted[k] = v[coco_idx] + print("Head conversion: changing shape from {} to {}".format(shape_old, model_converted[k].shape)) + continue + model_converted[k] = model_to_convert[old_k].detach() + + model_to_save = {"model": model_converted} + torch.save(model_to_save, args.output_model) + + +if __name__ == "__main__": + main() diff --git a/d2/detr/__init__.py b/d2/detr/__init__.py new file mode 100644 index 000000000..a618f8288 --- /dev/null +++ b/d2/detr/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from .config import add_detr_config +from .detr import Detr +from .dataset_mapper import DetrDatasetMapper diff --git a/d2/detr/config.py b/d2/detr/config.py new file mode 100644 index 000000000..45cc81e35 --- /dev/null +++ b/d2/detr/config.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from detectron2.config import CfgNode as CN + + +def add_detr_config(cfg): + """ + Add config for DETR. + """ + cfg.MODEL.DETR = CN() + cfg.MODEL.DETR.NUM_CLASSES = 80 + + # LOSS + cfg.MODEL.DETR.GIOU_WEIGHT = 2.0 + cfg.MODEL.DETR.L1_WEIGHT = 5.0 + cfg.MODEL.DETR.DEEP_SUPERVISION = True + cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1 + + # TRANSFORMER + cfg.MODEL.DETR.NHEADS = 8 + cfg.MODEL.DETR.DROPOUT = 0.1 + cfg.MODEL.DETR.DIM_FEEDFORWARD = 2048 + cfg.MODEL.DETR.ENC_LAYERS = 6 + cfg.MODEL.DETR.DEC_LAYERS = 6 + cfg.MODEL.DETR.PRE_NORM = False + cfg.MODEL.DETR.PASS_POS_AND_QUERY = True + + cfg.MODEL.DETR.HIDDEN_DIM = 256 + cfg.MODEL.DETR.NUM_OBJECT_QUERIES = 100 + + cfg.SOLVER.OPTIMIZER = "ADAMW" + cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 diff --git a/d2/detr/dataset_mapper.py b/d2/detr/dataset_mapper.py new file mode 100644 index 000000000..c2cf2ec8d --- /dev/null +++ b/d2/detr/dataset_mapper.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import copy +import logging + +import numpy as np +import torch + +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T +from detectron2.data.transforms import TransformGen + +__all__ = ["DetrDatasetMapper"] + + +def build_transform_gen(cfg, is_train): + """ + Create a list of :class:`TransformGen` from config. + Returns: + list[TransformGen] + """ + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = "choice" + if sample_style == "range": + assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) + + logger = logging.getLogger(__name__) + tfm_gens = [] + if is_train: + tfm_gens.append(T.RandomFlip()) + tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) + if is_train: + logger.info("TransformGens used in training: " + str(tfm_gens)) + return tfm_gens + + +class DetrDatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by DETR. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies geometric transforms to the image and annotation + 3. Find and applies suitable cropping to the image and annotation + 4. Prepare image and annotation to Tensors + """ + + def __init__(self, cfg, is_train=True): + if cfg.INPUT.CROP.ENABLED and is_train: + self.crop_gen = [ + T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), + T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), + ] + else: + self.crop_gen = None + + assert not cfg.MODEL.MASK_ON, "Mask is not supported" + + self.tfm_gens = build_transform_gen(cfg, is_train) + logging.getLogger(__name__).info( + "Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen)) + ) + + self.img_format = cfg.INPUT.FORMAT + self.is_train = is_train + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + utils.check_image_size(dataset_dict, image) + + if self.crop_gen is None: + image, transforms = T.apply_transform_gens(self.tfm_gens, image) + else: + if np.random.rand() > 0.5: + image, transforms = T.apply_transform_gens(self.tfm_gens, image) + else: + image, transforms = T.apply_transform_gens( + self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image + ) + + image_shape = image.shape[:2] # h, w + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + + if not self.is_train: + # USER: Modify this if you want to keep them for some reason. + dataset_dict.pop("annotations", None) + return dataset_dict + + if "annotations" in dataset_dict: + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + anno.pop("segmentation", None) + anno.pop("keypoints", None) + + # USER: Implement additional transformations if you have other types of data + annos = [ + utils.transform_instance_annotations(obj, transforms, image_shape) + for obj in dataset_dict.pop("annotations") + if obj.get("iscrowd", 0) == 0 + ] + instances = utils.annotations_to_instances(annos, image_shape) + dataset_dict["instances"] = utils.filter_empty_instances(instances) + return dataset_dict diff --git a/d2/detr/detr.py b/d2/detr/detr.py new file mode 100644 index 000000000..7d1d6c2a7 --- /dev/null +++ b/d2/detr/detr.py @@ -0,0 +1,230 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import logging +import math +from typing import List + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from torch import nn + +from detectron2.layers import ShapeSpec +from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess +from detectron2.structures import Boxes, ImageList, Instances +from detectron2.utils.logger import log_first_n +from fvcore.nn import giou_loss, smooth_l1_loss +from models.backbone import Joiner +from models.detr import DETR, SetCriterion +from models.matcher import HungarianMatcher +from models.position_encoding import PositionEmbeddingSine +from models.transformer import Transformer +from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh +from util.misc import NestedTensor + +__all__ = ["Detr"] + + +class MaskedBackbone(nn.Module): + """ This is a thin wrapper around D2's backbone to provide padding masking""" + + def __init__(self, cfg): + super().__init__() + self.backbone = build_backbone(cfg) + backbone_shape = self.backbone.output_shape() + self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] + self.num_channels = backbone_shape[list(backbone_shape.keys())[-1]].channels + + def forward(self, images): + features = self.backbone(images.tensor) + masks = self.mask_out_padding( + [features_per_level.shape for features_per_level in features.values()], + images.image_sizes, + images.tensor.device, + ) + assert len(features) == len(masks) + for i, k in enumerate(features.keys()): + features[k] = NestedTensor(features[k], masks[i]) + return features + + def mask_out_padding(self, feature_shapes, image_sizes, device): + masks = [] + assert len(feature_shapes) == len(self.feature_strides) + for idx, shape in enumerate(feature_shapes): + N, _, H, W = shape + masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device) + for img_idx, (h, w) in enumerate(image_sizes): + masks_per_feature_level[ + img_idx, + : int(np.ceil(float(h) / self.feature_strides[idx])), + : int(np.ceil(float(w) / self.feature_strides[idx])), + ] = 0 + masks.append(masks_per_feature_level) + return masks + + +@META_ARCH_REGISTRY.register() +class Detr(nn.Module): + """ + Implement Detr + """ + + def __init__(self, cfg): + super().__init__() + + self.device = torch.device(cfg.MODEL.DEVICE) + + self.num_classes = cfg.MODEL.DETR.NUM_CLASSES + hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM + num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES + # Transformer parameters: + nheads = cfg.MODEL.DETR.NHEADS + dropout = cfg.MODEL.DETR.DROPOUT + dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD + enc_layers = cfg.MODEL.DETR.ENC_LAYERS + dec_layers = cfg.MODEL.DETR.DEC_LAYERS + pre_norm = cfg.MODEL.DETR.PRE_NORM + pass_pos_and_query = cfg.MODEL.DETR.PASS_POS_AND_QUERY + + # Loss parameters: + giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT + l1_weight = cfg.MODEL.DETR.L1_WEIGHT + deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION + no_object_weight = cfg.MODEL.DETR.NO_OBJECT_WEIGHT + + N_steps = hidden_dim // 2 + d2_backbone = MaskedBackbone(cfg) + backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True)) + backbone.num_channels = d2_backbone.num_channels + + transformer = Transformer( + d_model=hidden_dim, + dropout=dropout, + nhead=nheads, + dim_feedforward=dim_feedforward, + num_encoder_layers=enc_layers, + num_decoder_layers=dec_layers, + normalize_before=pre_norm, + return_intermediate_dec=deep_supervision, + pass_pos_and_query=pass_pos_and_query, + ) + + self.detr = DETR( + backbone, transformer, num_classes=self.num_classes, num_queries=num_queries, aux_loss=deep_supervision + ) + self.detr.to(self.device) + + # building criterion + matcher = HungarianMatcher(cost_class=1, cost_bbox=l1_weight, cost_giou=giou_weight) + weight_dict = {"loss_ce": 1, "loss_bbox": l1_weight} + weight_dict["loss_giou"] = giou_weight + if deep_supervision: + aux_weight_dict = {} + for i in range(dec_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + losses = ["labels", "boxes", "cardinality"] + self.criterion = SetCriterion( + self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses + ) + self.criterion.to(self.device) + + pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) + pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) + self.normalizer = lambda x: (x - pixel_mean) / pixel_std + self.to(self.device) + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper` . + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + + * image: Tensor, image in (C, H, W) format. + * instances: Instances + + Other information that's included in the original dicts, such as: + + * "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + Returns: + dict[str: Tensor]: + mapping from a named loss to a tensor storing the loss. Used during training only. + """ + images = self.preprocess_image(batched_inputs) + output = self.detr(images) + + if self.training: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + + targets = self.prepare_targets(gt_instances) + loss_dict = self.criterion(output, targets) + weight_dict = self.criterion.weight_dict + for k in loss_dict.keys(): + if k in weight_dict: + loss_dict[k] *= weight_dict[k] + return loss_dict + else: + box_cls = output["pred_logits"] + box_pred = output["pred_boxes"] + results = self.inference(box_cls, box_pred, images.image_sizes) + processed_results = [] + for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"instances": r}) + return processed_results + + def prepare_targets(self, targets): + new_targets = [] + for targets_per_image in targets: + h, w = targets_per_image.image_size + image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) + gt_classes = targets_per_image.gt_classes + gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy + gt_boxes = box_xyxy_to_cxcywh(gt_boxes) + new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) + return new_targets + + def inference(self, box_cls, box_pred, image_sizes): + """ + Arguments: + box_cls (Tensor): tensor of shape (batch_size, num_queries, K). + The tensor predicts the classification probability for each query. + box_pred (Tensor): tensors of shape (batch_size, num_queries, 4). + The tensor predicts 4-vector (x,y,w,h) box + regression values for every queryx + image_sizes (List[torch.Size]): the input image sizes + + Returns: + results (List[Instances]): a list of #images elements. + """ + assert len(box_cls) == len(image_sizes) + results = [] + + # For each box we assign the best class or the second best if the best on is `no_object`. + scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1) + + for scores_per_image, labels_per_image, box_pred_per_image, image_size in zip( + scores, labels, box_pred, image_sizes + ): + result = Instances(image_size) + result.pred_boxes = Boxes(box_cxcywh_to_xyxy(box_pred_per_image)) + + result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) + + result.scores = scores_per_image + result.pred_classes = labels_per_image + results.append(result) + return results + + def preprocess_image(self, batched_inputs): + """ + Normalize, pad and batch the input images. + """ + images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] + images = ImageList.from_tensors(images) + return images diff --git a/d2/train_net.py b/d2/train_net.py new file mode 100644 index 000000000..bd79f7f70 --- /dev/null +++ b/d2/train_net.py @@ -0,0 +1,154 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Training Script. + +This script is a simplified version of the training script in detectron2/tools. +""" +import os +import sys + +# fmt: off +sys.path.insert(1, os.path.join(sys.path[0], '..')) +# fmt: on + +import time +from typing import Any, Dict, List, Set + +import torch + +import detectron2.utils.comm as comm +from d2.detr import DetrDatasetMapper, add_detr_config +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog, build_detection_train_loader +from detectron2.engine import AutogradProfiler, DefaultTrainer, default_argument_parser, default_setup, launch +from detectron2.evaluation import COCOEvaluator, verify_results + +from detectron2.solver.build import maybe_add_gradient_clipping + + +class Trainer(DefaultTrainer): + """ + Extension of the Trainer class adapted to DETR. + """ + + def __init__(self, cfg): + """ + Args: + cfg (CfgNode): + """ + self.clip_norm_val = 0.0 + if cfg.SOLVER.CLIP_GRADIENTS.ENABLED: + if cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": + self.clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE + super().__init__(cfg) + + def run_step(self): + assert self.model.training, "[Trainer] model was changed to eval mode!" + start = time.perf_counter() + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + loss_dict = self.model(data) + losses = sum(loss_dict.values()) + self._detect_anomaly(losses, loss_dict) + + metrics_dict = loss_dict + metrics_dict["data_time"] = data_time + self._write_metrics(metrics_dict) + + self.optimizer.zero_grad() + losses.backward() + if self.clip_norm_val > 0.0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm_val) + self.optimizer.step() + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + return COCOEvaluator(dataset_name, cfg, True, output_folder) + + @classmethod + def build_train_loader(cls, cfg): + if "Detr" == cfg.MODEL.META_ARCHITECTURE: + mapper = DetrDatasetMapper(cfg, True) + else: + mapper = None + return build_detection_train_loader(cfg, mapper=mapper) + + @classmethod + def build_optimizer(cls, cfg, model): + params: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + for key, value in model.named_parameters(recurse=True): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + lr = cfg.SOLVER.BASE_LR + weight_decay = cfg.SOLVER.WEIGHT_DECAY + if "backbone" in key: + lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER + params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] + + optimizer_type = cfg.SOLVER.OPTIMIZER + if optimizer_type == "SGD": + optimizer = torch.optim.SGD(params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM) + elif optimizer_type == "ADAMW": + optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR) + else: + raise NotImplementedError(f"no optimizer type {optimizer_type}") + if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": + optimizer = maybe_add_gradient_clipping(cfg, optimizer) + return optimizer + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_detr_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume) + res = Trainer.test(cfg, model) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + )