From 9f1fe5c541460b2f998b3ee0eb749bd417abe04d Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Thu, 23 Jun 2022 11:57:40 -0400 Subject: [PATCH] WIP: starting slicing --- .pre-commit-config.yaml | 2 +- src/boost_histogram/_internal/axis.py | 22 +++++++++++++++++++++- tests/test_axis.py | 20 ++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d403248c..2fcc6f97 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -70,7 +70,7 @@ repos: hooks: - id: mypy files: ^src - additional_dependencies: [numpy==1.22.4, pytest, uhi, types-dataclasses] + additional_dependencies: [numpy~=1.23.0, pytest, uhi, types-dataclasses] - repo: https://github.com/mgedmin/check-manifest rev: "0.48" diff --git a/src/boost_histogram/_internal/axis.py b/src/boost_histogram/_internal/axis.py index 11005eed..9fc8f3d2 100644 --- a/src/boost_histogram/_internal/axis.py +++ b/src/boost_histogram/_internal/axis.py @@ -226,10 +226,16 @@ def extent(self) -> int: """ return self._ax.extent # type: ignore[no-any-return] - def __getitem__(self, i: AxCallOrInt) -> Union[int, str, Tuple[float, float]]: + def __getitem__( + self: T, i: Union[AxCallOrInt, slice] + ) -> Union[int, str, Tuple[float, float], T]: """ Access a bin, using normal Python syntax for wraparound. """ + if isinstance(i, slice): + raise NotImplementedError( + f"Slicing not supported on {self.__class__.__name__}" + ) # UHI support if callable(i): i = i(self) @@ -241,6 +247,7 @@ def __getitem__(self, i: AxCallOrInt) -> Union[int, str, Tuple[float, float]]: f"Out of range access, {i} is more than {self._ax.size}" ) assert not callable(i) + assert not isinstance(i, slice) return self.bin(i) @property @@ -612,6 +619,9 @@ def _repr_args_(self) -> List[str]: return ret +TStrC = TypeVar("TStrC", bound="StrCategory") + + @set_module("boost_histogram.axis") @register({ca.category_str_growth, ca.category_str}) class StrCategory(BaseCategory, family=boost_histogram): @@ -660,6 +670,16 @@ def __init__( super().__init__(ax, metadata, __dict__) + def __getitem__( + self: TStrC, i: Union[AxCallOrInt, slice] + ) -> Union[int, str, Tuple[float, float], TStrC]: + + if isinstance(i, slice): + new_cats = list(self)[i] + return self.__class__(new_cats, __dict__=self.__dict__) # type: ignore[arg-type] + else: + return super().__getitem__(i) + def index(self, value: Union[float, str]) -> int: """ Return the fractional index(es) given a value (or values) on the axis. diff --git a/tests/test_axis.py b/tests/test_axis.py index 089694c2..0daf3aee 100644 --- a/tests/test_axis.py +++ b/tests/test_axis.py @@ -839,6 +839,26 @@ def test_edges_centers_widths(self, ref, growth): assert_allclose(a.centers, [0.5, 1.5, 2.5]) assert_allclose(a.widths, [1, 1, 1]) + def test_slicing(self, growth): + Cat = bh.axis.StrCategory + ref = ["a", "b", "c", "d", "e"] + + a = Cat(ref, growth=growth) + b = a[1:3] + assert list(a)[1:3] == list(b) + assert a.__dict__ == b.__dict__ + assert a.traits.growth == b.traits.growth + + def test_empty_slice(self, growth): + Cat = bh.axis.StrCategory + ref = ["a", "b", "c", "d", "e"] + a = Cat(ref, growth=growth) + if growth: + assert a[0:0] == Cat([], growth=True) + else: + with pytest.raises(RuntimeError): + a[0:0] + class TestBoolean: def test_init(self):