Skip to content

Commit

Permalink
Allow multiple types of models. (#506)
Browse files Browse the repository at this point in the history
* flexible model structure

* checking out conflicts from main

* update config with correct changes

* model architecture doc
  • Loading branch information
bw4sz authored Oct 11, 2023
1 parent 8de66db commit b612d85
Show file tree
Hide file tree
Showing 14 changed files with 646 additions and 201 deletions.
53 changes: 41 additions & 12 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import rasterio as rio
import cv2
import warnings
import importlib


class deepforest(pl.LightningModule):
Expand All @@ -27,11 +28,16 @@ def __init__(self,
num_classes: int = 1,
label_dict: dict = {"Tree": 0},
transforms=None,
config_file: str = 'deepforest_config.yml'):
config_file: str = 'deepforest_config.yml',
config_args=None,
model=None):
"""
Args:
num_classes (int): number of classes in the model
config_file (str): path to deepforest config file
model (model.Model()): a deepforest model object, see model.Model().
config_args (dict): a dictionary of key->value to update config file at run time. e.g. {"batch_size":10}
- This is useful for iterating over arguments during model testing.
Returns:
self: a deepforest pytorch lightning module
"""
Expand All @@ -51,13 +57,20 @@ def __init__(self,

print("Reading config file: {}".format(config_path))
self.config = utilities.read_config(config_path)
self.model = model

# release version id to flag if release is being used
self.__release_version__ = None

self.num_classes = num_classes
self.config["num_classes"] = num_classes
self.create_model()

# If num classes is specified, overwrite config
if not num_classes == 1:
warnings.warn(
"Directly specifying the num_classes arg in deepforest.main will be deprecated in 2.0 in favor of config_args. Use deepforest.main(config_args={'num_classes':value})"
)

# Metrics
self.iou_metric = IntersectionOverUnion(
class_metrics=True, iou_threshold=self.config["validation"]["iou_threshold"])
Expand All @@ -67,12 +80,12 @@ def __init__(self,
self.create_trainer()

# Label encoder and decoder
if not len(label_dict) == num_classes:
if not len(label_dict) == self.config["num_classes"]:
raise ValueError('label_dict {} does not match requested number of '
'classes {}, please supply a label_dict argument '
'{{"label1":0, "label2":1, "label3":2 ... etc}} '
'for each label in the '
'dataset'.format(label_dict, num_classes))
'dataset'.format(label_dict, self.config["num_classes"]))

self.label_dict = label_dict
self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
Expand All @@ -96,6 +109,12 @@ def use_release(self, check_release=True):
# Download latest model from github release
release_tag, self.release_state_dict = utilities.use_release(
check_release=check_release)
if self.config["architecture"] is not "retinanet":
warnings.warn(
"The config file specifies architecture {}, but the release model is torchvision retinanet. Reloading with deepforest.main with a retinanet model"
.format(self.config["architecture"]))
self.config["architecture"] = "retinanet"
self.create_model()
self.model.load_state_dict(torch.load(self.release_state_dict))

# load saved model and tag release
Expand Down Expand Up @@ -127,16 +146,30 @@ def use_bird_release(self, check_release=True):
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()}

def create_model(self):
"""Define a deepforest retinanet architecture"""
self.model = model.create_model(self.num_classes, self.config["nms_thresh"],
self.config["score_thresh"])
"""Define a deepforest architecture. This can be done in two ways.
Passed as the model argument to deepforest __init__(),
or as a named architecture in config["architecture"],
which corresponds to a file in models/, as is a subclass of model.Model().
The config args in the .yaml are specified
retinanet:
nms_thresh: 0.1
score_thresh: 0.2
RCNN:
nms_thresh: 0.1
etc.
"""
if self.model is None:
model_name = importlib.import_module("deepforest.models.{}".format(
self.config["architecture"]))
self.model = model_name.Model(config=self.config).create_model()
else:
pass

def create_trainer(self, logger=None, callbacks=[], **kwargs):
"""Create a pytorch lightning training by reading config files
Args:
callbacks (list): a list of pytorch-lightning callback classes
"""

