Skip to content

Commit

Permalink
fix attributes shadows value in naming conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Jan 18, 2023
1 parent 4149e3e commit c5bea2c
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 74 deletions.
50 changes: 46 additions & 4 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def frozen_check(func: Callable):
@wraps(func)
def decorator(self, *args, **kwargs):
if self.getattr("frozen", False):
raise ValueError("Attempting to alter a frozen config. Run config.defrost() to defrost first")
raise ValueError("Attempting to alter a frozen config. Run config.defrost() to defrost first.")
return func(self, *args, **kwargs)

return decorator
Expand All @@ -123,6 +123,46 @@ class Config(NestedDict):
accessing anything that does not exist will create a new empty Config sub-attribute.
It is recommended to call `config.freeze()` or `config.to(NestedDict)` to avoid this behavior.
Attributes
----------
parser: ConfigParser = ConfigParser()
Parser for command line arguments.
frozen: bool = False
If `True`, the config is frozen and cannot be altered.
Examples
--------
```python
>>> c = Config(**{"f.n": "chang"})
>>> c.i.d = 1013
>>> c.i.d
1013
>>> c.d.i
Config()
>>> c.freeze()
Config(
(f): Config(
(n): 'chang'
)
(i): Config(
(d): 1013
)
(d): Config(
(i): Config()
)
)
>>> c.d.i = 1013
Traceback (most recent call last):
ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first.
>>> c.d.e
Traceback (most recent call last):
KeyError: 'Config does not contain e'
>>> with c.unlocked():
... del c.d
>>> c.dict()
{'f': {'n': 'chang'}, 'i': {'d': 1013}}
"""

parser: ConfigParser
Expand All @@ -131,7 +171,9 @@ class Config(NestedDict):
def __init__(self, *args, **kwargs):
if not self.hasattr("default_mapping"):
self.setattr("default_mapping", Config)
super().__init__(*args, default_factory=Config, **kwargs)
if "default_factory" not in kwargs:
kwargs["default_factory"] = Config
super().__init__(*args, **kwargs)
self.setattr("parser", ConfigParser())

def get(self, name: str, default: Optional[Any] = None) -> Any:
Expand Down Expand Up @@ -234,7 +276,7 @@ def set(
{'i': {'d': 1013}}
>>> c['i.d'] = 1013
Traceback (most recent call last):
ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first
ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first.
>>> c.defrost().dict()
{'i': {'d': 1013}}
>>> c['i.d'] = 1013
Expand Down Expand Up @@ -316,7 +358,7 @@ def pop(self, name: str, default: Optional[Any] = None) -> Any:
{'i': {}}
>>> c['i.d'] = 1013
Traceback (most recent call last):
ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first
ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first.
>>> c.defrost().dict()
{'i': {}}
>>> c['i.d'] = 1013
Expand Down
114 changes: 55 additions & 59 deletions chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from os import PathLike
from os.path import splitext
from typing import IO, Any, Callable, Iterable, Optional, Union
from warnings import warn

from yaml import dump as yaml_dump
from yaml import load as yaml_load
Expand Down Expand Up @@ -51,6 +52,19 @@ class FlatDict(OrderedDict):
`FlatDict` works best with `Variable` objects.
Note that since `FlatDict` supports attribute-style access to keys.
Therefore, all internal attributes should be set and get through `FlatDict.setattr` and `FlatDict.getattr`.
`__class__`, `__dict__`, and `getattr` are reserved and cannot be override in any manner.
Although it is possible to override other internal methods, it is not recommended.
Attributes
----------
indent: int
Indentation level in printing and dumping to json or yaml.
default_factory: Optional[Callable]
Default factory for defaultdict behavior.
Examples
--------
```python
Expand Down Expand Up @@ -93,7 +107,7 @@ def __init__(self, *args, default_factory: Optional[Callable] = None, **kwargs)
self.setattr("default_factory", default_factory)
else:
raise TypeError(
f"default_factory={default_factory} should be of type Callable, but got {type(default_factory)}"
f"default_factory={default_factory} should be of type Callable, but got {type(default_factory)}."
)
self._init(*args, **kwargs)

