Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve predict batch #876

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,12 +881,11 @@ def predict_batch(self, images, preprocess_fn=None):

#using Pytorch Ligthning's predict_step
with torch.no_grad():
predictions = []
for idx, image in enumerate(images):
predictions = self.predict_step(image.unsqueeze(0), idx)
predictions.extend(predictions)
predictions = self.predict_step(images, 0)

#convert predictions to dataframes
results = [pd.DataFrame(pred) for pred in predictions if pred is not None]
results = [utilities.read_file(pred) for pred in predictions if pred is not None]

return results

def configure_optimizers(self):
Expand Down
133 changes: 39 additions & 94 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,107 +701,52 @@ def test_predict_tile_with_crop_model_empty():
# Assert the result
assert result is None

# @pytest.mark.parametrize("batch_size", [1, 4, 8])
# def test_batch_prediction(m, batch_size, raster_path):
#
# # Prepare input data
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=batch_size)

# # Perform prediction
# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# # Check results
# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# @pytest.mark.parametrize("batch_size", [1, 4])
# def test_batch_training(m, batch_size, tmpdir):
#
# # Generate synthetic training data
# csv_file = get_data("example.csv")
# root_dir = os.path.dirname(csv_file)
# train_ds = m.load_dataset(csv_file, root_dir=root_dir)
# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# # Configure the model and trainer
# m.config["batch_size"] = batch_size
# m.create_trainer()
# trainer = m.trainer

# # Train the model
# trainer.fit(m, train_dl)

# # Assertions
# assert trainer.current_epoch == 1
# assert trainer.batch_size == batch_size

# @pytest.mark.parametrize("batch_size", [2, 4])
# def test_batch_data_augmentation(m, batch_size, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True)
# dl = DataLoader(ds, batch_size=batch_size)
def test_batch_prediction(m, raster_path):
# Prepare input data
tile = np.array(Image.open(raster_path))
ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300)
dl = DataLoader(ds, batch_size=3)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)
# Perform prediction
predictions = []
for batch in dl:
prediction = m.predict_batch(batch)
predictions.append(prediction)

# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# def test_batch_inference_consistency(m, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=4)
# Check results
assert len(predictions) == len(dl)
for batch_pred in predictions:
for image_pred in batch_pred:
assert isinstance(image_pred, pd.DataFrame)
assert "label" in image_pred.columns
assert "score" in image_pred.columns
assert "geometry" in image_pred.columns

def test_batch_inference_consistency(m, raster_path):
tile = np.array(Image.open(raster_path))
ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300)
dl = DataLoader(ds, batch_size=4)

# batch_predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# batch_predictions.append(prediction)
batch_predictions = []
for batch in dl:
prediction = m.predict_batch(batch)
batch_predictions.extend(prediction)

# single_predictions = []
# for image in ds:
# prediction = m.predict_image(image=image)
# single_predictions.append(prediction)
single_predictions = []
for image in ds:
image = image.permute(1,2,0).numpy() * 255
prediction = m.predict_image(image=image)
single_predictions.append(prediction)

# batch_df = pd.concat(batch_predictions, ignore_index=True)
# single_df = pd.concat(single_predictions, ignore_index=True)
batch_df = pd.concat(batch_predictions, ignore_index=True)
single_df = pd.concat(single_predictions, ignore_index=True)

# pd.testing.assert_frame_equal(batch_df, single_df)
# Make all xmin, ymin, xmax, ymax integers
for col in ["xmin", "ymin", "xmax", "ymax"]:
batch_df[col] = batch_df[col].astype(int)
single_df[col] = single_df[col].astype(int)
pd.testing.assert_frame_equal(batch_df[["xmin", "ymin", "xmax", "ymax"]], single_df[["xmin", "ymin", "xmax", "ymax"]], check_dtype=False)

# def test_large_batch_handling(m, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=16)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# assert len(predictions) > 0
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }
# assert not batch_pred.empty

def test_epoch_evaluation_end(m):
preds = [{
Expand Down
Loading