-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow multiple types of models. (#506)
* flexible model structure * checking out conflicts from main * update config with correct changes * model architecture doc
- Loading branch information
Showing
14 changed files
with
646 additions
and
201 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,51 @@ | ||
# Model | ||
import torchvision | ||
from torchvision.models.detection.retinanet import RetinaNet | ||
from torchvision.models.detection.retinanet import AnchorGenerator | ||
from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights | ||
# Model - common class | ||
from deepforest.models import * | ||
import torch | ||
|
||
|
||
def load_backbone(): | ||
"""A torch vision retinanet model""" | ||
backbone = torchvision.models.detection.retinanet_resnet50_fpn( | ||
weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1) | ||
|
||
# load the model onto the computation device | ||
return backbone | ||
|
||
|
||
def create_anchor_generator(sizes=((8, 16, 32, 64, 128, 256, 400),), | ||
aspect_ratios=((0.5, 1.0, 2.0),)): | ||
""" | ||
Create anchor box generator as a function of sizes and aspect ratios | ||
Documented https://github.com/pytorch/vision/blob/67b25288ca202d027e8b06e17111f1bcebd2046c/torchvision/models/detection/anchor_utils.py#L9 | ||
let's make the network generate 5 x 3 anchors per spatial | ||
location, with 5 different sizes and 3 different aspect | ||
ratios. We have a Tuple[Tuple[int]] because each feature | ||
map could potentially have different sizes and | ||
aspect ratios | ||
Args: | ||
sizes: | ||
aspect_ratios: | ||
Returns: anchor_generator, a pytorch module | ||
""" | ||
anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios) | ||
|
||
return anchor_generator | ||
|
||
|
||
def create_model(num_classes, nms_thresh, score_thresh, backbone=None): | ||
"""Create a retinanet model | ||
class Model(): | ||
"""A architecture agnostic class that controls the basic train, eval and predict functions. | ||
A model should optionally allow a backbone for pretraining. To add new architectures, simply create a new module in models/ and write a create_model. | ||
Then add the result to the if else statement below. | ||
Args: | ||
num_classes (int): number of classes in the model | ||
nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1] | ||
score_thresh (float): minimum prediction score to keep during prediction [0,1] | ||
Returns: | ||
model: a pytorch nn module | ||
""" | ||
if not backbone: | ||
resnet = load_backbone() | ||
backbone = resnet.backbone | ||
|
||
model = RetinaNet(backbone=backbone, num_classes=num_classes) | ||
model.nms_thresh = nms_thresh | ||
model.score_thresh = score_thresh | ||
def __init__(self, config): | ||
|
||
# Check for required properties and formats | ||
self.config = config | ||
|
||
# Check input output format: | ||
self.check_model() | ||
|
||
def create_model(self): | ||
"""This function converts a deepforest config file into a model. An architecture should have a list of nested arguments in config that match this function""" | ||
raise ValueError( | ||
"The create_model class method needs to be implemented. Take in args and return a pytorch nn module." | ||
) | ||
|
||
def check_model(self): | ||
""" | ||
Ensure that model follows deepforest guidelines, see ##### | ||
If fails, raise ValueError | ||
""" | ||
# This assumes model creation is not expensive | ||
test_model = self.create_model() | ||
test_model.eval() | ||
|
||
# Create a dummy batch of 3 band data. | ||
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] | ||
|
||
# Optionally allow anchor generator parameters to be created here | ||
# https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html | ||
predictions = test_model(x) | ||
# Model takes in a batch of images | ||
assert len(predictions) == 2 | ||
|
||
return model | ||
# Returns a list equal to number of images with proper keys per image | ||
model_keys = list(predictions[1].keys()) | ||
model_keys.sort() | ||
assert model_keys == ['boxes', 'labels', 'scores'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Model modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torchvision | ||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | ||
from torchvision.models.detection import FasterRCNN | ||
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork | ||
import torch | ||
from deepforest.model import Model | ||
|
||
|
||
class Model(Model): | ||
|
||
def __init__(self, config, **kwargs): | ||
super().__init__(config) | ||
|
||
def load_backbone(self): | ||
"""A torch vision retinanet model""" | ||
backbone = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | ||
|
||
return backbone | ||
|
||
def create_model(self, backbone=None): | ||
"""Create a FasterRCNN model | ||
Args: | ||
backbone: a compatible torchvision backbone, e.g. torchvision.models.detection.fasterrcnn_resnet50_fpn | ||
Returns: | ||
model: a pytorch nn module | ||
""" | ||
# load Faster RCNN pre-trained model | ||
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | ||
|
||
# get the number of input features | ||
in_features = model.roi_heads.box_predictor.cls_score.in_features | ||
|
||
# define a new head for the detector with required number of classes | ||
model.roi_heads.box_predictor = FastRCNNPredictor( | ||
in_features, num_classes=self.config["num_classes"]) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Model | ||
import torchvision | ||
from torchvision.models.detection.retinanet import RetinaNet | ||
from torchvision.models.detection.retinanet import AnchorGenerator | ||
from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights | ||
from deepforest.model import Model | ||
|
||
|
||
class Model(Model): | ||
|
||
def __init__(self, config, **kwargs): | ||
super().__init__(config) | ||
|
||
def load_backbone(self): | ||
"""A torch vision retinanet model""" | ||
backbone = torchvision.models.detection.retinanet_resnet50_fpn( | ||
weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1) | ||
|
||
return backbone | ||
|
||
def create_anchor_generator(self, | ||
sizes=((8, 16, 32, 64, 128, 256, 400),), | ||
aspect_ratios=((0.5, 1.0, 2.0),)): | ||
""" | ||
Create anchor box generator as a function of sizes and aspect ratios | ||
Documented https://github.com/pytorch/vision/blob/67b25288ca202d027e8b06e17111f1bcebd2046c/torchvision/models/detection/anchor_utils.py#L9 | ||
let's make the network generate 5 x 3 anchors per spatial | ||
location, with 5 different sizes and 3 different aspect | ||
ratios. We have a Tuple[Tuple[int]] because each feature | ||
map could potentially have different sizes and | ||
aspect ratios | ||
Args: | ||
sizes: | ||
aspect_ratios: | ||
Returns: anchor_generator, a pytorch module | ||
""" | ||
anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios) | ||
|
||
return anchor_generator | ||
|
||
def create_model(self): | ||
"""Create a retinanet model | ||
Args: | ||
num_classes (int): number of classes in the model | ||
nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1] | ||
score_thresh (float): minimum prediction score to keep during prediction [0,1] | ||
Returns: | ||
model: a pytorch nn module | ||
""" | ||
resnet = self.load_backbone() | ||
backbone = resnet.backbone | ||
|
||
model = RetinaNet(backbone=backbone, num_classes=self.config["num_classes"]) | ||
model.nms_thresh = self.config["nms_thresh"] | ||
model.score_thresh = self.config["retinanet"]["score_thresh"] | ||
|
||
# Optionally allow anchor generator parameters to be created here | ||
# https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,40 @@ | ||
# Config file for DeepForest pytorch module | ||
|
||
#cpu workers for data loaders | ||
#Dataloaders | ||
# Cpu workers for data loaders | ||
# Dataloaders | ||
workers: 1 | ||
devices: auto | ||
accelerator: auto | ||
batch_size: 1 | ||
|
||
#Non-max supression of overlapping predictions | ||
# Model Architecture | ||
architecture: 'retinanet' | ||
num_classes: 1 | ||
nms_thresh: 0.05 | ||
score_thresh: 0.1 | ||
|
||
train: | ||
# Architecture specific params | ||
retinanet: | ||
# Non-max supression of overlapping predictions | ||
score_thresh: 0.1 | ||
|
||
train: | ||
csv_file: | ||
root_dir: | ||
|
||
#Optomizer initial learning rate | ||
# Optimizer initial learning rate | ||
lr: 0.001 | ||
|
||
#Print loss every n epochs | ||
# Print loss every n epochs | ||
epochs: 1 | ||
#Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. | ||
# Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. | ||
fast_dev_run: False | ||
#pin images to GPU memory for fast training. This depends on GPU size and number of images. | ||
# pin images to GPU memory for fast training. This depends on GPU size and number of images. | ||
preload_images: False | ||
|
||
validation: | ||
#callback args | ||
# callback args | ||
csv_file: | ||
root_dir: | ||
#Intersection over union evaluation | ||
# Intersection over union evaluation | ||
iou_threshold: 0.4 | ||
val_accuracy_interval: 20 |
Oops, something went wrong.