diff --git a/chanfig/utils.py b/chanfig/utils.py index 6f1f86ba..8ae72d1b 100644 --- a/chanfig/utils.py +++ b/chanfig/utils.py @@ -17,6 +17,7 @@ from __future__ import annotations +import os import sys from argparse import ArgumentTypeError from collections.abc import Callable, Mapping, Sequence @@ -31,6 +32,8 @@ import typing_extensions from typing_extensions import get_args, get_origin from yaml import SafeDumper, SafeLoader +from yaml.constructor import ConstructorError +from yaml.nodes import ScalarNode, SequenceNode try: # python 3.10+ from types import UnionType # type: ignore[attr-defined] # pylint: disable=C0412 @@ -293,18 +296,43 @@ def increase_indent(self, flow: bool = False, indentless: bool = False): # pyli return super().increase_indent(flow, indentless) -class YamlLoader(SafeLoader): # pylint: disable=R0901,R0903 +class YamlLoader(SafeLoader): r""" YAML Loader for Config. """ - -try: - from yamlinclude import YamlIncludeConstructor - - YamlIncludeConstructor.add_to_loader_class(loader_class=YamlLoader, relative=True) -except ImportError: - pass + def __init__(self, stream): + super().__init__(stream) + self._root = os.path.abspath(os.path.dirname(stream.name)) if hasattr(stream, "name") else os.getcwd() + self.add_constructor("!include", self._include) + self.add_constructor("!includes", self._includes) + self.add_constructor("!env", self._env) + + @staticmethod + def _include(loader: YamlLoader, node): + relative_path = loader.construct_scalar(node) + include_path = os.path.join(loader._root, relative_path) + + if not os.path.exists(include_path): + raise FileNotFoundError(f"Included file not found: {include_path}") + from .functional import load + + return load(include_path) + + @staticmethod + def _includes(loader: YamlLoader, node): + if not isinstance(node, SequenceNode): + raise ConstructorError(None, None, f"!includes tag expects a sequence, got {node.id}", node.start_mark) + files = loader.construct_sequence(node) + return [YamlLoader._include(loader, ScalarNode("tag:yaml.org,2002:str", file)) for file in files] + + @staticmethod + def _env(loader: YamlLoader, node): + env_var = loader.construct_scalar(node) + value = os.getenv(env_var) + if value is None: + raise ValueError(f"Environment variable '{env_var}' not set.") + return value Null = NULL() diff --git a/pyproject.toml b/pyproject.toml index 80e477c0..878bc197 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,6 @@ dependencies = [ "pyyaml", "typing-extensions", ] -optional-dependencies.include = [ - "pyyaml-include", -] urls.documentation = "https://chanfig.danling.org" urls.homepage = "https://chanfig.danling.org" urls.repository = "https://github.com/ZhiyuanChen/CHANfiG" diff --git a/tests/child.yaml b/tests/child.yaml new file mode 100644 index 00000000..a5b5967e --- /dev/null +++ b/tests/child.yaml @@ -0,0 +1,3 @@ +optim: + name: "adam" + lr: 0.01 \ No newline at end of file diff --git a/tests/model.yaml b/tests/model.yaml new file mode 100644 index 00000000..2136fdcc --- /dev/null +++ b/tests/model.yaml @@ -0,0 +1,2 @@ +num_channels: 512 +multiple: 256 \ No newline at end of file diff --git a/tests/parent.yaml b/tests/parent.yaml new file mode 100644 index 00000000..6c36c74e --- /dev/null +++ b/tests/parent.yaml @@ -0,0 +1,2 @@ +model: !include model.yaml +port: 80 \ No newline at end of file diff --git a/tests/test_misc.py b/tests/test_misc.py index 9b83f009..e3c2f7d4 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -36,3 +36,9 @@ def test_interpolate_eval(self): assert config.data.imagenet.data_dirs[2] == "localhost:80/X-C" assert config.model.num_heads == config.model.num_channels // 64 assert config.model.num_hidden_size == config.model.num_channels // 64 * config.model.multiple + + def test_include(self): + config = chanfig.load("tests/parent.yaml") + model = chanfig.load("tests/model.yaml") + assert config.model == model + assert config.port == 80