Skip to content

Commit

Permalink
add include
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Dec 13, 2024
1 parent 4c664f6 commit 9d40b9b
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 11 deletions.
44 changes: 36 additions & 8 deletions chanfig/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import os
import sys
from argparse import ArgumentTypeError
from collections.abc import Callable, Mapping, Sequence
Expand All @@ -31,6 +32,8 @@
import typing_extensions
from typing_extensions import get_args, get_origin
from yaml import SafeDumper, SafeLoader
from yaml.constructor import ConstructorError
from yaml.nodes import ScalarNode, SequenceNode

try: # python 3.10+
from types import UnionType # type: ignore[attr-defined] # pylint: disable=C0412
Expand Down Expand Up @@ -293,18 +296,43 @@ def increase_indent(self, flow: bool = False, indentless: bool = False): # pyli
return super().increase_indent(flow, indentless)


class YamlLoader(SafeLoader): # pylint: disable=R0901,R0903
class YamlLoader(SafeLoader):
r"""
YAML Loader for Config.
"""


try:
from yamlinclude import YamlIncludeConstructor

YamlIncludeConstructor.add_to_loader_class(loader_class=YamlLoader, relative=True)
except ImportError:
pass
def __init__(self, stream):
super().__init__(stream)
self._root = os.path.abspath(os.path.dirname(stream.name)) if hasattr(stream, "name") else os.getcwd()
self.add_constructor("!include", self._include)
self.add_constructor("!includes", self._includes)
self.add_constructor("!env", self._env)

@staticmethod
def _include(loader: YamlLoader, node):
relative_path = loader.construct_scalar(node)
include_path = os.path.join(loader._root, relative_path)

if not os.path.exists(include_path):
raise FileNotFoundError(f"Included file not found: {include_path}")
from .functional import load

return load(include_path)

@staticmethod
def _includes(loader: YamlLoader, node):
if not isinstance(node, SequenceNode):
raise ConstructorError(None, None, f"!includes tag expects a sequence, got {node.id}", node.start_mark)
files = loader.construct_sequence(node)
return [YamlLoader._include(loader, ScalarNode("tag:yaml.org,2002:str", file)) for file in files]

@staticmethod
def _env(loader: YamlLoader, node):
env_var = loader.construct_scalar(node)
value = os.getenv(env_var)
if value is None:
raise ValueError(f"Environment variable '{env_var}' not set.")
return value


Null = NULL()
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ dependencies = [
"pyyaml",
"typing-extensions",
]
optional-dependencies.include = [
"pyyaml-include",
]
urls.documentation = "https://chanfig.danling.org"
urls.homepage = "https://chanfig.danling.org"
urls.repository = "https://github.com/ZhiyuanChen/CHANfiG"
Expand Down
3 changes: 3 additions & 0 deletions tests/child.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
optim:
name: "adam"
lr: 0.01
2 changes: 2 additions & 0 deletions tests/model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
num_channels: 512
multiple: 256
2 changes: 2 additions & 0 deletions tests/parent.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model: !include model.yaml
port: 80
6 changes: 6 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ def test_interpolate_eval(self):
assert config.data.imagenet.data_dirs[2] == "localhost:80/X-C"
assert config.model.num_heads == config.model.num_channels // 64
assert config.model.num_hidden_size == config.model.num_channels // 64 * config.model.multiple

def test_include(self):
config = chanfig.load("tests/parent.yaml")
model = chanfig.load("tests/model.yaml")
assert config.model == model
assert config.port == 80

0 comments on commit 9d40b9b

Please sign in to comment.