Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
seyed-dev committed Nov 6, 2023
1 parent 35e5894 commit 55e304b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 53 deletions.
79 changes: 28 additions & 51 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
from typing import Any, Dict, Type, Union, List, TypeVar, Callable
from typing import Any, Dict, Type, Union, List

from mongoengine import Document, EmbeddedDocument, fields as mongoengine_fields
from mongoengine.base import TopLevelDocumentMetaclass
Expand All @@ -10,7 +9,6 @@
AnnotationError,
InvalidField,
InvalidEmbeddedField,
OutStageError,
InvalidArgument,
)
from aggify.types import QueryParams, CollectionType
Expand All @@ -21,30 +19,9 @@
convert_match_query,
check_field_exists,
get_db_field,
last_out_stage_check,
)

AggifyType = TypeVar("AggifyType", bound=Callable[..., "Aggify"])


def last_out_stage_check(method: AggifyType) -> AggifyType:
"""Check if the last stage is $out or not
This decorator check if the last stage is $out or not
MongoDB does not allow adding aggregation pipeline stage after $out stage
"""

@functools.wraps(method)
def decorator(*args, **kwargs):
try:
if bool(args[0].pipelines[-1].get("$out")):
raise OutStageError(method.__name__)
except IndexError:
return method(*args, **kwargs)
else:
return method(*args, **kwargs)

return decorator


class Aggify:
def __init__(self, base_model: Type[Document]):
Expand Down Expand Up @@ -94,7 +71,7 @@ def project(self, **kwargs: QueryParams) -> "Aggify":
if value == 1:
to_keep_values.add(key)
elif key not in self.base_model._fields and isinstance( # noqa
kwargs[key], (str, dict)
kwargs[key], (str, dict)
):
to_keep_values.add(key)
self.base_model._fields[key] = mongoengine_fields.IntField() # noqa
Expand Down Expand Up @@ -175,7 +152,7 @@ def add_fields(self, **fields) -> "Aggify": # noqa

@last_out_stage_check
def filter(
self, arg: Union[Q, None] = None, **kwargs: Union[QueryParams, F, list]
self, arg: Union[Q, None] = None, **kwargs: Union[QueryParams, F, list]
) -> "Aggify":
"""
# TODO: missing docs
Expand Down Expand Up @@ -246,16 +223,16 @@ def __to_aggregate(self, query: Dict[str, Any]) -> None:
join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore
# Check conditions for creating a 'match' pipeline stage.
if (
isinstance(
join_field, TopLevelDocumentMetaclass
) # check whether field is added by lookup stage or not
or "document_type_obj"
not in join_field.__dict__ # Check whether this field is a join field or not.
or issubclass(
isinstance(
join_field, TopLevelDocumentMetaclass
) # check whether field is added by lookup stage or not
or "document_type_obj"
not in join_field.__dict__ # Check whether this field is a join field or not.
or issubclass(
join_field.document_type, EmbeddedDocument # noqa
) # Check whether this field is embedded field or not
or len(split_query) == 1
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
or len(split_query) == 1
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
):
# Create a 'match' pipeline stage.
match = self.__match({key: value})
Expand Down Expand Up @@ -301,10 +278,10 @@ def __getitem__(self, index: Union[slice, int]) -> "Aggify":

@last_out_stage_check
def unwind(
self,
path: str,
include_array_index: Union[str, None] = None,
preserve: bool = False,
self,
path: str,
include_array_index: Union[str, None] = None,
preserve: bool = False,
) -> "Aggify":
"""Generates a MongoDB unwind pipeline stage.
Expand Down Expand Up @@ -352,7 +329,7 @@ def unwind(
return self

def annotate(
self, annotate_name: str, accumulator: str, f: Union[Union[str, Dict], F, int]
self, annotate_name: str, accumulator: str, f: Union[Union[str, Dict], F, int]
) -> "Aggify":
"""
Annotate a MongoDB aggregation pipeline with a new field.
Expand Down Expand Up @@ -465,7 +442,7 @@ def __match(self, matches: Dict[str, Any]):

@staticmethod
def __lookup(
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
from_collection: str, local_field: str, as_name: str, foreign_field: str = "_id"
) -> Dict[str, Dict[str, str]]:
"""
Generates a MongoDB lookup pipeline stage.
Expand Down Expand Up @@ -546,13 +523,13 @@ def get_field_name_recursively(self, field: str) -> str:

@last_out_stage_check
def lookup(
self,
from_collection: CollectionType,
as_name: str,
query: Union[List[Q], Union[Q, None], List["Aggify"]] = None,
let: Union[List[str], None] = None,
local_field: Union[str, None] = None,
foreign_field: Union[str, None] = None,
self,
from_collection: CollectionType,
as_name: str,
query: Union[List[Q], Union[Q, None], List["Aggify"]] = None,
let: Union[List[str], None] = None,
local_field: Union[str, None] = None,
foreign_field: Union[str, None] = None,
) -> "Aggify":
"""
Generates a MongoDB lookup pipeline stage.
Expand Down Expand Up @@ -672,15 +649,15 @@ def _replace_base(self, embedded_field) -> str:
model_field = self.get_model_field(self.base_model, embedded_field) # noqa

if not hasattr(model_field, "document_type") or not issubclass(
model_field.document_type, EmbeddedDocument
model_field.document_type, EmbeddedDocument
):
raise InvalidEmbeddedField(field=embedded_field)

return f"${model_field.db_field}"

@last_out_stage_check
def replace_root(
self, *, embedded_field: str, merge: Union[Dict, None] = None
self, *, embedded_field: str, merge: Union[Dict, None] = None
) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Expand Down Expand Up @@ -708,7 +685,7 @@ def replace_root(

@last_out_stage_check
def replace_with(
self, *, embedded_field: str, merge: Union[Dict, None] = None
self, *, embedded_field: str, merge: Union[Dict, None] = None
) -> "Aggify":
"""
Replace the root document in the aggregation pipeline with a specified embedded field or a merged result.
Expand Down
27 changes: 25 additions & 2 deletions aggify/utilty.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any, Union, List, Dict
import functools
from typing import Any, Union, List, Dict, TypeVar, Callable

from mongoengine import Document
from aggify.types import CollectionType
from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField
from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField, OutStageError

AggifyType = TypeVar("AggifyType", bound=Callable[..., "Aggify"])


def to_mongo_positive_index(index: Union[int, slice]) -> slice:
Expand Down Expand Up @@ -138,3 +141,23 @@ def get_db_field(model: CollectionType, field: str, add_dollar_sign=False) -> st
return f"${db_field}" if add_dollar_sign else db_field
except AttributeError:
return field


def last_out_stage_check(method: AggifyType) -> AggifyType:
"""Check if the last stage is $out or not
This decorator check if the last stage is $out or not
MongoDB does not allow adding aggregation pipeline stage after $out stage
"""

@functools.wraps(method)
def decorator(*args, **kwargs):
try:
if bool(args[0].pipelines[-1].get("$out")):
raise OutStageError(method.__name__)
except IndexError:
return method(*args, **kwargs)
else:
return method(*args, **kwargs)

return decorator

0 comments on commit 55e304b

Please sign in to comment.