Skip to content

Commit

Permalink
Support pathlib.Path values for dvcyaml argument in Live(). (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdullahMakhdoom authored Sep 19, 2024
1 parent 6888462 commit f69d198
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
25 changes: 17 additions & 8 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
resume: bool = False,
report: Literal["md", "notebook", "html", None] = None,
save_dvc_exp: bool = True,
dvcyaml: Optional[str] = "dvc.yaml",
dvcyaml: Union[str, os.PathLike, bool, None] = "dvc.yaml",
cache_images: bool = False,
exp_name: Optional[str] = None,
exp_message: Optional[str] = None,
Expand All @@ -104,11 +104,11 @@ def __init__(
part of `Live.end()`. Defaults to `True`. If you are using DVCLive
inside a DVC Pipeline and running with `dvc exp run`, the option will be
ignored.
dvcyaml (str | None): where to write dvc.yaml file, which adds DVC
dvcyaml (str | Path | None): where to write dvc.yaml file, which adds DVC
configuration for metrics, plots, and parameters as part of
`Live.next_step()` and `Live.end()`. If `None`, no dvc.yaml file is
written. Defaults to `"dvc.yaml"`. See `Live.make_dvcyaml()`.
If a string like `"subdir/dvc.yaml"`, DVCLive will write the
If a string or Path like `"subdir/dvc.yaml"`, DVCLive will write the
configuration to that path (file must be named "dvc.yaml").
If `False`, DVCLive will not write to "dvc.yaml" (useful if you are
tracking DVCLive metrics, plots, and parameters independently and
Expand Down Expand Up @@ -265,11 +265,19 @@ def _init_dvc(self): # noqa: C901
self._include_untracked.append(self.dir)

def _init_dvc_file(self) -> str:
if isinstance(self._dvcyaml, str):
if os.path.basename(self._dvcyaml) == "dvc.yaml":
return self._dvcyaml
raise InvalidDvcyamlError
return "dvc.yaml"
if self._dvcyaml is None:
return "dvc.yaml"
if isinstance(self._dvcyaml, bool):
return "dvc.yaml"

self._dvcyaml = os.fspath(self._dvcyaml)
if (
isinstance(self._dvcyaml, str)
and os.path.basename(self._dvcyaml) == "dvc.yaml"
):
return self._dvcyaml

raise InvalidDvcyamlError

def _init_dvc_pipeline(self):
if os.getenv(env.DVC_EXP_BASELINE_REV, None):
Expand Down Expand Up @@ -334,6 +342,7 @@ def _init_test(self):
"""
with tempfile.TemporaryDirectory() as dirpath:
self._dir = os.path.join(dirpath, self._dir)
self._dvcyaml = os.fspath(self._dvcyaml)
if isinstance(self._dvcyaml, str):
self._dvc_file = os.path.join(dirpath, self._dvcyaml)
self._save_dvc_exp = False
Expand Down
3 changes: 2 additions & 1 deletion tests/test_make_dvcyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from PIL import Image
from pathlib import Path

from dvclive import Live
from dvclive.dvc import make_dvcyaml
Expand Down Expand Up @@ -423,7 +424,7 @@ def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyam

@pytest.mark.parametrize(
"dvcyaml",
[True, False, "dvc.yaml"],
[True, False, "dvc.yaml", Path("dvc.yaml")],
)
def test_make_dvcyaml(tmp_dir, mocked_dvc_repo, dvcyaml):
dvclive = Live("logs", dvcyaml=dvcyaml)
Expand Down

0 comments on commit f69d198

Please sign in to comment.