Skip to content

Commit

Permalink
return Self type
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Nov 9, 2023
1 parent a97aac8 commit fa4b8a0
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 50 deletions.
21 changes: 13 additions & 8 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from functools import wraps
from typing import Any

try:
from typing import Self
except ImportError:
from typing_extensions import Self

from .nested_dict import NestedDict
from .parser import ConfigParser
from .utils import Null
Expand Down Expand Up @@ -101,7 +106,7 @@ def __init__(self, *args: Any, default_factory: Callable | None = None, **kwargs
default_factory = Config
super().__init__(*args, default_factory=default_factory, **kwargs)

def post(self) -> Config | None:
def post(self) -> Self | None:
r"""
Post process of `Config`.
Expand Down Expand Up @@ -138,7 +143,7 @@ def post(self) -> Config | None:
self.validate()
return self

def boot(self) -> Config:
def boot(self) -> Self:
r"""
Apply `post` recursively.
Expand Down Expand Up @@ -192,7 +197,7 @@ def parse(
default_config: str | None = None,
no_default_config_action: str = "raise",
boot: bool = True,
) -> Config:
) -> Self:
r"""
Parse command-line arguments with `ConfigParser`.
Expand Down Expand Up @@ -235,7 +240,7 @@ def parse_config(
default_config: str | None = None,
no_default_config_action: str = "raise",
boot: bool = True,
) -> Config:
) -> Self:
r"""
Parse command-line arguments with `ConfigParser`.
Expand Down Expand Up @@ -289,7 +294,7 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None:
self.setattr("parser", ConfigParser())
return self.getattr("parser").add_argument(*args, **kwargs)

def freeze(self, recursive: bool = True) -> Config:
def freeze(self, recursive: bool = True) -> Self:
r"""
Freeze `Config`.
Expand Down Expand Up @@ -327,7 +332,7 @@ def freeze(config: Config) -> None:
freeze(self)
return self

def lock(self, recursive: bool = True) -> Config:
def lock(self, recursive: bool = True) -> Self:
r"""
Alias of [`freeze`][chanfig.Config.freeze].
"""
Expand Down Expand Up @@ -357,7 +362,7 @@ def locked(self):
if not was_frozen:
self.defrost()

def defrost(self, recursive: bool = True) -> Config:
def defrost(self, recursive: bool = True) -> Self:
r"""
Defrost `Config`.
Expand Down Expand Up @@ -399,7 +404,7 @@ def defrost(config: Config) -> None:
defrost(self)
return self

def unlock(self, recursive: bool = True) -> Config:
def unlock(self, recursive: bool = True) -> Self:
r"""
Alias of [`defrost`][chanfig.Config.defrost].
"""
Expand Down
67 changes: 35 additions & 32 deletions chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from typing import IO, Any
from warnings import warn

try:
from typing import Self
except ImportError:
from typing_extensions import Self

from yaml import dump as yaml_dump
from yaml import load as yaml_load

Expand Down Expand Up @@ -514,7 +519,7 @@ def from_dict(cls, obj: Mapping | Sequence) -> Any: # pylint: disable=R0911
if obj is None:
return cls()
if issubclass(cls, FlatDict):
cls = cls.empty # type: ignore # pylint: disable=W0642
cls = cls.empty # pylint: disable=W0642
if isinstance(obj, Mapping):
return cls(obj)
if isinstance(obj, Sequence):
Expand All @@ -524,7 +529,7 @@ def from_dict(cls, obj: Mapping | Sequence) -> Any: # pylint: disable=R0911
return [cls(json) for json in obj]
raise TypeError(f"Expected Mapping or Sequence, but got {type(obj)}.")

def sort(self, key: Callable | None = None, reverse: bool = False) -> FlatDict:
def sort(self, key: Callable | None = None, reverse: bool = False) -> Self:
r"""
Sort `FlatDict`.
Expand Down Expand Up @@ -553,7 +558,7 @@ def sort(self, key: Callable | None = None, reverse: bool = False) -> FlatDict:

def interpolate( # pylint: disable=R0912
self, use_variable: bool = True, interpolators: MutableMapping | None = None
) -> FlatDict:
) -> Self:
r"""
Perform Variable interpolation.
Expand Down Expand Up @@ -667,7 +672,7 @@ def substitute(placeholder, interpolators, value):
except KeyError as exc:
raise ValueError(f"{exc} is not found in {interpolators}.") from None

