diff --git a/aggify/aggify.py b/aggify/aggify.py index dfd6a57..83fd07e 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Literal, Type +from typing import Any, Literal, Dict, Type from mongoengine import Document, EmbeddedDocument, fields @@ -54,8 +54,9 @@ def project(self, **kwargs: QueryParams) -> "Aggify": return self @last_out_stage_check - def group(self, key: str = "_id") -> "Aggify": - self.pipelines.append({"$group": {"_id": f"${key}"}}) + def group(self, expression: str | None = "_id") -> "Aggify": + expression = f"${expression}" if expression else None + self.pipelines.append({"$group": {"_id": expression}}) return self @last_out_stage_check @@ -169,7 +170,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None: raise ValueError(f"Invalid field: {split_query[0]}") # This is a nested query. if "document_type_obj" not in join_field.__dict__ or issubclass( - join_field.document_type, EmbeddedDocument + join_field.document_type, EmbeddedDocument ): match = self.__match({key: value}) if (match.get("$match")) != {}: @@ -213,7 +214,7 @@ def __getitem__(self, index: slice | int) -> "Aggify": @staticmethod def unwind( - path: str, preserve: bool = True + path: str, preserve: bool = True ) -> dict[ Literal["$unwind"], dict[Literal["path", "preserveNullAndEmptyArrays"], str | bool], @@ -239,49 +240,77 @@ def aggregate(self): """ return self.base_model.objects.aggregate(*self.pipelines) # type: ignore - def annotate(self, annotate_name, accumulator, f): - try: - if (stage := list(self.pipelines[-1].keys())[0]) != "$group": - raise AnnotationError( - f"Annotations apply only to $group, not to {stage}." - ) + def annotate(self, annotate_name: str, accumulator: str, + f: str | dict | F | int) -> "Aggify": + """ + Annotate a MongoDB aggregation pipeline with a new field. + Usage: https://www.mongodb.com/docs/manual/reference/operator/aggregation/group/#accumulator-operator + + Args: + annotate_name (str): The name of the new annotated field. + accumulator (str): The aggregation accumulator operator (e.g., "$sum", "$avg"). + f (str | dict | F | int): The value for the annotated field. + + Returns: + self. + + Raises: + AnnotationError: If the pipeline is empty or if an invalid accumulator is provided. + + Example: + annotate("totalSales", "sum", "sales") + """ - except IndexError as error: - raise AnnotationError( - "Annotations apply only to $group, you're pipeline is empty." - ) from error - - accumulator_dict = { - "sum": "$sum", - "avg": "$avg", - "first": "$first", - "last": "$last", - "max": "$max", - "min": "$min", - "push": "$push", - "addToSet": "$addToSet", - "stdDevPop": "$stdDevPop", - "stdDevSamp": "$stdDevSamp", - "merge": "$mergeObjects", + # Some of the accumulator fields might be false and should be checked. + aggregation_mapping: Dict[str, Type] = { + "sum": (fields.FloatField(), "$sum"), + "avg": (fields.FloatField(), "$avg"), + "stdDevPop": (fields.FloatField(), "$stdDevPop"), + "stdDevSamp": (fields.FloatField(), "$stdDevSamp"), + "push": (fields.ListField(), "$push"), + "addToSet": (fields.ListField(), "$addToSet"), + "count": (fields.IntField(), "$count"), + "first": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$first"), + "last": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$last"), + "max": (fields.DynamicField(), "$max"), + "accumulator": (fields.DynamicField(), "$accumulator"), + "min": (fields.DynamicField(), "$min"), + "median": (fields.DynamicField(), "$median"), + "mergeObjects": (fields.DictField(), "$mergeObjects"), + "top": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$top"), + "bottom": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$bottom"), + "topN": (fields.ListField(), "$topN"), + "bottomN": (fields.ListField(), "$bottomN"), + "firstN": (fields.ListField(), "$firstN"), + "lastN": (fields.ListField(), "$lastN"), + "maxN": (fields.ListField(), "$maxN"), } - # Determine the data type based on the accumulator - if accumulator in ["sum", "avg", "stdDevPop", "stdDevSamp"]: - field_type = fields.FloatField() - elif accumulator in ["push", "addToSet"]: - field_type = fields.ListField() - else: - field_type = fields.StringField() + try: + stage = list(self.pipelines[-1].keys())[0] + if stage != "$group": + raise AnnotationError(f"Annotations apply only to $group, not to {stage}") + except IndexError: + raise AnnotationError("Annotations apply only to $group, your pipeline is empty") try: - acc = accumulator_dict[accumulator] + field_type, acc = aggregation_mapping[accumulator] except KeyError as error: raise AnnotationError(f"Invalid accumulator: {accumulator}") from error if isinstance(f, F): value = f.to_dict() else: - value = f"${f}" + if isinstance(f, str): + try: + self.get_model_field(self.base_model, f) # noqa + value = f"${f}" + except InvalidField: + value = f + else: + value = f + + # Determine the data type based on the aggregation operator self.pipelines[-1]["$group"] |= {annotate_name: {acc: value}} self.base_model._fields[annotate_name] = field_type # noqa return self @@ -300,7 +329,7 @@ def __match(self, matches: dict[str, Any]): @staticmethod def __lookup( - from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id" + from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id" ) -> dict[str, dict[str, str]]: """ Generates a MongoDB lookup pipeline stage. @@ -345,8 +374,9 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]: return merged_pipeline + @last_out_stage_check def lookup( - self, from_collection: Document, let: list[str], query: list[Q], as_name: str + self, from_collection: Document, let: list[str], query: list[Q], as_name: str ) -> "Aggify": """ Generates a MongoDB lookup pipeline stage. @@ -363,8 +393,8 @@ def lookup( check_fields_exist(self.base_model, let) # noqa let_dict = { - field: f"${self.base_model._fields[field].db_field}" for field in let - } # noqa + field: f"${self.base_model._fields[field].db_field}" for field in let # noqa + } from_collection = from_collection._meta.get("collection") # noqa lookup_stages = [] @@ -443,6 +473,7 @@ def _replace_base(self, embedded_field) -> str: return f"${model_field.db_field}" + @last_out_stage_check def replace_root(self, *, embedded_field: str, merge: dict | None = None) -> "Aggify": """ Replace the root document in the aggregation pipeline with a specified embedded field or a merged result. @@ -480,6 +511,7 @@ def replace_root(self, *, embedded_field: str, merge: dict | None = None) -> "Ag return self + @last_out_stage_check def replace_with(self, *, embedded_field: str, merge: dict | None = None) -> "Aggify": """ Replace the root document in the aggregation pipeline with a specified embedded field or a merged result. diff --git a/tests/test_aggify.py b/tests/test_aggify.py index 1c1b3d9..998d933 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -105,10 +105,10 @@ def test_complex_conditional_expression_in_projection(self): "$gt": ["$age", 30] } assert ( - aggify.pipelines[0]["$project"]["custom_field"]["$cond"]["then"] == "Adult" + aggify.pipelines[0]["$project"]["custom_field"]["$cond"]["then"] == "Adult" ) assert ( - aggify.pipelines[0]["$project"]["custom_field"]["$cond"]["else"] == "Child" + aggify.pipelines[0]["$project"]["custom_field"]["$cond"]["else"] == "Child" ) # Test filtering using not operator @@ -170,7 +170,7 @@ def test_annotate_empty_pipeline_value_error(self): with pytest.raises(AnnotationError) as err: Aggify(BaseModel).annotate("size", "sum", None) - assert "you're pipeline is empty" in err.__str__().lower() + assert "your pipeline is empty" in err.__str__().lower() def test_annotate_not_group_value_error(self): with pytest.raises(AnnotationError) as err: @@ -185,16 +185,27 @@ def test_annotate_invalid_accumulator(self): @pytest.mark.parametrize( "accumulator", ( - "sum", - "avg", - "first", - "last", - "max", - "min", - "push", - "addToSet", - "stdDevPop", - "stdDevSamp", + "sum", + "avg", + "stdDevPop", + "stdDevSamp", + "push", + "addToSet", + "count", + "first", + "last", + "max", + "accumulator", + "min", + "median", + "mergeObjects", + "top", + "bottom", + "topN", + "bottomN", + "firstN", + "lastN", + "maxN", ), ) def test_annotate_with_raw_f(self, accumulator): @@ -206,16 +217,27 @@ def test_annotate_with_raw_f(self, accumulator): @pytest.mark.parametrize( "accumulator", ( - "sum", - "avg", - "first", - "last", - "max", - "min", - "push", - "addToSet", - "stdDevPop", - "stdDevSamp", + "sum", + "avg", + "stdDevPop", + "stdDevSamp", + "push", + "addToSet", + "count", + "first", + "last", + "max", + "accumulator", + "min", + "median", + "mergeObjects", + "top", + "bottom", + "topN", + "bottomN", + "firstN", + "lastN", + "maxN", ), ) def test_annotate_with_f(self, accumulator): @@ -229,25 +251,105 @@ def test_annotate_with_f(self, accumulator): @pytest.mark.parametrize( "accumulator", ( - "sum", - "avg", - "first", - "last", - "max", - "min", - "push", - "addToSet", - "stdDevPop", - "stdDevSamp", + "sum", + "avg", + "stdDevPop", + "stdDevSamp", + "push", + "addToSet", + "count", + "first", + "last", + "max", + "accumulator", + "min", + "median", + "mergeObjects", + "top", + "bottom", + "topN", + "bottomN", + "firstN", + "lastN", + "maxN", ), ) def test_annotate_raw_value(self, accumulator): + aggify = Aggify(BaseModel) + thing = aggify.group().annotate("some_name", accumulator, "name") + assert len(thing.pipelines) == 1 + assert thing.pipelines[-1]["$group"]["some_name"] == { + f"${accumulator}": "$name" + } + + @pytest.mark.parametrize( + "accumulator", + ( + "sum", + "avg", + "stdDevPop", + "stdDevSamp", + "push", + "addToSet", + "count", + "first", + "last", + "max", + "accumulator", + "min", + "median", + "mergeObjects", + "top", + "bottom", + "topN", + "bottomN", + "firstN", + "lastN", + "maxN", + ), + ) + def test_annotate_raw_value_not_model_field(self, accumulator): + aggify = Aggify(BaseModel) + thing = aggify.group().annotate("some_name", accumulator, "some_value") + assert len(thing.pipelines) == 1 + assert thing.pipelines[-1]["$group"]["some_name"] == { + f"${accumulator}": "some_value" + } + + @pytest.mark.parametrize( + "accumulator", + ( + "sum", + "avg", + "stdDevPop", + "stdDevSamp", + "push", + "addToSet", + "count", + "first", + "last", + "max", + "accumulator", + "min", + "median", + "mergeObjects", + "top", + "bottom", + "topN", + "bottomN", + "firstN", + "lastN", + "maxN", + ), + ) + def test_annotate_add_annotated_field_to_base_model(self, accumulator): aggify = Aggify(BaseModel) thing = aggify.group().annotate("some_name", accumulator, "some_value") assert len(thing.pipelines) == 1 assert thing.pipelines[-1]["$group"]["some_name"] == { - f"${accumulator}": "$some_value" + f"${accumulator}": "some_value" } + assert aggify.filter(some_name=123).pipelines[-1] == {"$match": {"some_name": 123}} def test_out_with_project_stage_error(self): with pytest.raises(OutStageError): @@ -256,11 +358,11 @@ def test_out_with_project_stage_error(self): @pytest.mark.parametrize( ("method", "args"), ( - ("group", ("_id",)), - ("order_by", ("field",)), - ("raw", ({"$query": "query"},)), - ("add_fields", ({"$field": "value"},)), - ("filter", (Q(age=20),)), + ("group", ("_id",)), + ("order_by", ("field",)), + ("raw", ({"$query": "query"},)), + ("add_fields", ({"$field": "value"},)), + ("filter", (Q(age=20),)), ), ) def test_out_stage_error(self, method, args):