Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve checks and tests for preprocess.py #564

Merged
merged 8 commits into from
Dec 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepforest/__init__.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
"""Top-level package for DeepForest."""
__author__ = """Ben Weinstein"""
__email__ = 'ben.weinstein@weecology.org'
__version__ = '1.3.0'
__version__ = '1.3.1'

import os

2 changes: 1 addition & 1 deletion deepforest/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.3.0'
__version__ = '1.3.1'
16 changes: 8 additions & 8 deletions deepforest/main.py
Original file line number Diff line number Diff line change
@@ -372,14 +372,14 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
dataloader = self.predict_dataloader(ds)

results = predict._dataloader_wrapper_(model=self,
trainer=self.trainer,
annotations=df,
dataloader=dataloader,
root_dir=root_dir,
nms_thresh=self.config["nms_thresh"],
savedir=savedir,
color=color,
thickness=thickness)
trainer=self.trainer,
annotations=df,
dataloader=dataloader,
root_dir=root_dir,
nms_thresh=self.config["nms_thresh"],
savedir=savedir,
color=color,
thickness=thickness)

return results

16 changes: 8 additions & 8 deletions deepforest/predict.py
Original file line number Diff line number Diff line change
@@ -134,14 +134,14 @@ def across_class_nms(predicted_boxes, iou_threshold=0.15):


def _dataloader_wrapper_(model,
trainer,
dataloader,
root_dir,
annotations,
nms_thresh,
savedir=None,
color=None,
thickness=1):
trainer,
dataloader,
root_dir,
annotations,
nms_thresh,
savedir=None,
color=None,
thickness=1):
"""Create a dataset and predict entire annotation file

Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
26 changes: 17 additions & 9 deletions deepforest/preprocess.py
Original file line number Diff line number Diff line change
@@ -146,13 +146,12 @@ def split_raster(annotations_file,
patch_overlap=0.05,
allow_empty=False,
image_name=None,
save_dir="."
):
save_dir="."):
"""Divide a large tile into smaller arrays. Each crop will be saved to
file.

