Skip to content

Commit

Permalink
Modified create_trainer function to allow the arguments take preceden…
Browse files Browse the repository at this point in the history
…ce (#680)
  • Loading branch information
Om-Doiphode authored Jun 7, 2024
1 parent 7de55c1 commit ffcabb0
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,22 @@ def create_trainer(self, logger=None, callbacks=[], **kwargs):
enable_checkpointing = True
else:
enable_checkpointing = False

self.trainer = pl.Trainer(logger=logger,
max_epochs=self.config["train"]["epochs"],
enable_checkpointing=enable_checkpointing,
devices=self.config["devices"],
accelerator=self.config["accelerator"],
fast_dev_run=self.config["train"]["fast_dev_run"],
callbacks=callbacks,
limit_val_batches=limit_val_batches,
num_sanity_val_steps=num_sanity_val_steps,
**kwargs)

trainer_args = {
"logger":logger,
"max_epochs":self.config["train"]["epochs"],
"enable_checkpointing":enable_checkpointing,
"devices":self.config["devices"],
"accelerator":self.config["accelerator"],
"fast_dev_run":self.config["train"]["fast_dev_run"],
"callbacks":callbacks,
"limit_val_batches":limit_val_batches,
"num_sanity_val_steps":num_sanity_val_steps
}
# Update with kwargs to allow them to override config
trainer_args.update(kwargs)

self.trainer = pl.Trainer(**trainer_args)

def on_fit_start(self):
if self.config["train"]["csv_file"] is None:
Expand Down

0 comments on commit ffcabb0

Please sign in to comment.