Expand All @@ -114,6 +128,11 @@ def _init(self, *args, **kwargs) -> None:
for key, value in kwargs.items():
self.set(key, value)

def __getattribute__(self, name):
if name not in ("__class__", "__dict__", "getattr") and name in self:
return self[name]
return super().__getattribute__(name)

def get(self, name: str, default: Optional[Any] = None) -> Any:
r"""
Get value from FlatDict.
Expand Down Expand Up @@ -155,7 +174,7 @@ def get(self, name: str, default: Optional[Any] = None) -> Any:
2
>>> d.get('f')
Traceback (most recent call last):
KeyError: 'FlatDict does not contain f'
KeyError: 'FlatDict does not contain f.'
```
"""
Expand Down Expand Up @@ -233,11 +252,11 @@ def delete(self, name: str) -> None:
>>> d.delete('d')
>>> d.d
Traceback (most recent call last):
KeyError: 'FlatDict does not contain d'
KeyError: 'FlatDict does not contain d.'
>>> del d.n
>>> d.n
Traceback (most recent call last):
KeyError: 'FlatDict does not contain n'
KeyError: 'FlatDict does not contain n.'
>>> del d.f
Traceback (most recent call last):
KeyError: 'f'
Expand Down Expand Up @@ -276,7 +295,7 @@ def getattr(self, name: str, default: Optional[Any] = None) -> Any:
2
>>> d.getattr('a')
Traceback (most recent call last):
AttributeError: FlatDict has no attribute a
AttributeError: FlatDict has no attribute a.
```
"""
Expand All @@ -290,7 +309,7 @@ def getattr(self, name: str, default: Optional[Any] = None) -> Any:
except AttributeError:
if default is not None:
return default
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") from None
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}.") from None

def setattr(self, name: str, value: Any) -> None:
r"""
Expand All @@ -303,17 +322,36 @@ def setattr(self, name: str, value: Any) -> None:
name: str
value: Any
Warns
------
RuntimeWarning
If name already exists in FlatDict.
Examples
--------
```python
>>> d = FlatDict()
>>> d.setattr('attr', 'value')
>>> d.getattr('attr')
'value'
>>> d.set('d', 1013)
>>> d.setattr('d', 1031) # Trigger RuntimeWarning: d already exists in FlatDict.
>>> d.get('d')
1013
>>> d.d
1013
>>> d.getattr('d')
1031
```
"""

if name in self:
warn(
f"{name} already exists in {self.__class__.__name__}.\n"
"Users must call `{self.__class__.__name__}.getattr()` to retrieve conflicting attribute value.",
RuntimeWarning,
)
self.__dict__[name] = value

def hasattr(self, name: str) -> bool:
Expand Down Expand Up @@ -365,45 +403,18 @@ def delattr(self, name: str) -> None:
>>> d.delattr('name')
>>> d.getattr('name')
Traceback (most recent call last):
AttributeError: FlatDict has no attribute name
AttributeError: FlatDict has no attribute name.
```
"""

del self.__dict__[name]

def __missing__(self, name: str, default: Optional[Any] = None) -> Any:
r"""
Allow dict to have default value if it doesn't exist.
Parameters
----------
name: str
default: Optional[Any] = None
Returns
-------
value: Any
If name does not exist, return `default`.
Examples
--------
```python
>>> d = FlatDict(default_factory=list)
>>> d.n
[]
>>> d.get('d', 1013)
1013
>>> d.__missing__('d', 1013)
1013
```
"""

if default is None:
# default_factory might not in __dict__ and cannot be replaced with if self.getattr("default_factory")
if "default_factory" not in self.__dict__:
raise KeyError(f"{self.__class__.__name__} does not contain {name}")
raise KeyError(f"{self.__class__.__name__} does not contain {name}.")
default_factory = self.getattr("default_factory")
default = default_factory()
if isinstance(default, FlatDict):
Expand Down Expand Up @@ -488,7 +499,7 @@ def difference(self, other: Union[Mapping, Iterable, PathStr]) -> FlatDict:
{'d': 4}
>>> d.difference(1)
Traceback (most recent call last):
TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got <class 'int'>
TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got <class 'int'>.
```
"""
Expand All @@ -498,7 +509,7 @@ def difference(self, other: Union[Mapping, Iterable, PathStr]) -> FlatDict:
if isinstance(other, (Mapping,)):
other = other.items()
if not isinstance(other, Iterable):
raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}")
raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}.")

