Skip to content

Commit

Permalink
🎨 Replace relation of Collection to FeatureSet with indirect rela…
Browse files Browse the repository at this point in the history
…tion through `Artifact` (#1905)
  • Loading branch information
falexwolf authored Sep 8, 2024
1 parent 39ae9e7 commit 34768a9
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 88 deletions.
102 changes: 40 additions & 62 deletions lamindb/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from lnschema_core.models import (
Collection,
CollectionArtifact,
FeatureManager,
FeatureSet,
)
from lnschema_core.types import VisibilityChoice
Expand All @@ -44,12 +43,45 @@
from ._query_set import QuerySet


class CollectionFeatureManager:
"""Query features of artifact in collection."""

def __init__(self, collection: Collection):
self._collection = collection

def get_feature_sets_union(self) -> dict[str, FeatureSet]:
links_feature_set_artifact = Artifact.feature_sets.through.objects.filter(
artifact_id__in=self._collection.artifacts.values_list("id", flat=True)
)
feature_sets_by_slots = defaultdict(list)
for link in links_feature_set_artifact:
feature_sets_by_slots[link.slot].append(link.featureset_id)
feature_sets_union = {}
for slot, feature_set_ids_slot in feature_sets_by_slots.items():
feature_set_1 = FeatureSet.get(id=feature_set_ids_slot[0])
related_name = feature_set_1._get_related_name()
features_registry = getattr(FeatureSet, related_name).field.model
# this way of writing the __in statement turned out to be the fastest
# evaluated on a link table with 16M entries connecting 500 feature sets with
# 60k genes
feature_ids = (
features_registry.feature_sets.through.objects.filter(
featureset_id__in=feature_set_ids_slot
)
.values(f"{features_registry.__name__.lower()}_id")
.distinct()
)
features = features_registry.filter(id__in=feature_ids)
feature_sets_union[slot] = FeatureSet(features, dtype=feature_set_1.dtype)
return feature_sets_union


def __init__(
collection: Collection,
*args,
**kwargs,
):
collection.features = FeatureManager(collection)
collection.features = CollectionFeatureManager(collection)
if len(args) == len(collection._meta.concrete_fields):
super(Collection, collection).__init__(*args, **kwargs)
return None
Expand Down Expand Up @@ -78,9 +110,6 @@ def __init__(
if "visibility" in kwargs
else VisibilityChoice.default.value
)
feature_sets: dict[str, FeatureSet] = (
kwargs.pop("feature_sets") if "feature_sets" in kwargs else {}
)
if "is_new_version_of" in kwargs:
logger.warning("`is_new_version_of` will be removed soon, please use `revises`")
revises = kwargs.pop("is_new_version_of")
Expand All @@ -98,7 +127,7 @@ def __init__(
if not hasattr(artifacts, "__getitem__"):
raise ValueError("Artifact or List[Artifact] is allowed.")
assert isinstance(artifacts[0], Artifact) # type: ignore # noqa: S101
hash, feature_sets = from_artifacts(artifacts) # type: ignore
hash = from_artifacts(artifacts) # type: ignore
if meta_artifact is not None:
if not isinstance(meta_artifact, Artifact):
raise ValueError("meta_artifact has to be an Artifact")
Expand All @@ -107,11 +136,6 @@ def __init__(
raise ValueError(
"Save meta_artifact artifact before creating collection!"
)
if not feature_sets:
feature_sets = meta_artifact.features._feature_set_by_slot
else:
if len(meta_artifact.features._feature_set_by_slot) > 0:
logger.info("overwriting feature sets linked to artifact")
# we ignore collections in trash containing the same hash
if hash is not None:
existing_collection = Collection.filter(hash=hash).one_or_none()
Expand All @@ -134,11 +158,6 @@ def __init__(
existing_collection.transform = run.transform
init_self_from_db(collection, existing_collection)
update_attributes(collection, {"description": description, "name": name})
for slot, feature_set in collection.features._feature_set_by_slot.items():
if slot in feature_sets:
if not feature_sets[slot] == feature_set:
collection.feature_sets.remove(feature_set)
logger.warning(f"removing feature set: {feature_set}")
else:
kwargs = {}
add_transform_to_kwargs(kwargs, run)
Expand All @@ -161,7 +180,6 @@ def __init__(
)
settings.creation.search_names = search_names_setting
collection._artifacts = artifacts
collection._feature_sets = feature_sets
# register provenance
if revises is not None:
_track_run_input(revises, run=run)
Expand All @@ -171,61 +189,21 @@ def __init__(
# internal function, not exposed to user
def from_artifacts(artifacts: Iterable[Artifact]) -> tuple[str, dict[str, str]]:
# assert all artifacts are already saved
logger.debug("check not saved")
saved = not any(artifact._state.adding for artifact in artifacts)
if not saved:
raise ValueError("Not all artifacts are yet saved, please save them")
# query all feature sets of artifacts
logger.debug("artifact ids")
artifact_ids = [artifact.id for artifact in artifacts]
# query all feature sets at the same time rather
# than making a single query per artifact
logger.debug("links_feature_set_artifact")
links_feature_set_artifact = Artifact.feature_sets.through.objects.filter(
artifact_id__in=artifact_ids
)
feature_sets_by_slots = defaultdict(list)
logger.debug("slots")
for link in links_feature_set_artifact:
feature_sets_by_slots[link.slot].append(link.featureset_id)
feature_sets_union = {}
logger.debug("union")
for slot, feature_set_ids_slot in feature_sets_by_slots.items():
feature_set_1 = FeatureSet.get(id=feature_set_ids_slot[0])
related_name = feature_set_1._get_related_name()
features_registry = getattr(FeatureSet, related_name).field.model
start_time = logger.debug("run filter")
# this way of writing the __in statement turned out to be the fastest
# evaluated on a link table with 16M entries connecting 500 feature sets with
# 60k genes
feature_ids = (
features_registry.feature_sets.through.objects.filter(
featureset_id__in=feature_set_ids_slot
)
.values(f"{features_registry.__name__.lower()}_id")
.distinct()
)
start_time = logger.debug("done, start evaluate", time=start_time)
features = features_registry.filter(id__in=feature_ids)
feature_sets_union[slot] = FeatureSet(features, dtype=feature_set_1.dtype)
start_time = logger.debug("done", time=start_time)
# validate consistency of hashes
# we do not allow duplicate hashes
logger.debug("hashes")
# artifact.hash is None for zarr
# todo: more careful handling of such cases
# validate consistency of hashes - we do not allow duplicate hashes
hashes = [artifact.hash for artifact in artifacts if artifact.hash is not None]
if len(hashes) != len(set(hashes)):
hashes_set = set(hashes)
if len(hashes) != len(hashes_set):
seen = set()
non_unique = [x for x in hashes if x in seen or seen.add(x)] # type: ignore
raise ValueError(
"Please pass artifacts with distinct hashes: these ones are non-unique"
f" {non_unique}"
)
time = logger.debug("hash")
hash = hash_set(set(hashes))
logger.debug("done", time=time)
return hash, feature_sets_union
hash = hash_set(hashes_set)
return hash


# docstring handled through attach_func_to_class_method
Expand Down
18 changes: 9 additions & 9 deletions lamindb/_curate.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ def save_artifact(self, description: str | None = None, **kwargs) -> Artifact:
from lamindb.core._settings import settings

if not self._validated:
raise ValidationError(
f"Data object is not validated, please run {colors.yellow('validate()')}!"
)
self.validate()
if not self._validated:
raise ValidationError("Dataset does not validate. Please curate.")

# Make sure all labels are saved in the current instance
verbosity = settings.verbosity
Expand Down Expand Up @@ -442,7 +442,7 @@ def __init__(
exclude=exclude,
check_valid_keys=False,
)
self._obs_fields = categoricals
self._obs_fields = categoricals or {}
self._check_valid_keys(extra={"var_index"})

@property
Expand Down Expand Up @@ -563,9 +563,9 @@ def save_artifact(self, description: str | None = None, **kwargs) -> Artifact:
A saved artifact record.
"""
if not self._validated:
raise ValidationError(
f"Data object is not validated, please run {colors.yellow('validate()')}!"
)
self.validate()
if not self._validated:
raise ValidationError("Dataset does not validate. Please curate.")

self._artifact = save_artifact(
self._data,
Expand Down Expand Up @@ -1498,14 +1498,14 @@ def log_saved_labels(

if k == "without reference" and validated_only:
msg = colors.yellow(
f"{len(labels)} non-validated categories are not saved in {model_field}: {labels}!"
f"{len(labels)} non-validated values are not saved in {model_field}: {labels}!"
)
lookup_print = (
f"lookup().{key}" if key.isidentifier() else f".lookup()['{key}']"
)

hint = f".add_new_from('{key}')"
msg += f"\n → to lookup categories, use {lookup_print}"
msg += f"\n → to lookup values, use {lookup_print}"
msg += (
f"\n → to save, run {colors.yellow(hint)}"
if save_function == "add_new_from"
Expand Down
11 changes: 5 additions & 6 deletions lamindb/core/_feature_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
Collection,
Feature,
FeatureManager,
FeatureManagerArtifact,
FeatureManagerCollection,
FeatureValue,
LinkORM,
Param,
Expand Down Expand Up @@ -360,7 +358,7 @@ def __getitem__(self, slot) -> QuerySet:


def filter_base(cls, **expression):
if cls in {FeatureManagerArtifact, FeatureManagerCollection}:
if cls is FeatureManager:
model = Feature
value_model = FeatureValue
else:
Expand Down Expand Up @@ -392,10 +390,11 @@ def filter_base(cls, **expression):
new_expression["ulabels"] = label
else:
raise NotImplementedError
if cls == FeatureManagerArtifact or cls == ParamManagerArtifact:
if cls == FeatureManager or cls == ParamManagerArtifact:
return Artifact.filter(**new_expression)
elif cls == FeatureManagerCollection:
return Collection.filter(**new_expression)
# might renable something similar in the future
# elif cls == FeatureManagerCollection:
# return Collection.filter(**new_expression)
elif cls == ParamManagerRun:
return Run.filter(**new_expression)

Expand Down
2 changes: 1 addition & 1 deletion sub/lnschema-core
28 changes: 20 additions & 8 deletions tests/core/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,16 @@ def test_from_inconsistent_artifacts(df, adata):


def test_from_consistent_artifacts(adata, adata2):
artifact1 = ln.Artifact.from_anndata(adata, description="My test")
artifact1.save()
artifact2 = ln.Artifact.from_anndata(adata2, description="My test2")
artifact2.save()
transform = ln.Transform(name="My test transform")
transform.save()
run = ln.Run(transform)
run.save()
ln.Feature(name="feat1", dtype="number").save()
curator = ln.Curator.from_anndata(adata, var_index=bt.Gene.symbol, organism="human")
curator.add_validated_from_var_index()
artifact1 = curator.save_artifact(description="My test")
curator = ln.Curator.from_anndata(
adata2, var_index=bt.Gene.symbol, organism="human"
)
artifact2 = curator.save_artifact(description="My test2").save()
transform = ln.Transform(name="My test transform").save()
run = ln.Run(transform).save()
collection = ln.Collection([artifact1, artifact2], name="My test", run=run)
assert collection._state.adding
collection.save()
Expand All @@ -159,6 +161,14 @@ def test_from_consistent_artifacts(adata, adata2):
assert "artifact_uid" in adata_joined.obs.columns
assert artifact1.uid in adata_joined.obs.artifact_uid.cat.categories

feature_sets = collection.features.get_feature_sets_union()
assert set(feature_sets["var"].members.values_list("symbol", flat=True)) == {
"MYC",
"TCF7",
"GATA1",
}
assert set(feature_sets["obs"].members.values_list("name", flat=True)) == {"feat1"}

# re-run with hash-based lookup
collection2 = ln.Collection([artifact1, artifact2], name="My test 1", run=run)
assert not collection2._state.adding
Expand All @@ -168,6 +178,8 @@ def test_from_consistent_artifacts(adata, adata2):
collection.delete(permanent=True)
artifact1.delete(permanent=True)
artifact2.delete(permanent=True)
ln.FeatureSet.filter().delete()
ln.Feature.filter().delete()


def test_collection_mapped(adata, adata2):
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_curate.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_unvalidated_data_object(df, categoricals):
curate = ln.Curator.from_df(df, categoricals=categoricals)
with pytest.raises(ValidationError) as error:
curate.save_artifact()
assert "Data object is not validated" in str(error.value)
assert "Dataset does not validate. Please curate." in str(error.value)


def test_clean_up_failed_runs():
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_unvalidated_adata_object(adata, categoricals):
)
with pytest.raises(ValidationError) as error:
curate.save_artifact()
assert "Data object is not validated" in str(error.value)
assert "Dataset does not validate. Please curate." in str(error.value)


def test_mudata_annotator(mdata):
Expand Down

0 comments on commit 34768a9

Please sign in to comment.