Skip to content

Commit

Permalink
cleanup preprocess tmpdir
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Dec 5, 2023
1 parent 0759e0b commit 52bc353
Show file tree
Hide file tree
Showing 11 changed files with 11 additions and 11 deletions.
14 changes: 7 additions & 7 deletions deepforest/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,6 @@ def split_raster(annotations_file,
if (numpy_image is None) & (path_to_raster is None):
raise IOError("supply a raster either as a path_to_raster or if ready "
"from existing in memory numpy object, as numpy_image=")

# Confirm that raster is H x W x C, if not, convert
if numpy_image.shape[0] < numpy_image.shape[1]:
warnings.warn(
"Input rasterio had shape {}, assuming channels first. Converting to channels last".format(
numpy_image.shape), UserWarning)
numpy_image = np.moveaxis(numpy_image, 0, 2)

if path_to_raster:
numpy_image = rasterio.open(path_to_raster).read()
Expand All @@ -194,6 +187,13 @@ def split_raster(annotations_file,
raise (IOError("If passing an numpy_image, please also specify a image_name"
" to match the column in the annotation.csv file"))

# Confirm that raster is H x W x C, if not, convert
if numpy_image.shape[0] < numpy_image.shape[1]:
warnings.warn(
"Input rasterio had shape {}, assuming channels first. Converting to channels last".format(
numpy_image.shape), UserWarning)
numpy_image = np.moveaxis(numpy_image, 0, 2)

# Check that its 3 band
bands = numpy_image.shape[2]
if not bands == 3:
Expand Down
Binary file removed tests/data/OSBS_029_0.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_1.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_2.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_3.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_4.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_5.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_6.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_7.png
Binary file not shown.
Binary file removed tests/data/OSBS_029_8.png
Binary file not shown.
8 changes: 4 additions & 4 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def test_split_raster_empty_crops(config, tmpdir):
# Returns a 6 column pandas array
assert not annotations_file[(annotations_file.xmin == 0) & (annotations_file.xmax == 0)].empty

def test_split_raster_from_image(config):
def test_split_raster_from_image(config, tmpdir):
r = rasterio.open(config["path_to_raster"]).read()
r = np.rollaxis(r,0,3)
annotations_file = preprocess.split_raster(numpy_image=r,
annotations_file=config["annotations_file"],
base_dir="tests/data/",
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")
Expand Down Expand Up @@ -168,11 +168,11 @@ def test_split_raster_empty(config):
assert os.path.exists("tests/output/empty/OSBS_029_1.png")


def test_split_size_error(config):
def test_split_size_error(config, tmpdir):
with pytest.raises(ValueError):
annotations_file = preprocess.split_raster(path_to_raster=config["path_to_raster"],
annotations_file=config["annotations_file"],
base_dir="tests/data/",
base_dir=tmpdir,
patch_size=2000,
patch_overlap=config["patch_overlap"])

Expand Down

0 comments on commit 52bc353

Please sign in to comment.