Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new accumulators for annotation #14

Merged
merged 2 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 73 additions & 41 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Literal, Type
from typing import Any, Literal, Dict, Type

from mongoengine import Document, EmbeddedDocument, fields

Expand Down Expand Up @@ -54,8 +54,9 @@ def project(self, **kwargs: QueryParams) -> "Aggify":
return self

@last_out_stage_check
def group(self, key: str = "_id") -> "Aggify":
self.pipelines.append({"$group": {"_id": f"${key}"}})
def group(self, expression: str | None = "_id") -> "Aggify":
expression = f"${expression}" if expression else None
self.pipelines.append({"$group": {"_id": expression}})
return self

@last_out_stage_check
Expand Down Expand Up @@ -169,7 +170,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
raise ValueError(f"Invalid field: {split_query[0]}")
# This is a nested query.
if "document_type_obj" not in join_field.__dict__ or issubclass(
join_field.document_type, EmbeddedDocument
join_field.document_type, EmbeddedDocument
):
match = self.__match({key: value})
if (match.get("$match")) != {}:
Expand Down Expand Up @@ -213,7 +214,7 @@ def __getitem__(self, index: slice | int) -> "Aggify":

@staticmethod
def unwind(
path: str, preserve: bool = True
path: str, preserve: bool = True
) -> dict[
Literal["$unwind"],
dict[Literal["path", "preserveNullAndEmptyArrays"], str | bool],
Expand All @@ -239,49 +240,77 @@ def aggregate(self):
"""
return self.base_model.objects.aggregate(*self.pipelines) # type: ignore

def annotate(self, annotate_name, accumulator, f):
try:
if (stage := list(self.pipelines[-1].keys())[0]) != "$group":
raise AnnotationError(
f"Annotations apply only to $group, not to {stage}."
)
def annotate(self, annotate_name: str, accumulator: str,
f: str | dict | F | int) -> "Aggify":
"""
Annotate a MongoDB aggregation pipeline with a new field.
Usage: https://www.mongodb.com/docs/manual/reference/operator/aggregation/group/#accumulator-operator

Args:
annotate_name (str): The name of the new annotated field.
accumulator (str): The aggregation accumulator operator (e.g., "$sum", "$avg").
f (str | dict | F | int): The value for the annotated field.

Returns:
self.

Raises:
AnnotationError: If the pipeline is empty or if an invalid accumulator is provided.

Example:
annotate("totalSales", "sum", "sales")
"""

except IndexError as error:
raise AnnotationError(
"Annotations apply only to $group, you're pipeline is empty."
) from error

accumulator_dict = {
"sum": "$sum",
"avg": "$avg",
"first": "$first",
"last": "$last",
"max": "$max",
"min": "$min",
"push": "$push",
"addToSet": "$addToSet",
"stdDevPop": "$stdDevPop",
"stdDevSamp": "$stdDevSamp",
"merge": "$mergeObjects",
# Some of the accumulator fields might be false and should be checked.
aggregation_mapping: Dict[str, Type] = {
"sum": (fields.FloatField(), "$sum"),
"avg": (fields.FloatField(), "$avg"),
"stdDevPop": (fields.FloatField(), "$stdDevPop"),
"stdDevSamp": (fields.FloatField(), "$stdDevSamp"),
"push": (fields.ListField(), "$push"),
"addToSet": (fields.ListField(), "$addToSet"),
"count": (fields.IntField(), "$count"),
"first": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$first"),
"last": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$last"),
"max": (fields.DynamicField(), "$max"),
"accumulator": (fields.DynamicField(), "$accumulator"),
"min": (fields.DynamicField(), "$min"),
"median": (fields.DynamicField(), "$median"),
"mergeObjects": (fields.DictField(), "$mergeObjects"),
"top": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$top"),
"bottom": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$bottom"),
"topN": (fields.ListField(), "$topN"),
"bottomN": (fields.ListField(), "$bottomN"),
"firstN": (fields.ListField(), "$firstN"),
"lastN": (fields.ListField(), "$lastN"),
"maxN": (fields.ListField(), "$maxN"),
}

# Determine the data type based on the accumulator
if accumulator in ["sum", "avg", "stdDevPop", "stdDevSamp"]:
field_type = fields.FloatField()
elif accumulator in ["push", "addToSet"]:
field_type = fields.ListField()
else:
field_type = fields.StringField()
try:
stage = list(self.pipelines[-1].keys())[0]
if stage != "$group":
raise AnnotationError(f"Annotations apply only to $group, not to {stage}")
except IndexError:
raise AnnotationError("Annotations apply only to $group, your pipeline is empty")

try:
acc = accumulator_dict[accumulator]
field_type, acc = aggregation_mapping[accumulator]
except KeyError as error:
raise AnnotationError(f"Invalid accumulator: {accumulator}") from error

if isinstance(f, F):
value = f.to_dict()
else:
value = f"${f}"
if isinstance(f, str):
try:
self.get_model_field(self.base_model, f) # noqa
value = f"${f}"
except InvalidField:
value = f
else:
value = f

# Determine the data type based on the aggregation operator
self.pipelines[-1]["$group"] |= {annotate_name: {acc: value}}
self.base_model._fields[annotate_name] = field_type # noqa
return self
Expand All @@ -300,7 +329,7 @@ def __match(self, matches: dict[str, Any]):

@staticmethod
def __lookup(
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
) -> dict[str, dict[str, str]]:
"""
Generates a MongoDB lookup pipeline stage.
Expand Down Expand Up @@ -345,8 +374,9 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]:

return merged_pipeline

@last_out_stage_check
def lookup(
self, from_collection: Document, let: list[str], query: list[Q], as_name: str
self, from_collection: Document, let: list[str], query: list[Q], as_name: str
) -> "Aggify":
"""
Generates a MongoDB lookup pipeline stage.
Expand All @@ -363,8 +393,8 @@ def lookup(
check_fields_exist(self.base_model, let) # noqa

let_dict = {
field: f"${self.base_model._fields[field].db_field}" for field in let
} # noqa
field: f"${self.base_model._fields[field].db_field}" for field in let # noqa
}
from_collection = from_collection._meta.get("collection") # noqa

lookup_stages = []
Expand Down Expand Up @@ -443,6 +473,7 @@ def _replace_base(self, embedded_field) -> str:

return f"${model_field.db_field}"

@last_out_stage_check
def replace_root(self, *, embedded_field: str, merge: dict | None = None) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Expand Down Expand Up @@ -480,6 +511,7 @@ def replace_root(self, *, embedded_field: str, merge: dict | None = None) -> "Ag

return self

@last_out_stage_check
def replace_with(self, *, embedded_field: str, merge: dict | None = None) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Expand Down
Loading