Skip to content

Commit

Permalink
return Self type
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Nov 9, 2023
1 parent a97aac8 commit d15f67e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 44 deletions.
18 changes: 9 additions & 9 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from functools import wraps
from typing import Any
from typing import Any, Self

from .nested_dict import NestedDict
from .parser import ConfigParser
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, *args: Any, default_factory: Callable | None = None, **kwargs
default_factory = Config
super().__init__(*args, default_factory=default_factory, **kwargs)

def post(self) -> Config | None:
def post(self) -> Self | None:
r"""
Post process of `Config`.
Expand Down Expand Up @@ -138,7 +138,7 @@ def post(self) -> Config | None:
self.validate()
return self

def boot(self) -> Config:
def boot(self) -> Self:
r"""
Apply `post` recursively.
Expand Down Expand Up @@ -192,7 +192,7 @@ def parse(
default_config: str | None = None,
no_default_config_action: str = "raise",
boot: bool = True,
) -> Config:
) -> Self:
r"""
Parse command-line arguments with `ConfigParser`.
Expand Down Expand Up @@ -235,7 +235,7 @@ def parse_config(
default_config: str | None = None,
no_default_config_action: str = "raise",
boot: bool = True,
) -> Config:
) -> Self:
r"""
Parse command-line arguments with `ConfigParser`.
Expand Down Expand Up @@ -289,7 +289,7 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None:
self.setattr("parser", ConfigParser())
return self.getattr("parser").add_argument(*args, **kwargs)

def freeze(self, recursive: bool = True) -> Config:
def freeze(self, recursive: bool = True) -> Self:
r"""
Freeze `Config`.
Expand Down Expand Up @@ -327,7 +327,7 @@ def freeze(config: Config) -> None:
freeze(self)
return self

def lock(self, recursive: bool = True) -> Config:
def lock(self, recursive: bool = True) -> Self:
r"""
Alias of [`freeze`][chanfig.Config.freeze].
"""
Expand Down Expand Up @@ -357,7 +357,7 @@ def locked(self):
if not was_frozen:
self.defrost()

def defrost(self, recursive: bool = True) -> Config:
def defrost(self, recursive: bool = True) -> Self:
r"""
Defrost `Config`.
Expand Down Expand Up @@ -399,7 +399,7 @@ def defrost(config: Config) -> None:
defrost(self)
return self

def unlock(self, recursive: bool = True) -> Config:
def unlock(self, recursive: bool = True) -> Self:
r"""
Alias of [`defrost`][chanfig.Config.defrost].
"""
Expand Down
58 changes: 29 additions & 29 deletions chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from json import loads as json_loads
from os import PathLike
from os.path import splitext
from typing import IO, Any
from typing import IO, Any, Self
from warnings import warn

from yaml import dump as yaml_dump
Expand Down Expand Up @@ -524,7 +524,7 @@ def from_dict(cls, obj: Mapping | Sequence) -> Any: # pylint: disable=R0911
return [cls(json) for json in obj]
raise TypeError(f"Expected Mapping or Sequence, but got {type(obj)}.")

def sort(self, key: Callable | None = None, reverse: bool = False) -> FlatDict:
def sort(self, key: Callable | None = None, reverse: bool = False) -> Self:
r"""
Sort `FlatDict`.
Expand Down Expand Up @@ -553,7 +553,7 @@ def sort(self, key: Callable | None = None, reverse: bool = False) -> FlatDict:

def interpolate( # pylint: disable=R0912
self, use_variable: bool = True, interpolators: MutableMapping | None = None
) -> FlatDict:
) -> Self:
r"""
Perform Variable interpolation.
Expand Down Expand Up @@ -667,7 +667,7 @@ def substitute(placeholder, interpolators, value):
except KeyError as exc:
raise ValueError(f"{exc} is not found in {interpolators}.") from None

def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> FlatDict:
def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> Self:
r"""
Merge `other` into `FlatDict`.
Expand Down Expand Up @@ -736,13 +736,13 @@ def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping:
this.set(key, value)
return this

def union(self, *args: Any, **kwargs: Any) -> FlatDict:
def union(self, *args: Any, **kwargs: Any) -> Self:
r"""
Alias of [`merge`][chanfig.FlatDict.merge].
"""
return self.merge(*args, **kwargs)

def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> FlatDict:
def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> Self:
r"""
Merge content of `file` into `FlatDict`.
Expand All @@ -762,7 +762,7 @@ def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> FlatDict:

return self.merge(self.load(file, *args, **kwargs))

