From d15f67e6005fbc1d506813e0b2e877f7aa057b9d Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 9 Nov 2023 16:18:19 +0800 Subject: [PATCH] return Self type --- chanfig/config.py | 18 ++++++------- chanfig/flat_dict.py | 58 +++++++++++++++++++++--------------------- chanfig/nested_dict.py | 12 ++++----- 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/chanfig/config.py b/chanfig/config.py index c0db0804..d843b4c9 100755 --- a/chanfig/config.py +++ b/chanfig/config.py @@ -18,7 +18,7 @@ from collections.abc import Callable, Iterable from contextlib import contextmanager from functools import wraps -from typing import Any +from typing import Any, Self from .nested_dict import NestedDict from .parser import ConfigParser @@ -101,7 +101,7 @@ def __init__(self, *args: Any, default_factory: Callable | None = None, **kwargs default_factory = Config super().__init__(*args, default_factory=default_factory, **kwargs) - def post(self) -> Config | None: + def post(self) -> Self | None: r""" Post process of `Config`. @@ -138,7 +138,7 @@ def post(self) -> Config | None: self.validate() return self - def boot(self) -> Config: + def boot(self) -> Self: r""" Apply `post` recursively. @@ -192,7 +192,7 @@ def parse( default_config: str | None = None, no_default_config_action: str = "raise", boot: bool = True, - ) -> Config: + ) -> Self: r""" Parse command-line arguments with `ConfigParser`. @@ -235,7 +235,7 @@ def parse_config( default_config: str | None = None, no_default_config_action: str = "raise", boot: bool = True, - ) -> Config: + ) -> Self: r""" Parse command-line arguments with `ConfigParser`. @@ -289,7 +289,7 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None: self.setattr("parser", ConfigParser()) return self.getattr("parser").add_argument(*args, **kwargs) - def freeze(self, recursive: bool = True) -> Config: + def freeze(self, recursive: bool = True) -> Self: r""" Freeze `Config`. @@ -327,7 +327,7 @@ def freeze(config: Config) -> None: freeze(self) return self - def lock(self, recursive: bool = True) -> Config: + def lock(self, recursive: bool = True) -> Self: r""" Alias of [`freeze`][chanfig.Config.freeze]. """ @@ -357,7 +357,7 @@ def locked(self): if not was_frozen: self.defrost() - def defrost(self, recursive: bool = True) -> Config: + def defrost(self, recursive: bool = True) -> Self: r""" Defrost `Config`. @@ -399,7 +399,7 @@ def defrost(config: Config) -> None: defrost(self) return self - def unlock(self, recursive: bool = True) -> Config: + def unlock(self, recursive: bool = True) -> Self: r""" Alias of [`defrost`][chanfig.Config.defrost]. """ diff --git a/chanfig/flat_dict.py b/chanfig/flat_dict.py index 20f27de6..8b89810b 100644 --- a/chanfig/flat_dict.py +++ b/chanfig/flat_dict.py @@ -25,7 +25,7 @@ from json import loads as json_loads from os import PathLike from os.path import splitext -from typing import IO, Any +from typing import IO, Any, Self from warnings import warn from yaml import dump as yaml_dump @@ -524,7 +524,7 @@ def from_dict(cls, obj: Mapping | Sequence) -> Any: # pylint: disable=R0911 return [cls(json) for json in obj] raise TypeError(f"Expected Mapping or Sequence, but got {type(obj)}.") - def sort(self, key: Callable | None = None, reverse: bool = False) -> FlatDict: + def sort(self, key: Callable | None = None, reverse: bool = False) -> Self: r""" Sort `FlatDict`. @@ -553,7 +553,7 @@ def sort(self, key: Callable | None = None, reverse: bool = False) -> FlatDict: def interpolate( # pylint: disable=R0912 self, use_variable: bool = True, interpolators: MutableMapping | None = None - ) -> FlatDict: + ) -> Self: r""" Perform Variable interpolation. @@ -667,7 +667,7 @@ def substitute(placeholder, interpolators, value): except KeyError as exc: raise ValueError(f"{exc} is not found in {interpolators}.") from None - def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> FlatDict: + def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> Self: r""" Merge `other` into `FlatDict`. @@ -736,13 +736,13 @@ def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping: this.set(key, value) return this - def union(self, *args: Any, **kwargs: Any) -> FlatDict: + def union(self, *args: Any, **kwargs: Any) -> Self: r""" Alias of [`merge`][chanfig.FlatDict.merge]. """ return self.merge(*args, **kwargs) - def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> FlatDict: + def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> Self: r""" Merge content of `file` into `FlatDict`. @@ -762,7 +762,7 @@ def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> FlatDict: return self.merge(self.load(file, *args, **kwargs)) - def intersect(self, other: Mapping | Iterable | PathStr) -> FlatDict: + def intersect(self, other: Mapping | Iterable | PathStr) -> Self: r""" Intersection of `FlatDict` and `other`. @@ -801,13 +801,13 @@ def intersect(self, other: Mapping | Iterable | PathStr) -> FlatDict: raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.") return self.empty(**{key: value for key, value in other if key in self and self[key] == value}) # type: ignore - def inter(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> FlatDict: + def inter(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self: r""" Alias of [`intersect`][chanfig.FlatDict.intersect]. """ return self.intersect(other, *args, **kwargs) - def difference(self, other: Mapping | Iterable | PathStr) -> FlatDict: + def difference(self, other: Mapping | Iterable | PathStr) -> Self: r""" Difference between `FlatDict` and `other`. @@ -848,13 +848,13 @@ def difference(self, other: Mapping | Iterable | PathStr) -> FlatDict: **{key: value for key, value in other if key not in self or self[key] != value} # type: ignore ) - def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> FlatDict: + def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self: r""" Alias of [`difference`][chanfig.FlatDict.difference]. """ return self.difference(other, *args, **kwargs) - def to(self, cls: str | TorchDevice | TorchDType) -> FlatDict: # pragma: no cover + def to(self, cls: str | TorchDevice | TorchDType) -> Self: # pragma: no cover r""" Convert values of `FlatDict` to target `cls`. @@ -881,7 +881,7 @@ def to(self, cls: str | TorchDevice | TorchDType) -> FlatDict: # pragma: no cov raise TypeError(f"to() only support torch.dtype and torch.device, but got {cls}.") - def cpu(self) -> FlatDict: # pragma: no cover + def cpu(self) -> Self: # pragma: no cover r""" Move all tensors to cpu. @@ -897,7 +897,7 @@ def cpu(self) -> FlatDict: # pragma: no cover return self.to(TorchDevice("cpu")) - def gpu(self) -> FlatDict: # pragma: no cover + def gpu(self) -> Self: # pragma: no cover r""" Move all tensors to gpu. @@ -919,13 +919,13 @@ def gpu(self) -> FlatDict: # pragma: no cover return self.to(TorchDevice("cuda")) - def cuda(self) -> FlatDict: # pragma: no cover + def cuda(self) -> Self: # pragma: no cover r""" Alias of [`gpu`][chanfig.FlatDict.gpu]. """ return self.gpu() - def tpu(self) -> FlatDict: # pragma: no cover + def tpu(self) -> Self: # pragma: no cover r""" Move all tensors to tpu. @@ -947,13 +947,13 @@ def tpu(self) -> FlatDict: # pragma: no cover return self.to(TorchDevice("xla")) - def xla(self) -> FlatDict: # pragma: no cover + def xla(self) -> Self: # pragma: no cover r""" Alias of [`tpu`][chanfig.FlatDict.tpu]. """ return self.tpu() - def copy(self) -> FlatDict: + def copy(self) -> Self: r""" Create a shallow copy of `FlatDict`. @@ -975,7 +975,7 @@ def copy(self) -> FlatDict: return copy(self) - def __deepcopy__(self, memo: Mapping | None = None) -> FlatDict: + def __deepcopy__(self, memo: Mapping | None = None) -> Self: # pylint: disable=C0103 if memo is not None and id(self) in memo: @@ -989,7 +989,7 @@ def __deepcopy__(self, memo: Mapping | None = None) -> FlatDict: ret[k] = deepcopy(v) return ret - def deepcopy(self, memo: Mapping | None = None) -> FlatDict: # pylint: disable=W0613 + def deepcopy(self, memo: Mapping | None = None) -> Self: # pylint: disable=W0613 r""" Create a deep copy of `FlatDict`. @@ -1017,7 +1017,7 @@ def deepcopy(self, memo: Mapping | None = None) -> FlatDict: # pylint: disable= return deepcopy(self) - def clone(self, memo: Mapping | None = None) -> FlatDict: + def clone(self, memo: Mapping | None = None) -> Self: r""" Alias of [`deepcopy`][chanfig.FlatDict.deepcopy]. """ @@ -1067,7 +1067,7 @@ def dump(self, file: File, method: str | None = None, *args: Any, **kwargs: Any) @classmethod def load( # pylint: disable=W1113 cls, file: File, method: str | None = None, *args: Any, **kwargs: Any - ) -> FlatDict: + ) -> Self: """ Load `FlatDict` from file. @@ -1122,7 +1122,7 @@ def json(self, file: File, *args: Any, **kwargs: Any) -> None: fp.write(self.jsons(*args, **kwargs)) @classmethod - def from_json(cls, file: File, *args: Any, **kwargs: Any) -> FlatDict: + def from_json(cls, file: File, *args: Any, **kwargs: Any) -> Self: r""" Construct `FlatDict` from json file. @@ -1161,7 +1161,7 @@ def jsons(self, *args: Any, **kwargs: Any) -> str: return json_dumps(self.dict(), *args, **kwargs) @classmethod - def from_jsons(cls, string: str, *args: Any, **kwargs: Any) -> FlatDict: + def from_jsons(cls, string: str, *args: Any, **kwargs: Any) -> Self: r""" Construct `FlatDict` from json string. @@ -1195,7 +1195,7 @@ def yaml(self, file: File, *args: Any, **kwargs: Any) -> None: self.yamls(fp, *args, **kwargs) @classmethod - def from_yaml(cls, file: File, *args: Any, **kwargs: Any) -> FlatDict: + def from_yaml(cls, file: File, *args: Any, **kwargs: Any) -> Self: r""" Construct `FlatDict` from yaml file. @@ -1233,7 +1233,7 @@ def yamls(self, *args: Any, **kwargs: Any) -> str: return yaml_dump(self.dict(), *args, **kwargs) # type: ignore @classmethod - def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> FlatDict: + def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> Self: r""" Construct `FlatDict` from yaml string. @@ -1304,7 +1304,7 @@ def open(file: File, *args: Any, encoding: str = "utf-8", **kwargs: Any) -> Gene raise TypeError(f"expected str, bytes, os.PathLike, IO or IOBase, not {type(file).__name__}") @classmethod - def empty(cls, *args: Any, **kwargs: Any) -> FlatDict: + def empty(cls, *args: Any, **kwargs: Any) -> Self: r""" Initialise an empty `FlatDict`. @@ -1330,7 +1330,7 @@ def empty(cls, *args: Any, **kwargs: Any) -> FlatDict: empty.merge(*args, **kwargs) # pylint: disable=W0212 return empty - def empty_like(self, *args: Any, **kwargs: Any) -> FlatDict: + def empty_like(self, *args: Any, **kwargs: Any) -> Self: r""" Initialise an empty copy of `FlatDict`. @@ -1392,7 +1392,7 @@ def all_items(self) -> Generator: """ yield from self.items() - def dropnull(self) -> FlatDict: + def dropnull(self) -> Self: r""" Drop key-value pairs with `Null` value. @@ -1415,7 +1415,7 @@ def dropnull(self) -> FlatDict: return self.empty({k: v for k, v in self.all_items() if v is not Null}) - def dropna(self) -> FlatDict: + def dropna(self) -> Self: r""" Alias of [`dropnull`][chanfig.FlatDict.dropnull]. """ diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index 862e9601..38afa3f9 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -20,7 +20,7 @@ from functools import wraps from inspect import ismethod from os import PathLike -from typing import Any +from typing import Any, Self try: from functools import cached_property # pylint: disable=C0412 @@ -240,7 +240,7 @@ def all_items(self, prefix=Null): return all_items(self) - def apply(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict: + def apply(self, func: Callable, *args: Any, **kwargs: Any) -> Self: r""" Recursively apply a function to `NestedDict` and its children. @@ -269,7 +269,7 @@ def apply(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict: return apply(self, func, *args, **kwargs) - def apply_(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict: + def apply_(self, func: Callable, *args: Any, **kwargs: Any) -> Self: r""" Recursively apply a function to `NestedDict` and its children. @@ -655,7 +655,7 @@ def validate(self) -> None: self.apply_(self._validate) - def sort(self, key: Callable | None = None, reverse: bool = False, recursive: bool = True) -> NestedDict: + def sort(self, key: Callable | None = None, reverse: bool = False, recursive: bool = True) -> Self: r""" Sort `NestedDict`. @@ -715,7 +715,7 @@ def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping: def intersect( # pylint: disable=W0221 self, other: Mapping | Iterable | PathStr, recursive: bool = True - ) -> NestedDict: + ) -> Self: r""" Intersection of `NestedDict` and `other`. @@ -763,7 +763,7 @@ def _intersect(this: NestedDict, that: Iterable, recursive: bool = True) -> Mapp def difference( # pylint: disable=W0221, C0103 self, other: Mapping | Iterable | PathStr, recursive: bool = True - ) -> NestedDict: + ) -> Self: r""" Difference between `NestedDict` and `other`.