From fa4b8a0249324998453b3c0f68c76f606a7b74ff Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 9 Nov 2023 16:18:19 +0800 Subject: [PATCH] return Self type --- chanfig/config.py | 21 ++++++++----- chanfig/flat_dict.py | 67 ++++++++++++++++++++++-------------------- chanfig/nested_dict.py | 21 +++++++------ chanfig/registry.py | 2 +- 4 files changed, 61 insertions(+), 50 deletions(-) diff --git a/chanfig/config.py b/chanfig/config.py index c0db0804..99549ba8 100755 --- a/chanfig/config.py +++ b/chanfig/config.py @@ -20,6 +20,11 @@ from functools import wraps from typing import Any +try: + from typing import Self +except ImportError: + from typing_extensions import Self + from .nested_dict import NestedDict from .parser import ConfigParser from .utils import Null @@ -101,7 +106,7 @@ def __init__(self, *args: Any, default_factory: Callable | None = None, **kwargs default_factory = Config super().__init__(*args, default_factory=default_factory, **kwargs) - def post(self) -> Config | None: + def post(self) -> Self | None: r""" Post process of `Config`. @@ -138,7 +143,7 @@ def post(self) -> Config | None: self.validate() return self - def boot(self) -> Config: + def boot(self) -> Self: r""" Apply `post` recursively. @@ -192,7 +197,7 @@ def parse( default_config: str | None = None, no_default_config_action: str = "raise", boot: bool = True, - ) -> Config: + ) -> Self: r""" Parse command-line arguments with `ConfigParser`. @@ -235,7 +240,7 @@ def parse_config( default_config: str | None = None, no_default_config_action: str = "raise", boot: bool = True, - ) -> Config: + ) -> Self: r""" Parse command-line arguments with `ConfigParser`. @@ -289,7 +294,7 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None: self.setattr("parser", ConfigParser()) return self.getattr("parser").add_argument(*args, **kwargs) - def freeze(self, recursive: bool = True) -> Config: + def freeze(self, recursive: bool = True) -> Self: r""" Freeze `Config`. @@ -327,7 +332,7 @@ def freeze(config: Config) -> None: freeze(self) return self - def lock(self, recursive: bool = True) -> Config: + def lock(self, recursive: bool = True) -> Self: r""" Alias of [`freeze`][chanfig.Config.freeze]. """ @@ -357,7 +362,7 @@ def locked(self): if not was_frozen: self.defrost() - def defrost(self, recursive: bool = True) -> Config: + def defrost(self, recursive: bool = True) -> Self: r""" Defrost `Config`. @@ -399,7 +404,7 @@ def defrost(config: Config) -> None: defrost(self) return self - def unlock(self, recursive: bool = True) -> Config: + def unlock(self, recursive: bool = True) -> Self: r""" Alias of [`defrost`][chanfig.Config.defrost]. """ diff --git a/chanfig/flat_dict.py b/chanfig/flat_dict.py index 20f27de6..125c4333 100644 --- a/chanfig/flat_dict.py +++ b/chanfig/flat_dict.py @@ -28,6 +28,11 @@ from typing import IO, Any from warnings import warn +try: + from typing import Self +except ImportError: + from typing_extensions import Self + from yaml import dump as yaml_dump from yaml import load as yaml_load @@ -514,7 +519,7 @@ def from_dict(cls, obj: Mapping | Sequence) -> Any: # pylint: disable=R0911 if obj is None: return cls() if issubclass(cls, FlatDict): - cls = cls.empty # type: ignore # pylint: disable=W0642 + cls = cls.empty # pylint: disable=W0642 if isinstance(obj, Mapping): return cls(obj) if isinstance(obj, Sequence): @@ -524,7 +529,7 @@ def from_dict(cls, obj: Mapping | Sequence) -> Any: # pylint: disable=R0911 return [cls(json) for json in obj] raise TypeError(f"Expected Mapping or Sequence, but got {type(obj)}.") - def sort(self, key: Callable | None = None, reverse: bool = False) -> FlatDict: + def sort(self, key: Callable | None = None, reverse: bool = False) -> Self: r""" Sort `FlatDict`. @@ -553,7 +558,7 @@ def sort(self, key: Callable | None = None, reverse: bool = False) -> FlatDict: def interpolate( # pylint: disable=R0912 self, use_variable: bool = True, interpolators: MutableMapping | None = None - ) -> FlatDict: + ) -> Self: r""" Perform Variable interpolation. @@ -667,7 +672,7 @@ def substitute(placeholder, interpolators, value): except KeyError as exc: raise ValueError(f"{exc} is not found in {interpolators}.") from None - def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> FlatDict: + def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> Self: r""" Merge `other` into `FlatDict`. @@ -736,13 +741,13 @@ def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping: this.set(key, value) return this - def union(self, *args: Any, **kwargs: Any) -> FlatDict: + def union(self, *args: Any, **kwargs: Any) -> Self: r""" Alias of [`merge`][chanfig.FlatDict.merge]. """ return self.merge(*args, **kwargs) - def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> FlatDict: + def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> Self: r""" Merge content of `file` into `FlatDict`. @@ -762,7 +767,7 @@ def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> FlatDict: return self.merge(self.load(file, *args, **kwargs)) - def intersect(self, other: Mapping | Iterable | PathStr) -> FlatDict: + def intersect(self, other: Mapping | Iterable | PathStr) -> Self: r""" Intersection of `FlatDict` and `other`. @@ -801,13 +806,13 @@ def intersect(self, other: Mapping | Iterable | PathStr) -> FlatDict: raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.") return self.empty(**{key: value for key, value in other if key in self and self[key] == value}) # type: ignore - def inter(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> FlatDict: + def inter(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self: r""" Alias of [`intersect`][chanfig.FlatDict.intersect]. """ return self.intersect(other, *args, **kwargs) - def difference(self, other: Mapping | Iterable | PathStr) -> FlatDict: + def difference(self, other: Mapping | Iterable | PathStr) -> Self: r""" Difference between `FlatDict` and `other`. @@ -848,13 +853,13 @@ def difference(self, other: Mapping | Iterable | PathStr) -> FlatDict: **{key: value for key, value in other if key not in self or self[key] != value} # type: ignore ) - def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> FlatDict: + def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self: r""" Alias of [`difference`][chanfig.FlatDict.difference]. """ return self.difference(other, *args, **kwargs) - def to(self, cls: str | TorchDevice | TorchDType) -> FlatDict: # pragma: no cover + def to(self, cls: str | TorchDevice | TorchDType) -> Self: # pragma: no cover r""" Convert values of `FlatDict` to target `cls`. @@ -881,7 +886,7 @@ def to(self, cls: str | TorchDevice | TorchDType) -> FlatDict: # pragma: no cov raise TypeError(f"to() only support torch.dtype and torch.device, but got {cls}.") - def cpu(self) -> FlatDict: # pragma: no cover + def cpu(self) -> Self: # pragma: no cover r""" Move all tensors to cpu. @@ -897,7 +902,7 @@ def cpu(self) -> FlatDict: # pragma: no cover return self.to(TorchDevice("cpu")) - def gpu(self) -> FlatDict: # pragma: no cover + def gpu(self) -> Self: # pragma: no cover r""" Move all tensors to gpu. @@ -919,13 +924,13 @@ def gpu(self) -> FlatDict: # pragma: no cover return self.to(TorchDevice("cuda")) - def cuda(self) -> FlatDict: # pragma: no cover + def cuda(self) -> Self: # pragma: no cover r""" Alias of [`gpu`][chanfig.FlatDict.gpu]. """ return self.gpu() - def tpu(self) -> FlatDict: # pragma: no cover + def tpu(self) -> Self: # pragma: no cover r""" Move all tensors to tpu. @@ -947,13 +952,13 @@ def tpu(self) -> FlatDict: # pragma: no cover return self.to(TorchDevice("xla")) - def xla(self) -> FlatDict: # pragma: no cover + def xla(self) -> Self: # pragma: no cover r""" Alias of [`tpu`][chanfig.FlatDict.tpu]. """ return self.tpu() - def copy(self) -> FlatDict: + def copy(self) -> Self: r""" Create a shallow copy of `FlatDict`. @@ -975,7 +980,7 @@ def copy(self) -> FlatDict: return copy(self) - def __deepcopy__(self, memo: Mapping | None = None) -> FlatDict: + def __deepcopy__(self, memo: Mapping | None = None) -> Self: # pylint: disable=C0103 if memo is not None and id(self) in memo: @@ -989,7 +994,7 @@ def __deepcopy__(self, memo: Mapping | None = None) -> FlatDict: ret[k] = deepcopy(v) return ret - def deepcopy(self, memo: Mapping | None = None) -> FlatDict: # pylint: disable=W0613 + def deepcopy(self, memo: Mapping | None = None) -> Self: # pylint: disable=W0613 r""" Create a deep copy of `FlatDict`. @@ -1017,7 +1022,7 @@ def deepcopy(self, memo: Mapping | None = None) -> FlatDict: # pylint: disable= return deepcopy(self) - def clone(self, memo: Mapping | None = None) -> FlatDict: + def clone(self, memo: Mapping | None = None) -> Self: r""" Alias of [`deepcopy`][chanfig.FlatDict.deepcopy]. """ @@ -1065,9 +1070,7 @@ def dump(self, file: File, method: str | None = None, *args: Any, **kwargs: Any) return self.save(file, method, *args, **kwargs) @classmethod - def load( # pylint: disable=W1113 - cls, file: File, method: str | None = None, *args: Any, **kwargs: Any - ) -> FlatDict: + def load(cls, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> Self: # pylint: disable=W1113 """ Load `FlatDict` from file. @@ -1122,7 +1125,7 @@ def json(self, file: File, *args: Any, **kwargs: Any) -> None: fp.write(self.jsons(*args, **kwargs)) @classmethod - def from_json(cls, file: File, *args: Any, **kwargs: Any) -> FlatDict: + def from_json(cls, file: File, *args: Any, **kwargs: Any) -> Self: r""" Construct `FlatDict` from json file. @@ -1161,7 +1164,7 @@ def jsons(self, *args: Any, **kwargs: Any) -> str: return json_dumps(self.dict(), *args, **kwargs) @classmethod - def from_jsons(cls, string: str, *args: Any, **kwargs: Any) -> FlatDict: + def from_jsons(cls, string: str, *args: Any, **kwargs: Any) -> Self: r""" Construct `FlatDict` from json string. @@ -1195,7 +1198,7 @@ def yaml(self, file: File, *args: Any, **kwargs: Any) -> None: self.yamls(fp, *args, **kwargs) @classmethod - def from_yaml(cls, file: File, *args: Any, **kwargs: Any) -> FlatDict: + def from_yaml(cls, file: File, *args: Any, **kwargs: Any) -> Self: r""" Construct `FlatDict` from yaml file. @@ -1233,7 +1236,7 @@ def yamls(self, *args: Any, **kwargs: Any) -> str: return yaml_dump(self.dict(), *args, **kwargs) # type: ignore @classmethod - def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> FlatDict: + def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> Self: r""" Construct `FlatDict` from yaml string. @@ -1295,7 +1298,7 @@ def open(file: File, *args: Any, encoding: str = "utf-8", **kwargs: Any) -> Gene yield file elif isinstance(file, (PathLike, str, bytes)): try: - file = open(file, *args, encoding=encoding, **kwargs) # type: ignore # noqa: SIM115 + file = open(file, *args, encoding=encoding, **kwargs) # noqa: SIM115 yield file # type: ignore finally: with suppress(Exception): @@ -1304,7 +1307,7 @@ def open(file: File, *args: Any, encoding: str = "utf-8", **kwargs: Any) -> Gene raise TypeError(f"expected str, bytes, os.PathLike, IO or IOBase, not {type(file).__name__}") @classmethod - def empty(cls, *args: Any, **kwargs: Any) -> FlatDict: + def empty(cls, *args: Any, **kwargs: Any) -> Self: r""" Initialise an empty `FlatDict`. @@ -1330,7 +1333,7 @@ def empty(cls, *args: Any, **kwargs: Any) -> FlatDict: empty.merge(*args, **kwargs) # pylint: disable=W0212 return empty - def empty_like(self, *args: Any, **kwargs: Any) -> FlatDict: + def empty_like(self, *args: Any, **kwargs: Any) -> Self: r""" Initialise an empty copy of `FlatDict`. @@ -1392,7 +1395,7 @@ def all_items(self) -> Generator: """ yield from self.items() - def dropnull(self) -> FlatDict: + def dropnull(self) -> Self: r""" Drop key-value pairs with `Null` value. @@ -1415,7 +1418,7 @@ def dropnull(self) -> FlatDict: return self.empty({k: v for k, v in self.all_items() if v is not Null}) - def dropna(self) -> FlatDict: + def dropna(self) -> Self: r""" Alias of [`dropnull`][chanfig.FlatDict.dropnull]. """ diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index 862e9601..f9f6a042 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -28,7 +28,12 @@ try: from backports.cached_property import cached_property # type: ignore except ImportError: - cached_property = property # type: ignore # pylint: disable=C0103 + cached_property = property # pylint: disable=C0103 + +try: + from typing import Self +except ImportError: + from typing_extensions import Self from .default_dict import DefaultDict from .flat_dict import FlatDict @@ -240,7 +245,7 @@ def all_items(self, prefix=Null): return all_items(self) - def apply(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict: + def apply(self, func: Callable, *args: Any, **kwargs: Any) -> Self: r""" Recursively apply a function to `NestedDict` and its children. @@ -269,7 +274,7 @@ def apply(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict: return apply(self, func, *args, **kwargs) - def apply_(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict: + def apply_(self, func: Callable, *args: Any, **kwargs: Any) -> Self: r""" Recursively apply a function to `NestedDict` and its children. @@ -557,7 +562,7 @@ def pop(self, name: Any, default: Any = Null) -> Any: raise KeyError(name) return super().pop(name) - def setdefault( # type: ignore # pylint: disable=R0912,W0221 + def setdefault( # pylint: disable=R0912,W0221 self, name: Any, value: Any, @@ -655,7 +660,7 @@ def validate(self) -> None: self.apply_(self._validate) - def sort(self, key: Callable | None = None, reverse: bool = False, recursive: bool = True) -> NestedDict: + def sort(self, key: Callable | None = None, reverse: bool = False, recursive: bool = True) -> Self: r""" Sort `NestedDict`. @@ -713,9 +718,7 @@ def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping: this[key] = value return this - def intersect( # pylint: disable=W0221 - self, other: Mapping | Iterable | PathStr, recursive: bool = True - ) -> NestedDict: + def intersect(self, other: Mapping | Iterable | PathStr, recursive: bool = True) -> Self: # pylint: disable=W0221 r""" Intersection of `NestedDict` and `other`. @@ -763,7 +766,7 @@ def _intersect(this: NestedDict, that: Iterable, recursive: bool = True) -> Mapp def difference( # pylint: disable=W0221, C0103 self, other: Mapping | Iterable | PathStr, recursive: bool = True - ) -> NestedDict: + ) -> Self: r""" Difference between `NestedDict` and `other`. diff --git a/chanfig/registry.py b/chanfig/registry.py index 7ae26b2e..bb780674 100644 --- a/chanfig/registry.py +++ b/chanfig/registry.py @@ -178,7 +178,7 @@ def lookup(self, name: str) -> Any: return self[name] @staticmethod - def init(cls: Callable, *args: Any, **kwargs: Any) -> Any: # type: ignore # pylint: disable=W0211 + def init(cls: Callable, *args: Any, **kwargs: Any) -> Any: # pylint: disable=W0211 r""" Constructor of component.