From 880a4c3d9acea2cae317fa6607189b6e38b63567 Mon Sep 17 00:00:00 2001 From: bw4sz Date: Fri, 1 Dec 2023 19:51:23 -0800 Subject: [PATCH] style tests --- deepforest/callbacks.py | 18 ++-- deepforest/evaluate.py | 55 +++++----- deepforest/main.py | 45 ++++---- deepforest/predict.py | 8 +- deepforest/visualize.py | 7 +- tests/data/OSBS_029.csv | 220 +++++++++++----------------------------- tests/test_IoU.py | 2 +- 7 files changed, 139 insertions(+), 216 deletions(-) diff --git a/deepforest/callbacks.py b/deepforest/callbacks.py index 072c4729..65872023 100644 --- a/deepforest/callbacks.py +++ b/deepforest/callbacks.py @@ -32,7 +32,13 @@ class images_callback(Callback): None: either prints validation scores or logs them to the pytorch-lightning logger """ - def __init__(self, savedir, n=2, every_n_epochs=5, select_random=False, color=None, thickness=1): + def __init__(self, + savedir, + n=2, + every_n_epochs=5, + select_random=False, + color=None, + thickness=1): self.savedir = savedir self.n = n self.color = color @@ -43,19 +49,19 @@ def __init__(self, savedir, n=2, every_n_epochs=5, select_random=False, color=No def log_images(self, pl_module): # It is not clear if this is per device, or per batch. If per batch, then this will not work. df = pl_module.predictions[0] - + # limit to n images, potentially randomly selected if self.select_random: selected_images = np.random.choice(df.image_path.unique(), self.n) else: selected_images = df.image_path.unique()[:self.n] df = df[df.image_path.isin(selected_images)] - + visualize.plot_prediction_dataframe( df, root_dir=pl_module.config["validation"]["root_dir"], - savedir=self.savedir, - color=self.color, + savedir=self.savedir, + color=self.color, thickness=self.thickness) try: @@ -73,4 +79,4 @@ def on_validation_epoch_end(self, trainer, pl_module): if trainer.current_epoch % self.every_n_epochs == 0: print("Running image callback") - self.log_images(pl_module) \ No newline at end of file + self.log_images(pl_module) diff --git a/deepforest/evaluate.py b/deepforest/evaluate.py index bb836e1f..07670a8f 100644 --- a/deepforest/evaluate.py +++ b/deepforest/evaluate.py @@ -90,8 +90,14 @@ def compute_class_recall(results): return class_recall -def __evaluate_wrapper__(predictions, ground_df, root_dir, iou_threshold, numeric_to_label_dict, savedir=None): - """Evaluate a set of predictions against a ground truth csv file + +def __evaluate_wrapper__(predictions, + ground_df, + root_dir, + iou_threshold, + numeric_to_label_dict, + savedir=None): + """Evaluate a set of predictions against a ground truth csv file Args: predictions: a pandas dataframe, if supplied a root dir is needed to give the relative path of files in df.name. The labels in ground truth and predictions must match. If one is numeric, the other must be numeric. csv_file: a csv file with columns xmin, ymin, xmax, ymax, label, image_path @@ -101,28 +107,29 @@ def __evaluate_wrapper__(predictions, ground_df, root_dir, iou_threshold, numeri Returns: results: a dictionary of results with keys, results, box_recall, box_precision, class_recall """ - # remove empty samples from ground truth - ground_df = ground_df[~((ground_df.xmin == 0) & (ground_df.xmax == 0))] - - results = evaluate(predictions=predictions, - ground_df=ground_df, - root_dir=root_dir, - iou_threshold=iou_threshold, - savedir=savedir) - - # replace classes if not NUll - if not results is None: - results["results"]["predicted_label"] = results["results"][ - "predicted_label"].apply(lambda x: numeric_to_label_dict[x] - if not pd.isnull(x) else x) - results["results"]["true_label"] = results["results"]["true_label"].apply( - lambda x: numeric_to_label_dict[x]) - results["predictions"] = predictions - results["predictions"]["label"] = results["predictions"]["label"].apply( - lambda x: numeric_to_label_dict[x]) - - return results - + # remove empty samples from ground truth + ground_df = ground_df[~((ground_df.xmin == 0) & (ground_df.xmax == 0))] + + results = evaluate(predictions=predictions, + ground_df=ground_df, + root_dir=root_dir, + iou_threshold=iou_threshold, + savedir=savedir) + + # replace classes if not NUll + if not results is None: + results["results"]["predicted_label"] = results["results"][ + "predicted_label"].apply(lambda x: numeric_to_label_dict[x] + if not pd.isnull(x) else x) + results["results"]["true_label"] = results["results"]["true_label"].apply( + lambda x: numeric_to_label_dict[x]) + results["predictions"] = predictions + results["predictions"]["label"] = results["predictions"]["label"].apply( + lambda x: numeric_to_label_dict[x]) + + return results + + def evaluate(predictions, ground_df, root_dir, iou_threshold=0.4, savedir=None): """Image annotated crown evaluation routine submission can be submitted as a .shp, existing pandas dataframe or .csv path diff --git a/deepforest/main.py b/deepforest/main.py index 8d927da7..e3b10616 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -197,7 +197,7 @@ def create_trainer(self, logger=None, callbacks=[], **kwargs): limit_val_batches=limit_val_batches, num_sanity_val_steps=num_sanity_val_steps, **kwargs) - + def on_fit_start(self): if self.config["train"]["csv_file"] is None: raise AttributeError( @@ -380,22 +380,21 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1 """ df = pd.read_csv(csv_file) ds = dataset.TreeDataset(csv_file=csv_file, - root_dir=root_dir, - transforms=None, - train=False) + root_dir=root_dir, + transforms=None, + train=False) dataloader = self.predict_dataloader(ds) - - results = predict.predict_file( - model=self, - trainer=self.trainer, - annotations=df, - dataloader=dataloader, - root_dir=root_dir, - nms_thresh=self.config["nms_thresh"], - savedir=savedir, - color=color, - thickness=thickness) - + + results = predict.predict_file(model=self, + trainer=self.trainer, + annotations=df, + dataloader=dataloader, + root_dir=root_dir, + nms_thresh=self.config["nms_thresh"], + savedir=savedir, + color=color, + thickness=thickness) + return results def predict_tile(self, @@ -558,7 +557,7 @@ def on_validation_epoch_end(self): output = {key: value for key, value in output.items() if not key == "classes"} self.log_dict(output) self.mAP_metric.reset() - + # Evaluate on validation data predictions self.predictions_df = pd.concat(self.predictions) ground_df = pd.read_csv(self.config["validation"]["csv_file"]) @@ -572,17 +571,17 @@ def on_validation_epoch_end(self): root_dir=self.config["validation"]["root_dir"], iou_threshold=self.config["validation"]["iou_threshold"], savedir=None, - numeric_to_label_dict=self.numeric_to_label_dict - ) - + numeric_to_label_dict=self.numeric_to_label_dict) + self.log("box_recall", results["box_recall"]) self.log("box_precision", results["box_precision"]) if isinstance(results, pd.DataFrame): for index, row in results["class_recall"].iterrows(): + self.log("{}_Recall".format(self.numeric_to_label_dict[row["label"]]), + row["recall"]) self.log( - "{}_Recall".format(self.numeric_to_label_dict[row["label"]]),row["recall"]) - self.log( - "{}_Precision".format(self.numeric_to_label_dict[row["label"]]),row["precision"]) + "{}_Precision".format(self.numeric_to_label_dict[row["label"]]), + row["precision"]) def predict_step(self, batch, batch_idx): batch_results = self.model(batch) diff --git a/deepforest/predict.py b/deepforest/predict.py index 751eb9d6..3ea22fb9 100644 --- a/deepforest/predict.py +++ b/deepforest/predict.py @@ -188,6 +188,10 @@ def predict_file(trainer, results = pd.concat(results, ignore_index=True) if savedir: - visualize.plot_prediction_dataframe(results,root_dir=root_dir, savedir=savedir, color=color, thickness=thickness) - + visualize.plot_prediction_dataframe(results, + root_dir=root_dir, + savedir=savedir, + color=color, + thickness=thickness) + return results diff --git a/deepforest/visualize.py b/deepforest/visualize.py index 1b4a6ef4..37a7a0e1 100644 --- a/deepforest/visualize.py +++ b/deepforest/visualize.py @@ -69,7 +69,12 @@ def plot_prediction_and_targets(image, predictions, targets, image_name, savedir return figure_path -def plot_prediction_dataframe(df, root_dir, savedir, color=None, thickness=1, ground_truth=None): +def plot_prediction_dataframe(df, + root_dir, + savedir, + color=None, + thickness=1, + ground_truth=None): """For each row in dataframe, call plot predictions and save plot files to disk. For multi-class labels, boxes will be colored by labels. Ground truth boxes will all be same color, regardless of class. Args: diff --git a/tests/data/OSBS_029.csv b/tests/data/OSBS_029.csv index a1ee76f1..da5f1754 100644 --- a/tests/data/OSBS_029.csv +++ b/tests/data/OSBS_029.csv @@ -1,160 +1,62 @@ image_path,xmin,ymin,xmax,ymax,label -OSBS_029_0.png,90,117,121,145,Tree -OSBS_029_0.png,115,109,150,152,Tree -OSBS_029_0.png,161,155,199,191,Tree -OSBS_029_0.png,120,153,160,192,Tree -OSBS_029_0.png,1,153,53,200,Tree -OSBS_029_0.png,65,143,117,190,Tree -OSBS_029_0.png,149,95,200,156,Tree -OSBS_029_0.png,154,195,190,200,Tree -OSBS_029_0.png,1,13,39,62,Tree -OSBS_029_0.png,1,65,40,106,Tree -OSBS_029_0.png,50,3,102,57,Tree -OSBS_029_0.png,102,36,130,68,Tree -OSBS_029_0.png,156,5,180,38,Tree -OSBS_029_0.png,186,1,200,40,Tree -OSBS_029_0.png,53,192,90,200,Tree -OSBS_029_0.png,115,64,151,103,Tree -OSBS_029_0.png,53,69,96,117,Tree -OSBS_029_1.png,166,103,200,154,Tree -OSBS_029_1.png,161,5,199,41,Tree -OSBS_029_1.png,120,3,160,42,Tree -OSBS_029_1.png,1,3,53,67,Tree -OSBS_029_1.png,1,68,41,104,Tree -OSBS_029_1.png,65,0,117,40,Tree -OSBS_029_1.png,154,45,190,79,Tree -OSBS_029_1.png,103,45,153,94,Tree -OSBS_029_1.png,117,113,149,157,Tree -OSBS_029_1.png,100,159,155,200,Tree -OSBS_029_1.png,199,190,200,200,Tree -OSBS_029_1.png,177,158,200,193,Tree -OSBS_029_1.png,53,42,90,88,Tree -OSBS_029_1.png,31,191,58,200,Tree -OSBS_029_1.png,1,111,31,146,Tree -OSBS_029_1.png,73,91,113,137,Tree -OSBS_029_1.png,60,142,96,182,Tree -OSBS_029_2.png,166,53,200,104,Tree -OSBS_029_2.png,1,18,41,54,Tree -OSBS_029_2.png,154,0,190,29,Tree -OSBS_029_2.png,103,0,153,44,Tree -OSBS_029_2.png,49,177,75,200,Tree -OSBS_029_2.png,116,167,151,200,Tree -OSBS_029_2.png,117,63,149,107,Tree -OSBS_029_2.png,100,109,155,171,Tree -OSBS_029_2.png,179,159,200,198,Tree -OSBS_029_2.png,199,140,200,174,Tree -OSBS_029_2.png,177,108,200,143,Tree -OSBS_029_2.png,53,0,90,38,Tree -OSBS_029_2.png,31,141,58,172,Tree -OSBS_029_2.png,19,168,52,200,Tree -OSBS_029_2.png,1,61,31,96,Tree -OSBS_029_2.png,73,41,113,87,Tree -OSBS_029_2.png,60,92,96,132,Tree -OSBS_029_2.png,89,162,114,190,Tree -OSBS_029_3.png,53,67,77,90,Tree -OSBS_029_3.png,106,99,138,140,Tree -OSBS_029_3.png,162,13,199,47,Tree -OSBS_029_3.png,128,1,162,37,Tree -OSBS_029_3.png,11,155,49,191,Tree -OSBS_029_3.png,0,153,10,192,Tree -OSBS_029_3.png,0,95,53,156,Tree -OSBS_029_3.png,4,195,40,200,Tree -OSBS_029_3.png,6,5,30,38,Tree -OSBS_029_3.png,36,1,81,40,Tree -OSBS_029_3.png,181,128,200,161,Tree -OSBS_029_3.png,182,84,200,118,Tree -OSBS_029_3.png,0,64,1,103,Tree -OSBS_029_3.png,181,42,200,87,Tree -OSBS_029_3.png,102,47,133,87,Tree -OSBS_029_3.png,141,89,183,138,Tree -OSBS_029_3.png,138,136,165,167,Tree -OSBS_029_3.png,53,88,97,139,Tree -OSBS_029_3.png,107,198,139,200,Tree -OSBS_029_3.png,86,132,103,152,Tree -OSBS_029_3.png,166,174,196,200,Tree -OSBS_029_4.png,16,103,75,154,Tree -OSBS_029_4.png,11,5,49,41,Tree -OSBS_029_4.png,0,3,10,42,Tree -OSBS_029_4.png,199,140,200,170,Tree -OSBS_029_4.png,4,45,40,79,Tree -OSBS_029_4.png,84,93,103,117,Tree -OSBS_029_4.png,183,186,200,200,Tree -OSBS_029_4.png,181,0,200,11,Tree -OSBS_029_4.png,49,190,81,200,Tree -OSBS_029_4.png,27,158,64,193,Tree -OSBS_029_4.png,89,156,129,192,Tree -OSBS_029_4.png,125,182,160,200,Tree -OSBS_029_4.png,113,93,139,132,Tree -OSBS_029_4.png,138,0,165,17,Tree -OSBS_029_4.png,107,48,139,82,Tree -OSBS_029_4.png,86,0,103,2,Tree -OSBS_029_4.png,166,24,196,64,Tree -OSBS_029_4.png,70,58,101,94,Tree -OSBS_029_5.png,16,53,75,104,Tree -OSBS_029_5.png,199,90,200,120,Tree -OSBS_029_5.png,4,0,40,29,Tree -OSBS_029_5.png,0,167,1,200,Tree -OSBS_029_5.png,84,43,103,67,Tree -OSBS_029_5.png,142,167,187,200,Tree -OSBS_029_5.png,183,136,200,184,Tree -OSBS_029_5.png,29,159,64,198,Tree -OSBS_029_5.png,49,140,81,174,Tree -OSBS_029_5.png,27,108,64,143,Tree -OSBS_029_5.png,89,106,129,142,Tree -OSBS_029_5.png,125,132,160,174,Tree -OSBS_029_5.png,113,43,139,82,Tree -OSBS_029_5.png,107,0,139,32,Tree -OSBS_029_5.png,166,0,196,14,Tree -OSBS_029_5.png,70,8,101,44,Tree -OSBS_029_6.png,3,67,27,90,Tree -OSBS_029_6.png,56,99,88,140,Tree -OSBS_029_6.png,165,2,200,27,Tree -OSBS_029_6.png,112,13,149,47,Tree -OSBS_029_6.png,165,21,200,70,Tree -OSBS_029_6.png,78,1,112,37,Tree -OSBS_029_6.png,168,78,200,110,Tree -OSBS_029_6.png,0,1,31,40,Tree -OSBS_029_6.png,131,128,163,161,Tree -OSBS_029_6.png,132,84,160,118,Tree -OSBS_029_6.png,163,115,193,146,Tree -OSBS_029_6.png,131,42,169,87,Tree -OSBS_029_6.png,52,47,83,87,Tree -OSBS_029_6.png,91,89,133,138,Tree -OSBS_029_6.png,88,136,115,167,Tree -OSBS_029_6.png,3,88,47,139,Tree -OSBS_029_6.png,57,198,89,200,Tree -OSBS_029_6.png,36,132,53,152,Tree -OSBS_029_6.png,116,174,146,200,Tree -OSBS_029_7.png,0,103,25,154,Tree -OSBS_029_7.png,164,54,200,96,Tree -OSBS_029_7.png,149,140,175,170,Tree -OSBS_029_7.png,34,93,53,117,Tree -OSBS_029_7.png,133,186,174,200,Tree -OSBS_029_7.png,182,110,200,155,Tree -OSBS_029_7.png,131,0,163,11,Tree -OSBS_029_7.png,0,190,31,200,Tree -OSBS_029_7.png,0,158,14,193,Tree -OSBS_029_7.png,39,156,79,192,Tree -OSBS_029_7.png,75,182,110,200,Tree -OSBS_029_7.png,63,93,89,132,Tree -OSBS_029_7.png,88,0,115,17,Tree -OSBS_029_7.png,57,48,89,82,Tree -OSBS_029_7.png,36,0,53,2,Tree -OSBS_029_7.png,116,24,146,64,Tree -OSBS_029_7.png,20,58,51,94,Tree -OSBS_029_8.png,0,53,25,104,Tree -OSBS_029_8.png,164,4,200,46,Tree -OSBS_029_8.png,149,90,175,120,Tree -OSBS_029_8.png,34,43,53,67,Tree -OSBS_029_8.png,92,167,137,200,Tree -OSBS_029_8.png,133,136,174,184,Tree -OSBS_029_8.png,182,60,200,105,Tree -OSBS_029_8.png,0,159,14,198,Tree -OSBS_029_8.png,0,140,31,174,Tree -OSBS_029_8.png,0,108,14,143,Tree -OSBS_029_8.png,39,106,79,142,Tree -OSBS_029_8.png,75,132,110,174,Tree -OSBS_029_8.png,63,43,89,82,Tree -OSBS_029_8.png,57,0,89,32,Tree -OSBS_029_8.png,116,0,146,14,Tree -OSBS_029_8.png,20,8,51,44,Tree +OSBS_029.tif,203,67,227,90,Tree +OSBS_029.tif,256,99,288,140,Tree +OSBS_029.tif,166,253,225,304,Tree +OSBS_029.tif,365,2,400,27,Tree +OSBS_029.tif,312,13,349,47,Tree +OSBS_029.tif,365,21,400,70,Tree +OSBS_029.tif,278,1,312,37,Tree +OSBS_029.tif,364,204,400,246,Tree +OSBS_029.tif,90,117,121,145,Tree +OSBS_029.tif,115,109,150,152,Tree +OSBS_029.tif,161,155,199,191,Tree +OSBS_029.tif,120,153,160,192,Tree +OSBS_029.tif,349,290,375,320,Tree +OSBS_029.tif,1,153,53,217,Tree +OSBS_029.tif,1,218,41,254,Tree +OSBS_029.tif,65,143,117,190,Tree +OSBS_029.tif,368,78,400,110,Tree +OSBS_029.tif,149,95,203,156,Tree +OSBS_029.tif,154,195,190,229,Tree +OSBS_029.tif,103,195,153,244,Tree +OSBS_029.tif,49,377,75,400,Tree +OSBS_029.tif,116,367,151,400,Tree +OSBS_029.tif,234,243,253,267,Tree +OSBS_029.tif,292,367,337,400,Tree +OSBS_029.tif,333,336,374,384,Tree +OSBS_029.tif,1,13,39,62,Tree +OSBS_029.tif,1,65,40,106,Tree +OSBS_029.tif,50,3,102,57,Tree +OSBS_029.tif,102,36,130,68,Tree +OSBS_029.tif,156,5,180,38,Tree +OSBS_029.tif,186,1,231,40,Tree +OSBS_029.tif,382,260,400,305,Tree +OSBS_029.tif,331,128,363,161,Tree +OSBS_029.tif,332,84,360,118,Tree +OSBS_029.tif,363,115,393,146,Tree +OSBS_029.tif,117,263,149,307,Tree +OSBS_029.tif,100,309,155,371,Tree +OSBS_029.tif,179,359,214,398,Tree +OSBS_029.tif,199,340,231,374,Tree +OSBS_029.tif,177,308,214,343,Tree +OSBS_029.tif,239,306,279,342,Tree +OSBS_029.tif,275,332,310,374,Tree +OSBS_029.tif,53,192,90,238,Tree +OSBS_029.tif,115,64,151,103,Tree +OSBS_029.tif,53,69,96,117,Tree +OSBS_029.tif,263,243,289,282,Tree +OSBS_029.tif,331,42,369,87,Tree +OSBS_029.tif,252,47,283,87,Tree +OSBS_029.tif,291,89,333,138,Tree +OSBS_029.tif,288,136,315,167,Tree +OSBS_029.tif,203,88,247,139,Tree +OSBS_029.tif,257,198,289,232,Tree +OSBS_029.tif,31,341,58,372,Tree +OSBS_029.tif,19,368,52,400,Tree +OSBS_029.tif,1,261,31,296,Tree +OSBS_029.tif,73,241,113,287,Tree +OSBS_029.tif,60,292,96,332,Tree +OSBS_029.tif,89,362,114,390,Tree +OSBS_029.tif,236,132,253,152,Tree +OSBS_029.tif,316,174,346,214,Tree +OSBS_029.tif,220,208,251,244,Tree diff --git a/tests/test_IoU.py b/tests/test_IoU.py index d5aa1731..287030f0 100644 --- a/tests/test_IoU.py +++ b/tests/test_IoU.py @@ -10,7 +10,7 @@ import geopandas as gpd import pandas as pd -def test_compute_IoU(m, download_release, tmpdir): +def test_compute_IoU(m, tmpdir): csv_file = get_data("OSBS_029.csv") predictions = m.predict_file(csv_file=csv_file, root_dir=os.path.dirname(csv_file)) ground_truth = pd.read_csv(csv_file)