Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _Axis.getattr_from and _Axis.getitem_from. #183

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 68 additions & 19 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Tuple,
TypeVar,
Union,
overload,
)

import anndata
Expand All @@ -19,7 +20,7 @@
import pandas as pd
import pyarrow as pa
from scipy import sparse
from typing_extensions import Literal, Protocol, Self, TypedDict, assert_never
from typing_extensions import Literal, Protocol, Self, TypedDict

from .. import data
from .. import measurement
Expand Down Expand Up @@ -368,14 +369,10 @@ def _read_axis_dataframe(
) -> pa.Table:
"""Reads the specified axis. Will cache join IDs if not present."""
column_names = axis_column_names.get(axis.value)
if axis is _Axis.OBS:
axis_df = self._obs_df
axis_query = self._matrix_axis_query.obs
elif axis is _Axis.VAR:
axis_df = self._var_df
axis_query = self._matrix_axis_query.var
else:
assert_never(axis) # must be obs or var

axis_df = axis.getattr_from(self, pre="_", suf="_df")
assert isinstance(axis_df, data.DataFrame)
axis_query = axis.getattr_from(self._matrix_axis_query)

# If we can cache join IDs, prepare to add them to the cache.
joinids_cached = self._joinids._is_cached(axis)
Expand Down Expand Up @@ -420,19 +417,24 @@ def _axisp_inner(
axis: "_Axis",
layer: str,
) -> data.SparseRead:
key = axis.value + "p"

if key not in self._ms:
raise ValueError(f"Measurement does not contain {key} data")
p_name = f"{axis.value}p"
try:
axisp = axis.getitem_from(self._ms, suf="p")
except KeyError as ke:
raise ValueError(f"Measurement does not contain {p_name} data") from ke

axisp = self._ms.obsp if axis is _Axis.OBS else self._ms.varp
if not (layer and layer in axisp):
raise ValueError(f"Must specify '{key}' layer")
if not isinstance(axisp[layer], data.SparseNDArray):
raise TypeError(f"Unexpected SOMA type stored in '{key}' layer")
try:
ap_layer = axisp[layer]
except KeyError as ke:
raise ValueError(f"layer {layer!r} is not available in {p_name}") from ke
if not isinstance(ap_layer, data.SparseNDArray):
raise TypeError(
f"Unexpected SOMA type {type(ap_layer).__name__}"
f" stored in {p_name} layer {layer!r}"
)

joinids = getattr(self._joinids, axis.value)
return axisp[layer].read((joinids, joinids))
return ap_layer.read((joinids, joinids))

@property
def _obs_df(self) -> data.DataFrame:
Expand Down Expand Up @@ -493,6 +495,30 @@ class _Axis(enum.Enum):
def value(self) -> Literal["obs", "var"]:
return super().value

@overload
def getattr_from(self, __source: "_HasObsVar[_T]") -> "_T":
...

@overload
def getattr_from(
self, __source: Any, *, pre: Literal[""], suf: Literal[""]
) -> object:
...

@overload
def getattr_from(self, __source: Any, *, pre: str = ..., suf: str = ...) -> object:
...

def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object:
"""Equivalent to ``something.<pre><obs/var><suf>``."""
return getattr(__source, pre + self.value + suf)

def getitem_from(
self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = ""
) -> "_T":
"""Equivalent to ``something[pre + "obs"/"var" + suf]``."""
return __source[pre + self.value + suf]


@attrs.define(frozen=True)
class _MatrixAxisQuery:
Expand Down Expand Up @@ -605,6 +631,14 @@ def _to_numpy(it: _Numpyable) -> np.ndarray:
return it.to_numpy()


#
# Type shenanigans
#

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)


class _Experimentish(Protocol):
"""The API we need from an Experiment."""

Expand All @@ -615,3 +649,18 @@ def ms(self) -> Mapping[str, measurement.Measurement]:
@property
def obs(self) -> data.DataFrame:
...


class _HasObsVar(Protocol[_T_co]):
"""Something which has an ``obs`` and ``var`` field.

Used to give nicer type inference in :meth:`_Axis.getattr_from`.
"""

@property
def obs(self) -> _T_co:
...

@property
def var(self) -> _T_co:
...
23 changes: 23 additions & 0 deletions python-spec/testing/test_query_axis.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, Tuple

import attrs
import numpy as np
import pytest
from pytest import mark

import somacore
from somacore import options
from somacore.query import query


@mark.parametrize(
Expand Down Expand Up @@ -49,3 +51,24 @@ def test_canonicalization_nparray() -> None:
def test_canonicalization_bad(coords) -> None:
with pytest.raises(TypeError):
somacore.AxisQuery(coords=coords)


@attrs.define(frozen=True)
class IHaveObsVarStuff:
obs: int
var: int
the_obs_suf: str
the_var_suf: str


def test_axis_helpers() -> None:
thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
assert 1 == query._Axis.OBS.getattr_from(thing)
assert 2 == query._Axis.VAR.getattr_from(thing)
assert "observe" == query._Axis.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == query._Axis.VAR.getattr_from(thing, pre="the_", suf="_suf")
ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
assert "erve" == query._Axis.OBS.getitem_from(ovdict)
assert "y" == query._Axis.VAR.getitem_from(ovdict)
assert "hide" == query._Axis.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == query._Axis.VAR.getitem_from(ovdict, pre="i_", suf="cure")