Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added wandb logger and extended when to log #379

Merged
merged 5 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from typing_extensions import Self

from ...base import ImplBase, LearnableBase, LearnableConfig, save_config
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...constants import (
IMPL_NOT_INITIALIZED_ERROR,
ActionSpace,
LoggingStrategy,
)
from ...dataset import (
ReplayBuffer,
TransitionMiniBatch,
Expand Down Expand Up @@ -360,6 +364,8 @@
n_steps_per_epoch: int = 10000,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
Expand All @@ -383,6 +389,8 @@
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: number of steps to log metrics.
logging_strategy: what logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
Expand All @@ -404,6 +412,8 @@
n_steps_per_epoch=n_steps_per_epoch,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logging_steps=logging_steps,
logging_strategy=logging_strategy,
logger_adapter=logger_adapter,
show_progress=show_progress,
save_interval=save_interval,
Expand All @@ -420,6 +430,8 @@
dataset: ReplayBuffer,
n_steps: int,
n_steps_per_epoch: int = 10000,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
Expand Down Expand Up @@ -448,6 +460,8 @@
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: number of steps to log metrics.
logging_strategy: what logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
Expand Down Expand Up @@ -540,6 +554,12 @@

total_step += 1

if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
):
metrics = logger.commit(epoch, total_step)

Check warning on line 561 in d3rlpy/algos/qlearning/base.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/base.py#L561

Added line #L561 was not covered by tests

# call callback if given
if callback:
callback(self, epoch, total_step)
Expand All @@ -554,7 +574,8 @@
logger.add_metric(name, test_score)

# save metrics
metrics = logger.commit(epoch, total_step)
if logging_strategy == LoggingStrategy.EPOCH:
metrics = logger.commit(epoch, total_step)

# save model parameters
if epoch % save_interval == 0:
Expand Down
5 changes: 5 additions & 0 deletions d3rlpy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ class ActionSpace(Enum):
class PositionEncodingType(Enum):
SIMPLE = "simple"
GLOBAL = "global"


class LoggingStrategy(Enum):
STEPS = "steps"
EPOCH = "epoch"
79 changes: 79 additions & 0 deletions d3rlpy/logging/wandb_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, Dict, Optional
from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol

Check warning on line 2 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L1-L2

Added lines #L1 - L2 were not covered by tests

claudius-kienle marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["LoggerWanDBAdapter", "WanDBAdapterFactory"]

Check warning on line 5 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L5

Added line #L5 was not covered by tests


class LoggerWanDBAdapter(LoggerAdapter):

Check warning on line 8 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L8

Added line #L8 was not covered by tests
r"""WandB Logger Adapter class.

This class logs data to Weights & Biases (WandB) for experiment tracking.

Args:
experiment_name (str): Name of the experiment.
"""

def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None):
try:
import wandb
except ImportError as e:
raise ImportError("Please install wandb") from e
self.run = wandb.init(project=project, name=experiment_name)

Check warning on line 22 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L17-L22

Added lines #L17 - L22 were not covered by tests
claudius-kienle marked this conversation as resolved.
Show resolved Hide resolved

def write_params(self, params: Dict[str, Any]) -> None:

Check warning on line 24 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L24

Added line #L24 was not covered by tests
"""Writes hyperparameters to WandB config."""
self.run.config.update(params)

Check warning on line 26 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L26

Added line #L26 was not covered by tests

def before_write_metric(self, epoch: int, step: int) -> None:

Check warning on line 28 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L28

Added line #L28 was not covered by tests
"""Callback executed before writing metric."""
pass

Check warning on line 30 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L30

Added line #L30 was not covered by tests

def write_metric(self, epoch: int, step: int, name: str, value: float) -> None:

Check warning on line 32 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L32

Added line #L32 was not covered by tests
"""Writes metric to WandB."""
self.run.log({name: value, 'epoch': epoch}, step=step)

Check warning on line 34 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L34

Added line #L34 was not covered by tests

def after_write_metric(self, epoch: int, step: int) -> None:

Check warning on line 36 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L36

Added line #L36 was not covered by tests
"""Callback executed after writing metric."""
pass

Check warning on line 38 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L38

Added line #L38 was not covered by tests

def save_model(self, epoch: int, algo: SaveProtocol) -> None:

Check warning on line 40 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L40

Added line #L40 was not covered by tests
"""Saves models to Weights & Biases. Not implemented for WandB."""
# Implement saving model to wandb if needed
pass

Check warning on line 43 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L43

Added line #L43 was not covered by tests

def close(self) -> None:

Check warning on line 45 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L45

Added line #L45 was not covered by tests
"""Closes the logger and finishes the WandB run."""
self.run.finish()

Check warning on line 47 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L47

Added line #L47 was not covered by tests


class WanDBAdapterFactory(LoggerAdapterFactory):

Check warning on line 50 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L50

Added line #L50 was not covered by tests
r"""WandB Logger Adapter Factory class.

This class creates instances of the WandB Logger Adapter for experiment tracking.

"""

_project: str

Check warning on line 57 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L57

Added line #L57 was not covered by tests

def __init__(self, project: Optional[str] = None) -> None:

Check warning on line 59 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L59

Added line #L59 was not covered by tests
"""Initialize the WandB Logger Adapter Factory.

Args:
project (Optional[str], optional): The name of the WandB project. Defaults to None.

"""
super().__init__()
self._project = project

Check warning on line 67 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L66-L67

Added lines #L66 - L67 were not covered by tests

def create(self, experiment_name: str) -> LoggerAdapter:

Check warning on line 69 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L69

Added line #L69 was not covered by tests
"""Creates a WandB Logger Adapter instance.

Args:
experiment_name (str): Name of the experiment.

Returns:
LoggerAdapter: Instance of the WandB Logger Adapter.

"""
return LoggerWanDBAdapter(project=self._project, experiment_name=experiment_name)

Check warning on line 79 in d3rlpy/logging/wandb_adapter.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/logging/wandb_adapter.py#L79

Added line #L79 was not covered by tests
Loading