# If val data is passed, monitor learning rate and setup classification metrics
if not self.config["validation"]["csv_file"] is None:
if logger is not None:
Expand Down Expand Up @@ -293,7 +326,6 @@ def predict_image(self,
"np.array(image).astype(float32)".format(type(image)))

self.model.eval()
self.model.score_thresh = self.config["score_thresh"]

if image.dtype != "float32":
warnings.warn(f"Image type is {image.dtype}, transforming to float32. "
Expand Down Expand Up @@ -354,7 +386,6 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
df: pandas dataframe with bounding boxes, label and scores for each image in the csv file
"""
self.model.eval()
self.model.score_thresh = self.config["score_thresh"]
df = pd.read_csv(csv_file)
paths = df.image_path.unique()
ds = dataset.TreeDataset(csv_file=csv_file,
Expand Down Expand Up @@ -442,7 +473,6 @@ def predict_tile(self,
Otherwise a numpy array of predicted bounding boxes, scores and labels
"""
self.model.eval()
self.model.score_thresh = self.config["score_thresh"]
self.model.nms_thresh = self.config["nms_thresh"]

if (raster_path is None) and (image is None):
Expand Down Expand Up @@ -606,7 +636,6 @@ def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None):
results: dict of ("results", "precision", "recall") for a given threshold
"""
self.model.eval()
self.model.score_thresh = self.config["score_thresh"]

predictions = self.predict_file(csv_file=csv_file,
root_dir=root_dir,
Expand Down
86 changes: 39 additions & 47 deletions deepforest/model.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,51 @@
# Model
import torchvision
from torchvision.models.detection.retinanet import RetinaNet
from torchvision.models.detection.retinanet import AnchorGenerator
from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights
# Model - common class
from deepforest.models import *
import torch


def load_backbone():
"""A torch vision retinanet model"""
backbone = torchvision.models.detection.retinanet_resnet50_fpn(
weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1)

# load the model onto the computation device
return backbone


def create_anchor_generator(sizes=((8, 16, 32, 64, 128, 256, 400),),
aspect_ratios=((0.5, 1.0, 2.0),)):
"""
Create anchor box generator as a function of sizes and aspect ratios
Documented https://github.com/pytorch/vision/blob/67b25288ca202d027e8b06e17111f1bcebd2046c/torchvision/models/detection/anchor_utils.py#L9
let's make the network generate 5 x 3 anchors per spatial
location, with 5 different sizes and 3 different aspect
ratios. We have a Tuple[Tuple[int]] because each feature
map could potentially have different sizes and
aspect ratios
Args:
sizes:
aspect_ratios:
Returns: anchor_generator, a pytorch module
"""
anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)

return anchor_generator


def create_model(num_classes, nms_thresh, score_thresh, backbone=None):
"""Create a retinanet model
class Model():
"""A architecture agnostic class that controls the basic train, eval and predict functions.
A model should optionally allow a backbone for pretraining. To add new architectures, simply create a new module in models/ and write a create_model.
Then add the result to the if else statement below.
Args:
num_classes (int): number of classes in the model
nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1]
score_thresh (float): minimum prediction score to keep during prediction [0,1]
Returns:
model: a pytorch nn module
"""
if not backbone:
resnet = load_backbone()
backbone = resnet.backbone

model = RetinaNet(backbone=backbone, num_classes=num_classes)
model.nms_thresh = nms_thresh
model.score_thresh = score_thresh
def __init__(self, config):

# Check for required properties and formats
self.config = config

# Check input output format:
self.check_model()

def create_model(self):
"""This function converts a deepforest config file into a model. An architecture should have a list of nested arguments in config that match this function"""
raise ValueError(
"The create_model class method needs to be implemented. Take in args and return a pytorch nn module."
)

def check_model(self):
"""
Ensure that model follows deepforest guidelines, see #####
If fails, raise ValueError
"""
# This assumes model creation is not expensive
test_model = self.create_model()
test_model.eval()

# Create a dummy batch of 3 band data.
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]

# Optionally allow anchor generator parameters to be created here
# https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html
predictions = test_model(x)
# Model takes in a batch of images
assert len(predictions) == 2

