Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trigger creation of lazy logger experiment in Trainer setup #17818

Merged
merged 12 commits into from
Jun 26, 2023
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed deriving default map location in `LightningModule.load_from_checkpoint` when there is extra state ([#17812](https://github.com/Lightning-AI/lightning/pull/17812))


- Fixed delayed creation of experiment metadata and checkpoint/log dir name when using `WandbLogger` ([#17818](https://github.com/Lightning-AI/lightning/pull/17818))


## [2.0.3] - 2023-06-07

### Changed
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
assert trainer.state.fn is not None
fn = trainer.state.fn

# Trigger lazy creation of experiment in loggers so loggers have their metadata available
for logger in trainer.loggers:
_ = logger.experiment
carmocca marked this conversation as resolved.
Show resolved Hide resolved

trainer.strategy.barrier("pre_setup")

if trainer.datamodule is not None:
Expand Down
30 changes: 19 additions & 11 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
WandbLogger,
)
ALL_LOGGER_CLASSES_WO_NEPTUNE = tuple(filter(lambda cls: cls is not NeptuneLogger, ALL_LOGGER_CLASSES))
ALL_LOGGER_CLASSES_WO_NEPTUNE_WANDB = tuple(filter(lambda cls: cls is not WandbLogger, ALL_LOGGER_CLASSES_WO_NEPTUNE))


def _get_logger_args(logger_class, save_dir):
Expand Down Expand Up @@ -224,10 +223,19 @@ def __init__(self, lr=0.1, batch_size=1):
assert logger2 == logger3, "Finder altered the logger of model"


class RankZeroLoggerCheck(Callback):
# this class has to be defined outside the test function, otherwise we get pickle error
# due to the way ddp process is launched
class LazyInitExperimentCheck(Callback):
def setup(self, trainer, pl_module, stage=None):
if trainer.global_rank > 0:
return
if isinstance(trainer.logger, MLFlowLogger):
assert trainer.logger._mlflow_client
elif isinstance(trainer.logger, NeptuneLogger):
assert trainer.logger._run_instance
else:
assert trainer.logger._experiment


class RankZeroLoggerCheck(Callback):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
is_dummy = isinstance(trainer.logger.experiment, DummyExperiment)
if trainer.is_global_zero:
Expand All @@ -237,18 +245,19 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
assert pl_module.logger.experiment.something(foo="bar") is None


@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES_WO_NEPTUNE_WANDB)
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES_WO_NEPTUNE)
@RunIf(skip_windows=True)
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
def test_logger_initialization(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0 and that the experiment object is
available at the right time in Trainer."""
_patch_comet_atexit(monkeypatch)
try:
_test_logger_created_on_rank_zero_only(tmpdir, logger_class)
_test_logger_initialization(tmpdir, logger_class)
except (ImportError, ModuleNotFoundError):
pytest.xfail(f"multi-process test requires {logger_class.__class__} dependencies to be installed.")


def _test_logger_created_on_rank_zero_only(tmpdir, logger_class):
def _test_logger_initialization(tmpdir, logger_class):
logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
model = BoringModel()
Expand All @@ -259,10 +268,9 @@ def _test_logger_created_on_rank_zero_only(tmpdir, logger_class):
accelerator="cpu",
devices=2,
max_steps=1,
callbacks=[RankZeroLoggerCheck()],
callbacks=[RankZeroLoggerCheck(), LazyInitExperimentCheck()],
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"


def test_logger_with_prefix_all(tmpdir, monkeypatch):
Expand Down