Skip to content

Commit

Permalink
Integrate llm assisted endpoint
Browse files Browse the repository at this point in the history
Update llm assisted tagging query

Add dataloader of llm assisted tagging

Integrate llm auto classification

Cleanup unwanted code
  • Loading branch information
sudan45 authored and AdityaKhatri committed Dec 5, 2024
1 parent 2fdf67c commit 7f9579b
Show file tree
Hide file tree
Showing 14 changed files with 424 additions and 150 deletions.
9 changes: 8 additions & 1 deletion apps/assisted_tagging/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@
from admin_auto_filters.filters import AutocompleteFilterFactory
from django.contrib import admin

from assisted_tagging.models import AssistedTaggingModelPredictionTag, AssistedTaggingPrediction, DraftEntry
from assisted_tagging.models import (
AssistedTaggingModelPredictionTag,
AssistedTaggingPrediction,
DraftEntry,
LLMAssistedTaggingPredication
)
from deep.admin import VersionAdmin

admin.site.register(LLMAssistedTaggingPredication)


@admin.register(DraftEntry)
class DraftEntryAdmin(VersionAdmin):
Expand Down
16 changes: 15 additions & 1 deletion apps/assisted_tagging/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.utils.functional import cached_property

from assisted_tagging.models import AssistedTaggingPrediction
from assisted_tagging.models import AssistedTaggingPrediction, LLMAssistedTaggingPredication

from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin

Expand All @@ -18,7 +18,21 @@ def batch_load_fn(self, keys):
return Promise.resolve([_map.get(key, []) for key in keys])


class LLMDraftEntryPredicationsLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
llm_assisted_tagging_qs = LLMAssistedTaggingPredication.objects.filter(draft_entry_id__in=keys)
_map = {
assisted_tagging.draft_entry_id: assisted_tagging
for assisted_tagging in llm_assisted_tagging_qs
}
return Promise.resolve([_map.get(key) for key in keys])


class DataLoaders(WithContextMixin):
@cached_property
def draft_entry_predications(self):
return DraftEntryPredicationsLoader(context=self.context)

@cached_property
def llm_draft_entry_predications(self):
return LLMDraftEntryPredicationsLoader(context=self.context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 3.2.25 on 2024-11-28 05:00

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('assisted_tagging', '0012_auto_20231222_0554'),
]

operations = [
migrations.CreateModel(
name='LLMAssistedTaggingPredication',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('value', models.CharField(blank=True, max_length=255)),
('model_tags', models.JSONField(blank=True, null=True)),
('draft_entry', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='llmpredictions', to='assisted_tagging.draftentry')),
('model_version', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='+', to='assisted_tagging.assistedtaggingmodelversion')),
],
),
]
12 changes: 12 additions & 0 deletions apps/assisted_tagging/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ def __str__(self):
return str(self.id)


class LLMAssistedTaggingPredication(models.Model):
model_version = models.ForeignKey(AssistedTaggingModelVersion, on_delete=models.CASCADE, related_name='+')
draft_entry = models.ForeignKey(DraftEntry, on_delete=models.CASCADE, related_name='llmpredictions')
value = models.CharField(max_length=255, blank=True)
model_tags = models.JSONField(null=True, blank=True)

id: int

def __str__(self):
return str(self.id)