Args:
numpy_image: a numpy object to be used as a raster, usually opened from rasterio.open.read()
numpy_image: a numpy object to be used as a raster, usually opened from rasterio.open.read(), in order (height, width, channels)
path_to_raster: (str): Path to a tile that can be read by rasterio on disk
annotations_file (str or pd.DataFrame): A pandas dataframe or path to annotations csv file. In the format -> image_path, xmin, ymin, xmax, ymax, label
save_dir (str): Directory to save images
@@ -167,11 +166,12 @@ def split_raster(annotations_file,
A pandas dataframe with annotations file for training.
A copy of this file is written to save_dir as a side effect.
"""
# Set deprecation warning for base_dir
if not base_dir == ".":
# Set deprecation warning for base_dir and set to save_dir
if base_dir:
warnings.warn(
"base_dir argument will be deprecated in 2.0. The naming is confusing, the rest of the API uses 'save_dir' to refer to location of images. Please use 'save_dir' argument.",
DeprecationWarning)
save_dir = base_dir

# Load raster as image
if (numpy_image is None) & (path_to_raster is None):
@@ -186,11 +186,19 @@ def split_raster(annotations_file,
raise (IOError("If passing an numpy_image, please also specify a image_name"
" to match the column in the annotation.csv file"))

# Confirm that raster is H x W x C, if not, convert, assuming image is wider/taller than channels
if numpy_image.shape[0] < numpy_image.shape[-1]:
warnings.warn(
"Input rasterio had shape {}, assuming channels first. Converting to channels last"
.format(numpy_image.shape), UserWarning)
numpy_image = np.moveaxis(numpy_image, 0, 2)

# Check that its 3 band
bands = numpy_image.shape[2]
if not bands == 3:
warnings.warn("Input rasterio had non-3 band shape of {}, ignoring "
"alpha channel".format(numpy_image.shape))
warnings.warn(
"Input rasterio had non-3 band shape of {}, ignoring "
"alpha channel".format(numpy_image.shape), UserWarning)
try:
numpy_image = numpy_image[:, :, :3].astype("uint8")
except:
@@ -265,7 +273,7 @@ def split_raster(annotations_file,
annotations_files.append(crop_annotations)

# save image crop
save_crop(base_dir, image_name, index, crop)
save_crop(save_dir, image_name, index, crop)
if len(annotations_files) == 0:
raise ValueError(
"Input file has no overlapping annotations and allow_empty is {}".format(
@@ -277,7 +285,7 @@ def split_raster(annotations_file,
# Use filename of the raster path to save the annotations
image_basename = os.path.splitext(image_name)[0]
file_path = image_basename + ".csv"
file_path = os.path.join(base_dir, file_path)
file_path = os.path.join(save_dir, file_path)
annotations_files.to_csv(file_path, index=False, header=True)

return annotations_files
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@
project = u'DeepForest'
copyright = u"2019, Ben Weinstein"
author = u"Ben Weinstein"
version = u"__version__ = '__version__ = '__version__ = '1.3.0'''"
version = u"__version__ = '__version__ = '__version__ = '1.3.1'''"
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints']
pygments_style = 'sphinx'
todo_include_todos = False
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.3.0
current_version = 1.3.1
commit = True
tag = True

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from distutils.command.build_ext import build_ext as DistUtilsBuildExt

NAME = 'deepforest'
VERSION = '1.3.0'
VERSION = '1.3.1'
DESCRIPTION = 'Tree crown prediction using deep learning retinanets'
URL = 'https://github.com/Weecology/DeepForest'
AUTHOR = 'Ben Weinstein'
Binary file removed tests/data/OSBS_029_0.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_1.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_2.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_3.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_4.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_5.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_6.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_7.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_8.png
Binary file not shown.
101 changes: 65 additions & 36 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

import rasterio


@pytest.fixture()
def config():
config = utilities.read_config("deepforest_config.yml")
@@ -25,8 +26,7 @@ def config():
config["path_to_raster"] = get_data("OSBS_029.tif")

# Create a clean config test data
annotations = utilities.xml_to_annotations(
xml_path=config["annotations_xml"])
annotations = utilities.xml_to_annotations(xml_path=config["annotations_xml"])
annotations.to_csv("tests/data/OSBS_029.csv", index=False)

return config
@@ -74,60 +74,69 @@ def test_select_annotations_tile(config, image):
assert selected_annotations.xmax.max() <= config["patch_size"]
assert selected_annotations.ymax.max() <= config["patch_size"]

@pytest.mark.parametrize("input_type",["path","dataframe"])

@pytest.mark.parametrize("input_type", ["path", "dataframe"])
def test_split_raster(config, tmpdir, input_type):
"""Split raster into crops with overlaps to maintain all annotations"""
raster = get_data("2019_YELL_2_528000_4978000_image_crop2.png")
annotations = utilities.xml_to_annotations(get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations = utilities.xml_to_annotations(
get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations.to_csv("{}/example.csv".format(tmpdir), index=False)
#annotations.label = 0

if input_type =="path":
if input_type == "path":
annotations_file = "{}/example.csv".format(tmpdir)
else:
annotations_file = annotations
annotations_file = annotations

output_annotations = preprocess.split_raster(path_to_raster=raster,
annotations_file=annotations_file,
base_dir=tmpdir,
patch_size=500,
patch_overlap=0)
annotations_file=annotations_file,
base_dir=tmpdir,
patch_size=500,
patch_overlap=0)

# Returns a 6 column pandas array
assert not output_annotations.empty
assert output_annotations.shape[1] == 6


def test_split_raster_empty_crops(config, tmpdir):
"""Split raster into crops with overlaps to maintain all annotations, allow empty crops"""
raster = get_data("2019_YELL_2_528000_4978000_image_crop2.png")
annotations = utilities.xml_to_annotations(get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations = utilities.xml_to_annotations(
get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations.to_csv("{}/example.csv".format(tmpdir), index=False)
#annotations.label = 0
#visualize.plot_prediction_dataframe(df=annotations, root_dir=os.path.dirname(get_data(".")), show=True)

annotations_file = preprocess.split_raster(path_to_raster=raster,
annotations_file="{}/example.csv".format(tmpdir),
base_dir=tmpdir,
patch_size=100,
patch_overlap=0,
allow_empty=True)

annotations_file = preprocess.split_raster(
path_to_raster=raster,
annotations_file="{}/example.csv".format(tmpdir),
base_dir=tmpdir,
patch_size=100,
patch_overlap=0,
allow_empty=True)

# Returns a 6 column pandas array
assert not annotations_file[(annotations_file.xmin == 0) & (annotations_file.xmax == 0)].empty

def test_split_raster_from_image(config):
assert not annotations_file[(annotations_file.xmin == 0) &
(annotations_file.xmax == 0)].empty


def test_split_raster_from_image(config, tmpdir):
r = rasterio.open(config["path_to_raster"]).read()
r = np.rollaxis(r,0,3)
annotations_file = preprocess.split_raster(numpy_image=r,
annotations_file=config["annotations_file"],
base_dir="tests/data/",
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")
r = np.rollaxis(r, 0, 3)
annotations_file = preprocess.split_raster(
numpy_image=r,
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")

# Returns a 6 column pandas array
assert annotations_file.shape[1] == 6


def test_split_raster_empty(config):
# Clean output folder
for f in glob.glob("tests/output/empty/*"):
@@ -168,11 +177,31 @@ def test_split_raster_empty(config):
assert os.path.exists("tests/output/empty/OSBS_029_1.png")


def test_split_size_error(config):
def test_split_size_error(config, tmpdir):
with pytest.raises(ValueError):
annotations_file = preprocess.split_raster(path_to_raster=config["path_to_raster"],
annotations_file=config["annotations_file"],
base_dir="tests/data/",
patch_size=2000,
patch_overlap=config["patch_overlap"])

annotations_file = preprocess.split_raster(
path_to_raster=config["path_to_raster"],
annotations_file=config["annotations_file"],
base_dir=tmpdir,
patch_size=2000,
patch_overlap=config["patch_overlap"])


@pytest.mark.parametrize("orders", [(4, 400, 400), (400, 400, 4)])
def test_split_raster_4_band_warns(config, tmpdir, orders):
"""Test rasterio channel order
(400, 400, 4) C x H x W
(4, 400, 400) wrong channel order, H x W x C
"""

# Confirm that the rasterio channel order is C x H x W
assert rasterio.open(get_data("OSBS_029.tif")).read().shape[0] == 3
numpy_image = np.zeros(orders, dtype=np.uint8)

with pytest.warns(UserWarning):
preprocess.split_raster(numpy_image=numpy_image,
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")
Loading