def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> FlatDict:
def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> Self:
r"""
Merge `other` into `FlatDict`.
Expand Down Expand Up @@ -736,13 +741,13 @@ def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping:
this.set(key, value)
return this

def union(self, *args: Any, **kwargs: Any) -> FlatDict:
def union(self, *args: Any, **kwargs: Any) -> Self:
r"""
Alias of [`merge`][chanfig.FlatDict.merge].
"""
return self.merge(*args, **kwargs)

def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> FlatDict:
def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> Self:
r"""
Merge content of `file` into `FlatDict`.
Expand All @@ -762,7 +767,7 @@ def merge_from_file(self, file: File, *args: Any, **kwargs: Any) -> FlatDict:

return self.merge(self.load(file, *args, **kwargs))

def intersect(self, other: Mapping | Iterable | PathStr) -> FlatDict:
def intersect(self, other: Mapping | Iterable | PathStr) -> Self:
r"""
Intersection of `FlatDict` and `other`.
Expand Down Expand Up @@ -801,13 +806,13 @@ def intersect(self, other: Mapping | Iterable | PathStr) -> FlatDict:
raise TypeError(f"`other={other}` should be of type Mapping, Iterable or PathStr, but got {type(other)}.")
return self.empty(**{key: value for key, value in other if key in self and self[key] == value}) # type: ignore

def inter(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> FlatDict:
def inter(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self:
r"""
Alias of [`intersect`][chanfig.FlatDict.intersect].
"""
return self.intersect(other, *args, **kwargs)

def difference(self, other: Mapping | Iterable | PathStr) -> FlatDict:
def difference(self, other: Mapping | Iterable | PathStr) -> Self:
r"""
Difference between `FlatDict` and `other`.
Expand Down Expand Up @@ -848,13 +853,13 @@ def difference(self, other: Mapping | Iterable | PathStr) -> FlatDict:
**{key: value for key, value in other if key not in self or self[key] != value} # type: ignore
)

def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> FlatDict:
def diff(self, other: Mapping | Iterable | PathStr, *args: Any, **kwargs: Any) -> Self:
r"""
Alias of [`difference`][chanfig.FlatDict.difference].
"""
return self.difference(other, *args, **kwargs)

def to(self, cls: str | TorchDevice | TorchDType) -> FlatDict: # pragma: no cover
def to(self, cls: str | TorchDevice | TorchDType) -> Self: # pragma: no cover
r"""
Convert values of `FlatDict` to target `cls`.
Expand All @@ -881,7 +886,7 @@ def to(self, cls: str | TorchDevice | TorchDType) -> FlatDict: # pragma: no cov

raise TypeError(f"to() only support torch.dtype and torch.device, but got {cls}.")

def cpu(self) -> FlatDict: # pragma: no cover
def cpu(self) -> Self: # pragma: no cover
r"""
Move all tensors to cpu.
Expand All @@ -897,7 +902,7 @@ def cpu(self) -> FlatDict: # pragma: no cover

return self.to(TorchDevice("cpu"))

def gpu(self) -> FlatDict: # pragma: no cover
def gpu(self) -> Self: # pragma: no cover
r"""
Move all tensors to gpu.
Expand All @@ -919,13 +924,13 @@ def gpu(self) -> FlatDict: # pragma: no cover

return self.to(TorchDevice("cuda"))

def cuda(self) -> FlatDict: # pragma: no cover
def cuda(self) -> Self: # pragma: no cover
r"""
Alias of [`gpu`][chanfig.FlatDict.gpu].
"""
return self.gpu()

def tpu(self) -> FlatDict: # pragma: no cover
def tpu(self) -> Self: # pragma: no cover
r"""
Move all tensors to tpu.
Expand All @@ -947,13 +952,13 @@ def tpu(self) -> FlatDict: # pragma: no cover

return self.to(TorchDevice("xla"))

def xla(self) -> FlatDict: # pragma: no cover
def xla(self) -> Self: # pragma: no cover
r"""
Alias of [`tpu`][chanfig.FlatDict.tpu].
"""
return self.tpu()

def copy(self) -> FlatDict:
def copy(self) -> Self:
r"""
Create a shallow copy of `FlatDict`.
Expand All @@ -975,7 +980,7 @@ def copy(self) -> FlatDict:

return copy(self)

def __deepcopy__(self, memo: Mapping | None = None) -> FlatDict:
def __deepcopy__(self, memo: Mapping | None = None) -> Self:
# pylint: disable=C0103

if memo is not None and id(self) in memo:
Expand All @@ -989,7 +994,7 @@ def __deepcopy__(self, memo: Mapping | None = None) -> FlatDict:
ret[k] = deepcopy(v)
return ret

def deepcopy(self, memo: Mapping | None = None) -> FlatDict: # pylint: disable=W0613
def deepcopy(self, memo: Mapping | None = None) -> Self: # pylint: disable=W0613
r"""
Create a deep copy of `FlatDict`.
Expand Down Expand Up @@ -1017,7 +1022,7 @@ def deepcopy(self, memo: Mapping | None = None) -> FlatDict: # pylint: disable=

