From 809e378f9d520af87d33f18d120aa54a31c79e87 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 9 Nov 2023 09:14:20 +0000 Subject: [PATCH] fix mypy Signed-off-by: Zhiyuan Chen --- chanfig/flat_dict.py | 32 +++++++++++++++++++------------- chanfig/functional.py | 19 ++++++++++++------- chanfig/nested_dict.py | 10 +++++----- chanfig/parser.py | 20 ++++++++++---------- chanfig/registry.py | 2 +- 5 files changed, 47 insertions(+), 36 deletions(-) diff --git a/chanfig/flat_dict.py b/chanfig/flat_dict.py index a7f74944..a3093db9 100644 --- a/chanfig/flat_dict.py +++ b/chanfig/flat_dict.py @@ -63,7 +63,7 @@ TORCH_AVAILABLE = False -def to_dict(obj: Any) -> Mapping[str, Any]: # pylint: disable=R0911 +def to_dict(obj: Any) -> Mapping | list | set | tuple: # pylint: disable=R0911 r""" Convert an object to a dict. @@ -97,14 +97,14 @@ def to_dict(obj: Any) -> Mapping[str, Any]: # pylint: disable=R0911 if isinstance(obj, Mapping): return {k: to_dict(v) for k, v in obj.items()} if isinstance(obj, list): - return [to_dict(v) for v in obj] # type: ignore + return [to_dict(v) for v in obj] if isinstance(obj, tuple): - return tuple(to_dict(v) for v in obj) # type: ignore + return tuple(to_dict(v) for v in obj) if isinstance(obj, set): try: - return {to_dict(v) for v in obj} # type: ignore + return {to_dict(v) for v in obj} except TypeError: - return tuple(to_dict(v) for v in obj) # type: ignore + return tuple(to_dict(v) for v in obj) if isinstance(obj, Variable): return obj.value return obj @@ -1028,7 +1028,9 @@ def clone(self, memo: Mapping | None = None) -> Self: """ return self.deepcopy(memo=memo) - def save(self, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=W1113 + def save( # pylint: disable=W1113 + self, file: File, method: str = None, *args: Any, **kwargs: Any # type: ignore + ) -> None: r""" Save `FlatDict` to file. @@ -1055,22 +1057,26 @@ def save(self, file: File, method: str | None = None, *args: Any, **kwargs: Any) if method is None: if isinstance(file, (IOBase, IO)): raise ValueError("`method` must be specified when saving to IO.") - method = splitext(file)[-1][1:] # type: ignore - extension = method.lower() # type: ignore + method = splitext(file)[-1][1:] + extension = method.lower() if extension in YAML: return self.yaml(file=file, *args, **kwargs) # type: ignore if extension in JSON: return self.json(file=file, *args, **kwargs) # type: ignore raise TypeError(f"`file={file!r}` should be in {JSON} or {YAML}, but got {extension}.") - def dump(self, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=W1113 + def dump( # pylint: disable=W1113 + self, file: File, method: str = None, *args: Any, **kwargs: Any # type: ignore + ) -> None: r""" Alias of [`save`][chanfig.FlatDict.save]. """ return self.save(file, method, *args, **kwargs) @classmethod - def load(cls, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> Self: # pylint: disable=W1113 + def load( # pylint: disable=W1113 + cls, file: File, method: str = None, *args: Any, **kwargs: Any # type: ignore + ) -> Self: """ Load `FlatDict` from file. @@ -1101,8 +1107,8 @@ def load(cls, file: File, method: str | None = None, *args: Any, **kwargs: Any) if method is None: if isinstance(file, (IOBase, IO)): raise ValueError("`method` must be specified when loading from IO.") - method = splitext(file)[-1][1:] # type: ignore - extension = method.lower() # type: ignore + method = splitext(file)[-1][1:] + extension = method.lower() if extension in JSON: return cls.from_json(file, *args, **kwargs) if extension in YAML: @@ -1233,7 +1239,7 @@ def yamls(self, *args: Any, **kwargs: Any) -> str: kwargs.setdefault("Dumper", YamlDumper) kwargs.setdefault("indent", self.getattr("indent", 2)) - return yaml_dump(self.dict(), *args, **kwargs) # type: ignore + return yaml_dump(self.dict(), *args, **kwargs) @classmethod def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> Self: diff --git a/chanfig/functional.py b/chanfig/functional.py index 252ad14c..44a38d98 100644 --- a/chanfig/functional.py +++ b/chanfig/functional.py @@ -27,7 +27,9 @@ from .utils import JSON, YAML, File, PathStr -def save(obj, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=W1113 +def save( # pylint: disable=W1113 + obj, file: File, method: str = None, *args: Any, **kwargs: Any # type: ignore +) -> None: r""" Save `FlatDict` to file. @@ -61,7 +63,7 @@ def save(obj, file: File, method: str | None = None, *args: Any, **kwargs: Any) if isinstance(file, IOBase): raise ValueError("`method` must be specified when saving to IO.") method = splitext(file)[-1][1:] # type: ignore - extension = method.lower() # type: ignore + extension = method.lower() if extension in YAML: with FlatDict.open(file, mode="w") as fp: # pylint: disable=C0103 yaml_dump(data, fp, *args, **kwargs) @@ -70,17 +72,20 @@ def save(obj, file: File, method: str | None = None, *args: Any, **kwargs: Any) with FlatDict.open(file, mode="w") as fp: # pylint: disable=C0103 fp.write(json_dumps(data, *args, **kwargs)) return - raise TypeError(f"`file={file!r}` should be in {JSON} or {YAML}, but got {extension}.") # type: ignore + raise TypeError(f"`file={file!r}` should be in {JSON} or {YAML}, but got {extension}.") -def load(file: PathStr, cls: type = NestedDict, *args: Any, **kwargs: Any) -> NestedDict: # pylint: disable=W1113 +def load( # pylint: disable=W1113 + file: PathStr, cls: type[FlatDict] = NestedDict, *args: Any, **kwargs: Any +) -> FlatDict: r""" - Load a file into a `NestedDict`. + Load a file into a `FlatDict`. - This function simply calls `NestedDict.load`. + This function simply calls `cls.load`, by default, `cls` is `NestedDict`. Args: file: The file to load. + cls: The class of the file to load. Defaults to `NestedDict`. *args: The arguments to pass to `NestedDict.load`. **kwargs: The keyword arguments to pass to `NestedDict.load`. @@ -98,4 +103,4 @@ def load(file: PathStr, cls: type = NestedDict, *args: Any, **kwargs: Any) -> Ne ) """ - return cls.load(file, *args, **kwargs) # type: ignore + return cls.load(file, *args, **kwargs) diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index d8c9df5a..ecfbae3e 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -562,7 +562,7 @@ def pop(self, name: Any, default: Any = Null) -> Any: raise KeyError(name) return super().pop(name) - def setdefault( # type: ignore # pylint: disable=R0912,W0221 + def setdefault( # type: ignore[override] # pylint: disable=R0912,W0221 self, name: Any, value: Any, @@ -690,7 +690,7 @@ def sort(self, key: Callable | None = None, reverse: bool = False, recursive: bo for value in self.values(): if isinstance(value, FlatDict): value.sort(key=key, reverse=reverse) - return super().sort(key=key, reverse=reverse) # type: ignore + return super().sort(key=key, reverse=reverse) @staticmethod def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping: @@ -749,7 +749,7 @@ def intersect(self, other: Mapping | Iterable | PathStr, recursive: bool = True) other = self.empty(other).items() if not isinstance(other, Iterable): raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.") - return self.empty(self._intersect(self, other, recursive)) # type: ignore + return self.empty(self._intersect(self, other, recursive)) @staticmethod def _intersect(this: NestedDict, that: Iterable, recursive: bool = True) -> Mapping: @@ -797,7 +797,7 @@ def difference( # pylint: disable=W0221, C0103 other = self.empty(other).items() if not isinstance(other, Iterable): raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.") - return self.empty(self._difference(self, other, recursive)) # type: ignore + return self.empty(self._difference(self, other, recursive)) @staticmethod def _difference(this: NestedDict, that: Iterable, recursive: bool = True) -> Mapping: @@ -822,7 +822,7 @@ def converting(self): finally: self.setattr("convert_mapping", convert_mapping) - def __contains__(self, name: Any) -> bool: # type: ignore + def __contains__(self, name: Any) -> bool: delimiter = self.getattr("delimiter", ".") try: while isinstance(name, str) and delimiter in name: diff --git a/chanfig/parser.py b/chanfig/parser.py index 004b6b92..6fc02151 100644 --- a/chanfig/parser.py +++ b/chanfig/parser.py @@ -56,7 +56,7 @@ def __init__(self, *args: Any, **kwargs: Any): def parse_config( # pylint: disable=R0912 self, args: Sequence[str] | None = None, - config: NestedDict | None = None, + config: Config | None = None, default_config: str | None = None, no_default_config_action: str = "raise", ) -> Config: @@ -127,12 +127,12 @@ def parse_config( # pylint: disable=R0912 if config.getattr("parser", None) is not self: config.setattr("parser", self) - return config.merge(parsed) # type: ignore + return config.merge(parsed) def parse( # pylint: disable=R0912 self, args: Sequence[str] | None = None, - config: NestedDict | None = None, + config: Config | None = None, default_config: str | None = None, no_default_config_action: str = "raise", ) -> Config: @@ -245,9 +245,9 @@ def parse( # pylint: disable=R0912 if config.getattr("parser", None) is not self: config.setattr("parser", self) - return config.merge(parsed) # type: ignore + return config.merge(parsed) - def parse_args( # type: ignore + def parse_args( # type: ignore[override] self, args: Sequence[str] | None = None, namespace: NestedDict | None = None, eval_str: bool = True ) -> NestedDict: r""" @@ -262,18 +262,18 @@ def parse_args( # type: ignore namespace (NestedDict | None, optional): existing configuration. eval_str (bool, optional): Whether to evaluate string values. """ - parsed = super().parse_args(args, namespace) + parsed: dict | Namespace = super().parse_args(args, namespace) if isinstance(parsed, Namespace): - parsed = vars(parsed) # type: ignore + parsed = vars(parsed) if not isinstance(parsed, NestedDict): - parsed = NestedDict({key: value for key, value in parsed.items() if value is not Null}) # type: ignore + parsed = NestedDict({key: value for key, value in parsed.items() if value is not Null}) if eval_str: for key, value in parsed.all_items(): if isinstance(value, str): with suppress(TypeError, ValueError, SyntaxError): value = literal_eval(value) parsed[key] = value - return parsed # type: ignore + return parsed def add_config_arguments(self, config): for key, value in config.all_items(): @@ -298,7 +298,7 @@ def merge_default_config(self, parsed, default_config: str, no_default_config_ac if default_config in parsed: path = parsed[default_config] warn(f"Config has 'default_config={path}' specified, its values will override values in Config") - return NestedDict.load(path).merge(parsed) # type: ignore + return NestedDict.load(path).merge(parsed) if no_default_config_action == "ignore": pass elif no_default_config_action == "warn": diff --git a/chanfig/registry.py b/chanfig/registry.py index bb780674..3e120117 100644 --- a/chanfig/registry.py +++ b/chanfig/registry.py @@ -131,7 +131,7 @@ def register(self, component: Any = None, name: Any | None = None) -> Callable: # Registry.register() if name is not None: self.set(name, component) - return component # type: ignore + return component # @Registry.register if callable(component) and name is None: self.set(component.__name__, component)