Skip to content

Commit

Permalink
log: add histogram metrics for gradients (#424)
Browse files Browse the repository at this point in the history
* initial implementation for gradient histogram

* updates

* more updates

* try simplify the implementation

* mypy :sad:

* file adapter: add suffix for grad

* fix name conflict

* update tests

* ignore errors
  • Loading branch information
hasan-yaman authored Oct 15, 2024
1 parent f386f6a commit 3d51ee7
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 14 deletions.
17 changes: 16 additions & 1 deletion d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def fit(
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
gradient_logging_steps: Optional[int] = None,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
Expand All @@ -403,6 +404,7 @@ def fit(
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
gradient_logging_steps: Number of steps to log gradients.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
Expand All @@ -425,6 +427,7 @@ def fit(
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logging_steps=logging_steps,
gradient_logging_steps=gradient_logging_steps,
logging_strategy=logging_strategy,
logger_adapter=logger_adapter,
show_progress=show_progress,
Expand All @@ -442,6 +445,7 @@ def fitter(
n_steps: int,
n_steps_per_epoch: int = 10000,
logging_steps: int = 500,
gradient_logging_steps: Optional[int] = None,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
Expand Down Expand Up @@ -471,7 +475,8 @@ def fitter(
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if loggig_strategy is EPOCH.
if logging_strategy is EPOCH.
gradient_logging_steps: Number of steps to log gradients.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
Expand Down Expand Up @@ -520,6 +525,8 @@ def fitter(
# save hyperparameters
save_config(self, logger)

logger.watch_model(0, 0, gradient_logging_steps, self) # type: ignore

# training loop
n_epochs = n_steps // n_steps_per_epoch
total_step = 0
Expand Down Expand Up @@ -559,6 +566,8 @@ def fitter(

total_step += 1

logger.watch_model(epoch, total_step, gradient_logging_steps, self) # type: ignore

if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
Expand Down Expand Up @@ -608,6 +617,7 @@ def fit_online(
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
gradient_logging_steps: Optional[int] = None,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
Expand Down Expand Up @@ -636,6 +646,7 @@ def fit_online(
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
gradient_logging_steps: Number of steps to log gradients.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
Expand Down Expand Up @@ -673,6 +684,8 @@ def fit_online(
# save hyperparameters
save_config(self, logger)

logger.watch_model(0, 0, gradient_logging_steps, self) # type: ignore

# switch based on show_progress flag
xrange = trange if show_progress else range

Expand Down Expand Up @@ -741,6 +754,8 @@ def fit_online(
for name, val in loss.items():
logger.add_metric(name, val)

logger.watch_model(epoch, total_step, gradient_logging_steps, self) # type: ignore

if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
Expand Down
29 changes: 27 additions & 2 deletions d3rlpy/logging/file_adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import json
import os
from enum import Enum, IntEnum
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np

from .logger import LOG, LoggerAdapter, LoggerAdapterFactory, SaveProtocol
from .logger import (
LOG,
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["FileAdapter", "FileAdapterFactory"]

Expand Down Expand Up @@ -76,6 +82,25 @@ def close(self) -> None:
def logdir(self) -> str:
return self._logdir

def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
if logging_steps is not None and step % logging_steps == 0:
for name, grad in algo.impl.modules.get_gradients():
path = os.path.join(self._logdir, f"{name}_grad.csv")
with open(path, "a") as f:
min_grad = grad.min()
max_grad = grad.max()
mean_grad = grad.mean()
print(
f"{epoch},{step},{name},{min_grad},{max_grad},{mean_grad}",
file=f,
)


class FileAdapterFactory(LoggerAdapterFactory):
r"""FileAdapterFactory class.
Expand Down
43 changes: 42 additions & 1 deletion d3rlpy/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import Any, DefaultDict, Dict, Iterator, List
from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Tuple

import structlog
from torch import nn
from typing_extensions import Protocol

from ..types import Float32NDArray

__all__ = [
"LOG",
"set_log_context",
Expand Down Expand Up @@ -39,6 +42,19 @@ class SaveProtocol(Protocol):
def save(self, fname: str) -> None: ...


class ModuleProtocol(Protocol):
def get_torch_modules(self) -> Dict[str, nn.Module]: ...
def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]: ...


class ImplProtocol(Protocol):
modules: ModuleProtocol


class TorchModuleProtocol(Protocol):
impl: ImplProtocol


class LoggerAdapter(Protocol):
r"""Interface of LoggerAdapter."""

Expand Down Expand Up @@ -88,6 +104,22 @@ def save_model(self, epoch: int, algo: SaveProtocol) -> None:
def close(self) -> None:
r"""Closes this LoggerAdapter."""

def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
r"""Watch model parameters / gradients during training.
Args:
epoch: Epoch.
step: Training step.
logging_steps: Training step.
algo: Algorithm.
"""


class LoggerAdapterFactory(Protocol):
r"""Interface of LoggerAdapterFactory."""
Expand Down Expand Up @@ -171,3 +203,12 @@ def measure_time(self, name: str) -> Iterator[None]:
@property
def adapter(self) -> LoggerAdapter:
return self._adapter

def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
self._adapter.watch_model(epoch, step, logging_steps, algo)
18 changes: 16 additions & 2 deletions d3rlpy/logging/noop_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol
from .logger import (
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["NoopAdapter", "NoopAdapterFactory"]

Expand Down Expand Up @@ -32,6 +37,15 @@ def save_model(self, epoch: int, algo: SaveProtocol) -> None:
def close(self) -> None:
pass

def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
pass


class NoopAdapterFactory(LoggerAdapterFactory):
r"""NoopAdapterFactory class.
Expand Down
22 changes: 20 additions & 2 deletions d3rlpy/logging/tensorboard_adapter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np

from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol
from .logger import (
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["TensorboardAdapter", "TensorboardAdapterFactory"]

Expand Down Expand Up @@ -64,6 +69,19 @@ def save_model(self, epoch: int, algo: SaveProtocol) -> None:
def close(self) -> None:
self._writer.close()

def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
if logging_steps is not None and step % logging_steps == 0:
for name, grad in algo.impl.modules.get_gradients():
self._writer.add_histogram(
f"histograms/{name}_grad", grad, epoch
)


class TensorboardAdapterFactory(LoggerAdapterFactory):
r"""TensorboardAdapterFactory class.
Expand Down
19 changes: 17 additions & 2 deletions d3rlpy/logging/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any, Dict, Sequence
from typing import Any, Dict, Optional, Sequence

from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol
from .logger import (
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["CombineAdapter", "CombineAdapterFactory"]

Expand Down Expand Up @@ -44,6 +49,16 @@ def close(self) -> None:
for adapter in self._adapters:
adapter.close()

def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
for adapter in self._adapters:
adapter.watch_model(epoch, step, logging_steps, algo)


class CombineAdapterFactory(LoggerAdapterFactory):
r"""CombineAdapterFactory class.
Expand Down
23 changes: 22 additions & 1 deletion d3rlpy/logging/wandb_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any, Dict, Optional

from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol
from .logger import (
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["WanDBAdapter", "WanDBAdapterFactory"]

Expand All @@ -24,6 +29,7 @@ def __init__(
except ImportError as e:
raise ImportError("Please install wandb") from e
self.run = wandb.init(project=project, name=experiment_name)
self._is_model_watched = False

def write_params(self, params: Dict[str, Any]) -> None:
"""Writes hyperparameters to WandB config."""
Expand Down Expand Up @@ -52,6 +58,21 @@ def close(self) -> None:
"""Closes the logger and finishes the WandB run."""
self.run.finish()

def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
if not self._is_model_watched:
self.run.watch(
tuple(algo.impl.modules.get_torch_modules().values()),
log="gradients",
log_freq=logging_steps,
)
self._is_model_watched = True


class WanDBAdapterFactory(LoggerAdapterFactory):
r"""WandB Logger Adapter Factory class.
Expand Down
15 changes: 15 additions & 0 deletions d3rlpy/torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
Any,
BinaryIO,
Dict,
Iterator,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
overload,
Expand Down Expand Up @@ -388,6 +390,19 @@ def reset_optimizer_states(self) -> None:
if isinstance(v, torch.optim.Optimizer):
v.state = collections.defaultdict(dict)

def get_torch_modules(self) -> Dict[str, nn.Module]:
torch_modules: Dict[str, nn.Module] = {}
for k, v in asdict_without_copy(self).items():
if isinstance(v, nn.Module):
torch_modules[k] = v
return torch_modules

def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]:
for module_name, module in self.get_torch_modules().items():
for name, parameter in module.named_parameters():
if parameter.requires_grad and parameter.grad is not None:
yield f"{module_name}.{name}", parameter.grad.cpu().detach().numpy()


TCallable = TypeVar("TCallable")

Expand Down
Loading

0 comments on commit 3d51ee7

Please sign in to comment.