diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 8e55ed92..a42a6752 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -1,6 +1,8 @@ # entry point for deepforest model import importlib import os +import contextlib +import io import typing import warnings @@ -486,7 +488,8 @@ def predict_tile(self, thickness=1, crop_model=None, crop_transform=None, - crop_augment=False): + crop_augment=False, + verbose=True): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and reassambles into a single array. @@ -507,19 +510,27 @@ def predict_tile(self, return_plot: return a plot of the image with predictions overlaid (deprecated) color: color of the bounding box as a tuple of BGR color (deprecated) thickness: thickness of the rectangle border line in px (deprecated) - + verbose: whether to show progress bar Returns: pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple """ self.model.eval() self.model.nms_thresh = self.config["nms_thresh"] - # if more than one GPU present, use only a the first available gpu + # if more than one GPU present, use only the first available gpu if torch.cuda.device_count() > 1: - # Get available gpus and regenerate trainer warnings.warn( "More than one GPU detected. Using only the first GPU for predict_tile.") self.config["devices"] = 1 + + # Configure trainer based on verbose setting + if not verbose: + callbacks = [ + cb for cb in self.trainer.callbacks + if not isinstance(cb, pl.callbacks.ProgressBar) + ] + self.create_trainer(enable_progress_bar=False, callbacks=callbacks) + else: self.create_trainer() if (raster_path is None) and (image is None): @@ -551,6 +562,7 @@ def predict_tile(self, patch_overlap=patch_overlap, patch_size=patch_size) + # Predict using trainer batched_results = self.trainer.predict(self, self.predict_dataloader(ds)) # Flatten list from batched prediction @@ -560,11 +572,21 @@ def predict_tile(self, results.append(boxes) if mosaic: - results = predict.mosiac(results, - ds.windows, - sigma=sigma, - thresh=thresh, - iou_threshold=iou_threshold) + # Suppress output if not verbose + if not verbose: + f = io.StringIO() + with contextlib.redirect_stdout(f): + results = predict.mosiac(results, + ds.windows, + sigma=sigma, + thresh=thresh, + iou_threshold=iou_threshold) + else: + results = predict.mosiac(results, + ds.windows, + sigma=sigma, + thresh=thresh, + iou_threshold=iou_threshold) results["label"] = results.label.apply( lambda x: self.numeric_to_label_dict[x]) if raster_path: diff --git a/tests/test_main.py b/tests/test_main.py index c234aa83..d745bc94 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -906,4 +906,24 @@ def test_evaluate_on_epoch_interval(m): m.create_trainer() m.trainer.fit(m) assert m.trainer.logged_metrics["box_precision"] - assert m.trainer.logged_metrics["box_recall"] \ No newline at end of file + assert m.trainer.logged_metrics["box_recall"] + +@pytest.mark.parametrize("verbose", [True, False]) +def test_predict_tile_verbose(m, raster_path, capsys, verbose): + """Test that verbose output can be controlled in predict_tile""" + m.config["train"]["fast_dev_run"] = False + m.create_trainer() + + m.predict_tile( + raster_path=raster_path, + patch_size=300, + patch_overlap=0, + mosaic=True, + verbose=verbose + ) + + captured = capsys.readouterr() + if verbose: + assert captured.out.strip() + else: + assert not captured.out.strip()