diff --git a/deepforest/main.py b/deepforest/main.py index 733a94f6..9b0e3dd3 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -17,6 +17,7 @@ import rasterio as rio import cv2 import warnings +import importlib class deepforest(pl.LightningModule): @@ -27,11 +28,16 @@ def __init__(self, num_classes: int = 1, label_dict: dict = {"Tree": 0}, transforms=None, - config_file: str = 'deepforest_config.yml'): + config_file: str = 'deepforest_config.yml', + config_args=None, + model=None): """ Args: num_classes (int): number of classes in the model config_file (str): path to deepforest config file + model (model.Model()): a deepforest model object, see model.Model(). + config_args (dict): a dictionary of key->value to update config file at run time. e.g. {"batch_size":10} +- This is useful for iterating over arguments during model testing. Returns: self: a deepforest pytorch lightning module """ @@ -51,13 +57,20 @@ def __init__(self, print("Reading config file: {}".format(config_path)) self.config = utilities.read_config(config_path) + self.model = model # release version id to flag if release is being used self.__release_version__ = None - self.num_classes = num_classes + self.config["num_classes"] = num_classes self.create_model() + # If num classes is specified, overwrite config + if not num_classes == 1: + warnings.warn( + "Directly specifying the num_classes arg in deepforest.main will be deprecated in 2.0 in favor of config_args. Use deepforest.main(config_args={'num_classes':value})" + ) + # Metrics self.iou_metric = IntersectionOverUnion( class_metrics=True, iou_threshold=self.config["validation"]["iou_threshold"]) @@ -67,12 +80,12 @@ def __init__(self, self.create_trainer() # Label encoder and decoder - if not len(label_dict) == num_classes: + if not len(label_dict) == self.config["num_classes"]: raise ValueError('label_dict {} does not match requested number of ' 'classes {}, please supply a label_dict argument ' '{{"label1":0, "label2":1, "label3":2 ... etc}} ' 'for each label in the ' - 'dataset'.format(label_dict, num_classes)) + 'dataset'.format(label_dict, self.config["num_classes"])) self.label_dict = label_dict self.numeric_to_label_dict = {v: k for k, v in label_dict.items()} @@ -96,6 +109,12 @@ def use_release(self, check_release=True): # Download latest model from github release release_tag, self.release_state_dict = utilities.use_release( check_release=check_release) + if self.config["architecture"] is not "retinanet": + warnings.warn( + "The config file specifies architecture {}, but the release model is torchvision retinanet. Reloading with deepforest.main with a retinanet model" + .format(self.config["architecture"])) + self.config["architecture"] = "retinanet" + self.create_model() self.model.load_state_dict(torch.load(self.release_state_dict)) # load saved model and tag release @@ -127,16 +146,30 @@ def use_bird_release(self, check_release=True): self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} def create_model(self): - """Define a deepforest retinanet architecture""" - self.model = model.create_model(self.num_classes, self.config["nms_thresh"], - self.config["score_thresh"]) + """Define a deepforest architecture. This can be done in two ways. + Passed as the model argument to deepforest __init__(), + or as a named architecture in config["architecture"], + which corresponds to a file in models/, as is a subclass of model.Model(). + The config args in the .yaml are specified + retinanet: + nms_thresh: 0.1 + score_thresh: 0.2 + RCNN: + nms_thresh: 0.1 + etc. + """ + if self.model is None: + model_name = importlib.import_module("deepforest.models.{}".format( + self.config["architecture"])) + self.model = model_name.Model(config=self.config).create_model() + else: + pass def create_trainer(self, logger=None, callbacks=[], **kwargs): """Create a pytorch lightning training by reading config files Args: callbacks (list): a list of pytorch-lightning callback classes """ - # If val data is passed, monitor learning rate and setup classification metrics if not self.config["validation"]["csv_file"] is None: if logger is not None: @@ -293,7 +326,6 @@ def predict_image(self, "np.array(image).astype(float32)".format(type(image))) self.model.eval() - self.model.score_thresh = self.config["score_thresh"] if image.dtype != "float32": warnings.warn(f"Image type is {image.dtype}, transforming to float32. " @@ -354,7 +386,6 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1 df: pandas dataframe with bounding boxes, label and scores for each image in the csv file """ self.model.eval() - self.model.score_thresh = self.config["score_thresh"] df = pd.read_csv(csv_file) paths = df.image_path.unique() ds = dataset.TreeDataset(csv_file=csv_file, @@ -442,7 +473,6 @@ def predict_tile(self, Otherwise a numpy array of predicted bounding boxes, scores and labels """ self.model.eval() - self.model.score_thresh = self.config["score_thresh"] self.model.nms_thresh = self.config["nms_thresh"] if (raster_path is None) and (image is None): @@ -606,7 +636,6 @@ def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None): results: dict of ("results", "precision", "recall") for a given threshold """ self.model.eval() - self.model.score_thresh = self.config["score_thresh"] predictions = self.predict_file(csv_file=csv_file, root_dir=root_dir, diff --git a/deepforest/model.py b/deepforest/model.py index 4b939922..bde76c0a 100644 --- a/deepforest/model.py +++ b/deepforest/model.py @@ -1,43 +1,12 @@ -# Model -import torchvision -from torchvision.models.detection.retinanet import RetinaNet -from torchvision.models.detection.retinanet import AnchorGenerator -from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights +# Model - common class +from deepforest.models import * +import torch -def load_backbone(): - """A torch vision retinanet model""" - backbone = torchvision.models.detection.retinanet_resnet50_fpn( - weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1) - - # load the model onto the computation device - return backbone - - -def create_anchor_generator(sizes=((8, 16, 32, 64, 128, 256, 400),), - aspect_ratios=((0.5, 1.0, 2.0),)): - """ - Create anchor box generator as a function of sizes and aspect ratios - Documented https://github.com/pytorch/vision/blob/67b25288ca202d027e8b06e17111f1bcebd2046c/torchvision/models/detection/anchor_utils.py#L9 - let's make the network generate 5 x 3 anchors per spatial - location, with 5 different sizes and 3 different aspect - ratios. We have a Tuple[Tuple[int]] because each feature - map could potentially have different sizes and - aspect ratios - Args: - sizes: - aspect_ratios: - - Returns: anchor_generator, a pytorch module - - """ - anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios) - - return anchor_generator - - -def create_model(num_classes, nms_thresh, score_thresh, backbone=None): - """Create a retinanet model +class Model(): + """A architecture agnostic class that controls the basic train, eval and predict functions. + A model should optionally allow a backbone for pretraining. To add new architectures, simply create a new module in models/ and write a create_model. + Then add the result to the if else statement below. Args: num_classes (int): number of classes in the model nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1] @@ -45,15 +14,38 @@ def create_model(num_classes, nms_thresh, score_thresh, backbone=None): Returns: model: a pytorch nn module """ - if not backbone: - resnet = load_backbone() - backbone = resnet.backbone - model = RetinaNet(backbone=backbone, num_classes=num_classes) - model.nms_thresh = nms_thresh - model.score_thresh = score_thresh + def __init__(self, config): + + # Check for required properties and formats + self.config = config + + # Check input output format: + self.check_model() + + def create_model(self): + """This function converts a deepforest config file into a model. An architecture should have a list of nested arguments in config that match this function""" + raise ValueError( + "The create_model class method needs to be implemented. Take in args and return a pytorch nn module." + ) + + def check_model(self): + """ + Ensure that model follows deepforest guidelines, see ##### + If fails, raise ValueError + """ + # This assumes model creation is not expensive + test_model = self.create_model() + test_model.eval() + + # Create a dummy batch of 3 band data. + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - # Optionally allow anchor generator parameters to be created here - # https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html + predictions = test_model(x) + # Model takes in a batch of images + assert len(predictions) == 2 - return model + # Returns a list equal to number of images with proper keys per image + model_keys = list(predictions[1].keys()) + model_keys.sort() + assert model_keys == ['boxes', 'labels', 'scores'] diff --git a/deepforest/models/# Model modules b/deepforest/models/# Model modules new file mode 100644 index 00000000..06fe33b7 --- /dev/null +++ b/deepforest/models/# Model modules @@ -0,0 +1 @@ +# Model modules \ No newline at end of file diff --git a/deepforest/models/FasterRCNN.py b/deepforest/models/FasterRCNN.py new file mode 100644 index 00000000..16d0f8cb --- /dev/null +++ b/deepforest/models/FasterRCNN.py @@ -0,0 +1,37 @@ +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +from torchvision.models.detection import FasterRCNN +from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +import torch +from deepforest.model import Model + + +class Model(Model): + + def __init__(self, config, **kwargs): + super().__init__(config) + + def load_backbone(self): + """A torch vision retinanet model""" + backbone = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + + return backbone + + def create_model(self, backbone=None): + """Create a FasterRCNN model + Args: + backbone: a compatible torchvision backbone, e.g. torchvision.models.detection.fasterrcnn_resnet50_fpn + Returns: + model: a pytorch nn module + """ + # load Faster RCNN pre-trained model + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + + # get the number of input features + in_features = model.roi_heads.box_predictor.cls_score.in_features + + # define a new head for the detector with required number of classes + model.roi_heads.box_predictor = FastRCNNPredictor( + in_features, num_classes=self.config["num_classes"]) + + return model diff --git a/deepforest/models/retinanet.py b/deepforest/models/retinanet.py new file mode 100644 index 00000000..efc7fc52 --- /dev/null +++ b/deepforest/models/retinanet.py @@ -0,0 +1,62 @@ +# Model +import torchvision +from torchvision.models.detection.retinanet import RetinaNet +from torchvision.models.detection.retinanet import AnchorGenerator +from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights +from deepforest.model import Model + + +class Model(Model): + + def __init__(self, config, **kwargs): + super().__init__(config) + + def load_backbone(self): + """A torch vision retinanet model""" + backbone = torchvision.models.detection.retinanet_resnet50_fpn( + weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1) + + return backbone + + def create_anchor_generator(self, + sizes=((8, 16, 32, 64, 128, 256, 400),), + aspect_ratios=((0.5, 1.0, 2.0),)): + """ + Create anchor box generator as a function of sizes and aspect ratios + Documented https://github.com/pytorch/vision/blob/67b25288ca202d027e8b06e17111f1bcebd2046c/torchvision/models/detection/anchor_utils.py#L9 + let's make the network generate 5 x 3 anchors per spatial + location, with 5 different sizes and 3 different aspect + ratios. We have a Tuple[Tuple[int]] because each feature + map could potentially have different sizes and + aspect ratios + Args: + sizes: + aspect_ratios: + + Returns: anchor_generator, a pytorch module + + """ + anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios) + + return anchor_generator + + def create_model(self): + """Create a retinanet model + Args: + num_classes (int): number of classes in the model + nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1] + score_thresh (float): minimum prediction score to keep during prediction [0,1] + Returns: + model: a pytorch nn module + """ + resnet = self.load_backbone() + backbone = resnet.backbone + + model = RetinaNet(backbone=backbone, num_classes=self.config["num_classes"]) + model.nms_thresh = self.config["nms_thresh"] + model.score_thresh = self.config["retinanet"]["score_thresh"] + + # Optionally allow anchor generator parameters to be created here + # https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html + + return model diff --git a/deepforest_config.yml b/deepforest_config.yml index 60afdf88..9290c509 100644 --- a/deepforest_config.yml +++ b/deepforest_config.yml @@ -1,35 +1,40 @@ # Config file for DeepForest pytorch module -#cpu workers for data loaders -#Dataloaders +# Cpu workers for data loaders +# Dataloaders workers: 1 devices: auto accelerator: auto batch_size: 1 -#Non-max supression of overlapping predictions +# Model Architecture +architecture: 'retinanet' +num_classes: 1 nms_thresh: 0.05 -score_thresh: 0.1 -train: +# Architecture specific params +retinanet: + # Non-max supression of overlapping predictions + score_thresh: 0.1 +train: csv_file: root_dir: - #Optomizer initial learning rate + # Optimizer initial learning rate lr: 0.001 - #Print loss every n epochs + # Print loss every n epochs epochs: 1 - #Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. + # Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. fast_dev_run: False - #pin images to GPU memory for fast training. This depends on GPU size and number of images. + # pin images to GPU memory for fast training. This depends on GPU size and number of images. preload_images: False validation: - #callback args + # callback args csv_file: root_dir: - #Intersection over union evaluation + # Intersection over union evaluation iou_threshold: 0.4 val_accuracy_interval: 20 \ No newline at end of file diff --git a/docs/Model_Architecture.md b/docs/Model_Architecture.md new file mode 100644 index 00000000..ac4dd425 --- /dev/null +++ b/docs/Model_Architecture.md @@ -0,0 +1,79 @@ +# Model Architecture + +DeepForest allows users to specify custom model architectures if they follow certain guidelines. +To create a compliant format, follow the recipe below. + +## Subclass the model.Model() structure + +A subclass is a class instance that inherits the methods and function of super classes. In this cases, model.Model() is defined as: + +``` +# Model - common class +from deepforest.models import * +import torch + +class Model(): + """A architecture agnostic class that controls the basic train, eval and predict functions. + A model should optionally allow a backbone for pretraining. To add new architectures, simply create a new module in models/ and write a create_model. + Then add the result to the if else statement below. + Args: + num_classes (int): number of classes in the model + nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1] + score_thresh (float): minimum prediction score to keep during prediction [0,1] + Returns: + model: a pytorch nn module + """ + def __init__(self, config): + + # Check for required properties and formats + self.config = config + + # Check input output format: + self.check_model() + + def create_model(): + """This function converts a deepforest config file into a model. An architecture should have a list of nested arguments in config that match this function""" + raise ValueError("The create_model class method needs to be implemented. Take in args and return a pytorch nn module.") + + def check_model(self): + """ + Ensure that model follows deepforest guidelines + If fails, raise ValueError + """ + # This assumes model creation is not expensive + test_model = self.create_model() + test_model.eval() + + # Create a dummy batch of 3 band data. + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + + predictions = test_model(x) + # Model takes in a batch of images + assert len(predictions) == 2 + + # Returns a list equal to number of images with proper keys per image + model_keys = list(predictions[1].keys()) + model_keys.sort() + assert model_keys == ['boxes','labels','scores'] +``` + +## Match torchvision formats + +From this definition we can see three format requirements. The model must be able to take in a batch of images in the order [channels, height, width]. The current model weights are trained on 3 band images, but you can update the check_model function if you have other image dimensions. +The second requirement is that the model ouputs a dictionary with keys ["boxes","labels","scores"], the boxes are formatted following torchvision object detection format. From the [docs](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.retinanet_resnet50_fpn.html#torchvision.models.detection.retinanet_resnet50_fpn) + +``` +During training, the model expects both the input tensors and targets (list of dictionary), containing: + +boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H. + +labels (Int64Tensor[N]): the class label for each ground-truth box + +During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows, where N is the number of detections: + +boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H. + +labels (Int64Tensor[N]): the predicted labels for each detection + +scores (Tensor[N]): the scores of each detection +``` diff --git a/docs/index.rst b/docs/index.rst index 7373883f..6a3cbcc6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,6 +44,7 @@ The most helpful thing you can do is leave feedback on DeepForest `issue page`_. training Evaluation multi_species + Model_Architecture ConfigurationFile deepforestr FAQ diff --git a/tests/conftest.py b/tests/conftest.py index a5207217..495e1392 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,16 +3,29 @@ import pytest from deepforest import utilities, main from deepforest import get_data +from deepforest import _ROOT import os collect_ignore = ['setup.py'] +@pytest.fixture(scope="session") +def config(): + config = utilities.read_config("{}/deepforest_config.yml".format(os.path.dirname(_ROOT))) + config["fast_dev_run"] = True + config["batch_size"] = True + + return config + @pytest.fixture(scope="session") def download_release(): print("running fixtures") utilities.use_release() assert os.path.exists(get_data("NEON.pt")) +@pytest.fixture(scope="session") +def ROOT(): + return _ROOT + @pytest.fixture() def two_class_m(): m = main.deepforest(num_classes=2,label_dict={"Alive":0,"Dead":1}) diff --git a/tests/data/OSBS_029.csv b/tests/data/OSBS_029.csv index da5f1754..a1ee76f1 100644 --- a/tests/data/OSBS_029.csv +++ b/tests/data/OSBS_029.csv @@ -1,62 +1,160 @@ image_path,xmin,ymin,xmax,ymax,label -OSBS_029.tif,203,67,227,90,Tree -OSBS_029.tif,256,99,288,140,Tree -OSBS_029.tif,166,253,225,304,Tree -OSBS_029.tif,365,2,400,27,Tree -OSBS_029.tif,312,13,349,47,Tree -OSBS_029.tif,365,21,400,70,Tree -OSBS_029.tif,278,1,312,37,Tree -OSBS_029.tif,364,204,400,246,Tree -OSBS_029.tif,90,117,121,145,Tree -OSBS_029.tif,115,109,150,152,Tree -OSBS_029.tif,161,155,199,191,Tree -OSBS_029.tif,120,153,160,192,Tree -OSBS_029.tif,349,290,375,320,Tree -OSBS_029.tif,1,153,53,217,Tree -OSBS_029.tif,1,218,41,254,Tree -OSBS_029.tif,65,143,117,190,Tree -OSBS_029.tif,368,78,400,110,Tree -OSBS_029.tif,149,95,203,156,Tree -OSBS_029.tif,154,195,190,229,Tree -OSBS_029.tif,103,195,153,244,Tree -OSBS_029.tif,49,377,75,400,Tree -OSBS_029.tif,116,367,151,400,Tree -OSBS_029.tif,234,243,253,267,Tree -OSBS_029.tif,292,367,337,400,Tree -OSBS_029.tif,333,336,374,384,Tree -OSBS_029.tif,1,13,39,62,Tree -OSBS_029.tif,1,65,40,106,Tree -OSBS_029.tif,50,3,102,57,Tree -OSBS_029.tif,102,36,130,68,Tree -OSBS_029.tif,156,5,180,38,Tree -OSBS_029.tif,186,1,231,40,Tree -OSBS_029.tif,382,260,400,305,Tree -OSBS_029.tif,331,128,363,161,Tree -OSBS_029.tif,332,84,360,118,Tree -OSBS_029.tif,363,115,393,146,Tree -OSBS_029.tif,117,263,149,307,Tree -OSBS_029.tif,100,309,155,371,Tree -OSBS_029.tif,179,359,214,398,Tree -OSBS_029.tif,199,340,231,374,Tree -OSBS_029.tif,177,308,214,343,Tree -OSBS_029.tif,239,306,279,342,Tree -OSBS_029.tif,275,332,310,374,Tree -OSBS_029.tif,53,192,90,238,Tree -OSBS_029.tif,115,64,151,103,Tree -OSBS_029.tif,53,69,96,117,Tree -OSBS_029.tif,263,243,289,282,Tree -OSBS_029.tif,331,42,369,87,Tree -OSBS_029.tif,252,47,283,87,Tree -OSBS_029.tif,291,89,333,138,Tree -OSBS_029.tif,288,136,315,167,Tree -OSBS_029.tif,203,88,247,139,Tree -OSBS_029.tif,257,198,289,232,Tree -OSBS_029.tif,31,341,58,372,Tree -OSBS_029.tif,19,368,52,400,Tree -OSBS_029.tif,1,261,31,296,Tree -OSBS_029.tif,73,241,113,287,Tree -OSBS_029.tif,60,292,96,332,Tree -OSBS_029.tif,89,362,114,390,Tree -OSBS_029.tif,236,132,253,152,Tree -OSBS_029.tif,316,174,346,214,Tree -OSBS_029.tif,220,208,251,244,Tree +OSBS_029_0.png,90,117,121,145,Tree +OSBS_029_0.png,115,109,150,152,Tree +OSBS_029_0.png,161,155,199,191,Tree +OSBS_029_0.png,120,153,160,192,Tree +OSBS_029_0.png,1,153,53,200,Tree +OSBS_029_0.png,65,143,117,190,Tree +OSBS_029_0.png,149,95,200,156,Tree +OSBS_029_0.png,154,195,190,200,Tree +OSBS_029_0.png,1,13,39,62,Tree +OSBS_029_0.png,1,65,40,106,Tree +OSBS_029_0.png,50,3,102,57,Tree +OSBS_029_0.png,102,36,130,68,Tree +OSBS_029_0.png,156,5,180,38,Tree +OSBS_029_0.png,186,1,200,40,Tree +OSBS_029_0.png,53,192,90,200,Tree +OSBS_029_0.png,115,64,151,103,Tree +OSBS_029_0.png,53,69,96,117,Tree +OSBS_029_1.png,166,103,200,154,Tree +OSBS_029_1.png,161,5,199,41,Tree +OSBS_029_1.png,120,3,160,42,Tree +OSBS_029_1.png,1,3,53,67,Tree +OSBS_029_1.png,1,68,41,104,Tree +OSBS_029_1.png,65,0,117,40,Tree +OSBS_029_1.png,154,45,190,79,Tree +OSBS_029_1.png,103,45,153,94,Tree +OSBS_029_1.png,117,113,149,157,Tree +OSBS_029_1.png,100,159,155,200,Tree +OSBS_029_1.png,199,190,200,200,Tree +OSBS_029_1.png,177,158,200,193,Tree +OSBS_029_1.png,53,42,90,88,Tree +OSBS_029_1.png,31,191,58,200,Tree +OSBS_029_1.png,1,111,31,146,Tree +OSBS_029_1.png,73,91,113,137,Tree +OSBS_029_1.png,60,142,96,182,Tree +OSBS_029_2.png,166,53,200,104,Tree +OSBS_029_2.png,1,18,41,54,Tree +OSBS_029_2.png,154,0,190,29,Tree +OSBS_029_2.png,103,0,153,44,Tree +OSBS_029_2.png,49,177,75,200,Tree +OSBS_029_2.png,116,167,151,200,Tree +OSBS_029_2.png,117,63,149,107,Tree +OSBS_029_2.png,100,109,155,171,Tree +OSBS_029_2.png,179,159,200,198,Tree +OSBS_029_2.png,199,140,200,174,Tree +OSBS_029_2.png,177,108,200,143,Tree +OSBS_029_2.png,53,0,90,38,Tree +OSBS_029_2.png,31,141,58,172,Tree +OSBS_029_2.png,19,168,52,200,Tree +OSBS_029_2.png,1,61,31,96,Tree +OSBS_029_2.png,73,41,113,87,Tree +OSBS_029_2.png,60,92,96,132,Tree +OSBS_029_2.png,89,162,114,190,Tree +OSBS_029_3.png,53,67,77,90,Tree +OSBS_029_3.png,106,99,138,140,Tree +OSBS_029_3.png,162,13,199,47,Tree +OSBS_029_3.png,128,1,162,37,Tree +OSBS_029_3.png,11,155,49,191,Tree +OSBS_029_3.png,0,153,10,192,Tree +OSBS_029_3.png,0,95,53,156,Tree +OSBS_029_3.png,4,195,40,200,Tree +OSBS_029_3.png,6,5,30,38,Tree +OSBS_029_3.png,36,1,81,40,Tree +OSBS_029_3.png,181,128,200,161,Tree +OSBS_029_3.png,182,84,200,118,Tree +OSBS_029_3.png,0,64,1,103,Tree +OSBS_029_3.png,181,42,200,87,Tree +OSBS_029_3.png,102,47,133,87,Tree +OSBS_029_3.png,141,89,183,138,Tree +OSBS_029_3.png,138,136,165,167,Tree +OSBS_029_3.png,53,88,97,139,Tree +OSBS_029_3.png,107,198,139,200,Tree +OSBS_029_3.png,86,132,103,152,Tree +OSBS_029_3.png,166,174,196,200,Tree +OSBS_029_4.png,16,103,75,154,Tree +OSBS_029_4.png,11,5,49,41,Tree +OSBS_029_4.png,0,3,10,42,Tree +OSBS_029_4.png,199,140,200,170,Tree +OSBS_029_4.png,4,45,40,79,Tree +OSBS_029_4.png,84,93,103,117,Tree +OSBS_029_4.png,183,186,200,200,Tree +OSBS_029_4.png,181,0,200,11,Tree +OSBS_029_4.png,49,190,81,200,Tree +OSBS_029_4.png,27,158,64,193,Tree +OSBS_029_4.png,89,156,129,192,Tree +OSBS_029_4.png,125,182,160,200,Tree +OSBS_029_4.png,113,93,139,132,Tree +OSBS_029_4.png,138,0,165,17,Tree +OSBS_029_4.png,107,48,139,82,Tree +OSBS_029_4.png,86,0,103,2,Tree +OSBS_029_4.png,166,24,196,64,Tree +OSBS_029_4.png,70,58,101,94,Tree +OSBS_029_5.png,16,53,75,104,Tree +OSBS_029_5.png,199,90,200,120,Tree +OSBS_029_5.png,4,0,40,29,Tree +OSBS_029_5.png,0,167,1,200,Tree +OSBS_029_5.png,84,43,103,67,Tree +OSBS_029_5.png,142,167,187,200,Tree +OSBS_029_5.png,183,136,200,184,Tree +OSBS_029_5.png,29,159,64,198,Tree +OSBS_029_5.png,49,140,81,174,Tree +OSBS_029_5.png,27,108,64,143,Tree +OSBS_029_5.png,89,106,129,142,Tree +OSBS_029_5.png,125,132,160,174,Tree +OSBS_029_5.png,113,43,139,82,Tree +OSBS_029_5.png,107,0,139,32,Tree +OSBS_029_5.png,166,0,196,14,Tree +OSBS_029_5.png,70,8,101,44,Tree +OSBS_029_6.png,3,67,27,90,Tree +OSBS_029_6.png,56,99,88,140,Tree +OSBS_029_6.png,165,2,200,27,Tree +OSBS_029_6.png,112,13,149,47,Tree +OSBS_029_6.png,165,21,200,70,Tree +OSBS_029_6.png,78,1,112,37,Tree +OSBS_029_6.png,168,78,200,110,Tree +OSBS_029_6.png,0,1,31,40,Tree +OSBS_029_6.png,131,128,163,161,Tree +OSBS_029_6.png,132,84,160,118,Tree +OSBS_029_6.png,163,115,193,146,Tree +OSBS_029_6.png,131,42,169,87,Tree +OSBS_029_6.png,52,47,83,87,Tree +OSBS_029_6.png,91,89,133,138,Tree +OSBS_029_6.png,88,136,115,167,Tree +OSBS_029_6.png,3,88,47,139,Tree +OSBS_029_6.png,57,198,89,200,Tree +OSBS_029_6.png,36,132,53,152,Tree +OSBS_029_6.png,116,174,146,200,Tree +OSBS_029_7.png,0,103,25,154,Tree +OSBS_029_7.png,164,54,200,96,Tree +OSBS_029_7.png,149,140,175,170,Tree +OSBS_029_7.png,34,93,53,117,Tree +OSBS_029_7.png,133,186,174,200,Tree +OSBS_029_7.png,182,110,200,155,Tree +OSBS_029_7.png,131,0,163,11,Tree +OSBS_029_7.png,0,190,31,200,Tree +OSBS_029_7.png,0,158,14,193,Tree +OSBS_029_7.png,39,156,79,192,Tree +OSBS_029_7.png,75,182,110,200,Tree +OSBS_029_7.png,63,93,89,132,Tree +OSBS_029_7.png,88,0,115,17,Tree +OSBS_029_7.png,57,48,89,82,Tree +OSBS_029_7.png,36,0,53,2,Tree +OSBS_029_7.png,116,24,146,64,Tree +OSBS_029_7.png,20,58,51,94,Tree +OSBS_029_8.png,0,53,25,104,Tree +OSBS_029_8.png,164,4,200,46,Tree +OSBS_029_8.png,149,90,175,120,Tree +OSBS_029_8.png,34,43,53,67,Tree +OSBS_029_8.png,92,167,137,200,Tree +OSBS_029_8.png,133,136,174,184,Tree +OSBS_029_8.png,182,60,200,105,Tree +OSBS_029_8.png,0,159,14,198,Tree +OSBS_029_8.png,0,140,31,174,Tree +OSBS_029_8.png,0,108,14,143,Tree +OSBS_029_8.png,39,106,79,142,Tree +OSBS_029_8.png,75,132,110,174,Tree +OSBS_029_8.png,63,43,89,82,Tree +OSBS_029_8.png,57,0,89,32,Tree +OSBS_029_8.png,116,0,146,14,Tree +OSBS_029_8.png,20,8,51,44,Tree diff --git a/tests/test_FasterRCNN.py b/tests/test_FasterRCNN.py new file mode 100644 index 00000000..ad8c749f --- /dev/null +++ b/tests/test_FasterRCNN.py @@ -0,0 +1,52 @@ +#test FasterRCNN +from deepforest.models import FasterRCNN +from deepforest import get_data +import pytest +import numpy as np +import torch +import torchvision +import os + +os.environ['KMP_DUPLICATE_LIB_OK']='True' + +#Empty tester from https://github.com/datumbox/vision/blob/06ebee1a9f10c76d8ac5768fd578362dd5ace6e9/test/test_models_detection_negative_samples.py#L14 +def _make_empty_sample(): + images = [torch.rand((3, 100, 100), dtype=torch.float32)] + boxes = torch.zeros((0, 4), dtype=torch.float32) + negative_target = {"boxes": boxes, + "labels": torch.zeros(0, dtype=torch.int64), + "image_id": 4, + "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), + "iscrowd": torch.zeros((0,), dtype=torch.int64)} + + targets = [negative_target] + return images, targets + +def test_retinanet(config): + r = FasterRCNN.Model(config) + + return r + +def test_load_backbone(config): + r = FasterRCNN.Model(config) + resnet_backbone = r.load_backbone() + resnet_backbone.eval() + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + prediction = resnet_backbone(x) + +# This test still fails, do we want a way to pass kwargs directly to method, instead of being limited by config structure? +# Need to create issue when I get online. +@pytest.mark.parametrize("num_classes",[1,2,10]) +def test_create_model(config, num_classes): + config["num_classes"] = num_classes + retinanet_model = FasterRCNN.Model(config).create_model() + retinanet_model.eval() + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + predictions = retinanet_model(x) + +def test_forward_empty(config): + r = FasterRCNN.Model(config) + model = r.create_model() + image, targets = _make_empty_sample() + loss = model(image, targets) + assert torch.equal(loss["loss_box_reg"], torch.tensor(0.)) diff --git a/tests/test_main.py b/tests/test_main.py index fc5b5f6f..4e5b4fd6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -23,6 +23,58 @@ #Import release model from global script to avoid thrasing github during testing. Just download once. from .conftest import download_release +@pytest.fixture() +def two_class_m(): + m = main.deepforest(num_classes=2,label_dict={"Alive":0,"Dead":1}) + m.config["train"]["csv_file"] = get_data("testfile_multi.csv") + m.config["train"]["root_dir"] = os.path.dirname(get_data("testfile_multi.csv")) + m.config["train"]["fast_dev_run"] = True + m.config["batch_size"] = 2 + + m.config["validation"]["csv_file"] = get_data("testfile_multi.csv") + m.config["validation"]["root_dir"] = os.path.dirname(get_data("testfile_multi.csv")) + m.config["validation"]["val_accuracy_interval"] = 1 + + m.create_trainer() + + return m + +@pytest.fixture() +def m(download_release): + m = main.deepforest() + m.config["train"]["csv_file"] = get_data("example.csv") + m.config["train"]["root_dir"] = os.path.dirname(get_data("example.csv")) + m.config["train"]["fast_dev_run"] = True + m.config["batch_size"] = 2 + + m.config["validation"]["csv_file"] = get_data("example.csv") + m.config["validation"]["root_dir"] = os.path.dirname(get_data("example.csv")) + m.config["workers"] = 0 + m.config["validation"]["val_accuracy_interval"] = 1 + m.config["train"]["epochs"] = 2 + + m.create_trainer() + m.use_release(check_release=False) + + return m +@pytest.fixture() + +def m_without_release(): + m = main.deepforest() + m.config["train"]["csv_file"] = get_data("example.csv") + m.config["train"]["root_dir"] = os.path.dirname(get_data("example.csv")) + m.config["train"]["fast_dev_run"] = True + m.config["batch_size"] = 2 + + m.config["validation"]["csv_file"] = get_data("example.csv") + m.config["validation"]["root_dir"] = os.path.dirname(get_data("example.csv")) + m.config["workers"] = 0 + m.config["validation"]["val_accuracy_interval"] = 1 + m.config["train"]["epochs"] = 2 + + m.create_trainer() + + return m @pytest.fixture() def raster_path(): return get_data(path='OSBS_029.tif') @@ -54,8 +106,7 @@ def test_use_bird_release(m): m.use_bird_release() boxes = m.predict_image(path=imgpath) assert not boxes.empty - assert boxes.label.unique() == "Bird" - + def test_train_empty(m, tmpdir): empty_csv = pd.DataFrame({"image_path":["OSBS_029.png","OSBS_029.tif"],"xmin":[0,10],"xmax":[0,20],"ymin":[0,20],"ymax":[0,30],"label":["Tree","Tree"]}) empty_csv.to_csv("{}/empty.csv".format(tmpdir)) @@ -64,14 +115,6 @@ def test_train_empty(m, tmpdir): m.create_trainer() m.trainer.fit(m) -# If the user forgets to set csv_file, yield an informative message. -def test_no_csv_file_train(m): - m.config["train"]["csv_file"] = None - m.config["batch_size"] = 1 - m.create_trainer() - with pytest.raises(AttributeError) as e: - m.trainer.fit(m) - def test_validation_step(m): m.trainer = None #Turn off trainer to test copying on some linux devices. @@ -82,10 +125,14 @@ def test_validation_step(m): for p1, p2 in zip(before.named_parameters(), m.named_parameters()): assert p1[1].ne(p2[1]).sum() == 0 -def test_train_single(m): - m.config["train"]["fast_dev_run"] = False - m.create_trainer() - m.trainer.fit(m) +# Test train with each architecture +@pytest.mark.parametrize("architecture",["retinanet","FasterRCNN"]) +def test_train_single(m_without_release, architecture): + m_without_release.config["architecture"] = architecture + m_without_release.create_model() + m_without_release.config["train"]["fast_dev_run"] = False + m_without_release.create_trainer() + m_without_release.trainer.fit(m_without_release) def test_train_preload_images(m): m.create_trainer() @@ -163,8 +210,9 @@ def test_predict_dataloader(m, batch_size, raster_path): dl = m.predict_dataloader(ds) batch = next(iter(dl)) batch.shape[0] == batch_size - + def test_predict_tile(m, raster_path): + m.create_model() m.config["train"]["fast_dev_run"] = False m.create_trainer() prediction = m.predict_tile(raster_path = raster_path, @@ -268,12 +316,8 @@ def on_train_end(self, trainer, pl_module): trainer = Trainer(fast_dev_run=True) trainer.fit(m, train_ds) -def test_custom_config_file_path(tmpdir): - print(os.getcwd()) - m = main.deepforest(config_file='tests/deepforest_config_test.yml') - assert m.config["batch_size"] == 9999 - assert m.config["nms_thresh"] == 0.9 - assert m.config["score_thresh"] == 0.9 +def test_custom_config_file_path(ROOT, tmpdir): + m = main.deepforest(config_file='{}/deepforest_config.yml'.format(os.path.dirname(ROOT))) def test_save_and_reload_checkpoint(m, tmpdir): img_path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png") @@ -320,7 +364,7 @@ def test_reload_multi_class(two_class_m, tmpdir): #reload old_model = main.deepforest.load_from_checkpoint("{}/checkpoint.pl".format(tmpdir)) old_model.config = two_class_m.config - assert old_model.num_classes == 2 + assert old_model.config["num_classes"] == 2 old_model.create_trainer() after = old_model.trainer.validate(old_model) @@ -354,10 +398,12 @@ def test_over_score_thresh(m): """A user might want to change the config after model training and update the score thresh""" img = get_data("OSBS_029.png") original_score_thresh = m.model.score_thresh - m.config["score_thresh"] = 0.8 + m.model.score_thresh = 0.8 #trigger update boxes = m.predict_image(path = img) + + assert all(boxes.score > 0.8) assert m.model.score_thresh == 0.8 assert not m.model.score_thresh == original_score_thresh diff --git a/tests/test_model.py b/tests/test_model.py index 2fa84325..d9251879 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,51 +1,10 @@ -#test model -from deepforest import model import pytest -import numpy as np import torch -import torchvision -import os - -os.environ['KMP_DUPLICATE_LIB_OK']='True' - -#Empty tester from https://github.com/datumbox/vision/blob/06ebee1a9f10c76d8ac5768fd578362dd5ace6e9/test/test_models_detection_negative_samples.py#L14 -def _make_empty_sample(): - images = [torch.rand((3, 100, 100), dtype=torch.float32)] - boxes = torch.zeros((0, 4), dtype=torch.float32) - negative_target = {"boxes": boxes, - "labels": torch.zeros(0, dtype=torch.int64), - "image_id": 4, - "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), - "iscrowd": torch.zeros((0,), dtype=torch.int64)} - - targets = [negative_target] - return images, targets - -def test_load_backbone(): - retinanet = model.load_backbone() - retinanet.eval() - x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - prediction = retinanet(x) - -@pytest.mark.parametrize("num_classes",[1,2,10]) -def test_create_model(num_classes): - retinanet_model = model.create_model(num_classes=2,nms_thresh=0.1, score_thresh=0.2) - - retinanet_model.eval() - x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - predictions = retinanet_model(x) - -def test_forward_empty(): - retinanet_model = model.create_model(num_classes=2,nms_thresh=0.1, score_thresh=0.2) - image, targets = _make_empty_sample() - loss = retinanet_model(image, targets) - assert torch.equal(loss["bbox_regression"], torch.tensor(0.)) - -def test_forward_negative_sample_retinanet(): - model = torchvision.models.detection.retinanet_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) +from deepforest import model - images, targets = _make_empty_sample() - loss_dict = model(images, targets) +# The model object is achitecture agnostic container. +def test_model_no_args(config): + with pytest.raises(ValueError): + model.Model(config) - assert torch.equal(loss_dict["bbox_regression"], torch.tensor(0.)) \ No newline at end of file + \ No newline at end of file diff --git a/tests/test_retinanet.py b/tests/test_retinanet.py new file mode 100644 index 00000000..7b10df68 --- /dev/null +++ b/tests/test_retinanet.py @@ -0,0 +1,71 @@ +#test retinanet +from deepforest.models import retinanet +from deepforest import get_data +import pytest +import numpy as np +import torch +import torchvision +import os +from torchvision.models import resnet50, ResNet50_Weights +from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights + + +os.environ['KMP_DUPLICATE_LIB_OK']='True' + +#Empty tester from https://github.com/datumbox/vision/blob/06ebee1a9f10c76d8ac5768fd578362dd5ace6e9/test/test_models_detection_negative_samples.py#L14 +def _make_empty_sample(): + images = [torch.rand((3, 100, 100), dtype=torch.float32)] + boxes = torch.zeros((0, 4), dtype=torch.float32) + negative_target = {"boxes": boxes, + "labels": torch.zeros(0, dtype=torch.int64), + "image_id": 4, + "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), + "iscrowd": torch.zeros((0,), dtype=torch.int64)} + + targets = [negative_target] + return images, targets + +def test_retinanet(config): + r = retinanet.Model(config) + + return r + +def test_load_backbone(config): + r = retinanet.Model(config) + resnet_backbone = r.load_backbone() + resnet_backbone.eval() + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + prediction = resnet_backbone(x) + +# This test still fails, do we want a way to pass kwargs directly to method, instead of being limited by config structure? +# Need to create issue when I get online. +@pytest.mark.parametrize("num_classes",[1,2,10]) +def test_create_model(config, num_classes): + config["num_classes"] = num_classes + retinanet_model = retinanet.Model(config).create_model() + retinanet_model.eval() + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + predictions = retinanet_model(x) + +def test_forward_empty(config): + r = retinanet.Model(config) + model = r.create_model() + image, targets = _make_empty_sample() + loss = model(image, targets) + assert torch.equal(loss["bbox_regression"], torch.tensor(0.)) + +# Can we update parameters after training +def test_mantain_parameters(config): + config["retinanet"]["score_thresh"] = 0.4 + retinanet_model = retinanet.Model(config).create_model() + assert retinanet_model.score_thresh == config["retinanet"]["score_thresh"] + retinanet_model.eval() + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + predictions = retinanet_model(x) + assert retinanet_model.score_thresh == config["retinanet"]["score_thresh"] + + retinanet_model.score_thresh = 0.9 + retinanet_model.eval() + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + predictions = retinanet_model(x) + assert retinanet_model.score_thresh == 0.9 \ No newline at end of file