From fdb7f436ebff69784fda01a87cd6d1d5c44fc78d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 Jul 2023 18:30:11 +0200 Subject: [PATCH] Allow custom loggers without an experiment property (#18093) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos MocholĂ­ (cherry picked from commit 281d6a27d1cef6d1ebe3a860a1f362a25bbe8ef7) --- src/lightning/pytorch/CHANGELOG.md | 3 +++ src/lightning/pytorch/trainer/call.py | 3 ++- tests/tests_pytorch/loggers/test_all.py | 27 +++++++++++++++++++++---- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5b74600570a43..76c9e6324e189 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278)) +- Fixed an issue that prevented the use of custom logger classes without an `experiment` property defined ([#18093](https://github.com/Lightning-AI/lightning/pull/18093)) + + ## [2.0.7] - 2023-08-14 ### Added diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 9aee1ea67a1e9..97c737744486e 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -76,7 +76,8 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None: # Trigger lazy creation of experiment in loggers so loggers have their metadata available for logger in trainer.loggers: - _ = logger.experiment + if hasattr(logger, "experiment"): + _ = logger.experiment trainer.strategy.barrier("pre_setup") diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 1c38811296cd9..c8fb8e9662535 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -30,7 +30,7 @@ TensorBoardLogger, WandbLogger, ) -from lightning.pytorch.loggers.logger import DummyExperiment +from lightning.pytorch.loggers.logger import DummyExperiment, Logger from lightning.pytorch.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.pytorch.tuner.tuning import Tuner from tests_pytorch.helpers.runif import RunIf @@ -239,7 +239,7 @@ def setup(self, trainer, pl_module, stage=None): assert trainer.logger._mlflow_client elif isinstance(trainer.logger, NeptuneLogger): assert trainer.logger._run_instance - else: + elif hasattr(trainer.logger, "_experiment"): assert trainer.logger._experiment @@ -253,7 +253,23 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): assert pl_module.logger.experiment.something(foo="bar") is None -@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES_WO_NEPTUNE) +class CustomLoggerWithoutExperiment(Logger): + @property + def name(self): + return "" + + @property + def version(self): + return None + + def log_metrics(self, metrics, step=None) -> None: + pass + + def log_hyperparams(self, params, *args, **kwargs) -> None: + pass + + +@pytest.mark.parametrize("logger_class", [*ALL_LOGGER_CLASSES_WO_NEPTUNE, CustomLoggerWithoutExperiment]) @RunIf(skip_windows=True) def test_logger_initialization(tmpdir, monkeypatch, logger_class): """Test that loggers get replaced by dummy loggers on global rank > 0 and that the experiment object is available @@ -268,6 +284,9 @@ def test_logger_initialization(tmpdir, monkeypatch, logger_class): def _test_logger_initialization(tmpdir, logger_class): logger_args = _get_logger_args(logger_class, tmpdir) logger = logger_class(**logger_args) + callbacks = [LazyInitExperimentCheck()] + if not isinstance(logger, CustomLoggerWithoutExperiment): + callbacks.append(RankZeroLoggerCheck()) model = BoringModel() trainer = Trainer( logger=logger, @@ -276,7 +295,7 @@ def _test_logger_initialization(tmpdir, logger_class): accelerator="cpu", devices=2, max_steps=1, - callbacks=[RankZeroLoggerCheck(), LazyInitExperimentCheck()], + callbacks=callbacks, ) trainer.fit(model)