Skip to content

Commit

Permalink
Add validators and dataloaders
Browse files Browse the repository at this point in the history
- Add basic test cases
- Add enable_publicly_viewable_analysis_report_snapshot in Project for
  global switch
  • Loading branch information
thenav56 committed Oct 31, 2023
1 parent 542efed commit d53da42
Show file tree
Hide file tree
Showing 19 changed files with 1,531 additions and 178 deletions.
98 changes: 98 additions & 0 deletions apps/analysis/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
from .models import (
Analysis,
AnalysisPillar,
AnalysisReport,
AnalyticalStatement,
AnalyticalStatementEntry,
DiscardedEntry,
TopicModelCluster,
AnalysisReportUpload,
AnalysisReportContainerData,
AnalysisReportContainer,
AnalysisReportSnapshot,
)


Expand Down Expand Up @@ -140,6 +145,75 @@ def batch_load_fn(self, keys):
return Promise.resolve([_map.get(key, []) for key in keys])


# -------------- Report Module -------------------------------
class AnalysisReportUploadsLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
qs = AnalysisReportUpload.objects.filter(
id__in=keys,
)
_map = {
item.pk: item
for item in qs
}
return Promise.resolve([_map.get(key, []) for key in keys])


class AnalysisReportContainerDataByContainerLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
qs = AnalysisReportContainerData.objects.filter(
container__in=keys,
)
_map = defaultdict(list)
for item in qs:
_map[item.container_id].append(item)
return Promise.resolve([_map.get(key, []) for key in keys])


class OrganizationByAnalysisReportLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
qs = AnalysisReport.organizations.through.objects.filter(
analysisreport__in=keys,
).select_related('organization')
_map = defaultdict(list)
for item in qs:
_map[item.analysisreport_id].append(item.organization)
return Promise.resolve([_map[key] for key in keys])


class ReportUploadByAnalysisReportLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
qs = AnalysisReportUpload.objects.filter(
report__in=keys,
)
_map = defaultdict(list)
for item in qs:
_map[item.report_id].append(item)
return Promise.resolve([_map[key] for key in keys])


class AnalysisReportContainerByAnalysisReportLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
qs = AnalysisReportContainer.objects.filter(
report__in=keys,
)
_map = defaultdict(list)
for item in qs:
_map[item.report_id].append(item)
return Promise.resolve([_map[key] for key in keys])


class LatestReportSnapshotByAnalysisReportLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
qs = AnalysisReportSnapshot.objects.filter(
report__in=keys,
).order_by('report_id', '-published_on').distinct('report_id')
_map = {
snapshot.report_id: snapshot
for snapshot in qs
}
return Promise.resolve([_map.get(key) for key in keys])


class DataLoaders(WithContextMixin):
@cached_property
def analysis_publication_date(self):
Expand Down Expand Up @@ -176,3 +250,27 @@ def analytical_statement_entries(self):
@cached_property
def topic_model_cluster_entries(self):
return AnalysisTopicModelClusterEntryLoader(context=self.context)

@cached_property
def analysis_report_uploads(self):
return AnalysisReportUploadsLoader(context=self.context)

@cached_property
def analysis_report_container_data_by_container(self):
return AnalysisReportContainerDataByContainerLoader(context=self.context)

@cached_property
def organization_by_analysis_report(self):
return OrganizationByAnalysisReportLoader(context=self.context)

@cached_property
def analysis_report_uploads_by_analysis_report(self):
return ReportUploadByAnalysisReportLoader(context=self.context)

@cached_property
def analysis_report_container_by_analysis_report(self):
return AnalysisReportContainerByAnalysisReportLoader(context=self.context)

@cached_property
def latest_report_snapshot_by_analysis_report(self):
return LatestReportSnapshotByAnalysisReportLoader(context=self.context)
16 changes: 16 additions & 0 deletions apps/analysis/factories.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import factory
from factory.django import DjangoModelFactory

from gallery.factories import FileFactory
from .models import (
Analysis,
AnalysisPillar,
AnalyticalStatement,
AnalyticalStatementEntry,
DiscardedEntry,
AnalysisReport,
AnalysisReportUpload,
)


Expand Down Expand Up @@ -44,3 +47,16 @@ class AnalyticalStatementEntryFactory(DjangoModelFactory):

class Meta:
model = AnalyticalStatementEntry


class AnalysisReportFactory(DjangoModelFactory):
class Meta:
model = AnalysisReport


class AnalysisReportUploadFactory(DjangoModelFactory):
type = AnalysisReportUpload.Type.CSV
file = factory.SubFactory(FileFactory)

class Meta:
model = AnalysisReportUpload
5 changes: 5 additions & 0 deletions apps/analysis/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ class AnalyticalStatementGeoEntry(models.Model):
data = models.JSONField(default=list)