class WrongPredictionReview(UserResource):
prediction = models.ForeignKey(
AssistedTaggingPrediction,
Expand Down
26 changes: 23 additions & 3 deletions apps/assisted_tagging/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from utils.graphene.enums import EnumDescription
from user_resource.schema import UserResourceMixin
from deep.permissions import ProjectPermissions as PP
from graphene.types.generic import GenericScalar

from geo.schema import (
ProjectGeoAreaType,
Expand All @@ -20,6 +21,7 @@
AssistedTaggingModelVersion,
AssistedTaggingModelPredictionTag,
AssistedTaggingPrediction,
LLMAssistedTaggingPredication,
MissingPredictionReview,
WrongPredictionReview,
)
Expand Down Expand Up @@ -145,6 +147,22 @@ class Meta:
'''


class LLMAssistedTaggingPredictionType(DjangoObjectType):
model_version = graphene.ID(source='model_version_id', required=True)
draft_entry = graphene.ID(source='draft_entry_id', required=True)
model_tags = GenericScalar()

class Meta:
model = LLMAssistedTaggingPredication
only_fields = (
'id',
'model_tags'
)
'''
NOTE: model_version_deepl_model_id and wrong_prediction_review are not included here because they are not used in client
'''


class MissingPredictionReviewType(UserResourceMixin, DjangoObjectType):
category = graphene.ID(source='category_id', required=True)
tag = graphene.ID(source='tag_id', required=True)
Expand All @@ -160,9 +178,7 @@ class Meta:
class DraftEntryType(DjangoObjectType):
prediction_status = graphene.Field(DraftEntryPredictionStatusEnum, required=True)
prediction_status_display = EnumDescription(source='get_prediction_status_display', required=True)
prediction_tags = graphene.List(
graphene.NonNull(AssistedTaggingPredictionType)
)
tags = graphene.Field(LLMAssistedTaggingPredictionType)
geo_areas = graphene.List(
graphene.NonNull(ProjectGeoAreaType)
)
Expand All @@ -187,6 +203,10 @@ def resolve_prediction_tags(root, info, **kwargs):
def resolve_geo_areas(root, info, **_):
return info.context.dl.geo.draft_entry_geo_area.load(root.pk)

@staticmethod
def resolve_tags(root, info, **_):
return info.context.dl.assisted_tagging.llm_draft_entry_predications.load(root.pk)


class DraftEntryListType(CustomDjangoListObjectType):
class Meta:
Expand Down
2 changes: 1 addition & 1 deletion apps/assisted_tagging/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def validate_lead(self, lead):
if lead.project != self.project:
raise serializers.ValidationError('Only lead from current project are allowed.')
af = lead.project.analysis_framework
if af is None or not af.assisted_tagging_enabled:
if af is None:
raise serializers.ValidationError('Assisted tagging is disabled for the Framework used by this project.')
return lead

Expand Down
8 changes: 4 additions & 4 deletions apps/assisted_tagging/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from utils.common import redis_lock
from deep.deepl import DeeplServiceEndpoint
from deepl_integration.handlers import (
AssistedTaggingDraftEntryHandler,
AutoAssistedTaggingDraftEntryHandler,
LlmAssistedTaggingDraftEntryHandler,
LLMAutoAssistedTaggingDraftEntryHandler,
BaseHandler as DeepHandler
)

Expand Down Expand Up @@ -95,14 +95,14 @@ def sync_models_with_deepl():
@redis_lock('trigger_request_for_draft_entry_task_{0}', 60 * 60 * 0.5)
def trigger_request_for_draft_entry_task(draft_entry_id):
draft_entry = DraftEntry.objects.get(pk=draft_entry_id)
return AssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry)
return LlmAssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry)


@shared_task
@redis_lock('trigger_request_for_auto_draft_entry_task_{0}', 60 * 60 * 0.5)
def trigger_request_for_auto_draft_entry_task(lead_id):
lead = Lead.objects.get(id=lead_id)
return AutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead)
return LLMAutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead)


@shared_task
Expand Down
123 changes: 2 additions & 121 deletions apps/assisted_tagging/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,8 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
ENABLE_NOW_PATCHER = True

ASSISTED_TAGGING_NLP_DATA = '''
query MyQuery ($taggingModelId: ID!, $predictionTag: ID!) {
query MyQuery ($taggingModelId: ID! ) {
assistedTagging {
predictionTags {
id
group
isCategory
isDeprecated
hideInAnalysisFrameworkMapping
parentTag
tagId
}
taggingModels {
id
modelId
Expand All @@ -65,15 +56,6 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
version
}
}
predictionTag(id: $predictionTag) {
id
group
isCategory
isDeprecated
hideInAnalysisFrameworkMapping
parentTag
tagId
}
}
}
'''
Expand All @@ -88,15 +70,6 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
predictionStatus
predictionStatusDisplay
predictionReceivedAt
predictionTags {
id
modelVersion
dataType
dataTypeDisplay
value
category
tag
}
geoAreas {
title
}
Expand All @@ -111,14 +84,12 @@ def test_unified_connector_nlp_data(self):

model1, *other_models = AssistedTaggingModelFactory.create_batch(2)
AssistedTaggingModelVersionFactory.create_batch(2, model=model1)
tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5)

# -- without login
content = self.query_check(
self.ASSISTED_TAGGING_NLP_DATA,
variables=dict(
taggingModelId=model1.id,
predictionTag=tag1.id,
),
assert_for_error=True,
)
Expand All @@ -129,31 +100,8 @@ def test_unified_connector_nlp_data(self):
self.ASSISTED_TAGGING_NLP_DATA,
variables=dict(
taggingModelId=model1.id,
predictionTag=tag1.id,
)
)['data']['assistedTagging']
self.assertEqual(content['predictionTags'], [
dict(
id=str(tag.id),
tagId=tag.tag_id,
isDeprecated=tag.is_deprecated,
isCategory=tag.is_category,
group=tag.group,
hideInAnalysisFrameworkMapping=tag.hide_in_analysis_framework_mapping,
parentTag=tag.parent_tag_id and str(tag.parent_tag_id),
)
for tag in [tag1, *other_tags]
])
self.assertEqual(content['predictionTag'], dict(
id=str(tag1.id),
tagId=tag1.tag_id,
isDeprecated=tag1.is_deprecated,
isCategory=tag1.is_category,
group=tag1.group,
hideInAnalysisFrameworkMapping=tag1.hide_in_analysis_framework_mapping,
parentTag=tag1.parent_tag_id and str(tag1.parent_tag_id),
))

self.assertEqual(content['taggingModels'], [
dict(
id=str(_model.id),
Expand Down Expand Up @@ -196,38 +144,8 @@ def test_unified_connector_draft_entry(self):
GeoAreaFactory.create(admin_level=admin_level, title='Nepal')
GeoAreaFactory.create(admin_level=admin_level, title='Bagmati')
GeoAreaFactory.create(admin_level=admin_level, title='Kathmandu')
model1 = AssistedTaggingModelFactory.create()
geo_model = AssistedTaggingModelFactory.create(model_id=AssistedTaggingModel.ModelID.GEO)
latest_model1_version = AssistedTaggingModelVersionFactory.create_batch(2, model=model1)[0]
latest_geo_model_version = AssistedTaggingModelVersionFactory.create(model=geo_model)
category1, tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5)

draft_entry1 = DraftEntryFactory.create(project=project, lead=lead, excerpt='sample excerpt')

prediction1 = AssistedTaggingPredictionFactory.create(
data_type=AssistedTaggingPrediction.DataType.TAG,
model_version=latest_model1_version,
draft_entry=draft_entry1,
category=category1,
tag=tag1,
prediction=0.1,
threshold=0.05,
is_selected=True,
)
prediction2 = AssistedTaggingPredictionFactory.create(
data_type=AssistedTaggingPrediction.DataType.RAW,
model_version=latest_geo_model_version,
draft_entry=draft_entry1,
value='Nepal',
is_selected=True,
)
prediction3 = AssistedTaggingPredictionFactory.create(
data_type=AssistedTaggingPrediction.DataType.RAW,
model_version=latest_geo_model_version,
draft_entry=draft_entry1,
value='Kathmandu',
is_selected=True,
)
draft_entry1.save_geo_data()

def _query_check(**kwargs):
Expand Down Expand Up @@ -257,44 +175,7 @@ def _query_check(**kwargs):
predictionReceivedAt=None,
predictionStatus=self.genum(draft_entry1.prediction_status),
predictionStatusDisplay=draft_entry1.get_prediction_status_display(),
predictionTags=[
dict(
id=str(prediction1.pk),
modelVersion=str(prediction1.model_version_id),
dataType=self.genum(prediction1.data_type),
dataTypeDisplay=prediction1.get_data_type_display(),
value='',
category=str(prediction1.category_id),
tag=str(prediction1.tag_id),
),
dict(
id=str(prediction2.id),
modelVersion=str(prediction2.model_version.id),
dataType=self.genum(prediction2.data_type),
dataTypeDisplay=prediction2.get_data_type_display(),
value=prediction2.value,
category=None,
tag=None,
),
dict(
id=str(prediction3.id),
modelVersion=str(prediction3.model_version.id),
dataType=self.genum(prediction3.data_type),
dataTypeDisplay=prediction3.get_data_type_display(),
value=prediction3.value,
category=None,
tag=None,
)
],
geoAreas=[
dict(
title='Nepal',
),
dict(
title='Kathmandu',
)

],
geoAreas=[]
))


Expand Down
Loading

0 comments on commit 7f9579b

Please sign in to comment.