Skip to content

Commit

Permalink
add StrEnum to yaml representer
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Aug 26, 2024
1 parent 2733540 commit 122c02b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 13 deletions.
38 changes: 25 additions & 13 deletions chanfig/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,21 @@

from __future__ import annotations

from yaml import add_representer
from yaml import add_multi_representer
from yaml.representer import SafeRepresenter

try:
from enum import StrEnum
except ImportError:
StrEnum = None

try:
from strenum import LowercaseStrEnum
from strenum import StrEnum as UppercaseStrEnum
except ImportError:
UppercaseStrEnum = None
LowercaseStrEnum = None

from . import utils
from ._version import __version__, __version_tuple__, version
from .config import Config
Expand Down Expand Up @@ -55,15 +67,15 @@
]


add_representer(FlatDict, SafeRepresenter.represent_dict)
add_representer(NestedDict, SafeRepresenter.represent_dict)
add_representer(DefaultDict, SafeRepresenter.represent_dict)
add_representer(Config, SafeRepresenter.represent_dict)
add_representer(Registry, SafeRepresenter.represent_dict)
add_representer(ConfigRegistry, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(FlatDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(NestedDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(DefaultDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(Config, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(Registry, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(ConfigRegistry, SafeRepresenter.represent_dict)
add_multi_representer(FlatDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_multi_representer(FlatDict, SafeRepresenter.represent_dict)

if StrEnum:
add_multi_representer(StrEnum, SafeRepresenter.represent_str)
SafeRepresenter.add_multi_representer(StrEnum, SafeRepresenter.represent_str)
if UppercaseStrEnum:
add_multi_representer(UppercaseStrEnum, SafeRepresenter.represent_str)
SafeRepresenter.add_multi_representer(UppercaseStrEnum, SafeRepresenter.represent_str)
if LowercaseStrEnum:
add_multi_representer(LowercaseStrEnum, SafeRepresenter.represent_str)
SafeRepresenter.add_multi_representer(LowercaseStrEnum, SafeRepresenter.represent_str)
45 changes: 45 additions & 0 deletions tests/test_dump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# CHANfiG, Easier Configuration.
# Copyright (c) 2022-Present, CHANfiG Contributors

# This program is free software: you can redistribute it and/or modify
# it under the terms of the following licenses:
# - The Unlicense
# - GNU Affero General Public License v3.0 or later
# - GNU General Public License v2.0 or later
# - BSD 4-Clause "Original" or "Old" License
# - MIT License
# - Apache License 2.0

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the LICENSE file for more details.


from chanfig import NestedDict
from enum import auto

try:
from enum import StrEnum
except ImportError:
from strenum import LowercaseStrEnum as StrEnum # type: ignore[no-redef]


class Task(StrEnum):
__test__ = False
Regression = auto()
Binary = auto()
MultiClass = auto()
MultiLabel = auto()


class TaskConfig(NestedDict):
__test__ = False
task: Task = "regression"


class TestDump:
def test_yamls(self):
config = TaskConfig()
s = config.yamls()
assert s == "task: regression\n"

0 comments on commit 122c02b

Please sign in to comment.