From 4184c05acc4a13307bd61d8aee8ab314d7778347 Mon Sep 17 00:00:00 2001 From: Braden MacDonald Date: Thu, 5 Oct 2023 11:29:25 -0700 Subject: [PATCH] feat: new implementation of get_filtered_tags --- openedx_tagging/core/tagging/api.py | 118 ++--- openedx_tagging/core/tagging/data.py | 24 + .../core/tagging/import_export/parsers.py | 6 +- openedx_tagging/core/tagging/models/base.py | 249 +++++++--- .../core/tagging/models/system_defined.py | 47 -- .../core/fixtures/tagging.yaml | 21 + .../openedx_tagging/core/tagging/test_api.py | 13 +- .../core/tagging/test_models.py | 446 +++++++++++++----- 8 files changed, 583 insertions(+), 341 deletions(-) create mode 100644 openedx_tagging/core/tagging/data.py diff --git a/openedx_tagging/core/tagging/api.py b/openedx_tagging/core/tagging/api.py index e9140968..58a04356 100644 --- a/openedx_tagging/core/tagging/api.py +++ b/openedx_tagging/core/tagging/api.py @@ -13,9 +13,10 @@ from __future__ import annotations from django.db import transaction -from django.db.models import F, QuerySet +from django.db.models import QuerySet from django.utils.translation import gettext as _ +from .data import TagData from .models import ObjectTag, Tag, Taxonomy # Export this as part of the API @@ -70,54 +71,56 @@ def get_taxonomies(enabled=True) -> QuerySet[Taxonomy]: return queryset.filter(enabled=enabled) -def get_tags(taxonomy: Taxonomy) -> list[Tag]: +def get_tags(taxonomy: Taxonomy) -> QuerySet[TagData]: """ - Returns a list of predefined tags for the given taxonomy. + Returns a QuerySet of all the tags in the given taxonomy. - Note that if the taxonomy allows free-text tags, then the returned list will be empty. + Note that if the taxonomy is dynamic or free-text, only tags that have + already been applied to some object will be returned. """ - return taxonomy.cast().get_tags() + return taxonomy.cast().get_filtered_tags() -def get_root_tags(taxonomy: Taxonomy) -> list[Tag]: +def get_root_tags(taxonomy: Taxonomy) -> QuerySet[TagData]: """ Returns a list of the root tags for the given taxonomy. Note that if the taxonomy allows free-text tags, then the returned list will be empty. """ - return list(taxonomy.cast().get_filtered_tags()) + return taxonomy.cast().get_filtered_tags(depth=1) -def search_tags(taxonomy: Taxonomy, search_term: str) -> list[Tag]: +def search_tags(taxonomy: Taxonomy, search_term: str, exclude_object_id: int | None = None) -> QuerySet[TagData]: """ - Returns a list of all tags that contains `search_term` of the given taxonomy. + Returns a list of all tags that contains `search_term` of the given + taxonomy, as well as their ancestors (so they can be displayed in a tree). - Note that if the taxonomy allows free-text tags, then the returned list will be empty. + If exclude_object_id is set, any tags applied to that object will be + excluded from the results, e.g. to power an autocomplete search when adding + additional tags to an object. """ - return list( - taxonomy.cast().get_filtered_tags( - search_term=search_term, - search_in_all=True, + qs = taxonomy.cast().get_filtered_tags(search_term=search_term) + if exclude_object_id: + # Fetch tags that the object already has to exclude them from the result + excluded_values = list( + taxonomy.objecttag_set.filter(object_id=exclude_object_id).values_list( + "_value", flat=True + ) ) - ) + qs = qs.exclude(value__in=excluded_values) + return qs def get_children_tags( taxonomy: Taxonomy, - parent_tag_id: int, - search_term: str | None = None, -) -> list[Tag]: + parent_tag_value: str, +) -> QuerySet[TagData]: """ - Returns a list of children tags for the given parent tag. + Returns a QuerySet of children tags for the given parent tag. Note that if the taxonomy allows free-text tags, then the returned list will be empty. """ - return list( - taxonomy.cast().get_filtered_tags( - parent_tag_id=parent_tag_id, - search_term=search_term, - ) - ) + return taxonomy.cast().get_filtered_tags(parent_tag_value=parent_tag_value) def resync_object_tags(object_tags: QuerySet | None = None) -> int: @@ -250,68 +253,3 @@ def _check_new_tag_count(new_tag_count: int) -> None: for object_tag in updated_tags: object_tag.full_clean() # Run validation object_tag.save() - - -# TODO: return tags from closed taxonomies as well as the count of how many times each is used. -def autocomplete_tags( - taxonomy: Taxonomy, - search: str, - object_id: str | None = None, - object_tags_only=True, -) -> QuerySet: - """ - Provides auto-complete suggestions by matching the `search` string against existing - ObjectTags linked to the given taxonomy. A case-insensitive search is used in order - to return the highest number of relevant tags. - - If `object_id` is provided, then object tag values already linked to this object - are omitted from the returned suggestions. (ObjectTag values must be unique for a - given object + taxonomy, and so omitting these suggestions helps users avoid - duplication errors.). - - Returns a QuerySet of dictionaries containing distinct `value` (string) and - `tag` (numeric ID) values, sorted alphabetically by `value`. - The `value` is what should be shown as a suggestion to users, - and if it's a free-text taxonomy, `tag` will be `None`: we include the `tag` ID - in anticipation of the second use case listed below. - - Use cases: - * This method is useful for reducing tag variation in free-text taxonomies by showing - users tags that are similar to what they're typing. E.g., if the `search` string "dn" - shows that other objects have been tagged with "DNA", "DNA electrophoresis", and "DNA fingerprinting", - this encourages users to use those existing tags if relevant, instead of creating new ones that - look similar (e.g. "dna finger-printing"). - * It could also be used to assist tagging for closed taxonomies with a list of possible tags which is too - large to return all at once, e.g. a user model taxonomy that dynamically creates tags on request for any - registered user in the database. (Note that this is not implemented yet, but may be as part of a future change.) - """ - if not object_tags_only: - raise NotImplementedError( - _( - "Using this would return a query set of tags instead of object tags." - "For now we recommend fetching all of the taxonomy's tags " - "using get_tags() and filtering them on the frontend." - ) - ) - # Fetch tags that the object already has to exclude them from the result - excluded_tags: list[str] = [] - if object_id: - excluded_tags = list( - taxonomy.objecttag_set.filter(object_id=object_id).values_list( - "_value", flat=True - ) - ) - return ( - # Fetch object tags from this taxonomy whose value contains the search - taxonomy.objecttag_set.filter(_value__icontains=search) - # omit any tags whose values match the tags on the given object - .exclude(_value__in=excluded_tags) - # alphabetical ordering - .order_by("_value") - # Alias the `_value` field to `value` to make it nicer for users - .annotate(value=F("_value")) - # obtain tag values - .values("value", "tag_id") - # remove repeats - .distinct() - ) diff --git a/openedx_tagging/core/tagging/data.py b/openedx_tagging/core/tagging/data.py new file mode 100644 index 00000000..b1a0ec3e --- /dev/null +++ b/openedx_tagging/core/tagging/data.py @@ -0,0 +1,24 @@ +""" +Data models used by openedx-tagging +""" +from __future__ import annotations + +from typing import TypedDict + + +class TagData(TypedDict): + """ + Data about a single tag. Many of the tagging API methods return Django + QuerySets that resolve to these dictionaries. + + Even though the data will be in this same format, it will not necessarily + be an instance of this class but rather a plain dictionary. This is more a + type than a class. + """ + value: str + external_id: str | None + child_count: int + depth: int + parent_value: str | None + # Note: usage_count may not actually be present but there's no way to indicate that w/ python types at the moment + usage_count: int diff --git a/openedx_tagging/core/tagging/import_export/parsers.py b/openedx_tagging/core/tagging/import_export/parsers.py index c0b8207f..e5db49aa 100644 --- a/openedx_tagging/core/tagging/import_export/parsers.py +++ b/openedx_tagging/core/tagging/import_export/parsers.py @@ -168,12 +168,12 @@ def _load_tags_for_export(cls, taxonomy: Taxonomy) -> list[dict]: The tags are ordered by hierarchy, first, parents and then children. `get_tags` is in charge of returning this in a hierarchical way. """ - tags = get_tags(taxonomy) + tags = Taxonomy.get_filtered_tags().all() result = [] for tag in tags: result_tag = { - "id": tag.external_id or tag.id, - "value": tag.value, + "id": tag["external_id"] or tag["id"], + "value": tag["value"], } if tag.parent: result_tag["parent_id"] = tag.parent.external_id or tag.parent.id diff --git a/openedx_tagging/core/tagging/models/base.py b/openedx_tagging/core/tagging/models/base.py index d24d67d4..b5656752 100644 --- a/openedx_tagging/core/tagging/models/base.py +++ b/openedx_tagging/core/tagging/models/base.py @@ -8,12 +8,17 @@ from django.core.exceptions import ValidationError from django.db import models +from django.db.models import F, Q, Value +from django.db.models.functions import Coalesce, Concat +from django.utils.functional import cached_property from django.utils.module_loading import import_string from django.utils.translation import gettext_lazy as _ from typing_extensions import Self # Until we upgrade to python 3.11 from openedx_learning.lib.fields import MultiCollationTextField, case_insensitive_char_field, case_sensitive_char_field +from ..data import TagData + log = logging.getLogger(__name__) @@ -105,6 +110,35 @@ def get_lineage(self) -> Lineage: tag = tag.parent depth -= 1 return lineage + + @cached_property + def num_ancestors(self) -> int: + """ + How many ancestors this Tag has. Equivalent to its "depth" in the tree. + Zero for root tags. + """ + num_ancestors = 0 + tag = self + while tag.parent: + num_ancestors += 1 + tag = tag.parent + return num_ancestors + + @staticmethod + def annotate_depth(qs: models.QuerySet) -> models.QuerySet: + """ + Given a query that loads Tag objects, annotate it with the depth of + each tag. + """ + return qs.annotate(depth=models.Case( + models.When(parent_id=None, then=0), + models.When(parent__parent_id=None, then=1), + models.When(parent__parent__parent_id=None, then=2), + models.When(parent__parent__parent__parent_id=None, then=3), + # If the depth is 4 or more, currently we just "collapse" the depth + # to 4 in order not to add too many joins to this query in general. + default=4, + )) class Taxonomy(models.Model): @@ -260,87 +294,174 @@ def copy(self, taxonomy: Taxonomy) -> Taxonomy: self._taxonomy_class = taxonomy._taxonomy_class # pylint: disable=protected-access return self - def get_tags( + def get_filtered_tags( self, - tag_set: models.QuerySet[Tag] | None = None, - ) -> list[Tag]: + depth: int | None = TAXONOMY_MAX_DEPTH, + parent_tag_value: str | None = None, + search_term: str | None = None, + include_counts: bool = True, + ) -> models.QuerySet[TagData]: """ - Returns a list of all Tags in the current taxonomy, from the root(s) - down to TAXONOMY_MAX_DEPTH tags, in tree order. + Returns a filtered QuerySet of tag values. + For free text or dynamic taxonomies, this will only return tag values + that have actually been used. - Use `tag_set` to do an initial filtering of the tags. + By default returns all the tags of the given taxonomy - Annotates each returned Tag with its ``depth`` in the tree (starting at - 0). + Use `depth=1` to return a single level of tags, without any child + tags included. Use `depth=None` or `depth=TAXONOMY_MAX_DEPTH` to return + all descendants of the tags, up to our maximum supported depth. - Performance note: may perform as many as TAXONOMY_MAX_DEPTH select - queries. + Use `parent_tag_value` to return only the children/descendants of a specific tag. + + Use `search_term` to filter the results by values that contains `search_term`. + + Note: This is mostly an 'internal' API and generally code outside of openedx_tagging + should use the APIs in openedx_tagging.api which in turn use this. """ - tags: list[Tag] = [] if self.allow_free_text: - return tags - - if tag_set is None: - tag_set = self.tag_set.all() - - parents = None - - for depth in range(TAXONOMY_MAX_DEPTH): - filtered_tags = tag_set.prefetch_related("parent") - if parents is None: - filtered_tags = filtered_tags.filter(parent=None) - else: - filtered_tags = filtered_tags.filter(parent__in=parents) - next_parents = list( - filtered_tags.annotate( - annotated_field=models.Value( - depth, output_field=models.IntegerField() - ) - ) - .order_by("parent__value", "value", "id") - .all() + if parent_tag_value is not None: + raise ValueError("Cannot specify a parent tag ID for free text taxonomies") + return self._get_filtered_tags_free_text(search_term=search_term, include_counts=include_counts) + elif depth == 1: + return self._get_filtered_tags_one_level( + parent_tag_value=parent_tag_value, + search_term=search_term, + include_counts=include_counts, + ) + elif depth is None or depth == TAXONOMY_MAX_DEPTH: + return self._get_filtered_tags_deep( + parent_tag_value=parent_tag_value, + search_term=search_term, + include_counts=include_counts, ) - tags.extend(next_parents) - parents = next_parents - if not parents: - break - return tags + else: + raise ValueError("Unsupported depth value for get_filtered_tags()") - def get_filtered_tags( + def _get_filtered_tags_free_text( self, - tag_set: models.QuerySet[Tag] | None = None, - parent_tag_id: int | None = None, - search_term: str | None = None, - search_in_all: bool = False, - ) -> models.QuerySet[Tag]: + search_term: str | None, + include_counts: bool, + ) -> models.QuerySet[TagData]: """ - Returns a filtered QuerySet of tags. - By default returns the root tags of the given taxonomy - - Use `parent_tag_id` to return the children of a tag. - - Use `search_term` to filter the results by values that contains `search_term`. + Implementation of get_filtered_tags() for free text taxonomies. + """ + assert self.allow_free_text + qs: models.QuerySet = self.objecttag_set.all() + if search_term: + qs = qs.filter(_value__icontains=search_term) + # Rename "_value" to "value" + qs = qs.annotate(value=F('_value')) + # Add in all these fixed fields that don't really apply to free text tags, but we include for consistency: + qs = qs.annotate( + depth=Value(0), + child_count=Value(0), + external_id=Value(None, output_field=models.CharField()), + parent_value=Value(None, output_field=models.CharField()), + ) + qs = qs.values("value", "child_count", "depth", "parent_value", "external_id").order_by("value") + if include_counts: + return qs.annotate(usage_count=models.Count("value")) + else: + return qs.distinct() - Set `search_in_all` to True to make the search in all tags on the given taxonomy. + def _get_filtered_tags_one_level( + self, + parent_tag_value: str | None, + search_term: str | None, + include_counts: bool, + ) -> models.QuerySet[TagData]: + """ + Implementation of get_filtered_tags() for closed taxonomies, where + depth=1. When depth=1, we're only looking at a single "level" of the + taxononomy, like all root tags or all children of a specific tag. + """ + # A closed, and possibly hierarchical taxonomy. We're just fetching a single "level" of tags. + if parent_tag_value: + parent_tag = self.tag_for_value(parent_tag_value) + qs: models.QuerySet = self.tag_set.filter(parent_id=parent_tag.pk) + qs = qs.annotate(depth=Value(parent_tag.num_ancestors + 1)) + # Use parent_tag.value not parent_tag_value because they may differ in case + qs = qs.annotate(parent_value=Value(parent_tag.value)) + else: + qs = self.tag_set.filter(parent=None).annotate(depth=Value(0)) + qs = qs.annotate(parent_value=Value(None, output_field=models.CharField())) + qs = qs.annotate(child_count=models.Count("children")) + # Filter by search term: + if search_term: + qs = qs.filter(value__icontains=search_term) + qs = qs.values("value", "child_count", "depth", "parent_value", "external_id").order_by("value") + if include_counts: + # We need to include the count of how many times this tag is used to tag objects. + # You'd think we could just use: + # qs = qs.annotate(usage_count=models.Count("objecttag__pk")) + # but that adds another join which starts creating a cross product and the children and usage_count become + # intertwined and multiplied with each other. So we use a subquery. + obj_tags = ObjectTag.objects.filter(tag_id=models.OuterRef("pk")).order_by().annotate( + # We need to use Func() to get Count() without GROUP BY - see https://stackoverflow.com/a/69031027 + count=models.Func(F('id'), function='Count') + ) + qs = qs.annotate(usage_count=models.Subquery(obj_tags.values('count'))) + return qs - Note: This is mostly an 'internal' API and generally code outside of openedx_tagging - should use the APIs in openedx_tagging.api which in turn use this. + def _get_filtered_tags_deep( + self, + parent_tag_value: str | None, + search_term: str | None, + include_counts: bool, + ) -> models.QuerySet[TagData]: """ - if tag_set is None: - tag_set = self.tag_set.all() - - if self.allow_free_text: - return tag_set.none() + Implementation of get_filtered_tags() for closed taxonomies, where + we're including tags from multiple levels of the hierarchy. + """ + # All tags (possibly below a certain tag?) in the closed taxonomy, up to depth TAXONOMY_MAX_DEPTH + if parent_tag_value: + main_parent_id = self.tag_for_value(parent_tag_value).pk + else: + main_parent_id = None - if not search_in_all: - # If not search in all taxonomy, then apply parent filter. - tag_set = tag_set.filter(parent=parent_tag_id) + assert TAXONOMY_MAX_DEPTH == 3 # If we change TAXONOMY_MAX_DEPTH we need to change this query code: + qs: models.QuerySet = self.tag_set.filter( + Q(parent_id=main_parent_id) | + Q(parent__parent_id=main_parent_id) | + Q(parent__parent__parent_id=main_parent_id) + ) if search_term: - # Apply search filter - tag_set = tag_set.filter(value__icontains=search_term) - - return tag_set.order_by("value", "id") + # We need to do an additional query to find all the tags that match the search term, then limit the + # search to those tags and their ancestors. + matching_tags = qs.filter(value__icontains=search_term).values( + 'id', 'parent_id', 'parent__parent_id', 'parent__parent__parent_id' + ) + matching_ids = [] + for row in matching_tags: + for pk in row.values(): + if pk is not None: + matching_ids.append(pk) + qs = qs.filter(pk__in=matching_ids) + + qs = qs.annotate(child_count=models.Count("children")) + # Add the "depth" to each tag: + qs = Tag.annotate_depth(qs) + # Add the "lineage" field to sort them in order correctly: + qs = qs.annotate(sort_key=Concat( + Coalesce(F("parent__parent__parent__value"), Value("")), + Coalesce(F("parent__parent__value"), Value("")), + Coalesce(F("parent__value"), Value("")), + F("value"), + output_field=models.CharField(), + )) + # Add the parent value + qs = qs.annotate(parent_value=F("parent__value")) + qs = qs.values("value", "child_count", "depth", "parent_value", "external_id").order_by("sort_key") + if include_counts: + # Including the counts is a bit tricky; see the comment above in _get_filtered_tags_one_level() + obj_tags = ObjectTag.objects.filter(tag_id=models.OuterRef("pk")).order_by().annotate( + # We need to use Func() to get Count() without GROUP BY - see https://stackoverflow.com/a/69031027 + count=models.Func(F('id'), function='Count') + ) + qs = qs.annotate(usage_count=models.Subquery(obj_tags.values('count'))) + return qs def validate_value(self, value: str) -> bool: """ diff --git a/openedx_tagging/core/tagging/models/system_defined.py b/openedx_tagging/core/tagging/models/system_defined.py index 6851ab9d..3efb68f2 100644 --- a/openedx_tagging/core/tagging/models/system_defined.py +++ b/openedx_tagging/core/tagging/models/system_defined.py @@ -198,53 +198,6 @@ class LanguageTaxonomy(SystemDefinedTaxonomy): class Meta: proxy = True - def get_tags( - self, - tag_set: models.QuerySet[Tag] | None = None, - ) -> list[Tag]: - """ - Returns a list of all the available Language Tags, annotated with ``depth`` = 0. - """ - available_langs = self._get_available_languages() - tag_set = self.tag_set.filter(external_id__in=available_langs) - return super().get_tags(tag_set=tag_set) - - def get_filtered_tags( - self, - tag_set: models.QuerySet[Tag] | None = None, - parent_tag_id: int | None = None, - search_term: str | None = None, - search_in_all: bool = False, - ) -> models.QuerySet[Tag]: - """ - Returns a filtered QuerySet of available Language Tags. - By default returns all the available Language Tags. - - `parent_tag_id` returns an empty result because all Language tags are root tags. - - Use `search_term` to filter the results by values that contains `search_term`. - """ - if parent_tag_id: - return self.tag_set.none() - - available_langs = self._get_available_languages() - tag_set = self.tag_set.filter(external_id__in=available_langs) - return super().get_filtered_tags( - tag_set=tag_set, - search_term=search_term, - search_in_all=search_in_all, - ) - - @classmethod - def _get_available_languages(cls) -> set[str]: - """ - Get available languages from Django LANGUAGE. - """ - langs = set() - for django_lang in settings.LANGUAGES: - langs.add(django_lang[0]) - return langs - def validate_value(self, value: str): """ Check if 'value' is part of this Taxonomy, based on the specified model. diff --git a/tests/openedx_tagging/core/fixtures/tagging.yaml b/tests/openedx_tagging/core/fixtures/tagging.yaml index 4715b667..164b9399 100644 --- a/tests/openedx_tagging/core/fixtures/tagging.yaml +++ b/tests/openedx_tagging/core/fixtures/tagging.yaml @@ -1,3 +1,24 @@ +# - Bacteria +# |- Archaebacteria +# |- Eubacteria +# - Archaea +# |- DPANN +# |- Euryarchaeida +# |- Proteoarchaeota +# - Eukaryota +# |- Animalia +# | |- Arthropoda +# | |- Chordata +# | | |- Mammalia +# | |- Cnidaria +# | |- Ctenophora +# | |- Gastrotrich +# | |- Placozoa +# | |- Porifera +# |- Fungi +# |- Monera +# |- Plantae +# |- Protista - model: oel_tagging.tag pk: 1 fields: diff --git a/tests/openedx_tagging/core/tagging/test_api.py b/tests/openedx_tagging/core/tagging/test_api.py index f82a2201..e1dbd03d 100644 --- a/tests/openedx_tagging/core/tagging/test_api.py +++ b/tests/openedx_tagging/core/tagging/test_api.py @@ -7,6 +7,7 @@ import ddt # type: ignore[import] import pytest +from django.db.models import QuerySet from django.test import TestCase, override_settings import openedx_tagging.core.tagging.api as tagging_api @@ -594,15 +595,11 @@ def test_autocomplete_tags(self, search: str, expected_values: list[str], expect expected_ids, ) - def test_autocompleate_not_implemented(self) -> None: - with self.assertRaises(NotImplementedError): - tagging_api.autocomplete_tags(self.taxonomy, 'test', None, object_tags_only=False) - - def _get_tag_values(self, tags) -> list[str]: + def _get_tag_values(self, tags: QuerySet[tagging_api.TagData]) -> list[str]: """ Get tag values from tagging_api.autocomplete_tags() result """ - return [tag.get("value") for tag in tags] + return [tag["value"] for tag in tags] def _get_tag_ids(self, tags) -> list[int]: """ @@ -622,7 +619,7 @@ def _validate_autocomplete_tags( """ # Normal search - result = tagging_api.autocomplete_tags(taxonomy, search) + result = tagging_api.search_tags(taxonomy, search) tag_values = self._get_tag_values(result) for value in tag_values: assert search.lower() in value.lower() @@ -644,6 +641,6 @@ def _validate_autocomplete_tags( ).save() # Search with object - result = tagging_api.autocomplete_tags(taxonomy, search, object_id) + result = tagging_api.search_tags(taxonomy, search, object_id) assert self._get_tag_values(result) == expected_values[1:] assert self._get_tag_ids(result) == expected_ids[1:] diff --git a/tests/openedx_tagging/core/tagging/test_models.py b/tests/openedx_tagging/core/tagging/test_models.py index 2a7131f4..af2c812b 100644 --- a/tests/openedx_tagging/core/tagging/test_models.py +++ b/tests/openedx_tagging/core/tagging/test_models.py @@ -1,6 +1,8 @@ """ Test the tagging base models """ +from __future__ import annotations + import ddt # type: ignore[import] import pytest from django.contrib.auth import get_user_model @@ -55,54 +57,6 @@ def setUp(self): ) self.user_2.save() - # Domain tags (depth=0) - # https://en.wikipedia.org/wiki/Domain_(biology) - self.domain_tags = [ - get_tag("Archaea"), - get_tag("Bacteria"), - get_tag("Eukaryota"), - ] - # Domain tags that contains 'ar' - self.filtered_domain_tags = [ - get_tag("Archaea"), - get_tag("Eukaryota"), - ] - - # Kingdom tags (depth=1) - self.kingdom_tags = [ - # Kingdoms of https://en.wikipedia.org/wiki/Archaea - get_tag("DPANN"), - get_tag("Euryarchaeida"), - get_tag("Proteoarchaeota"), - # Kingdoms of https://en.wikipedia.org/wiki/Bacterial_taxonomy - get_tag("Archaebacteria"), - get_tag("Eubacteria"), - # Kingdoms of https://en.wikipedia.org/wiki/Eukaryote - get_tag("Animalia"), - get_tag("Fungi"), - get_tag("Monera"), - get_tag("Plantae"), - get_tag("Protista"), - ] - - # Phylum tags (depth=2) - self.phylum_tags = [ - # Some phyla of https://en.wikipedia.org/wiki/Animalia - get_tag("Arthropoda"), - get_tag("Chordata"), - get_tag("Cnidaria"), - get_tag("Ctenophora"), - get_tag("Gastrotrich"), - get_tag("Placozoa"), - get_tag("Porifera"), - ] - # Phylum tags that contains 'da' - self.filtered_phylum_tags = [ - get_tag("Arthropoda"), - get_tag("Chordata"), - get_tag("Cnidaria"), - ] - # Biology tags that contains 'eu' self.filtered_tags = [ get_tag("Eubacteria"), @@ -132,17 +86,6 @@ def setUp(self): ) self.dummy_taxonomies.append(taxonomy) - def setup_tag_depths(self): - """ - Annotate our tags with depth so we can compare them. - """ - for tag in self.domain_tags: - tag.depth = 0 - for tag in self.kingdom_tags: - tag.depth = 1 - for tag in self.phylum_tags: - tag.depth = 2 - class TaxonomyTestSubclassA(Taxonomy): """ @@ -237,6 +180,20 @@ def test_taxonomy_cast_bad_value(self): self.taxonomy.taxonomy_class = str assert " must be a subclass of Taxonomy" in str(exc.exception) + def test_unique_tags(self): + # Creating new tag + Tag( + taxonomy=self.taxonomy, + value='New value' + ).save() + + # Creating repeated tag + with self.assertRaises(IntegrityError): + Tag( + taxonomy=self.taxonomy, + value=self.archaea.value, + ).save() + @ddt.data( # Root tags just return their own value ("bacteria", ["Bacteria"]), @@ -251,80 +208,311 @@ def test_taxonomy_cast_bad_value(self): def test_get_lineage(self, tag_attr, lineage): assert getattr(self, tag_attr).get_lineage() == lineage - def test_get_tags(self): - self.setup_tag_depths() - assert self.taxonomy.get_tags() == [ - *self.domain_tags, - *self.kingdom_tags, - *self.phylum_tags, + + +@ddt.ddt +class TestFilteredTagsClosedTaxonomy(TestTagTaxonomyMixin, TestCase): + """ + Test the the get_filtered_tags() method of closed taxonomies + """ + def test_get_root(self) -> None: + """ + Test basic retrieval of root tags in the closed taxonomy, using + get_filtered_tags(). Without counts included. + """ + result = list(self.taxonomy.get_filtered_tags(depth=1, include_counts=False)) + common_fields = {"depth": 0, "parent_value": None, "external_id": None} + assert result == [ + # These are the root tags, in alphabetical order: + {"value": "Archaea", "child_count": 3, **common_fields}, + {"value": "Bacteria", "child_count": 2, **common_fields}, + {"value": "Eukaryota", "child_count": 5, **common_fields}, ] - def test_get_root_tags(self): - assert list(self.taxonomy.get_filtered_tags()) == self.domain_tags - assert list( - self.taxonomy.get_filtered_tags(search_term='aR') - ) == self.filtered_domain_tags - - def test_get_tags_free_text(self): - self.taxonomy.allow_free_text = True - with self.assertNumQueries(0): - assert self.taxonomy.get_tags() == [] - - def test_get_children_tags(self): - assert list( - self.taxonomy.get_filtered_tags(parent_tag_id=self.animalia.id) - ) == self.phylum_tags - assert list( - self.taxonomy.get_filtered_tags( - parent_tag_id=self.animalia.id, - search_term='dA', - ) - ) == self.filtered_phylum_tags - assert not list( - self.system_taxonomy.get_filtered_tags( - parent_tag_id=self.system_taxonomy_tag.id - ) - ) + def test_get_child_tags_one_level(self) -> None: + """ + Test basic retrieval of tags one level below the "Eukaryota" root tag in + the closed taxonomy, using get_filtered_tags(). With counts included. + """ + result = list(self.taxonomy.get_filtered_tags(depth=1, parent_tag_value="Eukaryota")) + common_fields = {"depth": 1, "parent_value": "Eukaryota", "usage_count": 0, "external_id": None} + assert result == [ + # These are the Eukaryota tags, in alphabetical order: + {"value": "Animalia", "child_count": 7, **common_fields}, + {"value": "Fungi", "child_count": 0, **common_fields}, + {"value": "Monera", "child_count": 0, **common_fields}, + {"value": "Plantae", "child_count": 0, **common_fields}, + {"value": "Protista", "child_count": 0, **common_fields}, + ] - def test_get_children_tags_free_text(self): - self.taxonomy.allow_free_text = True - assert not list(self.taxonomy.get_filtered_tags( - parent_tag_id=self.animalia.id - )) - assert not list(self.taxonomy.get_filtered_tags( - parent_tag_id=self.animalia.id, - search_term='dA', - )) + def test_get_grandchild_tags_one_level(self) -> None: + """ + Test basic retrieval of a single level of tags at two level belows the + "Eukaryota" root tag in the closed taxonomy, using get_filtered_tags(). + """ + result = list(self.taxonomy.get_filtered_tags(depth=1, parent_tag_value="Animalia")) + common_fields = {"depth": 2, "parent_value": "Animalia", "usage_count": 0, "external_id": None} + assert result == [ + # These are the Eukaryota tags, in alphabetical order: + {"value": "Arthropoda", "child_count": 0, **common_fields}, + {"value": "Chordata", "child_count": 1, **common_fields}, + {"value": "Cnidaria", "child_count": 0, **common_fields}, + {"value": "Ctenophora", "child_count": 0, **common_fields}, + {"value": "Gastrotrich", "child_count": 0, **common_fields}, + {"value": "Placozoa", "child_count": 0, **common_fields}, + {"value": "Porifera", "child_count": 0, **common_fields}, + ] - def test_search_tags(self): - assert list(self.taxonomy.get_filtered_tags( - search_term='eU', - search_in_all=True - )) == self.filtered_tags - - def test_get_tags_shallow_taxonomy(self): - taxonomy = Taxonomy.objects.create(name="Difficulty") - tags = [ - Tag.objects.create(taxonomy=taxonomy, value="1. Easy"), - Tag.objects.create(taxonomy=taxonomy, value="2. Moderate"), - Tag.objects.create(taxonomy=taxonomy, value="3. Hard"), + def test_get_depth_1_search_term(self) -> None: + """ + Filter the root tags to only those that match a search term + """ + result = list(self.taxonomy.get_filtered_tags(depth=1, search_term="ARCH")) + assert result == [ + { + "value": "Archaea", + "child_count": 3, + "depth": 0, + "usage_count": 0, + "parent_value": None, + "external_id": None, + }, + ] + # Note that other tags in the taxonomy match "ARCH" but are excluded because of the depth=1 search + + def test_get_depth_1_child_search_term(self) -> None: + """ + Filter the child tags of "Bacteria" to only those that match a search term + """ + result = list(self.taxonomy.get_filtered_tags(depth=1, search_term="ARCH", parent_tag_value="Bacteria")) + assert result == [ + { + "value": "Archaebacteria", + "child_count": 0, + "depth": 1, + "usage_count": 0, + "parent_value": "Bacteria", + "external_id": None, + }, ] + # Note that other tags in the taxonomy match "ARCH" but are excluded because of the depth=1 search + + def test_depth_1_queries(self) -> None: + """ + Test the number of queries used by get_filtered_tags() with closed + taxonomies when depth=1. This should be a constant, not O(n). + """ + with self.assertNumQueries(1): + self.test_get_root() + with self.assertNumQueries(1): + self.test_get_depth_1_search_term() + # When listing the tags below a specific tag, there is one additional query to load each ancestor tag: + with self.assertNumQueries(2): + self.test_get_child_tags_one_level() with self.assertNumQueries(2): - assert taxonomy.get_tags() == tags + self.test_get_depth_1_child_search_term() + with self.assertNumQueries(3): + self.test_get_grandchild_tags_one_level() - def test_unique_tags(self): - # Creating new tag - Tag( - taxonomy=self.taxonomy, - value='New value' - ).save() + ################## - # Creating repeated tag - with self.assertRaises(IntegrityError): - Tag( - taxonomy=self.taxonomy, - value=self.archaea.value, - ).save() + @staticmethod + def _pretty_format_result(result) -> list[str]: + """ + Format a result to be more human readable. + """ + return [ + f"{t['depth'] * ' '}{t['value']} ({t['parent_value']}) " + + f"(used: {t['usage_count']}, children: {t['child_count']})" + for t in result + ] + + def test_get_all(self) -> None: + """ + Test getting all of the tags in the taxonomy, using get_filtered_tags() + """ + result = self._pretty_format_result(self.taxonomy.get_filtered_tags()) + assert result == [ + "Archaea (None) (used: 0, children: 3)", + " DPANN (Archaea) (used: 0, children: 0)", + " Euryarchaeida (Archaea) (used: 0, children: 0)", + " Proteoarchaeota (Archaea) (used: 0, children: 0)", + "Bacteria (None) (used: 0, children: 2)", + " Archaebacteria (Bacteria) (used: 0, children: 0)", + " Eubacteria (Bacteria) (used: 0, children: 0)", + "Eukaryota (None) (used: 0, children: 5)", + " Animalia (Eukaryota) (used: 0, children: 7)", + " Arthropoda (Animalia) (used: 0, children: 0)", + " Chordata (Animalia) (used: 0, children: 1)", # note this has a child but the child is not included + " Cnidaria (Animalia) (used: 0, children: 0)", + " Ctenophora (Animalia) (used: 0, children: 0)", + " Gastrotrich (Animalia) (used: 0, children: 0)", + " Placozoa (Animalia) (used: 0, children: 0)", + " Porifera (Animalia) (used: 0, children: 0)", + " Fungi (Eukaryota) (used: 0, children: 0)", + " Monera (Eukaryota) (used: 0, children: 0)", + " Plantae (Eukaryota) (used: 0, children: 0)", + " Protista (Eukaryota) (used: 0, children: 0)", + ] + + def test_search(self) -> None: + """ + Search the whole taxonomy (up to max depth) for a given term. Should + return all tags that match the term as well as their ancestors. + """ + result = self._pretty_format_result(self.taxonomy.get_filtered_tags(search_term="ARCH")) + assert result == [ + "Archaea (None) (used: 0, children: 3)", # Matches the value of this root tag, ARCHaea + " Euryarchaeida (Archaea) (used: 0, children: 0)", # Matches the value of this child tag + " Proteoarchaeota (Archaea) (used: 0, children: 0)", # Matches the value of this child tag + "Bacteria (None) (used: 0, children: 2)", # Does not match this tag but matches a descendant: + " Archaebacteria (Bacteria) (used: 0, children: 0)", # Matches the value of this child tag + ] + + def test_search_2(self) -> None: + """ + Another search test, that matches a tag deeper in the taxonomy to check + that all its ancestors are returned by the search. + """ + result = self._pretty_format_result(self.taxonomy.get_filtered_tags(search_term="chordata")) + assert result == [ + "Eukaryota (None) (used: 0, children: 5)", + " Animalia (Eukaryota) (used: 0, children: 7)", + " Chordata (Animalia) (used: 0, children: 1)", # this is the matching tag. + ] + + def test_tags_deep(self) -> None: + """ + Test getting a deep tag in the taxonomy + """ + result = list(self.taxonomy.get_filtered_tags(parent_tag_value="Chordata")) + assert result == [ + { + "value": "Mammalia", + "parent_value": "Chordata", + "depth": 3, + "usage_count": 0, + "child_count": 0, + "external_id": None, + } + ] + + def test_deep_queries(self) -> None: + """ + Test the number of queries used by get_filtered_tags() with closed + taxonomies when depth=None. This should be a constant, not O(n). + """ + with self.assertNumQueries(1): + self.test_get_all() + # Searching below a specific tag requires an additional query to load that tag: + with self.assertNumQueries(2): + self.test_tags_deep() + # Keyword search requires an additional query: + with self.assertNumQueries(2): + self.test_search() + with self.assertNumQueries(2): + self.test_search_2() + + def test_get_external_id(self) -> None: + """ + Test that if our tags have external IDs, those external IDs are returned + """ + self.bacteria.external_id = "bct001" + self.bacteria.save() + result = list(self.taxonomy.get_filtered_tags(search_term="Eubacteria")) + assert result[0]["value"] == "Bacteria" + assert result[0]["external_id"] == "bct001" + + def test_usage_count(self) -> None: + """ + Test that the usage count in the results is right + """ + api.tag_object(object_id="obj01", taxonomy=self.taxonomy, tags=["Bacteria"]) + api.tag_object(object_id="obj02", taxonomy=self.taxonomy, tags=["Bacteria"]) + api.tag_object(object_id="obj03", taxonomy=self.taxonomy, tags=["Bacteria"]) + api.tag_object(object_id="obj04", taxonomy=self.taxonomy, tags=["Eubacteria"]) + # Now the API should reflect these usage counts: + result = self._pretty_format_result(self.taxonomy.get_filtered_tags(search_term="bacteria")) + assert result == [ + "Bacteria (None) (used: 3, children: 2)", + " Archaebacteria (Bacteria) (used: 0, children: 0)", + " Eubacteria (Bacteria) (used: 1, children: 0)", + ] + # Same with depth=1, which uses a different query internally: + result1 = self._pretty_format_result(self.taxonomy.get_filtered_tags(search_term="bacteria", depth=1)) + assert result1 == [ + "Bacteria (None) (used: 3, children: 2)", + ] + + +class TestFilteredTagsFreeTextTaxonomy(TestCase): + """ + Tests for listing/autocompleting/searching for tags in a free text taxonomy. + + Free text taxonomies only return tags that are actually used. + """ + + def setUp(self): + super().setUp() + self.taxonomy = Taxonomy.objects.create(allow_free_text=True, name="FreeText") + # The "triple" tag will be applied to three objects, "double" to two, and "solo" to one: + api.tag_object(object_id="obj1", taxonomy=self.taxonomy, tags=["triple"]) + api.tag_object(object_id="obj2", taxonomy=self.taxonomy, tags=["triple", "double"]) + api.tag_object(object_id="obj3", taxonomy=self.taxonomy, tags=["triple", "double"]) + api.tag_object(object_id="obj4", taxonomy=self.taxonomy, tags=["solo"]) + + def test_get_filtered_tags(self): + """ + Test basic retrieval of all tags in the taxonomy. + Without counts included. + """ + result = list(self.taxonomy.get_filtered_tags(include_counts=False)) + common_fields = {"child_count": 0, "depth": 0, "parent_value": None, "external_id": None} + assert result == [ + # These should appear in alphabetical order: + {"value": "double", **common_fields}, + {"value": "solo", **common_fields}, + {"value": "triple", **common_fields}, + ] + + def test_get_filtered_tags_with_count(self): + """ + Test basic retrieval of all tags in the taxonomy. + Without counts included. + """ + result = list(self.taxonomy.get_filtered_tags(include_counts=True)) + common_fields = {"child_count": 0, "depth": 0, "parent_value": None, "external_id": None} + assert result == [ + # These should appear in alphabetical order: + {"value": "double", "usage_count": 2, **common_fields}, + {"value": "solo", "usage_count": 1, **common_fields}, + {"value": "triple", "usage_count": 3, **common_fields}, + ] + + def test_get_filtered_tags_num_queries(self): + """ + Test that the number of queries used by get_filtered_tags() is fixed + and not O(n) or worse. + """ + with self.assertNumQueries(1): + self.test_get_filtered_tags() + with self.assertNumQueries(1): + self.test_get_filtered_tags_with_count() + + def test_get_filtered_tags_with_search(self) -> None: + """ + Test basic retrieval of only matching tags. + """ + result1 = list(self.taxonomy.get_filtered_tags(search_term="le")) + common_fields = {"child_count": 0, "depth": 0, "parent_value": None, "external_id": None} + assert result1 == [ + # These should appear in alphabetical order: + {"value": "double", "usage_count": 2, **common_fields}, + {"value": "triple", "usage_count": 3, **common_fields}, + ] + # And it should be case insensitive: + result2 = list(self.taxonomy.get_filtered_tags(search_term="LE")) + assert result1 == result2 class TestObjectTag(TestTagTaxonomyMixin, TestCase): @@ -450,10 +638,10 @@ def test_tag_case(self) -> None: Test that the object_id is case sensitive. """ # Tag with object_id with lower case - api.tag_object(self.taxonomy, [self.domain_tags[0].value], object_id="case:id:2") + api.tag_object(self.taxonomy, [self.chordata.value], object_id="case:id:2") # Tag with object_id with upper case should not trigger IntegrityError - api.tag_object(self.taxonomy, [self.domain_tags[0].value], object_id="CASE:id:2") + api.tag_object(self.taxonomy, [self.chordata.value], object_id="CASE:id:2") # Create another ObjectTag with lower case object_id should trigger IntegrityError with transaction.atomic(): @@ -461,7 +649,7 @@ def test_tag_case(self) -> None: ObjectTag( object_id="case:id:2", taxonomy=self.taxonomy, - tag=self.domain_tags[0], + tag=self.chordata, ).save() # Create another ObjectTag with upper case object_id should trigger IntegrityError @@ -470,7 +658,7 @@ def test_tag_case(self) -> None: ObjectTag( object_id="CASE:id:2", taxonomy=self.taxonomy, - tag=self.domain_tags[0], + tag=self.chordata, ).save() def test_is_deleted(self):