def intersect(self, other: Mapping | Iterable | PathStr) -> FlatDict:
def intersect(self, other: Mapping | Iterable | PathStr) -> Self:
r"""
Intersection of `FlatDict` and `other`.
Expand Down Expand Up @@ -801,13 +801,13 @@ def intersect(self, other: Mapping | Iterable | PathStr) -> FlatDict:
raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.")
return self.empty(**{key: value for key, value in other if key in self and self[key] == value}) # type: ignore

def inter(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> FlatDict:
def inter(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self:
r"""
Alias of [`intersect`][chanfig.FlatDict.intersect].
"""
return self.intersect(other, *args, **kwargs)

def difference(self, other: Mapping | Iterable | PathStr) -> FlatDict:
def difference(self, other: Mapping | Iterable | PathStr) -> Self:
r"""
Difference between `FlatDict` and `other`.
Expand Down Expand Up @@ -848,13 +848,13 @@ def difference(self, other: Mapping | Iterable | PathStr) -> FlatDict:
**{key: value for key, value in other if key not in self or self[key] != value} # type: ignore
)

def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> FlatDict:
def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self:
r"""
Alias of [`difference`][chanfig.FlatDict.difference].
"""
return self.difference(other, *args, **kwargs)

def to(self, cls: str | TorchDevice | TorchDType) -> FlatDict: # pragma: no cover
def to(self, cls: str | TorchDevice | TorchDType) -> Self: # pragma: no cover
r"""
Convert values of `FlatDict` to target `cls`.
Expand All @@ -881,7 +881,7 @@ def to(self, cls: str | TorchDevice | TorchDType) -> FlatDict: # pragma: no cov

raise TypeError(f"to() only support torch.dtype and torch.device, but got {cls}.")

def cpu(self) -> FlatDict: # pragma: no cover
def cpu(self) -> Self: # pragma: no cover
r"""
Move all tensors to cpu.
Expand All @@ -897,7 +897,7 @@ def cpu(self) -> FlatDict: # pragma: no cover

return self.to(TorchDevice("cpu"))

def gpu(self) -> FlatDict: # pragma: no cover
def gpu(self) -> Self: # pragma: no cover
r"""
Move all tensors to gpu.
Expand All @@ -919,13 +919,13 @@ def gpu(self) -> FlatDict: # pragma: no cover

return self.to(TorchDevice("cuda"))

def cuda(self) -> FlatDict: # pragma: no cover
def cuda(self) -> Self: # pragma: no cover
r"""
Alias of [`gpu`][chanfig.FlatDict.gpu].
"""
return self.gpu()

def tpu(self) -> FlatDict: # pragma: no cover
def tpu(self) -> Self: # pragma: no cover
r"""
Move all tensors to tpu.
Expand All @@ -947,13 +947,13 @@ def tpu(self) -> FlatDict: # pragma: no cover

return self.to(TorchDevice("xla"))

def xla(self) -> FlatDict: # pragma: no cover
def xla(self) -> Self: # pragma: no cover
r"""
Alias of [`tpu`][chanfig.FlatDict.tpu].
"""
return self.tpu()

def copy(self) -> FlatDict:
def copy(self) -> Self:
r"""
Create a shallow copy of `FlatDict`.
Expand All @@ -975,7 +975,7 @@ def copy(self) -> FlatDict:

return copy(self)

def __deepcopy__(self, memo: Mapping | None = None) -> FlatDict:
def __deepcopy__(self, memo: Mapping | None = None) -> Self:
# pylint: disable=C0103

if memo is not None and id(self) in memo:
Expand All @@ -989,7 +989,7 @@ def __deepcopy__(self, memo: Mapping | None = None) -> FlatDict:
ret[k] = deepcopy(v)
return ret

def deepcopy(self, memo: Mapping | None = None) -> FlatDict: # pylint: disable=W0613
def deepcopy(self, memo: Mapping | None = None) -> Self: # pylint: disable=W0613
r"""
Create a deep copy of `FlatDict`.
Expand Down Expand Up @@ -1017,7 +1017,7 @@ def deepcopy(self, memo: Mapping | None = None) -> FlatDict: # pylint: disable=

return deepcopy(self)

def clone(self, memo: Mapping | None = None) -> FlatDict:
def clone(self, memo: Mapping | None = None) -> Self:
r"""
Alias of [`deepcopy`][chanfig.FlatDict.deepcopy].
"""
Expand Down Expand Up @@ -1067,7 +1067,7 @@ def dump(self, file: File, method: str | None = None, *args: Any, **kwargs: Any)
@classmethod
def load( # pylint: disable=W1113
cls, file: File, method: str | None = None, *args: Any, **kwargs: Any
) -> FlatDict:
) -> Self:
"""
Load `FlatDict` from file.
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def json(self, file: File, *args: Any, **kwargs: Any) -> None:
fp.write(self.jsons(*args, **kwargs))

@classmethod
def from_json(cls, file: File, *args: Any, **kwargs: Any) -> FlatDict:
def from_json(cls, file: File, *args: Any, **kwargs: Any) -> Self:
r"""
Construct `FlatDict` from json file.
Expand Down Expand Up @@ -1161,7 +1161,7 @@ def jsons(self, *args: Any, **kwargs: Any) -> str:
return json_dumps(self.dict(), *args, **kwargs)

