diff --git a/python-spec/src/somacore/query/query.py b/python-spec/src/somacore/query/query.py index 1032517e..136970c9 100644 --- a/python-spec/src/somacore/query/query.py +++ b/python-spec/src/somacore/query/query.py @@ -10,6 +10,7 @@ Tuple, TypeVar, Union, + overload, ) import anndata @@ -19,7 +20,7 @@ import pandas as pd import pyarrow as pa from scipy import sparse -from typing_extensions import Literal, Protocol, Self, TypedDict, assert_never +from typing_extensions import Literal, Protocol, Self, TypedDict from .. import data from .. import measurement @@ -368,14 +369,10 @@ def _read_axis_dataframe( ) -> pa.Table: """Reads the specified axis. Will cache join IDs if not present.""" column_names = axis_column_names.get(axis.value) - if axis is _Axis.OBS: - axis_df = self._obs_df - axis_query = self._matrix_axis_query.obs - elif axis is _Axis.VAR: - axis_df = self._var_df - axis_query = self._matrix_axis_query.var - else: - assert_never(axis) # must be obs or var + + axis_df = axis.getattr_from(self, pre="_", suf="_df") + assert isinstance(axis_df, data.DataFrame) + axis_query = axis.getattr_from(self._matrix_axis_query) # If we can cache join IDs, prepare to add them to the cache. joinids_cached = self._joinids._is_cached(axis) @@ -420,19 +417,24 @@ def _axisp_inner( axis: "_Axis", layer: str, ) -> data.SparseRead: - key = axis.value + "p" - - if key not in self._ms: - raise ValueError(f"Measurement does not contain {key} data") + p_name = f"{axis.value}p" + try: + axisp = axis.getitem_from(self._ms, suf="p") + except KeyError as ke: + raise ValueError(f"Measurement does not contain {p_name} data") from ke - axisp = self._ms.obsp if axis is _Axis.OBS else self._ms.varp - if not (layer and layer in axisp): - raise ValueError(f"Must specify '{key}' layer") - if not isinstance(axisp[layer], data.SparseNDArray): - raise TypeError(f"Unexpected SOMA type stored in '{key}' layer") + try: + ap_layer = axisp[layer] + except KeyError as ke: + raise ValueError(f"layer {layer!r} is not available in {p_name}") from ke + if not isinstance(ap_layer, data.SparseNDArray): + raise TypeError( + f"Unexpected SOMA type {type(ap_layer).__name__}" + f" stored in {p_name} layer {layer!r}" + ) joinids = getattr(self._joinids, axis.value) - return axisp[layer].read((joinids, joinids)) + return ap_layer.read((joinids, joinids)) @property def _obs_df(self) -> data.DataFrame: @@ -493,6 +495,30 @@ class _Axis(enum.Enum): def value(self) -> Literal["obs", "var"]: return super().value + @overload + def getattr_from(self, __source: "_HasObsVar[_T]") -> "_T": + ... + + @overload + def getattr_from( + self, __source: Any, *, pre: Literal[""], suf: Literal[""] + ) -> object: + ... + + @overload + def getattr_from(self, __source: Any, *, pre: str = ..., suf: str = ...) -> object: + ... + + def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object: + """Equivalent to ``something.
``."""
+        return getattr(__source, pre + self.value + suf)
+
+    def getitem_from(
+        self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = ""
+    ) -> "_T":
+        """Equivalent to ``something[pre + "obs"/"var" + suf]``."""
+        return __source[pre + self.value + suf]
+
 
 @attrs.define(frozen=True)
 class _MatrixAxisQuery:
@@ -605,6 +631,14 @@ def _to_numpy(it: _Numpyable) -> np.ndarray:
     return it.to_numpy()
 
 
+#
+# Type shenanigans
+#
+
+_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
+
+
 class _Experimentish(Protocol):
     """The API we need from an Experiment."""
 
@@ -615,3 +649,18 @@ def ms(self) -> Mapping[str, measurement.Measurement]:
     @property
     def obs(self) -> data.DataFrame:
         ...
+
+
+class _HasObsVar(Protocol[_T_co]):
+    """Something which has an ``obs`` and ``var`` field.
+
+    Used to give nicer type inference in :meth:`_Axis.getattr_from`.
+    """
+
+    @property
+    def obs(self) -> _T_co:
+        ...
+
+    @property
+    def var(self) -> _T_co:
+        ...
diff --git a/python-spec/testing/test_query_axis.py b/python-spec/testing/test_query_axis.py
index a235698d..e9d418f9 100644
--- a/python-spec/testing/test_query_axis.py
+++ b/python-spec/testing/test_query_axis.py
@@ -1,11 +1,13 @@
 from typing import Any, Tuple
 
+import attrs
 import numpy as np
 import pytest
 from pytest import mark
 
 import somacore
 from somacore import options
+from somacore.query import query
 
 
 @mark.parametrize(
@@ -49,3 +51,24 @@ def test_canonicalization_nparray() -> None:
 def test_canonicalization_bad(coords) -> None:
     with pytest.raises(TypeError):
         somacore.AxisQuery(coords=coords)
+
+
+@attrs.define(frozen=True)
+class IHaveObsVarStuff:
+    obs: int
+    var: int
+    the_obs_suf: str
+    the_var_suf: str
+
+
+def test_axis_helpers() -> None:
+    thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
+    assert 1 == query._Axis.OBS.getattr_from(thing)
+    assert 2 == query._Axis.VAR.getattr_from(thing)
+    assert "observe" == query._Axis.OBS.getattr_from(thing, pre="the_", suf="_suf")
+    assert "vary" == query._Axis.VAR.getattr_from(thing, pre="the_", suf="_suf")
+    ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
+    assert "erve" == query._Axis.OBS.getitem_from(ovdict)
+    assert "y" == query._Axis.VAR.getitem_from(ovdict)
+    assert "hide" == query._Axis.OBS.getitem_from(ovdict, pre="i_", suf="cure")
+    assert "???" == query._Axis.VAR.getitem_from(ovdict, pre="i_", suf="cure")