return model
# Returns a list equal to number of images with proper keys per image
model_keys = list(predictions[1].keys())
model_keys.sort()
assert model_keys == ['boxes', 'labels', 'scores']
1 change: 1 addition & 0 deletions deepforest/models/# Model modules
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Model modules
37 changes: 37 additions & 0 deletions deepforest/models/FasterRCNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
import torch
from deepforest.model import Model


class Model(Model):

def __init__(self, config, **kwargs):
super().__init__(config)

def load_backbone(self):
"""A torch vision retinanet model"""
backbone = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

return backbone

def create_model(self, backbone=None):
"""Create a FasterRCNN model
Args:
backbone: a compatible torchvision backbone, e.g. torchvision.models.detection.fasterrcnn_resnet50_fpn
Returns:
model: a pytorch nn module
"""
# load Faster RCNN pre-trained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# get the number of input features
in_features = model.roi_heads.box_predictor.cls_score.in_features

# define a new head for the detector with required number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(
in_features, num_classes=self.config["num_classes"])

return model
62 changes: 62 additions & 0 deletions deepforest/models/retinanet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Model
import torchvision
from torchvision.models.detection.retinanet import RetinaNet
from torchvision.models.detection.retinanet import AnchorGenerator
from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights
from deepforest.model import Model


class Model(Model):

def __init__(self, config, **kwargs):
super().__init__(config)

def load_backbone(self):
"""A torch vision retinanet model"""
backbone = torchvision.models.detection.retinanet_resnet50_fpn(
weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1)

return backbone

def create_anchor_generator(self,
sizes=((8, 16, 32, 64, 128, 256, 400),),
aspect_ratios=((0.5, 1.0, 2.0),)):
"""
Create anchor box generator as a function of sizes and aspect ratios
Documented https://github.com/pytorch/vision/blob/67b25288ca202d027e8b06e17111f1bcebd2046c/torchvision/models/detection/anchor_utils.py#L9
let's make the network generate 5 x 3 anchors per spatial
location, with 5 different sizes and 3 different aspect
ratios. We have a Tuple[Tuple[int]] because each feature
map could potentially have different sizes and
aspect ratios
Args:
sizes:
aspect_ratios:
Returns: anchor_generator, a pytorch module
"""
anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)

return anchor_generator

def create_model(self):
"""Create a retinanet model
Args:
num_classes (int): number of classes in the model
nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1]
score_thresh (float): minimum prediction score to keep during prediction [0,1]
Returns:
model: a pytorch nn module
"""
resnet = self.load_backbone()
backbone = resnet.backbone

model = RetinaNet(backbone=backbone, num_classes=self.config["num_classes"])
model.nms_thresh = self.config["nms_thresh"]
model.score_thresh = self.config["retinanet"]["score_thresh"]

# Optionally allow anchor generator parameters to be created here
# https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html

return model
27 changes: 16 additions & 11 deletions deepforest_config.yml
Original file line number Diff line number Diff line change
@@ -1,35 +1,40 @@
# Config file for DeepForest pytorch module

#cpu workers for data loaders
#Dataloaders
# Cpu workers for data loaders
# Dataloaders
workers: 1
devices: auto
accelerator: auto
batch_size: 1

#Non-max supression of overlapping predictions
# Model Architecture
architecture: 'retinanet'
num_classes: 1
nms_thresh: 0.05
score_thresh: 0.1

train:
# Architecture specific params
retinanet:
# Non-max supression of overlapping predictions
score_thresh: 0.1

train:
csv_file:
root_dir:

#Optomizer initial learning rate
# Optimizer initial learning rate
lr: 0.001

#Print loss every n epochs
# Print loss every n epochs
epochs: 1
#Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings.
# Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings.
fast_dev_run: False
#pin images to GPU memory for fast training. This depends on GPU size and number of images.
# pin images to GPU memory for fast training. This depends on GPU size and number of images.
preload_images: False

validation:
#callback args
# callback args
csv_file:
root_dir:
#Intersection over union evaluation
# Intersection over union evaluation
iou_threshold: 0.4
val_accuracy_interval: 20
Loading

0 comments on commit b612d85

Please sign in to comment.