diff --git a/src/dvclive/live.py b/src/dvclive/live.py index c0b4aa8..754137f 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -82,7 +82,7 @@ def __init__( resume: bool = False, report: Literal["md", "notebook", "html", None] = None, save_dvc_exp: bool = True, - dvcyaml: Optional[str] = "dvc.yaml", + dvcyaml: Union[str, os.PathLike, bool, None] = "dvc.yaml", cache_images: bool = False, exp_name: Optional[str] = None, exp_message: Optional[str] = None, @@ -104,11 +104,11 @@ def __init__( part of `Live.end()`. Defaults to `True`. If you are using DVCLive inside a DVC Pipeline and running with `dvc exp run`, the option will be ignored. - dvcyaml (str | None): where to write dvc.yaml file, which adds DVC + dvcyaml (str | Path | None): where to write dvc.yaml file, which adds DVC configuration for metrics, plots, and parameters as part of `Live.next_step()` and `Live.end()`. If `None`, no dvc.yaml file is written. Defaults to `"dvc.yaml"`. See `Live.make_dvcyaml()`. - If a string like `"subdir/dvc.yaml"`, DVCLive will write the + If a string or Path like `"subdir/dvc.yaml"`, DVCLive will write the configuration to that path (file must be named "dvc.yaml"). If `False`, DVCLive will not write to "dvc.yaml" (useful if you are tracking DVCLive metrics, plots, and parameters independently and @@ -265,11 +265,19 @@ def _init_dvc(self): # noqa: C901 self._include_untracked.append(self.dir) def _init_dvc_file(self) -> str: - if isinstance(self._dvcyaml, str): - if os.path.basename(self._dvcyaml) == "dvc.yaml": - return self._dvcyaml - raise InvalidDvcyamlError - return "dvc.yaml" + if self._dvcyaml is None: + return "dvc.yaml" + if isinstance(self._dvcyaml, bool): + return "dvc.yaml" + + self._dvcyaml = os.fspath(self._dvcyaml) + if ( + isinstance(self._dvcyaml, str) + and os.path.basename(self._dvcyaml) == "dvc.yaml" + ): + return self._dvcyaml + + raise InvalidDvcyamlError def _init_dvc_pipeline(self): if os.getenv(env.DVC_EXP_BASELINE_REV, None): @@ -334,6 +342,7 @@ def _init_test(self): """ with tempfile.TemporaryDirectory() as dirpath: self._dir = os.path.join(dirpath, self._dir) + self._dvcyaml = os.fspath(self._dvcyaml) if isinstance(self._dvcyaml, str): self._dvc_file = os.path.join(dirpath, self._dvcyaml) self._save_dvc_exp = False diff --git a/tests/test_make_dvcyaml.py b/tests/test_make_dvcyaml.py index 7f5da8b..1639785 100644 --- a/tests/test_make_dvcyaml.py +++ b/tests/test_make_dvcyaml.py @@ -2,6 +2,7 @@ import pytest from PIL import Image +from pathlib import Path from dvclive import Live from dvclive.dvc import make_dvcyaml @@ -423,7 +424,7 @@ def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyam @pytest.mark.parametrize( "dvcyaml", - [True, False, "dvc.yaml"], + [True, False, "dvc.yaml", Path("dvc.yaml")], ) def test_make_dvcyaml(tmp_dir, mocked_dvc_repo, dvcyaml): dvclive = Live("logs", dvcyaml=dvcyaml)