From cc15df8ab7ff88c4b403fec93feac9fe2069d35d Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 18 Feb 2024 10:45:10 +0900 Subject: [PATCH] Add logging_strategy option to fit_online --- d3rlpy/algos/qlearning/base.py | 19 ++++++++++++++++--- d3rlpy/logging/wandb_adapter.py | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 905967bb..c9036251 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -461,8 +461,9 @@ def fitter( the directory name will be `{class name}_{timestamp}`. with_timestamp: Flag to add timestamp string to the last of directory name. - logging_steps: number of steps to log metrics. - logging_strategy: what logging strategy to use. + logging_steps: Number of steps to log metrics. This will be ignored + if loggig_strategy is EPOCH. + logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. save_interval: Interval to save parameters. @@ -601,6 +602,8 @@ def fit_online( save_interval: int = 1, experiment_name: Optional[str] = None, with_timestamp: bool = True, + logging_steps: int = 500, + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, callback: Optional[Callable[[Self, int, int], None]] = None, @@ -623,6 +626,9 @@ def fit_online( the directory name will be ``{class name}_online_{timestamp}``. with_timestamp: Flag to add timestamp string to the last of directory name. + logging_steps: Number of steps to log metrics. This will be ignored + if logging_strategy is EPOCH. + logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. callback: Callable function that takes ``(algo, epoch, total_step)`` @@ -726,6 +732,12 @@ def fit_online( for name, val in loss.items(): logger.add_metric(name, val) + if ( + logging_strategy == LoggingStrategy.STEPS + and total_step % logging_steps == 0 + ): + logger.commit(epoch, total_step) + # call callback if given if callback: callback(self, epoch, total_step) @@ -742,7 +754,8 @@ def fit_online( logger.save_model(total_step, self) # save metrics - logger.commit(epoch, total_step) + if logging_strategy == LoggingStrategy.EPOCH: + logger.commit(epoch, total_step) # clip the last episode buffer.clip_episode(False) diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py index d06bf8ea..616b6850 100644 --- a/d3rlpy/logging/wandb_adapter.py +++ b/d3rlpy/logging/wandb_adapter.py @@ -77,7 +77,7 @@ def create(self, experiment_name: str) -> LoggerAdapter: experiment_name (str): Name of the experiment. Returns: - LoggerAdapter: Instance of the WandB Logger Adapter. + Instance of the WandB Logger Adapter. """ return WanDBAdapter( experiment_name=experiment_name,