Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Nov 9, 2023
1 parent ac11994 commit 54a283e
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 37 deletions.
30 changes: 14 additions & 16 deletions chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
TORCH_AVAILABLE = False


def to_dict(obj: Any) -> Mapping[str, Any]: # pylint: disable=R0911
def to_dict(obj: Any) -> Mapping | list | set | tuple: # pylint: disable=R0911
r"""
Convert an object to a dict.
Expand Down Expand Up @@ -97,14 +97,14 @@ def to_dict(obj: Any) -> Mapping[str, Any]: # pylint: disable=R0911
if isinstance(obj, Mapping):
return {k: to_dict(v) for k, v in obj.items()}
if isinstance(obj, list):
return [to_dict(v) for v in obj] # type: ignore
return [to_dict(v) for v in obj]
if isinstance(obj, tuple):
return tuple(to_dict(v) for v in obj) # type: ignore
return tuple(to_dict(v) for v in obj)
if isinstance(obj, set):
try:
return {to_dict(v) for v in obj} # type: ignore
return {to_dict(v) for v in obj}
except TypeError:
return tuple(to_dict(v) for v in obj) # type: ignore
return tuple(to_dict(v) for v in obj)
if isinstance(obj, Variable):
return obj.value
return obj
Expand Down Expand Up @@ -849,9 +849,7 @@ def difference(self, other: Mapping | Iterable | PathStr) -> Self:
other = self.empty(other).items()
if not isinstance(other, Iterable):
raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.")
return self.empty(
**{key: value for key, value in other if key not in self or self[key] != value} # type: ignore
)
return self.empty(**{key: value for key, value in other if key not in self or self[key] != value}) # type: ignore

def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self:
r"""
Expand Down Expand Up @@ -1028,7 +1026,7 @@ def clone(self, memo: Mapping | None = None) -> Self:
"""
return self.deepcopy(memo=memo)

def save(self, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=W1113
def save(self, file: File, method: str = None, *args: Any, **kwargs: Any) -> None: # type: ignore # pylint: disable=W1113
r"""
Save `FlatDict` to file.
Expand All @@ -1055,22 +1053,22 @@ def save(self, file: File, method: str | None = None, *args: Any, **kwargs: Any)
if method is None:
if isinstance(file, (IOBase, IO)):
raise ValueError("`method` must be specified when saving to IO.")
method = splitext(file)[-1][1:] # type: ignore
extension = method.lower() # type: ignore
method = splitext(file)[-1][1:]
extension = method.lower()
if extension in YAML:
return self.yaml(file=file, *args, **kwargs) # type: ignore
if extension in JSON:
return self.json(file=file, *args, **kwargs) # type: ignore
raise TypeError(f"`file={file!r}` should be in {JSON} or {YAML}, but got {extension}.")

