Skip to content

Commit

Permalink
Add Axis Return to visualize.plot_results (#880)
Browse files Browse the repository at this point in the history
* Added axes arg default axes=False
  • Loading branch information
gadhvirushiraj authored Jan 12, 2025
1 parent 676289a commit c7e70ae
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/deepforest/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,8 @@ def plot_results(results,
thickness=2,
basename=None,
radius=3,
image=None):
image=None,
axes=False):
"""Plot the prediction results.
Args:
Expand All @@ -486,8 +487,9 @@ def plot_results(results,
basename: optional basename for the saved figure. If None (default), the basename will be extracted from the image path.
radius: radius of the points in px
image: an optional numpy array of an image to annotate. If None (default), the image will be loaded from the results dataframe.
axes: returns matplotlib axes object if True
Returns:
None
Matplotlib axes object if axes=True, otherwise None
"""
# Convert colors, check for multi-class labels
num_labels = len(results.label.unique())
Expand Down Expand Up @@ -530,6 +532,8 @@ def plot_results(results,
else:
# Display the image using Matplotlib
plt.imshow(annotated_scene)
if axes:
return ax
plt.axis('off') # Hide axes for a cleaner look
plt.show()

Expand Down

0 comments on commit c7e70ae

Please sign in to comment.