diff --git a/deepforest/data/deepforest_config.yml b/deepforest/data/deepforest_config.yml index 9290c509..0a541578 100644 --- a/deepforest/data/deepforest_config.yml +++ b/deepforest/data/deepforest_config.yml @@ -23,6 +23,26 @@ train: # Optimizer initial learning rate lr: 0.001 + scheduler: + type: + params: + # Common parameters + T_max: 10 + eta_min: 0.00001 + lr_lambda: "lambda epoch: 0.95 ** epoch" # For lambdaLR and multiplicativeLR + step_size: 30 # For stepLR + gamma: 0.1 # For stepLR, multistepLR, and exponentialLR + milestones: [50, 100] # For multistepLR + + # ReduceLROnPlateau parameters (used if type is not explicitly mentioned) + mode: "min" + factor: 0.1 + patience: 10 + threshold: 0.0001 + threshold_mode: "rel" + cooldown: 0 + min_lr: 0 + eps: 1e-08 # Print loss every n epochs epochs: 1 @@ -37,4 +57,4 @@ validation: root_dir: # Intersection over union evaluation iou_threshold: 0.4 - val_accuracy_interval: 20 \ No newline at end of file + val_accuracy_interval: 20 diff --git a/deepforest/main.py b/deepforest/main.py index ad1a2fba..d50e8f3f 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -635,15 +635,41 @@ def configure_optimizers(self): optimizer = optim.SGD(self.model.parameters(), lr=self.config["train"]["lr"], momentum=0.9) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, - mode='min', - factor=0.1, - patience=10, - threshold=0.0001, - threshold_mode='rel', - cooldown=0, - min_lr=0, - eps=1e-08) + + scheduler_config = self.config["train"]["scheduler"] + scheduler_type = scheduler_config["type"] + params = scheduler_config["params"] + + if scheduler_type == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, + T_max=params["T_max"], + eta_min=params["eta_min"]) + + elif scheduler_type=="lambdaLR": + scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=params["lr_lambda"]) + + elif scheduler_type=="multiplicativeLR": + scheduler=torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=params["lr_lambda"]) + + elif scheduler_type=="stepLR": + scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=params["step_size"], gamma=params["gamma"]) + + elif scheduler_type=="multistepLR": + scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params["milestones"], gamma=params["gamma"]) + + elif scheduler_type=="exponentialLR": + scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"]) + + else: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + mode=params["mode"], + factor=params["factor"], + patience=params["patience"], + threshold=params["threshold"], + threshold_mode=params["threshold_mode"], + cooldown=params["cooldown"], + min_lr=params["min_lr"], + eps=params["eps"]) # Monitor rate is val data is used if self.config["validation"]["csv_file"] is not None: diff --git a/deepforest_config.yml b/deepforest_config.yml index 9f5af746..5f28694a 100644 --- a/deepforest_config.yml +++ b/deepforest_config.yml @@ -23,6 +23,26 @@ train: # Optimizer initial learning rate lr: 0.001 + scheduler: + type: + params: + # Common parameters + T_max: 10 + eta_min: 0.00001 + lr_lambda: "lambda epoch: 0.95 ** epoch" # For lambdaLR and multiplicativeLR + step_size: 30 # For stepLR + gamma: 0.1 # For stepLR, multistepLR, and exponentialLR + milestones: [50, 100] # For multistepLR + + # ReduceLROnPlateau parameters (used if type is not explicitly mentioned) + mode: "min" + factor: 0.1 + patience: 10 + threshold: 0.0001 + threshold_mode: "rel" + cooldown: 0 + min_lr: 0 + eps: 1e-08 # Print loss every n epochs epochs: 1 diff --git a/docs/ConfigurationFile.md b/docs/ConfigurationFile.md index de935145..86dcf618 100644 --- a/docs/ConfigurationFile.md +++ b/docs/ConfigurationFile.md @@ -33,6 +33,26 @@ train: # Optimizer initial learning rate lr: 0.001 + scheduler: + type: + params: + # Common parameters + T_max: 10 + eta_min: 0.00001 + lr_lambda: "lambda epoch: 0.95 ** epoch" # For lambdaLR and multiplicativeLR + step_size: 30 # For stepLR + gamma: 0.1 # For stepLR, multistepLR, and exponentialLR + milestones: [50, 100] # For multistepLR + + # ReduceLROnPlateau parameters (used if type is not explicitly mentioned) + mode: "min" + factor: 0.1 + patience: 10 + threshold: 0.0001 + threshold_mode: "rel" + cooldown: 0 + min_lr: 0 + eps: 1e-08 # Print loss every n epochs epochs: 1 @@ -48,6 +68,7 @@ validation: # Intersection over union evaluation iou_threshold: 0.4 val_accuracy_interval: 20 + ``` ## Passing config arguments at runtime using a dict @@ -127,6 +148,8 @@ Learning rate for the training optimization. By default the optimizer is stochas optim.SGD(self.model.parameters(), lr=self.config["train"]["lr"], momentum=0.9) ``` +A learning rate scheduler is used to adjust the learning rate based on validation loss. The default scheduler is ReduceLROnPlateau: + ``` self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, patience=10, @@ -134,13 +157,42 @@ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08) ``` -This scheduler can be overwritten by replacing the model class +This default scheduler can be overridden by specifying a different scheduler in the config_args: ``` -m = main.deepforest() -m.scheduler = <> +scheduler_config = { + "type": "cosine", # or "lambdaLR", "multiplicativeLR", "stepLR", "multistepLR", "exponentialLR", "reduceLROnPlateau" + "params": { + # Scheduler-specific parameters + } +} + +config_args = { + "train": { + "lr": 0.01, + "scheduler": scheduler_config, + "csv_file": "path/to/annotations.csv", + "root_dir": "path/to/root_dir", + "fast_dev_run": False, + "epochs": 2 + }, + "validation": { + "csv_file": "path/to/annotations.csv", + "root_dir": "path/to/root_dir" + } +} ``` +The scheduler types supported are: + +- **cosine**: CosineAnnealingLR +- **lambdaLR**: LambdaLR +- **multiplicativeLR**: MultiplicativeLR +- **stepLR**: StepLR +- **multistepLR**: MultiStepLR +- **exponentialLR**: ExponentialLR +- **reduceLROnPlateau**: ReduceLROnPlateau + ### Epochs The number of times to run a full pass of the dataloader during model training. diff --git a/tests/deepforest_config_test.yml b/tests/deepforest_config_test.yml index 41ecd451..5f28694a 100644 --- a/tests/deepforest_config_test.yml +++ b/tests/deepforest_config_test.yml @@ -1,35 +1,60 @@ # Config file for DeepForest pytorch module -#cpu workers for data loaders -#Dataloaders +# Cpu workers for data loaders +# Dataloaders workers: 1 devices: auto accelerator: auto -batch_size: 9999 +batch_size: 1 -#Non-max supression of overlapping predictions -nms_thresh: 0.9 -score_thresh: 0.9 +# Model Architecture +architecture: 'retinanet' +num_classes: 1 +nms_thresh: 0.05 -train: +# Architecture specific params +retinanet: + # Non-max supression of overlapping predictions + score_thresh: 0.1 +train: csv_file: root_dir: - #Optomizer initial learning rate + # Optimizer initial learning rate lr: 0.001 + scheduler: + type: + params: + # Common parameters + T_max: 10 + eta_min: 0.00001 + lr_lambda: "lambda epoch: 0.95 ** epoch" # For lambdaLR and multiplicativeLR + step_size: 30 # For stepLR + gamma: 0.1 # For stepLR, multistepLR, and exponentialLR + milestones: [50, 100] # For multistepLR + + # ReduceLROnPlateau parameters (used if type is not explicitly mentioned) + mode: "min" + factor: 0.1 + patience: 10 + threshold: 0.0001 + threshold_mode: "rel" + cooldown: 0 + min_lr: 0 + eps: 1e-08 - #Print loss every n epochs + # Print loss every n epochs epochs: 1 - #Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. + # Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. fast_dev_run: False - #pin images to GPU memory for fast training. This depends on GPU size and number of images. + # pin images to GPU memory for fast training. This depends on GPU size and number of images. preload_images: False validation: - #callback args + # callback args csv_file: root_dir: - #Intersection over union evaluation + # Intersection over union evaluation iou_threshold: 0.4 - val_accuracy_interval: 5 \ No newline at end of file + val_accuracy_interval: 20 diff --git a/tests/test_main.py b/tests/test_main.py index c95aefc4..e25f20f7 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -553,4 +553,64 @@ def test_existing_predict_dataloader(m, tmpdir): ds = dataset.TileDataset(tile=np.random.random((400,400,3)).astype("float32"), patch_overlap=0.1, patch_size=100) existing_loader = m.predict_dataloader(ds) batches = m.trainer.predict(m, existing_loader) - len(batches[0]) == m.config["batch_size"] + 1 \ No newline at end of file + len(batches[0]) == m.config["batch_size"] + 1 + +# Test train with each scheduler +@pytest.mark.parametrize("scheduler,expected",[("cosine","CosineAnnealingLR"), + ("lambdaLR","LambdaLR"), + ("multiplicativeLR","MultiplicativeLR"), + ("stepLR","StepLR"), + ("multistepLR","MultiStepLR"), + ("exponentialLR","ExponentialLR"), + ("reduceLROnPlateau","ReduceLROnPlateau")]) +def test_configure_optimizers(scheduler, expected): + scheduler_config = { + "type": scheduler, + "params": { + "T_max": 10, + "eta_min": 0.00001, + "lr_lambda": lambda epoch: 0.95 ** epoch, # For lambdaLR and multiplicativeLR + "step_size": 30, # For stepLR + "gamma": 0.1, # For stepLR, multistepLR, and exponentialLR + "milestones": [50, 100], # For multistepLR + + # ReduceLROnPlateau parameters (used if type is not explicitly mentioned) + "mode": "min", + "factor": 0.1, + "patience": 10, + "threshold": 0.0001, + "threshold_mode": "rel", + "cooldown": 0, + "min_lr": 0, + "eps": 1e-08 + }, + "expected": expected + } + + annotations_file = get_data("testfile_deepforest.csv") + root_dir = os.path.dirname(annotations_file) + + config_args = { + "train": { + "lr": 0.01, + "scheduler": scheduler_config, + "csv_file": annotations_file, + "root_dir": root_dir, + "fast_dev_run": False, + "epochs": 2 + }, + "validation": { + "csv_file": annotations_file, + "root_dir": root_dir + } + } + + # Initialize the model with the config arguments + m = main.deepforest(config_args=config_args) + + # Create and run the trainer + m.create_trainer(limit_train_batches=1.0) + m.trainer.fit(m) + + # Assert the scheduler type + assert type(m.trainer.lr_scheduler_configs[0].scheduler).__name__ == scheduler_config["expected"], f"Scheduler type mismatch for {scheduler_config['type']}" \ No newline at end of file