def dump(self, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=W1113
def dump(self, file: File, method: str = None, *args: Any, **kwargs: Any) -> None: # type: ignore # pylint: disable=W1113
r"""
Alias of [`save`][chanfig.FlatDict.save].
"""
return self.save(file, method, *args, **kwargs)

@classmethod
def load(cls, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> Self: # pylint: disable=W1113
def load(cls, file: File, method: str = None, *args: Any, **kwargs: Any) -> Self: # type: ignore # pylint: disable=W1113
"""
Load `FlatDict` from file.
Expand Down Expand Up @@ -1101,8 +1099,8 @@ def load(cls, file: File, method: str | None = None, *args: Any, **kwargs: Any)
if method is None:
if isinstance(file, (IOBase, IO)):
raise ValueError("`method` must be specified when loading from IO.")
method = splitext(file)[-1][1:] # type: ignore
extension = method.lower() # type: ignore
method = splitext(file)[-1][1:]
extension = method.lower()
if extension in JSON:
return cls.from_json(file, *args, **kwargs)
if extension in YAML:
Expand Down Expand Up @@ -1233,7 +1231,7 @@ def yamls(self, *args: Any, **kwargs: Any) -> str:

kwargs.setdefault("Dumper", YamlDumper)
kwargs.setdefault("indent", self.getattr("indent", 2))
return yaml_dump(self.dict(), *args, **kwargs) # type: ignore
return yaml_dump(self.dict(), *args, **kwargs)

@classmethod
def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> Self:
Expand Down
17 changes: 10 additions & 7 deletions chanfig/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .utils import JSON, YAML, File, PathStr


def save(obj, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=W1113
def save(obj, file: File, method: str = None, *args: Any, **kwargs: Any) -> None: # type: ignore # pylint: disable=W1113
r"""
Save `FlatDict` to file.
Expand Down Expand Up @@ -61,7 +61,7 @@ def save(obj, file: File, method: str | None = None, *args: Any, **kwargs: Any)
if isinstance(file, IOBase):
raise ValueError("`method` must be specified when saving to IO.")
method = splitext(file)[-1][1:] # type: ignore
extension = method.lower() # type: ignore
extension = method.lower()
if extension in YAML:
with FlatDict.open(file, mode="w") as fp: # pylint: disable=C0103
yaml_dump(data, fp, *args, **kwargs)
Expand All @@ -70,17 +70,20 @@ def save(obj, file: File, method: str | None = None, *args: Any, **kwargs: Any)
with FlatDict.open(file, mode="w") as fp: # pylint: disable=C0103
fp.write(json_dumps(data, *args, **kwargs))
return
raise TypeError(f"`file={file!r}` should be in {JSON} or {YAML}, but got {extension}.") # type: ignore
raise TypeError(f"`file={file!r}` should be in {JSON} or {YAML}, but got {extension}.")


def load(file: PathStr, cls: type = NestedDict, *args: Any, **kwargs: Any) -> NestedDict: # pylint: disable=W1113
def load( # pylint: disable=W1113
file: PathStr, cls: type[FlatDict] = NestedDict, *args: Any, **kwargs: Any
) -> FlatDict:
r"""
Load a file into a `NestedDict`.
Load a file into a `FlatDict`.
This function simply calls `NestedDict.load`.
This function simply calls `cls.load`, by default, `cls` is `NestedDict`.
Args:
file: The file to load.
cls: The class of the file to load. Defaults to `NestedDict`.
*args: The arguments to pass to `NestedDict.load`.
**kwargs: The keyword arguments to pass to `NestedDict.load`.
Expand All @@ -98,4 +101,4 @@ def load(file: PathStr, cls: type = NestedDict, *args: Any, **kwargs: Any) -> Ne
)
"""

return cls.load(file, *args, **kwargs) # type: ignore
return cls.load(file, *args, **kwargs)
8 changes: 4 additions & 4 deletions chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def sort(self, key: Callable | None = None, reverse: bool = False, recursive: bo
for value in self.values():
if isinstance(value, FlatDict):
value.sort(key=key, reverse=reverse)
return super().sort(key=key, reverse=reverse) # type: ignore
return super().sort(key=key, reverse=reverse)

@staticmethod
def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping:
Expand Down Expand Up @@ -749,7 +749,7 @@ def intersect(self, other: Mapping | Iterable | PathStr, recursive: bool = True)
other = self.empty(other).items()
if not isinstance(other, Iterable):
raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.")
return self.empty(self._intersect(self, other, recursive)) # type: ignore
return self.empty(self._intersect(self, other, recursive))

@staticmethod
def _intersect(this: NestedDict, that: Iterable, recursive: bool = True) -> Mapping:
Expand Down Expand Up @@ -797,7 +797,7 @@ def difference( # pylint: disable=W0221, C0103
other = self.empty(other).items()
if not isinstance(other, Iterable):
raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.")
return self.empty(self._difference(self, other, recursive)) # type: ignore
return self.empty(self._difference(self, other, recursive))

@staticmethod
def _difference(this: NestedDict, that: Iterable, recursive: bool = True) -> Mapping:
Expand All @@ -822,7 +822,7 @@ def converting(self):
finally:
self.setattr("convert_mapping", convert_mapping)

def __contains__(self, name: Any) -> bool: # type: ignore
def __contains__(self, name: Any) -> bool:
delimiter = self.getattr("delimiter", ".")
try:
while isinstance(name, str) and delimiter in name:
Expand Down
18 changes: 9 additions & 9 deletions chanfig/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, *args: Any, **kwargs: Any):
def parse_config( # pylint: disable=R0912
self,
args: Sequence[str] | None = None,
config: NestedDict | None = None,
config: Config | None = None,
default_config: str | None = None,
no_default_config_action: str = "raise",
) -> Config:
Expand Down Expand Up @@ -127,12 +127,12 @@ def parse_config( # pylint: disable=R0912

if config.getattr("parser", None) is not self:
config.setattr("parser", self)
return config.merge(parsed) # type: ignore
return config.merge(parsed)

def parse( # pylint: disable=R0912
self,
args: Sequence[str] | None = None,
config: NestedDict | None = None,
config: Config | None = None,
default_config: str | None = None,
no_default_config_action: str = "raise",
) -> Config:
Expand Down Expand Up @@ -245,7 +245,7 @@ def parse( # pylint: disable=R0912

if config.getattr("parser", None) is not self:
config.setattr("parser", self)
return config.merge(parsed) # type: ignore
return config.merge(parsed)

def parse_args( # type: ignore
self, args: Sequence[str] | None = None, namespace: NestedDict | None = None, eval_str: bool = True
Expand All @@ -262,18 +262,18 @@ def parse_args( # type: ignore
namespace (NestedDict | None, optional): existing configuration.
eval_str (bool, optional): Whether to evaluate string values.
"""
parsed = super().parse_args(args, namespace)
parsed: dict | Namespace = super().parse_args(args, namespace)
if isinstance(parsed, Namespace):
parsed = vars(parsed) # type: ignore
parsed = vars(parsed)
if not isinstance(parsed, NestedDict):
parsed = NestedDict({key: value for key, value in parsed.items() if value is not Null}) # type: ignore
parsed = NestedDict({key: value for key, value in parsed.items() if value is not Null})
if eval_str:
for key, value in parsed.all_items():
if isinstance(value, str):
with suppress(TypeError, ValueError, SyntaxError):
value = literal_eval(value)
parsed[key] = value
return parsed # type: ignore
return parsed

def add_config_arguments(self, config):
for key, value in config.all_items():
Expand All @@ -298,7 +298,7 @@ def merge_default_config(self, parsed, default_config: str, no_default_config_ac
if default_config in parsed:
path = parsed[default_config]
warn(f"Config has 'default_config={path}' specified, its values will override values in Config")
return NestedDict.load(path).merge(parsed) # type: ignore
return NestedDict.load(path).merge(parsed)
if no_default_config_action == "ignore":
pass
elif no_default_config_action == "warn":
Expand Down
2 changes: 1 addition & 1 deletion chanfig/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def register(self, component: Any = None, name: Any | None = None) -> Callable:
# Registry.register()
if name is not None:
self.set(name, component)
return component # type: ignore
return component
# @Registry.register
if callable(component) and name is None:
self.set(component.__name__, component)
Expand Down

0 comments on commit 54a283e

Please sign in to comment.