diff --git a/apps/assisted_tagging/admin.py b/apps/assisted_tagging/admin.py index 8bcc2de9f4..5327f40863 100644 --- a/apps/assisted_tagging/admin.py +++ b/apps/assisted_tagging/admin.py @@ -2,9 +2,16 @@ from admin_auto_filters.filters import AutocompleteFilterFactory from django.contrib import admin -from assisted_tagging.models import AssistedTaggingModelPredictionTag, AssistedTaggingPrediction, DraftEntry +from assisted_tagging.models import ( + AssistedTaggingModelPredictionTag, + AssistedTaggingPrediction, + DraftEntry, + LLMAssistedTaggingPredication +) from deep.admin import VersionAdmin +admin.site.register(LLMAssistedTaggingPredication) + @admin.register(DraftEntry) class DraftEntryAdmin(VersionAdmin): diff --git a/apps/assisted_tagging/dataloaders.py b/apps/assisted_tagging/dataloaders.py index 959403e55f..7b978d22f3 100644 --- a/apps/assisted_tagging/dataloaders.py +++ b/apps/assisted_tagging/dataloaders.py @@ -3,7 +3,7 @@ from django.utils.functional import cached_property -from assisted_tagging.models import AssistedTaggingPrediction +from assisted_tagging.models import AssistedTaggingPrediction, LLMAssistedTaggingPredication from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin @@ -18,7 +18,21 @@ def batch_load_fn(self, keys): return Promise.resolve([_map.get(key, []) for key in keys]) +class LLMDraftEntryPredicationsLoader(DataLoaderWithContext): + def batch_load_fn(self, keys): + llm_assisted_tagging_qs = LLMAssistedTaggingPredication.objects.filter(draft_entry_id__in=keys) + _map = { + assisted_tagging.draft_entry_id: assisted_tagging + for assisted_tagging in llm_assisted_tagging_qs + } + return Promise.resolve([_map.get(key) for key in keys]) + + class DataLoaders(WithContextMixin): @cached_property def draft_entry_predications(self): return DraftEntryPredicationsLoader(context=self.context) + + @cached_property + def llm_draft_entry_predications(self): + return LLMDraftEntryPredicationsLoader(context=self.context) diff --git a/apps/assisted_tagging/migrations/0013_llmassistedtaggingpredication.py b/apps/assisted_tagging/migrations/0013_llmassistedtaggingpredication.py new file mode 100644 index 0000000000..3b88bdfde6 --- /dev/null +++ b/apps/assisted_tagging/migrations/0013_llmassistedtaggingpredication.py @@ -0,0 +1,24 @@ +# Generated by Django 3.2.25 on 2024-11-28 05:00 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('assisted_tagging', '0012_auto_20231222_0554'), + ] + + operations = [ + migrations.CreateModel( + name='LLMAssistedTaggingPredication', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('value', models.CharField(blank=True, max_length=255)), + ('model_tags', models.JSONField(blank=True, null=True)), + ('draft_entry', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='llmpredictions', to='assisted_tagging.draftentry')), + ('model_version', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='+', to='assisted_tagging.assistedtaggingmodelversion')), + ], + ), + ] diff --git a/apps/assisted_tagging/models.py b/apps/assisted_tagging/models.py index 85875daf3b..9d8762de97 100644 --- a/apps/assisted_tagging/models.py +++ b/apps/assisted_tagging/models.py @@ -189,6 +189,18 @@ def __str__(self): return str(self.id) +class LLMAssistedTaggingPredication(models.Model): + model_version = models.ForeignKey(AssistedTaggingModelVersion, on_delete=models.CASCADE, related_name='+') + draft_entry = models.ForeignKey(DraftEntry, on_delete=models.CASCADE, related_name='llmpredictions') + value = models.CharField(max_length=255, blank=True) + model_tags = models.JSONField(null=True, blank=True) + + id: int + + def __str__(self): + return str(self.id) + + class WrongPredictionReview(UserResource): prediction = models.ForeignKey( AssistedTaggingPrediction, diff --git a/apps/assisted_tagging/schema.py b/apps/assisted_tagging/schema.py index 8bc4265035..228716d43f 100644 --- a/apps/assisted_tagging/schema.py +++ b/apps/assisted_tagging/schema.py @@ -7,6 +7,7 @@ from utils.graphene.enums import EnumDescription from user_resource.schema import UserResourceMixin from deep.permissions import ProjectPermissions as PP +from graphene.types.generic import GenericScalar from geo.schema import ( ProjectGeoAreaType, @@ -20,6 +21,7 @@ AssistedTaggingModelVersion, AssistedTaggingModelPredictionTag, AssistedTaggingPrediction, + LLMAssistedTaggingPredication, MissingPredictionReview, WrongPredictionReview, ) @@ -145,6 +147,22 @@ class Meta: ''' +class LLMAssistedTaggingPredictionType(DjangoObjectType): + model_version = graphene.ID(source='model_version_id', required=True) + draft_entry = graphene.ID(source='draft_entry_id', required=True) + model_tags = GenericScalar() + + class Meta: + model = LLMAssistedTaggingPredication + only_fields = ( + 'id', + 'model_tags' + ) + ''' + NOTE: model_version_deepl_model_id and wrong_prediction_review are not included here because they are not used in client + ''' + + class MissingPredictionReviewType(UserResourceMixin, DjangoObjectType): category = graphene.ID(source='category_id', required=True) tag = graphene.ID(source='tag_id', required=True) @@ -160,9 +178,7 @@ class Meta: class DraftEntryType(DjangoObjectType): prediction_status = graphene.Field(DraftEntryPredictionStatusEnum, required=True) prediction_status_display = EnumDescription(source='get_prediction_status_display', required=True) - prediction_tags = graphene.List( - graphene.NonNull(AssistedTaggingPredictionType) - ) + tags = graphene.Field(LLMAssistedTaggingPredictionType) geo_areas = graphene.List( graphene.NonNull(ProjectGeoAreaType) ) @@ -187,6 +203,10 @@ def resolve_prediction_tags(root, info, **kwargs): def resolve_geo_areas(root, info, **_): return info.context.dl.geo.draft_entry_geo_area.load(root.pk) + @staticmethod + def resolve_tags(root, info, **_): + return info.context.dl.assisted_tagging.llm_draft_entry_predications.load(root.pk) + class DraftEntryListType(CustomDjangoListObjectType): class Meta: diff --git a/apps/assisted_tagging/serializers.py b/apps/assisted_tagging/serializers.py index a0baecaa2b..aa3566a823 100644 --- a/apps/assisted_tagging/serializers.py +++ b/apps/assisted_tagging/serializers.py @@ -23,7 +23,7 @@ def validate_lead(self, lead): if lead.project != self.project: raise serializers.ValidationError('Only lead from current project are allowed.') af = lead.project.analysis_framework - if af is None or not af.assisted_tagging_enabled: + if af is None: raise serializers.ValidationError('Assisted tagging is disabled for the Framework used by this project.') return lead diff --git a/apps/assisted_tagging/tasks.py b/apps/assisted_tagging/tasks.py index c809c30bd4..f75cc04983 100644 --- a/apps/assisted_tagging/tasks.py +++ b/apps/assisted_tagging/tasks.py @@ -7,8 +7,8 @@ from utils.common import redis_lock from deep.deepl import DeeplServiceEndpoint from deepl_integration.handlers import ( - AssistedTaggingDraftEntryHandler, - AutoAssistedTaggingDraftEntryHandler, + LlmAssistedTaggingDraftEntryHandler, + LLMAutoAssistedTaggingDraftEntryHandler, BaseHandler as DeepHandler ) @@ -95,14 +95,14 @@ def sync_models_with_deepl(): @redis_lock('trigger_request_for_draft_entry_task_{0}', 60 * 60 * 0.5) def trigger_request_for_draft_entry_task(draft_entry_id): draft_entry = DraftEntry.objects.get(pk=draft_entry_id) - return AssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry) + return LlmAssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry) @shared_task @redis_lock('trigger_request_for_auto_draft_entry_task_{0}', 60 * 60 * 0.5) def trigger_request_for_auto_draft_entry_task(lead_id): lead = Lead.objects.get(id=lead_id) - return AutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead) + return LLMAutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead) @shared_task diff --git a/apps/assisted_tagging/tests/test_query.py b/apps/assisted_tagging/tests/test_query.py index 19601d9fa2..e00711a0a6 100644 --- a/apps/assisted_tagging/tests/test_query.py +++ b/apps/assisted_tagging/tests/test_query.py @@ -36,17 +36,8 @@ class TestAssistedTaggingQuery(GraphQLTestCase): ENABLE_NOW_PATCHER = True ASSISTED_TAGGING_NLP_DATA = ''' - query MyQuery ($taggingModelId: ID!, $predictionTag: ID!) { + query MyQuery ($taggingModelId: ID! ) { assistedTagging { - predictionTags { - id - group - isCategory - isDeprecated - hideInAnalysisFrameworkMapping - parentTag - tagId - } taggingModels { id modelId @@ -65,15 +56,6 @@ class TestAssistedTaggingQuery(GraphQLTestCase): version } } - predictionTag(id: $predictionTag) { - id - group - isCategory - isDeprecated - hideInAnalysisFrameworkMapping - parentTag - tagId - } } } ''' @@ -88,15 +70,6 @@ class TestAssistedTaggingQuery(GraphQLTestCase): predictionStatus predictionStatusDisplay predictionReceivedAt - predictionTags { - id - modelVersion - dataType - dataTypeDisplay - value - category - tag - } geoAreas { title } @@ -111,14 +84,12 @@ def test_unified_connector_nlp_data(self): model1, *other_models = AssistedTaggingModelFactory.create_batch(2) AssistedTaggingModelVersionFactory.create_batch(2, model=model1) - tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5) # -- without login content = self.query_check( self.ASSISTED_TAGGING_NLP_DATA, variables=dict( taggingModelId=model1.id, - predictionTag=tag1.id, ), assert_for_error=True, ) @@ -129,31 +100,8 @@ def test_unified_connector_nlp_data(self): self.ASSISTED_TAGGING_NLP_DATA, variables=dict( taggingModelId=model1.id, - predictionTag=tag1.id, ) )['data']['assistedTagging'] - self.assertEqual(content['predictionTags'], [ - dict( - id=str(tag.id), - tagId=tag.tag_id, - isDeprecated=tag.is_deprecated, - isCategory=tag.is_category, - group=tag.group, - hideInAnalysisFrameworkMapping=tag.hide_in_analysis_framework_mapping, - parentTag=tag.parent_tag_id and str(tag.parent_tag_id), - ) - for tag in [tag1, *other_tags] - ]) - self.assertEqual(content['predictionTag'], dict( - id=str(tag1.id), - tagId=tag1.tag_id, - isDeprecated=tag1.is_deprecated, - isCategory=tag1.is_category, - group=tag1.group, - hideInAnalysisFrameworkMapping=tag1.hide_in_analysis_framework_mapping, - parentTag=tag1.parent_tag_id and str(tag1.parent_tag_id), - )) - self.assertEqual(content['taggingModels'], [ dict( id=str(_model.id), @@ -196,38 +144,8 @@ def test_unified_connector_draft_entry(self): GeoAreaFactory.create(admin_level=admin_level, title='Nepal') GeoAreaFactory.create(admin_level=admin_level, title='Bagmati') GeoAreaFactory.create(admin_level=admin_level, title='Kathmandu') - model1 = AssistedTaggingModelFactory.create() - geo_model = AssistedTaggingModelFactory.create(model_id=AssistedTaggingModel.ModelID.GEO) - latest_model1_version = AssistedTaggingModelVersionFactory.create_batch(2, model=model1)[0] - latest_geo_model_version = AssistedTaggingModelVersionFactory.create(model=geo_model) - category1, tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5) - draft_entry1 = DraftEntryFactory.create(project=project, lead=lead, excerpt='sample excerpt') - prediction1 = AssistedTaggingPredictionFactory.create( - data_type=AssistedTaggingPrediction.DataType.TAG, - model_version=latest_model1_version, - draft_entry=draft_entry1, - category=category1, - tag=tag1, - prediction=0.1, - threshold=0.05, - is_selected=True, - ) - prediction2 = AssistedTaggingPredictionFactory.create( - data_type=AssistedTaggingPrediction.DataType.RAW, - model_version=latest_geo_model_version, - draft_entry=draft_entry1, - value='Nepal', - is_selected=True, - ) - prediction3 = AssistedTaggingPredictionFactory.create( - data_type=AssistedTaggingPrediction.DataType.RAW, - model_version=latest_geo_model_version, - draft_entry=draft_entry1, - value='Kathmandu', - is_selected=True, - ) draft_entry1.save_geo_data() def _query_check(**kwargs): @@ -257,44 +175,7 @@ def _query_check(**kwargs): predictionReceivedAt=None, predictionStatus=self.genum(draft_entry1.prediction_status), predictionStatusDisplay=draft_entry1.get_prediction_status_display(), - predictionTags=[ - dict( - id=str(prediction1.pk), - modelVersion=str(prediction1.model_version_id), - dataType=self.genum(prediction1.data_type), - dataTypeDisplay=prediction1.get_data_type_display(), - value='', - category=str(prediction1.category_id), - tag=str(prediction1.tag_id), - ), - dict( - id=str(prediction2.id), - modelVersion=str(prediction2.model_version.id), - dataType=self.genum(prediction2.data_type), - dataTypeDisplay=prediction2.get_data_type_display(), - value=prediction2.value, - category=None, - tag=None, - ), - dict( - id=str(prediction3.id), - modelVersion=str(prediction3.model_version.id), - dataType=self.genum(prediction3.data_type), - dataTypeDisplay=prediction3.get_data_type_display(), - value=prediction3.value, - category=None, - tag=None, - ) - ], - geoAreas=[ - dict( - title='Nepal', - ), - dict( - title='Kathmandu', - ) - - ], + geoAreas=[] )) diff --git a/apps/deepl_integration/handlers.py b/apps/deepl_integration/handlers.py index 70e685c592..7731c5ab11 100644 --- a/apps/deepl_integration/handlers.py +++ b/apps/deepl_integration/handlers.py @@ -27,6 +27,7 @@ AssistedTaggingModelVersion, AssistedTaggingModelPredictionTag, AssistedTaggingPrediction, + LLMAssistedTaggingPredication, ) from unified_connector.models import ( ConnectorLead, @@ -1128,3 +1129,267 @@ def save_data( else: geo_task.status = AnalyticalStatementGeoTask.Status.FAILED geo_task.save(update_fields=('status',)) + + +class LlmAssistedTaggingDraftEntryHandler(BaseHandler): + model = DraftEntry + callback_url_name = 'llm-assisted_tagging_draft_entry_prediction_callback' + + @classmethod + def send_trigger_request_to_extractor(cls, draft_entry): + source_organization = draft_entry.lead.source + author_organizations = [ + author.data.title + for author in draft_entry.lead.authors.all() + ] + payload = { + 'entries': [ + { + 'client_id': cls.get_client_id(draft_entry), + 'entry': draft_entry.excerpt, + } + ], + + 'project_id': draft_entry.project_id, + 'af_id': draft_entry.project.analysis_framework.id, + 'publishing_organization': source_organization and source_organization.data.title, + 'authoring_organization': author_organizations, + 'callback_url': cls.get_callback_url(), + } + response_content = None + try: + response = requests.post( + DeeplServiceEndpoint.LLM_ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT, + headers=cls.REQUEST_HEADERS, + json=payload + ) + response_content = response.content + if response.status_code == 202: + return True + except Exception: + logger.error('Assisted tagging send failed, Exception occurred!!', exc_info=True) + draft_entry.prediction_status = DraftEntry.PredictionStatus.SEND_FAILED + draft_entry.save(update_fields=('prediction_status',)) + logger.error( + 'Assisted tagging send failed!!', + extra={ + 'data': { + 'payload': payload, + 'response': response_content, + }, + }, + ) + + # --- Callback logics + @staticmethod + def _get_or_create_models_version(models_data): + def get_versions_map(): + return { + (model_version.model.model_id, model_version.version): model_version + for model_version in AssistedTaggingModelVersion.objects.filter( + reduce( + lambda acc, item: acc | item, + [ + models.Q( + model__model_id=model_data['id'], + version=model_data['version'], + ) + for model_data in models_data + ], + ) + ).select_related('model').all() + } + + existing_model_versions = get_versions_map() + new_model_versions = [ + model_data + for model_data in models_data + if (model_data['id'], model_data['version']) not in existing_model_versions + ] + + if new_model_versions: + AssistedTaggingModelVersion.objects.bulk_create([ + AssistedTaggingModelVersion( + model=AssistedTaggingModel.objects.get_or_create( + model_id=model_data['id'], + defaults=dict( + name=model_data['id'], + ), + )[0], + version=model_data['version'], + ) + for model_data in models_data + ]) + existing_model_versions = get_versions_map() + return existing_model_versions + + @classmethod + def _process_model_preds(cls, model_version, draft_entry, model_prediction): + LLMAssistedTaggingPredication.objects.create( + model_tags=model_prediction['model_tags'], + draft_entry=draft_entry, + model_version=model_version + ) + + @classmethod + def save_data(cls, draft_entry, data): + model_preds = data + models_version_map = cls._get_or_create_models_version( + [ + model_preds['model_info'] + ] + ) + with transaction.atomic(): + draft_entry.clear_data() # Clear old data if exists + draft_entry.calculated_at = timezone.now() + model_version = models_version_map[(model_preds['model_info']['id'], model_preds['model_info']['version'])] + cls._process_model_preds(model_version, draft_entry, model_preds) + draft_entry.prediction_status = DraftEntry.PredictionStatus.DONE + draft_entry.save_geo_data() + draft_entry.save() + return draft_entry + + +class LLMAutoAssistedTaggingDraftEntryHandler(BaseHandler): + model = Lead + callback_url_name = 'auto-llm-assisted_tagging_draft_entry_prediction_callback' + + @classmethod + def auto_trigger_request_to_extractor(cls, lead): + lead_preview = LeadPreview.objects.get(lead=lead) + payload = { + "documents": [ + { + "client_id": cls.get_client_id(lead), + "text_extraction_id": str(lead_preview.text_extraction_id), + } + ], + 'project_id': lead.project_id, + 'af_id': lead.project.analysis_framework_id, + "callback_url": cls.get_callback_url() + } + response_content = None + try: + response = requests.post( + url=DeeplServiceEndpoint.LLM_ENTRY_EXTRACTION_CLASSIFICATION, + headers=cls.REQUEST_HEADERS, + json=payload + ) + response_content = response.content + if response.status_code == 202: + lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.PENDING + lead.save(update_fields=('auto_entry_extraction_status',)) + return True + + except Exception: + logger.error('Entry Extraction send failed, Exception occurred!!', exc_info=True) + lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.FAILED + lead.save(update_fields=('auto_entry_extraction_status',)) + logger.error( + 'Entry Extraction send failed!!', + extra={ + 'data': { + 'payload': payload, + 'response': response_content, + }, + }, + ) + + # --- Callback logics + @staticmethod + def _get_or_create_models_version(models_data): + def get_versions_map(): + return { + (model_version.model.model_id, model_version.version): model_version + for model_version in AssistedTaggingModelVersion.objects.filter( + reduce( + lambda acc, item: acc | item, + [ + models.Q( + model__model_id=model_data['name'], + version=model_data['version'], + ) + for model_data in models_data + ], + ) + ).select_related('model').all() + } + + existing_model_versions = get_versions_map() + new_model_versions = [ + model_data + for model_data in models_data + if (model_data['name'], model_data['version']) not in existing_model_versions + ] + + if new_model_versions: + AssistedTaggingModelVersion.objects.bulk_create([ + AssistedTaggingModelVersion( + model=AssistedTaggingModel.objects.get_or_create( + model_id=model_data['name'], + defaults=dict( + name=model_data['name'], + ), + )[0], + version=model_data['version'], + ) + for model_data in models_data + ]) + existing_model_versions = get_versions_map() + return existing_model_versions + + @classmethod + def _process_model_preds(cls, model_version, draft_entry, model_prediction): + prediction_status = model_prediction['prediction_status'] + if not prediction_status: # If False no tags are provided + return + + tags = model_prediction.get('classification', {}) # NLP TagId + + common_attrs = dict( + model_version=model_version, + draft_entry_id=draft_entry.id, + ) + LLMAssistedTaggingPredication.objects.create( + **common_attrs, + model_tags=tags + ) + + @classmethod + @transaction.atomic + def save_data(cls, lead, data_url): + # NOTE: Schema defined here + # - https://docs.google.com/document/d/1NmjOO5sOrhJU6b4QXJBrGAVk57_NW87mLJ9wzeY_NZI/edit#heading=h.t3u7vdbps5pt + data = RequestHelper(url=data_url, ignore_error=True).json() + draft_entry_qs = DraftEntry.objects.filter(lead=lead, type=DraftEntry.Type.AUTO) + if draft_entry_qs.exists(): + raise serializers.ValidationError('Draft entries already exit') + for model_preds in data['blocks']: + if not model_preds['relevant']: + continue + models_version_map = cls._get_or_create_models_version([ + data['classification_model_info'] + ]) + draft = DraftEntry.objects.create( + page=model_preds['page'], + text_order=model_preds['textOrder'], + project=lead.project, + lead=lead, + excerpt=model_preds['text'], + prediction_status=DraftEntry.PredictionStatus.STARTED, + type=DraftEntry.Type.AUTO + ) + if model_preds['geolocations']: + geo_areas_qs = GeoAreaGqlFilterSet( + data={'titles': [geo['entity'] for geo in model_preds['geolocations']]}, + queryset=GeoArea.get_for_project(lead.project) + ).qs.distinct('title') + draft.related_geoareas.set(geo_areas_qs) + + model_version = models_version_map[ + (data['classification_model_info']['name'], data['classification_model_info']['version']) + ] + cls._process_model_preds(model_version, draft, model_preds) + lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.SUCCESS + lead.save(update_fields=('auto_entry_extraction_status',)) + return lead diff --git a/apps/deepl_integration/serializers.py b/apps/deepl_integration/serializers.py index a2fe470cbd..6ecdf98f19 100644 --- a/apps/deepl_integration/serializers.py +++ b/apps/deepl_integration/serializers.py @@ -8,12 +8,14 @@ BaseHandler, AssistedTaggingDraftEntryHandler, LeadExtractionHandler, + LlmAssistedTaggingDraftEntryHandler, UnifiedConnectorLeadHandler, AnalysisTopicModelHandler, AnalysisAutomaticSummaryHandler, AnalyticalStatementNGramHandler, AnalyticalStatementGeoHandler, - AutoAssistedTaggingDraftEntryHandler + AutoAssistedTaggingDraftEntryHandler, + LLMAutoAssistedTaggingDraftEntryHandler, ) from deduplication.tasks.indexing import index_lead_and_calculate_duplicates @@ -261,6 +263,22 @@ def create(self, validated_data): ) +class LlmAssistedTaggingDraftEntryPredictionCallbackSerializer(BaseCallbackSerializer): + model_tags = serializers.DictField(child=serializers.DictField()) + prediction_status = serializers.BooleanField() + model_info = serializers.DictField() + nlp_handler = LlmAssistedTaggingDraftEntryHandler + + def create(self, validated_data): + draft_entry = validated_data['object'] + if draft_entry.prediction_status == DraftEntry.PredictionStatus.DONE: + return draft_entry + return self.nlp_handler.save_data( + draft_entry, + validated_data, + ) + + class AutoAssistedBlockPredicationCallbackSerializer(serializers.Serializer): page = serializers.IntegerField() textOrder = serializers.IntegerField() @@ -285,6 +303,20 @@ def create(self, validated_data): ) +class AutoLLMAssistedTaggingDraftEntryCallbackSerializer(BaseCallbackSerializer): + entry_extraction_classification_path = serializers.URLField(required=True) + text_extraction_id = serializers.CharField(required=True) + status = serializers.IntegerField() + nlp_handler = LLMAutoAssistedTaggingDraftEntryHandler + + def create(self, validated_data): + obj = validated_data['object'] + return self.nlp_handler.save_data( + obj, + validated_data['entry_extraction_classification_path'], + ) + + class EntriesCollectionBaseCallbackSerializer(DeeplServerBaseCallbackSerializer): model: Type[DeeplTrackBaseModel] presigned_s3_url = serializers.URLField() diff --git a/apps/deepl_integration/views.py b/apps/deepl_integration/views.py index 264d0ce6ec..890c794f6f 100644 --- a/apps/deepl_integration/views.py +++ b/apps/deepl_integration/views.py @@ -9,13 +9,15 @@ from .serializers import ( AssistedTaggingDraftEntryPredictionCallbackSerializer, + AutoLLMAssistedTaggingDraftEntryCallbackSerializer, + LlmAssistedTaggingDraftEntryPredictionCallbackSerializer, LeadExtractCallbackSerializer, UnifiedConnectorLeadExtractCallbackSerializer, AnalysisTopicModelCallbackSerializer, AnalysisAutomaticSummaryCallbackSerializer, AnalyticalStatementNGramCallbackSerializer, AnalyticalStatementGeoCallbackSerializer, - AutoAssistedTaggingDraftEntryCallbackSerializer + AutoAssistedTaggingDraftEntryCallbackSerializer, ) @@ -34,10 +36,18 @@ class AssistedTaggingDraftEntryPredictionCallbackView(BaseCallbackView): serializer = AssistedTaggingDraftEntryPredictionCallbackSerializer +class LlmAssistedTaggingDraftEntryPredictionCallbackView(BaseCallbackView): + serializer = LlmAssistedTaggingDraftEntryPredictionCallbackSerializer + + class AutoTaggingDraftEntryPredictionCallbackView(BaseCallbackView): serializer = AutoAssistedTaggingDraftEntryCallbackSerializer +class AutoLLMTaggingDraftEntryPredictionCallbackView(BaseCallbackView): + serializer = AutoLLMAssistedTaggingDraftEntryCallbackSerializer + + class LeadExtractCallbackView(BaseCallbackView): serializer = LeadExtractCallbackSerializer diff --git a/apps/geo/mutations.py b/apps/geo/mutations.py index 58538b405a..a895781144 100644 --- a/apps/geo/mutations.py +++ b/apps/geo/mutations.py @@ -100,6 +100,7 @@ def mutate(root, info, admin_level_id): ) ], ok=False) admin_level.delete() + # check boundsfile is empty or not in Region return DeleteAdminLevel(errors=None, ok=True) diff --git a/apps/geo/serializers.py b/apps/geo/serializers.py index b98a843318..204bf8a143 100644 --- a/apps/geo/serializers.py +++ b/apps/geo/serializers.py @@ -223,8 +223,9 @@ def update(self, instance, validated_data): validated_data, ) region = admin_level.region + region.status = Region.Status.INITIATED region.modified_by = self.context['request'].user - region.save(update_fields=('modified_by', 'modified_at',)) + region.save(update_fields=('modified_by', 'modified_at', 'status')) transaction.on_commit(lambda: load_geo_areas.delay(region.id)) diff --git a/apps/lead/filter_set.py b/apps/lead/filter_set.py index 20c944eb5c..cd6f3bb5d6 100644 --- a/apps/lead/filter_set.py +++ b/apps/lead/filter_set.py @@ -18,7 +18,7 @@ from project.models import Project from organization.models import OrganizationType from user.models import User -from entry.models import Entry +from entry.models import Entry, EntryAttachment from entry.filter_set import EntryGQFilterSet, EntriesFilterDataInputType, EntriesFilterDataType from user_resource.filters import UserResourceGqlFilterSet @@ -577,7 +577,18 @@ def filter_exclude_lead_attachment_ids(self, qs, _, value): def filter_exclude_leadattachment_created_entries(self, qs, _, value): if value: - qs = qs.exclude(lead__entry__isnull=value) + ids = qs.values_list('id', flat=True) + + entry_attachment_qs = EntryAttachment.objects.filter( + lead_attachment__in=ids, + lead_attachment__lead__project=self.request.active_project + ).values_list('lead_attachment__id', flat=True).distinct() + + entry_qs = Entry.objects.filter( + project=self.request.active_project, + entry_attachment__in=entry_attachment_qs + ).values_list('entry_attachment__id', flat=True).distinct() + qs = qs.exclude(id__in=entry_qs) return qs return qs diff --git a/deep/deepl.py b/deep/deepl.py index 78e47a5ab6..eba353177a 100644 --- a/deep/deepl.py +++ b/deep/deepl.py @@ -19,3 +19,5 @@ class DeeplServiceEndpoint(): ANALYSIS_GEO = f'{DEEPL_SERVER_DOMAIN}/api/v1/geolocation/' ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-classification/' ENTRY_EXTRACTION_CLASSIFICATION = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-extraction-classification/' + LLM_ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-classification-llm/' + LLM_ENTRY_EXTRACTION_CLASSIFICATION = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-extraction-classification-llm/' diff --git a/deep/urls.py b/deep/urls.py index 8b298bb667..3f14a5db8b 100644 --- a/deep/urls.py +++ b/deep/urls.py @@ -150,7 +150,9 @@ ) from deepl_integration.views import ( AssistedTaggingDraftEntryPredictionCallbackView, + LlmAssistedTaggingDraftEntryPredictionCallbackView, AutoTaggingDraftEntryPredictionCallbackView, + AutoLLMTaggingDraftEntryPredictionCallbackView, LeadExtractCallbackView, UnifiedConnectorLeadExtractCallbackView, AnalysisTopicModelCallbackView, @@ -580,6 +582,18 @@ def get_api_path(path): name='auto-assisted_tagging_draft_entry_prediction_callback', ), + re_path( + get_api_path(r'callback/llm-assisted-tagging-draft-entry-prediction/$'), + LlmAssistedTaggingDraftEntryPredictionCallbackView.as_view(), + name='llm-assisted_tagging_draft_entry_prediction_callback', + ), + + re_path( + get_api_path(r'callback/auto-llm-assisted-tagging-draft-entry-prediction/$'), + AutoLLMTaggingDraftEntryPredictionCallbackView.as_view(), + name='auto-llm-assisted_tagging_draft_entry_prediction_callback', + ), + re_path( get_api_path(r'callback/analysis-topic-model/$'), AnalysisTopicModelCallbackView.as_view(), diff --git a/schema.graphql b/schema.graphql index ac3808e7e1..2c1519b1f2 100644 --- a/schema.graphql +++ b/schema.graphql @@ -3238,20 +3238,6 @@ enum AssistedTaggingPredictionDataTypeEnum { TAG } -type AssistedTaggingPredictionType { - id: ID! - value: String! - prediction: Decimal - threshold: Decimal - isSelected: Boolean! - modelVersion: ID! - draftEntry: ID! - dataType: AssistedTaggingPredictionDataTypeEnum! - dataTypeDisplay: EnumDescription! - category: ID - tag: ID -} - type AssistedTaggingQueryType { draftEntry(id: ID!): DraftEntryType draftEntries(lead: ID, draftEntryTypes: [DraftEntryTypeEnum!], ignoreIds: [ID!], isDiscarded: Boolean, page: Int = 1, pageSize: Int): DraftEntryListType @@ -3725,8 +3711,6 @@ type DateCountType { scalar DateTime -scalar Decimal - type DeleteAdminLevel { errors: [GenericScalar!] ok: Boolean @@ -3907,7 +3891,7 @@ type DraftEntryType { predictionReceivedAt: DateTime predictionStatus: DraftEntryPredictionStatusEnum! predictionStatusDisplay: EnumDescription! - predictionTags: [AssistedTaggingPredictionType!] + tags: LLMAssistedTaggingPredictionType geoAreas: [ProjectGeoAreaType!] } @@ -4628,6 +4612,13 @@ type JwtTokenType { expiresIn: String } +type LLMAssistedTaggingPredictionType { + id: ID! + modelTags: GenericScalar + modelVersion: ID! + draftEntry: ID! +} + enum LeadAutoEntryExtractionTypeEnum { NONE STARTED