diff --git a/python-spec/src/somacore/query/query.py b/python-spec/src/somacore/query/query.py index 1032517e..136970c9 100644 --- a/python-spec/src/somacore/query/query.py +++ b/python-spec/src/somacore/query/query.py @@ -10,6 +10,7 @@ Tuple, TypeVar, Union, + overload, ) import anndata @@ -19,7 +20,7 @@ import pandas as pd import pyarrow as pa from scipy import sparse -from typing_extensions import Literal, Protocol, Self, TypedDict, assert_never +from typing_extensions import Literal, Protocol, Self, TypedDict from .. import data from .. import measurement @@ -368,14 +369,10 @@ def _read_axis_dataframe( ) -> pa.Table: """Reads the specified axis. Will cache join IDs if not present.""" column_names = axis_column_names.get(axis.value) - if axis is _Axis.OBS: - axis_df = self._obs_df - axis_query = self._matrix_axis_query.obs - elif axis is _Axis.VAR: - axis_df = self._var_df - axis_query = self._matrix_axis_query.var - else: - assert_never(axis) # must be obs or var + + axis_df = axis.getattr_from(self, pre="_", suf="_df") + assert isinstance(axis_df, data.DataFrame) + axis_query = axis.getattr_from(self._matrix_axis_query) # If we can cache join IDs, prepare to add them to the cache. joinids_cached = self._joinids._is_cached(axis) @@ -420,19 +417,24 @@ def _axisp_inner( axis: "_Axis", layer: str, ) -> data.SparseRead: - key = axis.value + "p" - - if key not in self._ms: - raise ValueError(f"Measurement does not contain {key} data") + p_name = f"{axis.value}p" + try: + axisp = axis.getitem_from(self._ms, suf="p") + except KeyError as ke: + raise ValueError(f"Measurement does not contain {p_name} data") from ke - axisp = self._ms.obsp if axis is _Axis.OBS else self._ms.varp - if not (layer and layer in axisp): - raise ValueError(f"Must specify '{key}' layer") - if not isinstance(axisp[layer], data.SparseNDArray): - raise TypeError(f"Unexpected SOMA type stored in '{key}' layer") + try: + ap_layer = axisp[layer] + except KeyError as ke: + raise ValueError(f"layer {layer!r} is not available in {p_name}") from ke + if not isinstance(ap_layer, data.SparseNDArray): + raise TypeError( + f"Unexpected SOMA type {type(ap_layer).__name__}" + f" stored in {p_name} layer {layer!r}" + ) joinids = getattr(self._joinids, axis.value) - return axisp[layer].read((joinids, joinids)) + return ap_layer.read((joinids, joinids)) @property def _obs_df(self) -> data.DataFrame: @@ -493,6 +495,30 @@ class _Axis(enum.Enum): def value(self) -> Literal["obs", "var"]: return super().value + @overload + def getattr_from(self, __source: "_HasObsVar[_T]") -> "_T": + ... + + @overload + def getattr_from( + self, __source: Any, *, pre: Literal[""], suf: Literal[""] + ) -> object: + ... + + @overload + def getattr_from(self, __source: Any, *, pre: str = ..., suf: str = ...) -> object: + ... + + def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object: + """Equivalent to ``something.
``.""" + return getattr(__source, pre + self.value + suf) + + def getitem_from( + self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = "" + ) -> "_T": + """Equivalent to ``something[pre + "obs"/"var" + suf]``.""" + return __source[pre + self.value + suf] + @attrs.define(frozen=True) class _MatrixAxisQuery: @@ -605,6 +631,14 @@ def _to_numpy(it: _Numpyable) -> np.ndarray: return it.to_numpy() +# +# Type shenanigans +# + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + + class _Experimentish(Protocol): """The API we need from an Experiment.""" @@ -615,3 +649,18 @@ def ms(self) -> Mapping[str, measurement.Measurement]: @property def obs(self) -> data.DataFrame: ... + + +class _HasObsVar(Protocol[_T_co]): + """Something which has an ``obs`` and ``var`` field. + + Used to give nicer type inference in :meth:`_Axis.getattr_from`. + """ + + @property + def obs(self) -> _T_co: + ... + + @property + def var(self) -> _T_co: + ... diff --git a/python-spec/testing/test_query_axis.py b/python-spec/testing/test_query_axis.py index a235698d..e9d418f9 100644 --- a/python-spec/testing/test_query_axis.py +++ b/python-spec/testing/test_query_axis.py @@ -1,11 +1,13 @@ from typing import Any, Tuple +import attrs import numpy as np import pytest from pytest import mark import somacore from somacore import options +from somacore.query import query @mark.parametrize( @@ -49,3 +51,24 @@ def test_canonicalization_nparray() -> None: def test_canonicalization_bad(coords) -> None: with pytest.raises(TypeError): somacore.AxisQuery(coords=coords) + + +@attrs.define(frozen=True) +class IHaveObsVarStuff: + obs: int + var: int + the_obs_suf: str + the_var_suf: str + + +def test_axis_helpers() -> None: + thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary") + assert 1 == query._Axis.OBS.getattr_from(thing) + assert 2 == query._Axis.VAR.getattr_from(thing) + assert "observe" == query._Axis.OBS.getattr_from(thing, pre="the_", suf="_suf") + assert "vary" == query._Axis.VAR.getattr_from(thing, pre="the_", suf="_suf") + ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"} + assert "erve" == query._Axis.OBS.getitem_from(ovdict) + assert "y" == query._Axis.VAR.getitem_from(ovdict) + assert "hide" == query._Axis.OBS.getitem_from(ovdict, pre="i_", suf="cure") + assert "???" == query._Axis.VAR.getitem_from(ovdict, pre="i_", suf="cure")