From 3ec5a4d71beb2ee9393e70792b33d3eb97e629f8 Mon Sep 17 00:00:00 2001 From: Claudius Kienle Date: Fri, 16 Feb 2024 14:28:32 +0100 Subject: [PATCH 1/5] added wandb logger and extended when to log --- d3rlpy/algos/qlearning/base.py | 14 +++++++++++- d3rlpy/logging/wandb_adapter.py | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 d3rlpy/logging/wandb_adapter.py diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index e6de38b0..33be4203 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -360,6 +360,8 @@ def fit( n_steps_per_epoch: int = 10000, experiment_name: Optional[str] = None, with_timestamp: bool = True, + logging_steps: int = 500, + logging_strategy: str = "epoch", logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, save_interval: int = 1, @@ -404,6 +406,8 @@ def fit( n_steps_per_epoch=n_steps_per_epoch, experiment_name=experiment_name, with_timestamp=with_timestamp, + logging_steps=logging_steps, + logging_strategy=logging_strategy, logger_adapter=logger_adapter, show_progress=show_progress, save_interval=save_interval, @@ -420,6 +424,8 @@ def fitter( dataset: ReplayBuffer, n_steps: int, n_steps_per_epoch: int = 10000, + logging_steps: int = 500, + logging_strategy: str = "epoch", experiment_name: Optional[str] = None, with_timestamp: bool = True, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), @@ -467,6 +473,8 @@ def fitter( # check action space assert_action_space_with_dataset(self, dataset.dataset_info) + assert logging_strategy in ['steps', 'epoch'], 'Logging strategy invalid' + # initialize scalers build_scalers_with_transition_picker(self, dataset) @@ -540,6 +548,9 @@ def fitter( total_step += 1 + if logging_strategy == 'steps' and total_step % logging_steps == 0: + metrics = logger.commit(epoch, total_step) + # call callback if given if callback: callback(self, epoch, total_step) @@ -554,7 +565,8 @@ def fitter( logger.add_metric(name, test_score) # save metrics - metrics = logger.commit(epoch, total_step) + if logging_strategy == 'epoch': + metrics = logger.commit(epoch, total_step) # save model parameters if epoch % save_interval == 0: diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py new file mode 100644 index 00000000..6e5577f4 --- /dev/null +++ b/d3rlpy/logging/wandb_adapter.py @@ -0,0 +1,40 @@ +import wandb +from typing import Any, Dict, Optional +from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol + + +class LoggerWanDBAdapter(LoggerAdapter): + + def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None): + self.run = wandb.init(project=project, name=experiment_name) + + def write_params(self, params: Dict[str, Any]) -> None: + self.run.config.update(params) + + def before_write_metric(self, epoch: int, step: int) -> None: + pass + + def write_metric(self, epoch: int, step: int, name: str, value: float) -> None: + self.run.log({name: value, 'epoch': epoch}, step=step) + + def after_write_metric(self, epoch: int, step: int) -> None: + pass + + def save_model(self, epoch: int, algo: SaveProtocol) -> None: + # Implement saving model to wandb if needed + pass + + def close(self) -> None: + self.run.finish() + + +class WanDBAdapterFactory(LoggerAdapterFactory): + + _project: str + + def __init__(self, project: Optional[str] = None) -> None: + super().__init__() + self._project = project + + def create(self, experiment_name: str) -> LoggerAdapter: + return LoggerWanDBAdapter(project=self._project, experiment_name=experiment_name) From c1b2059868c8ff29cfb60f33e9659411f7e35e89 Mon Sep 17 00:00:00 2001 From: Claudius Kienle Date: Fri, 16 Feb 2024 14:33:12 +0100 Subject: [PATCH 2/5] add docs --- d3rlpy/logging/wandb_adapter.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py index 6e5577f4..69c03944 100644 --- a/d3rlpy/logging/wandb_adapter.py +++ b/d3rlpy/logging/wandb_adapter.py @@ -4,37 +4,70 @@ class LoggerWanDBAdapter(LoggerAdapter): + r"""WandB Logger Adapter class. + + This class logs data to Weights & Biases (WandB) for experiment tracking. + + Args: + experiment_name (str): Name of the experiment. + """ def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None): self.run = wandb.init(project=project, name=experiment_name) def write_params(self, params: Dict[str, Any]) -> None: + """Writes hyperparameters to WandB config.""" self.run.config.update(params) def before_write_metric(self, epoch: int, step: int) -> None: + """Callback executed before writing metric.""" pass def write_metric(self, epoch: int, step: int, name: str, value: float) -> None: + """Writes metric to WandB.""" self.run.log({name: value, 'epoch': epoch}, step=step) def after_write_metric(self, epoch: int, step: int) -> None: + """Callback executed after writing metric.""" pass def save_model(self, epoch: int, algo: SaveProtocol) -> None: + """Saves models to Weights & Biases. Not implemented for WandB.""" # Implement saving model to wandb if needed pass def close(self) -> None: + """Closes the logger and finishes the WandB run.""" self.run.finish() class WanDBAdapterFactory(LoggerAdapterFactory): + r"""WandB Logger Adapter Factory class. + + This class creates instances of the WandB Logger Adapter for experiment tracking. + + """ _project: str def __init__(self, project: Optional[str] = None) -> None: + """Initialize the WandB Logger Adapter Factory. + + Args: + project (Optional[str], optional): The name of the WandB project. Defaults to None. + + """ super().__init__() self._project = project def create(self, experiment_name: str) -> LoggerAdapter: + """Creates a WandB Logger Adapter instance. + + Args: + experiment_name (str): Name of the experiment. + + Returns: + LoggerAdapter: Instance of the WandB Logger Adapter. + + """ return LoggerWanDBAdapter(project=self._project, experiment_name=experiment_name) From a7814783eb98a63173458061cf2b7f1306512b34 Mon Sep 17 00:00:00 2001 From: Claudius Kienle Date: Sat, 17 Feb 2024 10:20:24 +0100 Subject: [PATCH 3/5] included pr review --- d3rlpy/algos/qlearning/base.py | 16 +++++++++------- d3rlpy/constants.py | 5 +++++ d3rlpy/logging/wandb_adapter.py | 8 +++++++- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 33be4203..f8c18f84 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -18,7 +18,7 @@ from typing_extensions import Self from ...base import ImplBase, LearnableBase, LearnableConfig, save_config -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace, LoggingStrategyEnum from ...dataset import ( ReplayBuffer, TransitionMiniBatch, @@ -361,7 +361,7 @@ def fit( experiment_name: Optional[str] = None, with_timestamp: bool = True, logging_steps: int = 500, - logging_strategy: str = "epoch", + logging_strategy: LoggingStrategyEnum = LoggingStrategyEnum.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, save_interval: int = 1, @@ -385,6 +385,8 @@ def fit( the directory name will be `{class name}_{timestamp}`. with_timestamp: Flag to add timestamp string to the last of directory name. + logging_steps: number of steps to log metrics. + logging_strategy: what logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. save_interval: Interval to save parameters. @@ -425,7 +427,7 @@ def fitter( n_steps: int, n_steps_per_epoch: int = 10000, logging_steps: int = 500, - logging_strategy: str = "epoch", + logging_strategy: LoggingStrategyEnum = LoggingStrategyEnum.EPOCH, experiment_name: Optional[str] = None, with_timestamp: bool = True, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), @@ -454,6 +456,8 @@ def fitter( the directory name will be `{class name}_{timestamp}`. with_timestamp: Flag to add timestamp string to the last of directory name. + logging_steps: number of steps to log metrics. + logging_strategy: what logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. save_interval: Interval to save parameters. @@ -473,8 +477,6 @@ def fitter( # check action space assert_action_space_with_dataset(self, dataset.dataset_info) - assert logging_strategy in ['steps', 'epoch'], 'Logging strategy invalid' - # initialize scalers build_scalers_with_transition_picker(self, dataset) @@ -548,7 +550,7 @@ def fitter( total_step += 1 - if logging_strategy == 'steps' and total_step % logging_steps == 0: + if logging_strategy == LoggingStrategyEnum.STEPS and total_step % logging_steps == 0: metrics = logger.commit(epoch, total_step) # call callback if given @@ -565,7 +567,7 @@ def fitter( logger.add_metric(name, test_score) # save metrics - if logging_strategy == 'epoch': + if logging_strategy == LoggingStrategyEnum.EPOCH: metrics = logger.commit(epoch, total_step) # save model parameters diff --git a/d3rlpy/constants.py b/d3rlpy/constants.py index 9b233760..54d3adfb 100644 --- a/d3rlpy/constants.py +++ b/d3rlpy/constants.py @@ -47,3 +47,8 @@ class ActionSpace(Enum): class PositionEncodingType(Enum): SIMPLE = "simple" GLOBAL = "global" + + +class LoggingStrategyEnum(Enum): + STEPS = "steps" + EPOCH = "epoch" \ No newline at end of file diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py index 69c03944..373986eb 100644 --- a/d3rlpy/logging/wandb_adapter.py +++ b/d3rlpy/logging/wandb_adapter.py @@ -1,8 +1,10 @@ -import wandb from typing import Any, Dict, Optional from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol +__all__ = ["LoggerWanDBAdapter", "WanDBAdapterFactory"] + + class LoggerWanDBAdapter(LoggerAdapter): r"""WandB Logger Adapter class. @@ -13,6 +15,10 @@ class LoggerWanDBAdapter(LoggerAdapter): """ def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None): + try: + import wandb + except ImportError as e: + raise ImportError("Please install wandb") from e self.run = wandb.init(project=project, name=experiment_name) def write_params(self, params: Dict[str, Any]) -> None: From 5547014ca817585eceb5ba72b1cb51b35131d721 Mon Sep 17 00:00:00 2001 From: Claudius Kienle Date: Sat, 17 Feb 2024 10:42:16 +0100 Subject: [PATCH 4/5] formatted code --- d3rlpy/algos/qlearning/base.py | 11 +++++++++-- d3rlpy/constants.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index f8c18f84..b5887d13 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -18,7 +18,11 @@ from typing_extensions import Self from ...base import ImplBase, LearnableBase, LearnableConfig, save_config -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace, LoggingStrategyEnum +from ...constants import ( + IMPL_NOT_INITIALIZED_ERROR, + ActionSpace, + LoggingStrategyEnum, +) from ...dataset import ( ReplayBuffer, TransitionMiniBatch, @@ -550,7 +554,10 @@ def fitter( total_step += 1 - if logging_strategy == LoggingStrategyEnum.STEPS and total_step % logging_steps == 0: + if ( + logging_strategy == LoggingStrategyEnum.STEPS + and total_step % logging_steps == 0 + ): metrics = logger.commit(epoch, total_step) # call callback if given diff --git a/d3rlpy/constants.py b/d3rlpy/constants.py index 54d3adfb..02394877 100644 --- a/d3rlpy/constants.py +++ b/d3rlpy/constants.py @@ -51,4 +51,4 @@ class PositionEncodingType(Enum): class LoggingStrategyEnum(Enum): STEPS = "steps" - EPOCH = "epoch" \ No newline at end of file + EPOCH = "epoch" From 74ffe203e56ba9e291b6df0cf5d55a3c29948ded Mon Sep 17 00:00:00 2001 From: Claudius Kienle Date: Sat, 17 Feb 2024 17:37:53 +0100 Subject: [PATCH 5/5] updated logging strategy name --- d3rlpy/algos/qlearning/base.py | 10 +++++----- d3rlpy/constants.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index b5887d13..ba92f3d3 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -21,7 +21,7 @@ from ...constants import ( IMPL_NOT_INITIALIZED_ERROR, ActionSpace, - LoggingStrategyEnum, + LoggingStrategy, ) from ...dataset import ( ReplayBuffer, @@ -365,7 +365,7 @@ def fit( experiment_name: Optional[str] = None, with_timestamp: bool = True, logging_steps: int = 500, - logging_strategy: LoggingStrategyEnum = LoggingStrategyEnum.EPOCH, + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, save_interval: int = 1, @@ -431,7 +431,7 @@ def fitter( n_steps: int, n_steps_per_epoch: int = 10000, logging_steps: int = 500, - logging_strategy: LoggingStrategyEnum = LoggingStrategyEnum.EPOCH, + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, experiment_name: Optional[str] = None, with_timestamp: bool = True, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), @@ -555,7 +555,7 @@ def fitter( total_step += 1 if ( - logging_strategy == LoggingStrategyEnum.STEPS + logging_strategy == LoggingStrategy.STEPS and total_step % logging_steps == 0 ): metrics = logger.commit(epoch, total_step) @@ -574,7 +574,7 @@ def fitter( logger.add_metric(name, test_score) # save metrics - if logging_strategy == LoggingStrategyEnum.EPOCH: + if logging_strategy == LoggingStrategy.EPOCH: metrics = logger.commit(epoch, total_step) # save model parameters diff --git a/d3rlpy/constants.py b/d3rlpy/constants.py index 02394877..6e63c2a6 100644 --- a/d3rlpy/constants.py +++ b/d3rlpy/constants.py @@ -49,6 +49,6 @@ class PositionEncodingType(Enum): GLOBAL = "global" -class LoggingStrategyEnum(Enum): +class LoggingStrategy(Enum): STEPS = "steps" EPOCH = "epoch"