From c7e70aef80476bbcbacbccf04a6f04ac57d84498 Mon Sep 17 00:00:00 2001 From: Rushiraj Gadhvi Date: Mon, 13 Jan 2025 03:01:37 +0530 Subject: [PATCH] Add Axis Return to visualize.plot_results (#880) * Added axes arg default axes=False --- src/deepforest/visualize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/deepforest/visualize.py b/src/deepforest/visualize.py index 51097d15..d32b41cf 100644 --- a/src/deepforest/visualize.py +++ b/src/deepforest/visualize.py @@ -471,7 +471,8 @@ def plot_results(results, thickness=2, basename=None, radius=3, - image=None): + image=None, + axes=False): """Plot the prediction results. Args: @@ -486,8 +487,9 @@ def plot_results(results, basename: optional basename for the saved figure. If None (default), the basename will be extracted from the image path. radius: radius of the points in px image: an optional numpy array of an image to annotate. If None (default), the image will be loaded from the results dataframe. + axes: returns matplotlib axes object if True Returns: - None + Matplotlib axes object if axes=True, otherwise None """ # Convert colors, check for multi-class labels num_labels = len(results.label.unique()) @@ -530,6 +532,8 @@ def plot_results(results, else: # Display the image using Matplotlib plt.imshow(annotated_scene) + if axes: + return ax plt.axis('off') # Hide axes for a cleaner look plt.show()