return deepcopy(self)

def clone(self, memo: Mapping | None = None) -> FlatDict:
def clone(self, memo: Mapping | None = None) -> Self:
r"""
Alias of [`deepcopy`][chanfig.FlatDict.deepcopy].
"""
Expand Down Expand Up @@ -1065,9 +1070,7 @@ def dump(self, file: File, method: str | None = None, *args: Any, **kwargs: Any)
return self.save(file, method, *args, **kwargs)

@classmethod
def load( # pylint: disable=W1113
cls, file: File, method: str | None = None, *args: Any, **kwargs: Any
) -> FlatDict:
def load(cls, file: File, method: str | None = None, *args: Any, **kwargs: Any) -> Self: # pylint: disable=W1113
"""
Load `FlatDict` from file.
Expand Down Expand Up @@ -1122,7 +1125,7 @@ def json(self, file: File, *args: Any, **kwargs: Any) -> None:
fp.write(self.jsons(*args, **kwargs))

@classmethod
def from_json(cls, file: File, *args: Any, **kwargs: Any) -> FlatDict:
def from_json(cls, file: File, *args: Any, **kwargs: Any) -> Self:
r"""
Construct `FlatDict` from json file.
Expand Down Expand Up @@ -1161,7 +1164,7 @@ def jsons(self, *args: Any, **kwargs: Any) -> str:
return json_dumps(self.dict(), *args, **kwargs)

@classmethod
def from_jsons(cls, string: str, *args: Any, **kwargs: Any) -> FlatDict:
def from_jsons(cls, string: str, *args: Any, **kwargs: Any) -> Self:
r"""
Construct `FlatDict` from json string.
Expand Down Expand Up @@ -1195,7 +1198,7 @@ def yaml(self, file: File, *args: Any, **kwargs: Any) -> None:
self.yamls(fp, *args, **kwargs)

@classmethod
def from_yaml(cls, file: File, *args: Any, **kwargs: Any) -> FlatDict:
def from_yaml(cls, file: File, *args: Any, **kwargs: Any) -> Self:
r"""
Construct `FlatDict` from yaml file.
Expand Down Expand Up @@ -1233,7 +1236,7 @@ def yamls(self, *args: Any, **kwargs: Any) -> str:
return yaml_dump(self.dict(), *args, **kwargs) # type: ignore

@classmethod
def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> FlatDict:
def from_yamls(cls, string: str, *args: Any, **kwargs: Any) -> Self:
r"""
Construct `FlatDict` from yaml string.
Expand Down Expand Up @@ -1295,7 +1298,7 @@ def open(file: File, *args: Any, encoding: str = "utf-8", **kwargs: Any) -> Gene
yield file
elif isinstance(file, (PathLike, str, bytes)):
try:
file = open(file, *args, encoding=encoding, **kwargs) # type: ignore # noqa: SIM115
file = open(file, *args, encoding=encoding, **kwargs) # noqa: SIM115
yield file # type: ignore
finally:
with suppress(Exception):
Expand All @@ -1304,7 +1307,7 @@ def open(file: File, *args: Any, encoding: str = "utf-8", **kwargs: Any) -> Gene
raise TypeError(f"expected str, bytes, os.PathLike, IO or IOBase, not {type(file).__name__}")

@classmethod
def empty(cls, *args: Any, **kwargs: Any) -> FlatDict:
def empty(cls, *args: Any, **kwargs: Any) -> Self:
r"""
Initialise an empty `FlatDict`.
Expand All @@ -1330,7 +1333,7 @@ def empty(cls, *args: Any, **kwargs: Any) -> FlatDict:
empty.merge(*args, **kwargs) # pylint: disable=W0212
return empty

def empty_like(self, *args: Any, **kwargs: Any) -> FlatDict:
def empty_like(self, *args: Any, **kwargs: Any) -> Self:
r"""
Initialise an empty copy of `FlatDict`.
Expand Down Expand Up @@ -1392,7 +1395,7 @@ def all_items(self) -> Generator:
"""
yield from self.items()

def dropnull(self) -> FlatDict:
def dropnull(self) -> Self:
r"""
Drop key-value pairs with `Null` value.
Expand All @@ -1415,7 +1418,7 @@ def dropnull(self) -> FlatDict:

return self.empty({k: v for k, v in self.all_items() if v is not Null})

def dropna(self) -> FlatDict:
def dropna(self) -> Self:
r"""
Alias of [`dropnull`][chanfig.FlatDict.dropnull].
"""
Expand Down
Loading

0 comments on commit fa4b8a0

Please sign in to comment.