Skip to content

Commit

Permalink
Add logging_strategy option to fit_online
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 18, 2024
1 parent 48058ed commit cc15df8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
19 changes: 16 additions & 3 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,9 @@ def fitter(
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: number of steps to log metrics.
logging_strategy: what logging strategy to use.
logging_steps: Number of steps to log metrics. This will be ignored
if loggig_strategy is EPOCH.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
Expand Down Expand Up @@ -601,6 +602,8 @@ def fit_online(
save_interval: int = 1,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
callback: Optional[Callable[[Self, int, int], None]] = None,
Expand All @@ -623,6 +626,9 @@ def fit_online(
the directory name will be ``{class name}_online_{timestamp}``.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
callback: Callable function that takes ``(algo, epoch, total_step)``
Expand Down Expand Up @@ -726,6 +732,12 @@ def fit_online(
for name, val in loss.items():
logger.add_metric(name, val)

if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
):
logger.commit(epoch, total_step)

# call callback if given
if callback:
callback(self, epoch, total_step)
Expand All @@ -742,7 +754,8 @@ def fit_online(
logger.save_model(total_step, self)

# save metrics
logger.commit(epoch, total_step)
if logging_strategy == LoggingStrategy.EPOCH:
logger.commit(epoch, total_step)

# clip the last episode
buffer.clip_episode(False)
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/logging/wandb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def create(self, experiment_name: str) -> LoggerAdapter:
experiment_name (str): Name of the experiment.
Returns:
LoggerAdapter: Instance of the WandB Logger Adapter.
Instance of the WandB Logger Adapter.
"""
return WanDBAdapter(
experiment_name=experiment_name,
Expand Down

0 comments on commit cc15df8

Please sign in to comment.