Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added wandb logger and extended when to log #379

Merged
merged 5 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from typing_extensions import Self

from ...base import ImplBase, LearnableBase, LearnableConfig, save_config
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...constants import (
IMPL_NOT_INITIALIZED_ERROR,
ActionSpace,
LoggingStrategyEnum,
)
from ...dataset import (
ReplayBuffer,
TransitionMiniBatch,
Expand Down Expand Up @@ -360,6 +364,8 @@ def fit(
n_steps_per_epoch: int = 10000,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategyEnum = LoggingStrategyEnum.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
Expand All @@ -383,6 +389,8 @@ def fit(
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: number of steps to log metrics.
logging_strategy: what logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
Expand All @@ -404,6 +412,8 @@ def fit(
n_steps_per_epoch=n_steps_per_epoch,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logging_steps=logging_steps,
logging_strategy=logging_strategy,
logger_adapter=logger_adapter,
show_progress=show_progress,
save_interval=save_interval,
Expand All @@ -420,6 +430,8 @@ def fitter(
dataset: ReplayBuffer,
n_steps: int,
n_steps_per_epoch: int = 10000,
logging_steps: int = 500,
logging_strategy: LoggingStrategyEnum = LoggingStrategyEnum.EPOCH,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
Expand Down Expand Up @@ -448,6 +460,8 @@ def fitter(
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: number of steps to log metrics.
logging_strategy: what logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
Expand Down Expand Up @@ -540,6 +554,12 @@ def fitter(

total_step += 1

if (
logging_strategy == LoggingStrategyEnum.STEPS
and total_step % logging_steps == 0
):
metrics = logger.commit(epoch, total_step)

# call callback if given
if callback:
callback(self, epoch, total_step)
Expand All @@ -554,7 +574,8 @@ def fitter(
logger.add_metric(name, test_score)

# save metrics
metrics = logger.commit(epoch, total_step)
if logging_strategy == LoggingStrategyEnum.EPOCH:
metrics = logger.commit(epoch, total_step)

# save model parameters
if epoch % save_interval == 0:
Expand Down
5 changes: 5 additions & 0 deletions d3rlpy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ class ActionSpace(Enum):
class PositionEncodingType(Enum):
SIMPLE = "simple"
GLOBAL = "global"


class LoggingStrategyEnum(Enum):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename this to LoggingStrategy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course. I just changed it.

STEPS = "steps"
EPOCH = "epoch"
79 changes: 79 additions & 0 deletions d3rlpy/logging/wandb_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, Dict, Optional
from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol

claudius-kienle marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["LoggerWanDBAdapter", "WanDBAdapterFactory"]


class LoggerWanDBAdapter(LoggerAdapter):
r"""WandB Logger Adapter class.

This class logs data to Weights & Biases (WandB) for experiment tracking.

Args:
experiment_name (str): Name of the experiment.
"""

def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None):
try:
import wandb
except ImportError as e:
raise ImportError("Please install wandb") from e
self.run = wandb.init(project=project, name=experiment_name)
claudius-kienle marked this conversation as resolved.
Show resolved Hide resolved

def write_params(self, params: Dict[str, Any]) -> None:
"""Writes hyperparameters to WandB config."""
self.run.config.update(params)

def before_write_metric(self, epoch: int, step: int) -> None:
"""Callback executed before writing metric."""
pass

def write_metric(self, epoch: int, step: int, name: str, value: float) -> None:
"""Writes metric to WandB."""
self.run.log({name: value, 'epoch': epoch}, step=step)

def after_write_metric(self, epoch: int, step: int) -> None:
"""Callback executed after writing metric."""
pass

def save_model(self, epoch: int, algo: SaveProtocol) -> None:
"""Saves models to Weights & Biases. Not implemented for WandB."""
# Implement saving model to wandb if needed
pass

def close(self) -> None:
"""Closes the logger and finishes the WandB run."""
self.run.finish()


class WanDBAdapterFactory(LoggerAdapterFactory):
r"""WandB Logger Adapter Factory class.

This class creates instances of the WandB Logger Adapter for experiment tracking.

"""

_project: str

def __init__(self, project: Optional[str] = None) -> None:
"""Initialize the WandB Logger Adapter Factory.

Args:
project (Optional[str], optional): The name of the WandB project. Defaults to None.

"""
super().__init__()
self._project = project

def create(self, experiment_name: str) -> LoggerAdapter:
"""Creates a WandB Logger Adapter instance.

Args:
experiment_name (str): Name of the experiment.

Returns:
LoggerAdapter: Instance of the WandB Logger Adapter.

"""
return LoggerWanDBAdapter(project=self._project, experiment_name=experiment_name)
Loading