diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index e6de38b0..ba92f3d3 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -18,7 +18,11 @@ from typing_extensions import Self from ...base import ImplBase, LearnableBase, LearnableConfig, save_config -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ( + IMPL_NOT_INITIALIZED_ERROR, + ActionSpace, + LoggingStrategy, +) from ...dataset import ( ReplayBuffer, TransitionMiniBatch, @@ -360,6 +364,8 @@ def fit( n_steps_per_epoch: int = 10000, experiment_name: Optional[str] = None, with_timestamp: bool = True, + logging_steps: int = 500, + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, save_interval: int = 1, @@ -383,6 +389,8 @@ def fit( the directory name will be `{class name}_{timestamp}`. with_timestamp: Flag to add timestamp string to the last of directory name. + logging_steps: number of steps to log metrics. + logging_strategy: what logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. save_interval: Interval to save parameters. @@ -404,6 +412,8 @@ def fit( n_steps_per_epoch=n_steps_per_epoch, experiment_name=experiment_name, with_timestamp=with_timestamp, + logging_steps=logging_steps, + logging_strategy=logging_strategy, logger_adapter=logger_adapter, show_progress=show_progress, save_interval=save_interval, @@ -420,6 +430,8 @@ def fitter( dataset: ReplayBuffer, n_steps: int, n_steps_per_epoch: int = 10000, + logging_steps: int = 500, + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, experiment_name: Optional[str] = None, with_timestamp: bool = True, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), @@ -448,6 +460,8 @@ def fitter( the directory name will be `{class name}_{timestamp}`. with_timestamp: Flag to add timestamp string to the last of directory name. + logging_steps: number of steps to log metrics. + logging_strategy: what logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. save_interval: Interval to save parameters. @@ -540,6 +554,12 @@ def fitter( total_step += 1 + if ( + logging_strategy == LoggingStrategy.STEPS + and total_step % logging_steps == 0 + ): + metrics = logger.commit(epoch, total_step) + # call callback if given if callback: callback(self, epoch, total_step) @@ -554,7 +574,8 @@ def fitter( logger.add_metric(name, test_score) # save metrics - metrics = logger.commit(epoch, total_step) + if logging_strategy == LoggingStrategy.EPOCH: + metrics = logger.commit(epoch, total_step) # save model parameters if epoch % save_interval == 0: diff --git a/d3rlpy/constants.py b/d3rlpy/constants.py index 9b233760..6e63c2a6 100644 --- a/d3rlpy/constants.py +++ b/d3rlpy/constants.py @@ -47,3 +47,8 @@ class ActionSpace(Enum): class PositionEncodingType(Enum): SIMPLE = "simple" GLOBAL = "global" + + +class LoggingStrategy(Enum): + STEPS = "steps" + EPOCH = "epoch" diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py new file mode 100644 index 00000000..373986eb --- /dev/null +++ b/d3rlpy/logging/wandb_adapter.py @@ -0,0 +1,79 @@ +from typing import Any, Dict, Optional +from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol + + +__all__ = ["LoggerWanDBAdapter", "WanDBAdapterFactory"] + + +class LoggerWanDBAdapter(LoggerAdapter): + r"""WandB Logger Adapter class. + + This class logs data to Weights & Biases (WandB) for experiment tracking. + + Args: + experiment_name (str): Name of the experiment. + """ + + def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None): + try: + import wandb + except ImportError as e: + raise ImportError("Please install wandb") from e + self.run = wandb.init(project=project, name=experiment_name) + + def write_params(self, params: Dict[str, Any]) -> None: + """Writes hyperparameters to WandB config.""" + self.run.config.update(params) + + def before_write_metric(self, epoch: int, step: int) -> None: + """Callback executed before writing metric.""" + pass + + def write_metric(self, epoch: int, step: int, name: str, value: float) -> None: + """Writes metric to WandB.""" + self.run.log({name: value, 'epoch': epoch}, step=step) + + def after_write_metric(self, epoch: int, step: int) -> None: + """Callback executed after writing metric.""" + pass + + def save_model(self, epoch: int, algo: SaveProtocol) -> None: + """Saves models to Weights & Biases. Not implemented for WandB.""" + # Implement saving model to wandb if needed + pass + + def close(self) -> None: + """Closes the logger and finishes the WandB run.""" + self.run.finish() + + +class WanDBAdapterFactory(LoggerAdapterFactory): + r"""WandB Logger Adapter Factory class. + + This class creates instances of the WandB Logger Adapter for experiment tracking. + + """ + + _project: str + + def __init__(self, project: Optional[str] = None) -> None: + """Initialize the WandB Logger Adapter Factory. + + Args: + project (Optional[str], optional): The name of the WandB project. Defaults to None. + + """ + super().__init__() + self._project = project + + def create(self, experiment_name: str) -> LoggerAdapter: + """Creates a WandB Logger Adapter instance. + + Args: + experiment_name (str): Name of the experiment. + + Returns: + LoggerAdapter: Instance of the WandB Logger Adapter. + + """ + return LoggerWanDBAdapter(project=self._project, experiment_name=experiment_name)