Skip to content

Commit

Permalink
added wandb logger and extended when to log (#379)
Browse files Browse the repository at this point in the history
* added wandb logger and extended when to log

* add docs

* included pr review

* formatted code

* updated logging strategy name

---------

Co-authored-by: Claudius Kienle <[email protected]>
  • Loading branch information
claudius-kienle and Claudius Kienle authored Feb 18, 2024
1 parent 06dd7a7 commit 3d46d71
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 2 deletions.
25 changes: 23 additions & 2 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from typing_extensions import Self

from ...base import ImplBase, LearnableBase, LearnableConfig, save_config
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...constants import (
IMPL_NOT_INITIALIZED_ERROR,
ActionSpace,
LoggingStrategy,
)
from ...dataset import (
ReplayBuffer,
TransitionMiniBatch,
Expand Down Expand Up @@ -360,6 +364,8 @@ def fit(
n_steps_per_epoch: int = 10000,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
Expand All @@ -383,6 +389,8 @@ def fit(
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: number of steps to log metrics.
logging_strategy: what logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
Expand All @@ -404,6 +412,8 @@ def fit(
n_steps_per_epoch=n_steps_per_epoch,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logging_steps=logging_steps,
logging_strategy=logging_strategy,
logger_adapter=logger_adapter,
show_progress=show_progress,
save_interval=save_interval,
Expand All @@ -420,6 +430,8 @@ def fitter(
dataset: ReplayBuffer,
n_steps: int,
n_steps_per_epoch: int = 10000,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
Expand Down Expand Up @@ -448,6 +460,8 @@ def fitter(
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: number of steps to log metrics.
logging_strategy: what logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
Expand Down Expand Up @@ -540,6 +554,12 @@ def fitter(

total_step += 1

if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
):
metrics = logger.commit(epoch, total_step)

# call callback if given
if callback:
callback(self, epoch, total_step)
Expand All @@ -554,7 +574,8 @@ def fitter(
logger.add_metric(name, test_score)

# save metrics
metrics = logger.commit(epoch, total_step)
if logging_strategy == LoggingStrategy.EPOCH:
metrics = logger.commit(epoch, total_step)

# save model parameters
if epoch % save_interval == 0:
Expand Down
5 changes: 5 additions & 0 deletions d3rlpy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ class ActionSpace(Enum):
class PositionEncodingType(Enum):
SIMPLE = "simple"
GLOBAL = "global"


class LoggingStrategy(Enum):
STEPS = "steps"
EPOCH = "epoch"
79 changes: 79 additions & 0 deletions d3rlpy/logging/wandb_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, Dict, Optional
from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol


__all__ = ["LoggerWanDBAdapter", "WanDBAdapterFactory"]


class LoggerWanDBAdapter(LoggerAdapter):
r"""WandB Logger Adapter class.
This class logs data to Weights & Biases (WandB) for experiment tracking.
Args:
experiment_name (str): Name of the experiment.
"""

def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None):
try:
import wandb
except ImportError as e:
raise ImportError("Please install wandb") from e
self.run = wandb.init(project=project, name=experiment_name)

def write_params(self, params: Dict[str, Any]) -> None:
"""Writes hyperparameters to WandB config."""
self.run.config.update(params)

def before_write_metric(self, epoch: int, step: int) -> None:
"""Callback executed before writing metric."""
pass

def write_metric(self, epoch: int, step: int, name: str, value: float) -> None:
"""Writes metric to WandB."""
self.run.log({name: value, 'epoch': epoch}, step=step)

def after_write_metric(self, epoch: int, step: int) -> None:
"""Callback executed after writing metric."""
pass

def save_model(self, epoch: int, algo: SaveProtocol) -> None:
"""Saves models to Weights & Biases. Not implemented for WandB."""
# Implement saving model to wandb if needed
pass

def close(self) -> None:
"""Closes the logger and finishes the WandB run."""
self.run.finish()


class WanDBAdapterFactory(LoggerAdapterFactory):
r"""WandB Logger Adapter Factory class.
This class creates instances of the WandB Logger Adapter for experiment tracking.
"""

_project: str

def __init__(self, project: Optional[str] = None) -> None:
"""Initialize the WandB Logger Adapter Factory.
Args:
project (Optional[str], optional): The name of the WandB project. Defaults to None.
"""
super().__init__()
self._project = project

def create(self, experiment_name: str) -> LoggerAdapter:
"""Creates a WandB Logger Adapter instance.
Args:
experiment_name (str): Name of the experiment.
Returns:
LoggerAdapter: Instance of the WandB Logger Adapter.
"""
return LoggerWanDBAdapter(project=self._project, experiment_name=experiment_name)

0 comments on commit 3d46d71

Please sign in to comment.