# ---- Analysis Report ----
class AnalysisReport(UserResource):
analysis = models.ForeignKey(Analysis, on_delete=models.CASCADE)
is_public = models.BooleanField(
Expand All @@ -569,6 +570,7 @@ def get_latest_snapshot(slug=None, report_id=None):
return
queryset = AnalysisReportSnapshot.objects.filter(
report__is_public=True,
report__analysis__project__enable_publicly_viewable_analysis_report_snapshot=True,
)
if slug is not None:
queryset = queryset.filter(report__slug=slug)
Expand Down Expand Up @@ -615,6 +617,9 @@ class AnalysisReportContainerData(models.Model):
# Generic for now. Client will define this later
data = models.JSONField(default=dict)

# Types
container_id: int


class AnalysisReportSnapshot(UserResource):
report = models.ForeignKey(AnalysisReport, on_delete=models.CASCADE)
Expand Down
9 changes: 8 additions & 1 deletion apps/analysis/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class Arguments:
result = graphene.Field(AnalyticalStatementGeoTaskType)


# Analysis Report
# ----------------- Analysis Report ------------------------------------------
class CreateAnalysisReport(AnalysisReportMutationMixin, PsGrapheneMutation):
class Arguments:
data = AnalysisReportInputType(required=True)
Expand All @@ -229,6 +229,13 @@ class Arguments:
serializer_class = AnalysisReportSerializer
result = graphene.Field(AnalysisReportType)

@classmethod
def get_serializer_context(cls, instance, context):
return {
**context,
'report': instance,
}


class DeleteAnalysisReport(AnalysisReportMutationMixin, PsDeleteMutation):
class Arguments:
Expand Down
39 changes: 21 additions & 18 deletions apps/analysis/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ class Meta:
model = AnalysisReportUpload
only_fields = (
'id',
'file', # TODO Dataloader
'file',
)

report = graphene.ID(source='report_id', required=True)
Expand All @@ -578,16 +578,24 @@ class Meta:
def get_custom_queryset(queryset, info, **_):
return get_analysis_report_upload_qs(info)

@staticmethod
def resolve_file(root, info, **_):
return info.context.dl.deep_gallery.file.load(root.file_id)


class AnalysisReportContainerDataType(ClientIdMixin, DjangoObjectType):
class Meta:
model = AnalysisReportContainerData
only_fields = (
'id',
'upload', # AnalysisReportUploadType # TODO: Dataloader
'upload', # AnalysisReportUploadType
'data', # NOTE: This is Generic for now
)

@staticmethod
def resolve_upload(root, info, **_):
return info.context.dl.analysis.analysis_report_uploads.load(root.upload_id)


class AnalysisReportContainerType(ClientIdMixin, DjangoObjectType):
class Meta:
Expand Down Expand Up @@ -619,9 +627,9 @@ class Meta:
)
content_data = graphene.List(graphene.NonNull(AnalysisReportContainerDataType), required=True)

@staticmethod
def resolve_content_data(root, info, **_):
# TODO: Dataloader
return root.analysisreportcontainerdata_set.all()
return info.context.dl.analysis.analysis_report_container_data_by_container.load(root.pk)


class AnalysisReportSnapshotType(DjangoObjectType):
Expand All @@ -647,15 +655,14 @@ def resolve_published_by(root, info, **_):

@staticmethod
def resolve_files(root, info, **_):
# TODO: Maybe filter this out?
# For now
# - organization logos
# - report uploads
# TODO: use queryset instead
related_file_id = [
*list(root.report.analysisreportupload_set.values_list('file_id', flat=True)),
*list(root.report.organizations.values_list('logo_id', flat=True)),
]
related_file_id = (
root.report.analysisreportupload_set.values_list('file').union(
root.report.organizations.values_list('logo')
)
)
return GalleryFile.objects.filter(id__in=related_file_id).all()


Expand Down Expand Up @@ -692,23 +699,19 @@ def get_custom_queryset(queryset, info, **_):

@staticmethod
def resolve_organizations(root, info, **_):
# TODO: Dataloader
return root.organizations.all()
return info.context.dl.analysis.organization_by_analysis_report.load(root.pk)

@staticmethod
def resolve_uploads(root, info, **_):
# TODO: Dataloader
return root.analysisreportupload_set.all()
return info.context.dl.analysis.analysis_report_uploads_by_analysis_report.load(root.pk)

@staticmethod
def resolve_containers(root, info, **_):
# TODO: Dataloader
return root.analysisreportcontainer_set.all()
return info.context.dl.analysis.analysis_report_container_by_analysis_report.load(root.pk)

@staticmethod
def resolve_latest_snapshot(root, info, **_):
# TODO: Dataloader
return AnalysisReport.get_latest_snapshot(report_id=root.id)
return info.context.dl.analysis.latest_report_snapshot_by_analysis_report.load(root.pk)


class AnalysisReportListType(CustomDjangoListObjectType):
Expand Down
Loading

0 comments on commit d53da42

Please sign in to comment.