Skip to content

Commit

Permalink
feat(serializerCog): add get_model_fields method
Browse files Browse the repository at this point in the history
  • Loading branch information
Xenepix committed Nov 25, 2024
1 parent c5b70a6 commit 52170f0
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion backend/nango/cogs/serializer_cog.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from django.db.models import fields as django_fields
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel

from nango.utils import AbstractCog


Expand All @@ -6,11 +9,19 @@ class SerializerCog(AbstractCog):
Settings:
--------
- speedy: A faster serializer than rest_framework.serializers.Serializer, but with less features.
- speedy (bool ; default = False): A faster serializer than rest_framework.serializers.Serializer, but with less features.
- selected_fields (list[str] ; default = []): list of selected fields' name to serialize.
- excluded_fields (list[str] ; default = []): list of fields' name to not serialize.
"""

id = "serializer"

# Fields to NEVER serialize.
_forbidden_fields: tuple[str] = (
"password",
"outstandingtoken",
)

def _get_drf_imports(self) -> list[str]:
"""Return list of imports from rest_framework module."""
if self.settings.get("speedy", False):
Expand All @@ -28,6 +39,42 @@ def _get_base_imports(self) -> list[str]:
"from typing import ClassVar",
]

def get_model_fields(self) -> dict[str, django_fields.Field]:
"""Return the model fields to serialize.
Returns:
-------
```
{
"field_name": field,
...,
}
```
"""
selected_fields_name: list[str] = self.settings.get("selected_fields", [])
if selected_fields_name:
return {field_name: getattr(self.model, field_name) for field_name in selected_fields_name}

model_fields_list: dict[str, django_fields.Field] = {}

# Select all fields except excluded_fields
excluded_fields_name: list[str] = self.settings.get("excluded_fields", [])
for field in self.model._meta.get_fields(): # noqa: SLF001
# Get field's name
if isinstance(field, ManyToOneRel | OneToOneRel):
# External Key point at actual model
field_name = field.related_name if field.related_name is not None else f"{field.related_model.__name__.lower()}"
else:
field_name = getattr(field, "field_name", None) or getattr(field, "name", None)

if field_name in excluded_fields_name or field_name in self._forbidden_fields:
continue

# Basic behavior
model_fields_list[field_name] = field

return model_fields_list

def get_serializer_name(self) -> dict[str]:
"""Return name of generated serializers.
Expand Down Expand Up @@ -66,5 +113,7 @@ def generated_customizable_serializer(self) -> None:
def run(self) -> None:
"""Generate a serializer for each django's model."""
super().run()
self._get_model_fields()

self.generate_cached_serializer()
self.generated_customizable_serializer()

0 comments on commit 52170f0

Please sign in to comment.