return self.empty_like(
**{key: value for key, value in other if key not in self or self[key] != value} # type: ignore
Expand Down Expand Up @@ -534,7 +545,7 @@ def intersection(self, other: Union[Mapping, Iterable, PathStr]) -> FlatDict:
{'c': 3}
>>> d.intersection(1)
Traceback (most recent call last):
TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got <class 'int'>
TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got <class 'int'>.
```
"""
Expand All @@ -544,7 +555,7 @@ def intersection(self, other: Union[Mapping, Iterable, PathStr]) -> FlatDict:
if isinstance(other, (Mapping,)):
other = other.items()
if not isinstance(other, Iterable):
raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}")
raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}.")
return self.empty_like(
**{key: value for key, value in other if key in self and self[key] == value} # type: ignore
)
Expand Down Expand Up @@ -715,7 +726,7 @@ def to(self, cls: Union[str, TorchDevice, TorchDtype]) -> FlatDict:
self[k] = v.to(cls)
return self

raise TypeError(f"to() only support torch.dtype and torch.device, but got {cls}")
raise TypeError(f"to() only support torch.dtype and torch.device, but got {cls}.")

def cpu(self) -> FlatDict:
r"""
Expand Down Expand Up @@ -1014,7 +1025,7 @@ def dump(self, file: File, method: Optional[str] = None, *args, **kwargs) -> Non

if method is None:
if isinstance(file, IO):
raise ValueError("method must be specified when dumping to file-like object")
raise ValueError("method must be specified when dumping to file-like object.")
method = splitext(file)[-1][1:] # type: ignore
extension = method.lower() # type: ignore
if extension in YAML:
Expand Down Expand Up @@ -1044,14 +1055,14 @@ def load(cls, file: File, method: Optional[str] = None, *args, **kwargs) -> Flat

if method is None:
if isinstance(file, IO):
raise ValueError("method must be specified when loading from file-like object")
raise ValueError("method must be specified when loading from file-like object.")
method = splitext(file)[-1][1:] # type: ignore
extension = method.lower() # type: ignore
if extension in JSON:
return cls.from_json(file, *args, **kwargs)
if extension in YAML:
return cls.from_yaml(file, *args, **kwargs)
raise FileError("file {file} should be in {JSON} or {YAML}, but got {extension}")
raise FileError("file {file} should be in {JSON} or {YAML}, but got {extension}.")

@staticmethod
@contextmanager
Expand All @@ -1069,28 +1080,13 @@ def open(file: File, *args, **kwargs):
elif isinstance(file, (IO,)):
yield file
else:
raise TypeError(
f"file={file} should be of type (str, os.PathLike) or (io.IOBase), but got {type(file)}" # type: ignore
)
raise TypeError(f"file={file!r} should be of type (str, os.PathLike) or (io.IOBase), but got {type(file)}.")

@staticmethod
def extra_repr() -> str: # pylint: disable=C0116
return ""

def __repr__(self) -> str:
r"""
Representation of FlatDict.
Examples
--------
```python
>>> d = FlatDict(a=1, b=2, c=3)
>>> repr(d)
'FlatDict(\n (a): 1\n (b): 2\n (c): 3\n)'
```
"""

extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
Expand Down
Loading

0 comments on commit c5bea2c

Please sign in to comment.