@classmethod
def from_jsons(cls, string: str, *args: Any, **kwargs: Any) -> FlatDict:
def from_jsons(cls, string: str, *args: Any, **kwargs: Any) -> Self:
r"""
Construct `FlatDict` from json string.
Expand Down Expand Up @@ -1195,7 +1195,7 @@ def yaml(self, file: File, *args: Any, **kwargs: Any) -> None:
self.yamls(fp, *args, **kwargs)

@classmethod
def from_yaml(cls, file: File, *args: Any, **kwargs: Any) -> FlatDict:
def from_yaml(cls, file: File, *args: Any, **kwargs: Any) -> Self:
r"""
Construct `FlatDict` from yaml file.
Expand Down Expand Up @@ -1233,7 +1233,7 @@ def yamls(self, *args: Any, **kwargs: Any) -> str:
return yaml_dump(self.dict(), *args, **kwargs) # type: ignore

@classmethod
def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> FlatDict:
def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> Self:
r"""
Construct `FlatDict` from yaml string.
Expand Down Expand Up @@ -1304,7 +1304,7 @@ def open(file: File, *args: Any, encoding: str = "utf-8", **kwargs: Any) -> Gene
raise TypeError(f"expected str, bytes, os.PathLike, IO or IOBase, not {type(file).__name__}")

@classmethod
def empty(cls, *args: Any, **kwargs: Any) -> FlatDict:
def empty(cls, *args: Any, **kwargs: Any) -> Self:
r"""
Initialise an empty `FlatDict`.
Expand All @@ -1330,7 +1330,7 @@ def empty(cls, *args: Any, **kwargs: Any) -> FlatDict:
empty.merge(*args, **kwargs) # pylint: disable=W0212
return empty

def empty_like(self, *args: Any, **kwargs: Any) -> FlatDict:
def empty_like(self, *args: Any, **kwargs: Any) -> Self:
r"""
Initialise an empty copy of `FlatDict`.
Expand Down Expand Up @@ -1392,7 +1392,7 @@ def all_items(self) -> Generator:
"""
yield from self.items()

def dropnull(self) -> FlatDict:
def dropnull(self) -> Self:
r"""
Drop key-value pairs with `Null` value.
Expand All @@ -1415,7 +1415,7 @@ def dropnull(self) -> FlatDict:

return self.empty({k: v for k, v in self.all_items() if v is not Null})

def dropna(self) -> FlatDict:
def dropna(self) -> Self:
r"""
Alias of [`dropnull`][chanfig.FlatDict.dropnull].
"""
Expand Down
12 changes: 6 additions & 6 deletions chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from functools import wraps
from inspect import ismethod
from os import PathLike
from typing import Any
from typing import Any, Self

try:
from functools import cached_property # pylint: disable=C0412
Expand Down Expand Up @@ -240,7 +240,7 @@ def all_items(self, prefix=Null):

return all_items(self)

def apply(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict:
def apply(self, func: Callable, *args: Any, **kwargs: Any) -> Self:
r"""
Recursively apply a function to `NestedDict` and its children.
Expand Down Expand Up @@ -269,7 +269,7 @@ def apply(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict:

return apply(self, func, *args, **kwargs)

def apply_(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict:
def apply_(self, func: Callable, *args: Any, **kwargs: Any) -> Self:
r"""
Recursively apply a function to `NestedDict` and its children.
Expand Down Expand Up @@ -655,7 +655,7 @@ def validate(self) -> None:

self.apply_(self._validate)

def sort(self, key: Callable | None = None, reverse: bool = False, recursive: bool = True) -> NestedDict:
def sort(self, key: Callable | None = None, reverse: bool = False, recursive: bool = True) -> Self:
r"""
Sort `NestedDict`.
Expand Down Expand Up @@ -715,7 +715,7 @@ def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping:

def intersect( # pylint: disable=W0221
self, other: Mapping | Iterable | PathStr, recursive: bool = True
) -> NestedDict:
) -> Self:
r"""
Intersection of `NestedDict` and `other`.
Expand Down Expand Up @@ -763,7 +763,7 @@ def _intersect(this: NestedDict, that: Iterable, recursive: bool = True) -> Mapp

def difference( # pylint: disable=W0221, C0103
self, other: Mapping | Iterable | PathStr, recursive: bool = True
) -> NestedDict:
) -> Self:
r"""
Difference between `NestedDict` and `other`.
Expand Down

0 comments on commit d15f67e

Please sign in to comment.