Skip to content

Commit

Permalink
Improve checks and tests for preprocess.py (weecology#564)
Browse files Browse the repository at this point in the history
* Bump version: 1.3.0 → 1.3.1

* rasterio reads in a .tif, its channels first, CxHxW
  improve warnings and 4 band checking in preprocess
  pytest.mark.parametrize for multiple test instances

---------

Co-authored-by: henry senyondo <[email protected]>
  • Loading branch information
bw4sz and henrykironde authored Dec 12, 2023
1 parent 741d5ed commit a01acef
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 61 deletions.
16 changes: 8 additions & 8 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,14 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
dataloader = self.predict_dataloader(ds)

results = predict._dataloader_wrapper_(model=self,
trainer=self.trainer,
annotations=df,
dataloader=dataloader,
root_dir=root_dir,
nms_thresh=self.config["nms_thresh"],
savedir=savedir,
color=color,
thickness=thickness)
trainer=self.trainer,
annotations=df,
dataloader=dataloader,
root_dir=root_dir,
nms_thresh=self.config["nms_thresh"],
savedir=savedir,
color=color,
thickness=thickness)

return results

Expand Down
16 changes: 8 additions & 8 deletions deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ def across_class_nms(predicted_boxes, iou_threshold=0.15):


def _dataloader_wrapper_(model,
trainer,
dataloader,
root_dir,
annotations,
nms_thresh,
savedir=None,
color=None,
thickness=1):
trainer,
dataloader,
root_dir,
annotations,
nms_thresh,
savedir=None,
color=None,
thickness=1):
"""Create a dataset and predict entire annotation file
Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
Expand Down
26 changes: 17 additions & 9 deletions deepforest/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,12 @@ def split_raster(annotations_file,
patch_overlap=0.05,
allow_empty=False,
image_name=None,
save_dir="."
):
save_dir="."):
"""Divide a large tile into smaller arrays. Each crop will be saved to
file.
Args:
numpy_image: a numpy object to be used as a raster, usually opened from rasterio.open.read()
numpy_image: a numpy object to be used as a raster, usually opened from rasterio.open.read(), in order (height, width, channels)
path_to_raster: (str): Path to a tile that can be read by rasterio on disk
annotations_file (str or pd.DataFrame): A pandas dataframe or path to annotations csv file. In the format -> image_path, xmin, ymin, xmax, ymax, label
save_dir (str): Directory to save images
Expand All @@ -167,11 +166,12 @@ def split_raster(annotations_file,
A pandas dataframe with annotations file for training.
A copy of this file is written to save_dir as a side effect.
"""
# Set deprecation warning for base_dir
if not base_dir == ".":
# Set deprecation warning for base_dir and set to save_dir
if base_dir:
warnings.warn(
"base_dir argument will be deprecated in 2.0. The naming is confusing, the rest of the API uses 'save_dir' to refer to location of images. Please use 'save_dir' argument.",
DeprecationWarning)
save_dir = base_dir

# Load raster as image
if (numpy_image is None) & (path_to_raster is None):
Expand All @@ -186,11 +186,19 @@ def split_raster(annotations_file,
raise (IOError("If passing an numpy_image, please also specify a image_name"
" to match the column in the annotation.csv file"))

# Confirm that raster is H x W x C, if not, convert, assuming image is wider/taller than channels
if numpy_image.shape[0] < numpy_image.shape[-1]:
warnings.warn(
"Input rasterio had shape {}, assuming channels first. Converting to channels last"
.format(numpy_image.shape), UserWarning)
numpy_image = np.moveaxis(numpy_image, 0, 2)

# Check that its 3 band
bands = numpy_image.shape[2]
if not bands == 3:
warnings.warn("Input rasterio had non-3 band shape of {}, ignoring "
"alpha channel".format(numpy_image.shape))
warnings.warn(
"Input rasterio had non-3 band shape of {}, ignoring "
"alpha channel".format(numpy_image.shape), UserWarning)
try:
numpy_image = numpy_image[:, :, :3].astype("uint8")
except:
Expand Down Expand Up @@ -265,7 +273,7 @@ def split_raster(annotations_file,
annotations_files.append(crop_annotations)

# save image crop
save_crop(base_dir, image_name, index, crop)
save_crop(save_dir, image_name, index, crop)
if len(annotations_files) == 0:
raise ValueError(
"Input file has no overlapping annotations and allow_empty is {}".format(
Expand All @@ -277,7 +285,7 @@ def split_raster(annotations_file,
# Use filename of the raster path to save the annotations
image_basename = os.path.splitext(image_name)[0]
file_path = image_basename + ".csv"
file_path = os.path.join(base_dir, file_path)
file_path = os.path.join(save_dir, file_path)
annotations_files.to_csv(file_path, index=False, header=True)

return annotations_files
Binary file removed tests/data/OSBS_029_0.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_1.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_2.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_3.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_4.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_5.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_6.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_7.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_8.png
Binary file not shown.
101 changes: 65 additions & 36 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import rasterio


@pytest.fixture()
def config():
config = utilities.read_config("deepforest_config.yml")
Expand All @@ -25,8 +26,7 @@ def config():
config["path_to_raster"] = get_data("OSBS_029.tif")

# Create a clean config test data
annotations = utilities.xml_to_annotations(
xml_path=config["annotations_xml"])
annotations = utilities.xml_to_annotations(xml_path=config["annotations_xml"])
annotations.to_csv("tests/data/OSBS_029.csv", index=False)

return config
Expand Down Expand Up @@ -74,60 +74,69 @@ def test_select_annotations_tile(config, image):
assert selected_annotations.xmax.max() <= config["patch_size"]
assert selected_annotations.ymax.max() <= config["patch_size"]

@pytest.mark.parametrize("input_type",["path","dataframe"])

@pytest.mark.parametrize("input_type", ["path", "dataframe"])
def test_split_raster(config, tmpdir, input_type):
"""Split raster into crops with overlaps to maintain all annotations"""
raster = get_data("2019_YELL_2_528000_4978000_image_crop2.png")
annotations = utilities.xml_to_annotations(get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations = utilities.xml_to_annotations(
get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations.to_csv("{}/example.csv".format(tmpdir), index=False)
#annotations.label = 0

if input_type =="path":
if input_type == "path":
annotations_file = "{}/example.csv".format(tmpdir)
else:
annotations_file = annotations
annotations_file = annotations

output_annotations = preprocess.split_raster(path_to_raster=raster,
annotations_file=annotations_file,
base_dir=tmpdir,
patch_size=500,
patch_overlap=0)
annotations_file=annotations_file,
base_dir=tmpdir,
patch_size=500,
patch_overlap=0)

# Returns a 6 column pandas array
assert not output_annotations.empty
assert output_annotations.shape[1] == 6


def test_split_raster_empty_crops(config, tmpdir):
"""Split raster into crops with overlaps to maintain all annotations, allow empty crops"""
raster = get_data("2019_YELL_2_528000_4978000_image_crop2.png")
annotations = utilities.xml_to_annotations(get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations = utilities.xml_to_annotations(
get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations.to_csv("{}/example.csv".format(tmpdir), index=False)
#annotations.label = 0
#visualize.plot_prediction_dataframe(df=annotations, root_dir=os.path.dirname(get_data(".")), show=True)

annotations_file = preprocess.split_raster(path_to_raster=raster,
annotations_file="{}/example.csv".format(tmpdir),
base_dir=tmpdir,
patch_size=100,
patch_overlap=0,
allow_empty=True)

annotations_file = preprocess.split_raster(
path_to_raster=raster,
annotations_file="{}/example.csv".format(tmpdir),
base_dir=tmpdir,
patch_size=100,
patch_overlap=0,
allow_empty=True)

# Returns a 6 column pandas array
assert not annotations_file[(annotations_file.xmin == 0) & (annotations_file.xmax == 0)].empty

def test_split_raster_from_image(config):
assert not annotations_file[(annotations_file.xmin == 0) &
(annotations_file.xmax == 0)].empty


def test_split_raster_from_image(config, tmpdir):
r = rasterio.open(config["path_to_raster"]).read()
r = np.rollaxis(r,0,3)
annotations_file = preprocess.split_raster(numpy_image=r,
annotations_file=config["annotations_file"],
base_dir="tests/data/",
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")
r = np.rollaxis(r, 0, 3)
annotations_file = preprocess.split_raster(
numpy_image=r,
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")

# Returns a 6 column pandas array
assert annotations_file.shape[1] == 6


def test_split_raster_empty(config):
# Clean output folder
for f in glob.glob("tests/output/empty/*"):
Expand Down Expand Up @@ -168,11 +177,31 @@ def test_split_raster_empty(config):
assert os.path.exists("tests/output/empty/OSBS_029_1.png")


def test_split_size_error(config):
def test_split_size_error(config, tmpdir):
with pytest.raises(ValueError):
annotations_file = preprocess.split_raster(path_to_raster=config["path_to_raster"],
annotations_file=config["annotations_file"],
base_dir="tests/data/",
patch_size=2000,
patch_overlap=config["patch_overlap"])

annotations_file = preprocess.split_raster(
path_to_raster=config["path_to_raster"],
annotations_file=config["annotations_file"],
base_dir=tmpdir,
patch_size=2000,
patch_overlap=config["patch_overlap"])


@pytest.mark.parametrize("orders", [(4, 400, 400), (400, 400, 4)])
def test_split_raster_4_band_warns(config, tmpdir, orders):
"""Test rasterio channel order
(400, 400, 4) C x H x W
(4, 400, 400) wrong channel order, H x W x C
"""

# Confirm that the rasterio channel order is C x H x W
assert rasterio.open(get_data("OSBS_029.tif")).read().shape[0] == 3
numpy_image = np.zeros(orders, dtype=np.uint8)

with pytest.warns(UserWarning):
preprocess.split_raster(numpy_image=numpy_image,
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")

0 comments on commit a01acef

Please sign in to comment.