Skip to content

Commit

Permalink
Merge pull request #1540 from the-deep/feature/assisted-tagging-with-llm
Browse files Browse the repository at this point in the history
LLM Assisted Tagging
  • Loading branch information
AdityaKhatri authored Dec 10, 2024
2 parents 2fdf67c + f773779 commit 130c5a0
Show file tree
Hide file tree
Showing 17 changed files with 438 additions and 153 deletions.
9 changes: 8 additions & 1 deletion apps/assisted_tagging/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@
from admin_auto_filters.filters import AutocompleteFilterFactory
from django.contrib import admin

from assisted_tagging.models import AssistedTaggingModelPredictionTag, AssistedTaggingPrediction, DraftEntry
from assisted_tagging.models import (
AssistedTaggingModelPredictionTag,
AssistedTaggingPrediction,
DraftEntry,
LLMAssistedTaggingPredication
)
from deep.admin import VersionAdmin

admin.site.register(LLMAssistedTaggingPredication)


@admin.register(DraftEntry)
class DraftEntryAdmin(VersionAdmin):
Expand Down
16 changes: 15 additions & 1 deletion apps/assisted_tagging/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.utils.functional import cached_property

from assisted_tagging.models import AssistedTaggingPrediction
from assisted_tagging.models import AssistedTaggingPrediction, LLMAssistedTaggingPredication

from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin

Expand All @@ -18,7 +18,21 @@ def batch_load_fn(self, keys):
return Promise.resolve([_map.get(key, []) for key in keys])


class LLMDraftEntryPredicationsLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
llm_assisted_tagging_qs = LLMAssistedTaggingPredication.objects.filter(draft_entry_id__in=keys)
_map = {
assisted_tagging.draft_entry_id: assisted_tagging
for assisted_tagging in llm_assisted_tagging_qs
}
return Promise.resolve([_map.get(key) for key in keys])


class DataLoaders(WithContextMixin):
@cached_property
def draft_entry_predications(self):
return DraftEntryPredicationsLoader(context=self.context)

@cached_property
def llm_draft_entry_predications(self):
return LLMDraftEntryPredicationsLoader(context=self.context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 3.2.25 on 2024-11-28 05:00

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('assisted_tagging', '0012_auto_20231222_0554'),
]

operations = [
migrations.CreateModel(
name='LLMAssistedTaggingPredication',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('value', models.CharField(blank=True, max_length=255)),
('model_tags', models.JSONField(blank=True, null=True)),
('draft_entry', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='llmpredictions', to='assisted_tagging.draftentry')),
('model_version', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='+', to='assisted_tagging.assistedtaggingmodelversion')),
],
),
]
12 changes: 12 additions & 0 deletions apps/assisted_tagging/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ def __str__(self):
return str(self.id)


class LLMAssistedTaggingPredication(models.Model):
model_version = models.ForeignKey(AssistedTaggingModelVersion, on_delete=models.CASCADE, related_name='+')
draft_entry = models.ForeignKey(DraftEntry, on_delete=models.CASCADE, related_name='llmpredictions')
value = models.CharField(max_length=255, blank=True)
model_tags = models.JSONField(null=True, blank=True)

id: int

def __str__(self):
return str(self.id)


class WrongPredictionReview(UserResource):
prediction = models.ForeignKey(
AssistedTaggingPrediction,
Expand Down
26 changes: 23 additions & 3 deletions apps/assisted_tagging/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from utils.graphene.enums import EnumDescription
from user_resource.schema import UserResourceMixin
from deep.permissions import ProjectPermissions as PP
from graphene.types.generic import GenericScalar

from geo.schema import (
ProjectGeoAreaType,
Expand All @@ -20,6 +21,7 @@
AssistedTaggingModelVersion,
AssistedTaggingModelPredictionTag,
AssistedTaggingPrediction,
LLMAssistedTaggingPredication,
MissingPredictionReview,
WrongPredictionReview,
)
Expand Down Expand Up @@ -145,6 +147,22 @@ class Meta:
'''


class LLMAssistedTaggingPredictionType(DjangoObjectType):
model_version = graphene.ID(source='model_version_id', required=True)
draft_entry = graphene.ID(source='draft_entry_id', required=True)
model_tags = GenericScalar()

class Meta:
model = LLMAssistedTaggingPredication
only_fields = (
'id',
'model_tags'
)
'''
NOTE: model_version_deepl_model_id and wrong_prediction_review are not included here because they are not used in client
'''


class MissingPredictionReviewType(UserResourceMixin, DjangoObjectType):
category = graphene.ID(source='category_id', required=True)
tag = graphene.ID(source='tag_id', required=True)
Expand All @@ -160,9 +178,7 @@ class Meta:
class DraftEntryType(DjangoObjectType):
prediction_status = graphene.Field(DraftEntryPredictionStatusEnum, required=True)
prediction_status_display = EnumDescription(source='get_prediction_status_display', required=True)
prediction_tags = graphene.List(
graphene.NonNull(AssistedTaggingPredictionType)
)
tags = graphene.Field(LLMAssistedTaggingPredictionType)
geo_areas = graphene.List(
graphene.NonNull(ProjectGeoAreaType)
)
Expand All @@ -187,6 +203,10 @@ def resolve_prediction_tags(root, info, **kwargs):
def resolve_geo_areas(root, info, **_):
return info.context.dl.geo.draft_entry_geo_area.load(root.pk)

@staticmethod
def resolve_tags(root, info, **_):
return info.context.dl.assisted_tagging.llm_draft_entry_predications.load(root.pk)


class DraftEntryListType(CustomDjangoListObjectType):
class Meta:
Expand Down
2 changes: 1 addition & 1 deletion apps/assisted_tagging/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def validate_lead(self, lead):
if lead.project != self.project:
raise serializers.ValidationError('Only lead from current project are allowed.')
af = lead.project.analysis_framework
if af is None or not af.assisted_tagging_enabled:
if af is None:
raise serializers.ValidationError('Assisted tagging is disabled for the Framework used by this project.')
return lead

Expand Down
8 changes: 4 additions & 4 deletions apps/assisted_tagging/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from utils.common import redis_lock
from deep.deepl import DeeplServiceEndpoint
from deepl_integration.handlers import (
AssistedTaggingDraftEntryHandler,
AutoAssistedTaggingDraftEntryHandler,
LlmAssistedTaggingDraftEntryHandler,
LLMAutoAssistedTaggingDraftEntryHandler,
BaseHandler as DeepHandler
)

Expand Down Expand Up @@ -95,14 +95,14 @@ def sync_models_with_deepl():
@redis_lock('trigger_request_for_draft_entry_task_{0}', 60 * 60 * 0.5)
def trigger_request_for_draft_entry_task(draft_entry_id):
draft_entry = DraftEntry.objects.get(pk=draft_entry_id)
return AssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry)
return LlmAssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry)


@shared_task
@redis_lock('trigger_request_for_auto_draft_entry_task_{0}', 60 * 60 * 0.5)
def trigger_request_for_auto_draft_entry_task(lead_id):
lead = Lead.objects.get(id=lead_id)
return AutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead)
return LLMAutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead)


@shared_task
Expand Down
123 changes: 2 additions & 121 deletions apps/assisted_tagging/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,8 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
ENABLE_NOW_PATCHER = True

ASSISTED_TAGGING_NLP_DATA = '''
query MyQuery ($taggingModelId: ID!, $predictionTag: ID!) {
query MyQuery ($taggingModelId: ID! ) {
assistedTagging {
predictionTags {
id
group
isCategory
isDeprecated
hideInAnalysisFrameworkMapping
parentTag
tagId
}
taggingModels {
id
modelId
Expand All @@ -65,15 +56,6 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
version
}
}
predictionTag(id: $predictionTag) {
id
group
isCategory
isDeprecated
hideInAnalysisFrameworkMapping
parentTag
tagId
}
}
}
'''
Expand All @@ -88,15 +70,6 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
predictionStatus
predictionStatusDisplay
predictionReceivedAt
predictionTags {
id
modelVersion
dataType
dataTypeDisplay
value
category
tag
}
geoAreas {
title
}
Expand All @@ -111,14 +84,12 @@ def test_unified_connector_nlp_data(self):

model1, *other_models = AssistedTaggingModelFactory.create_batch(2)
AssistedTaggingModelVersionFactory.create_batch(2, model=model1)
tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5)

# -- without login
content = self.query_check(
self.ASSISTED_TAGGING_NLP_DATA,
variables=dict(
taggingModelId=model1.id,
predictionTag=tag1.id,
),
assert_for_error=True,
)
Expand All @@ -129,31 +100,8 @@ def test_unified_connector_nlp_data(self):
self.ASSISTED_TAGGING_NLP_DATA,
variables=dict(
taggingModelId=model1.id,
predictionTag=tag1.id,
)
)['data']['assistedTagging']
self.assertEqual(content['predictionTags'], [
dict(
id=str(tag.id),
tagId=tag.tag_id,
isDeprecated=tag.is_deprecated,
isCategory=tag.is_category,
group=tag.group,
hideInAnalysisFrameworkMapping=tag.hide_in_analysis_framework_mapping,
parentTag=tag.parent_tag_id and str(tag.parent_tag_id),
)
for tag in [tag1, *other_tags]
])
self.assertEqual(content['predictionTag'], dict(
id=str(tag1.id),
tagId=tag1.tag_id,
isDeprecated=tag1.is_deprecated,
isCategory=tag1.is_category,
group=tag1.group,
hideInAnalysisFrameworkMapping=tag1.hide_in_analysis_framework_mapping,
parentTag=tag1.parent_tag_id and str(tag1.parent_tag_id),
))

self.assertEqual(content['taggingModels'], [
dict(
id=str(_model.id),
Expand Down Expand Up @@ -196,38 +144,8 @@ def test_unified_connector_draft_entry(self):
GeoAreaFactory.create(admin_level=admin_level, title='Nepal')
GeoAreaFactory.create(admin_level=admin_level, title='Bagmati')
GeoAreaFactory.create(admin_level=admin_level, title='Kathmandu')
model1 = AssistedTaggingModelFactory.create()
geo_model = AssistedTaggingModelFactory.create(model_id=AssistedTaggingModel.ModelID.GEO)
latest_model1_version = AssistedTaggingModelVersionFactory.create_batch(2, model=model1)[0]
latest_geo_model_version = AssistedTaggingModelVersionFactory.create(model=geo_model)
category1, tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5)

draft_entry1 = DraftEntryFactory.create(project=project, lead=lead, excerpt='sample excerpt')

prediction1 = AssistedTaggingPredictionFactory.create(
data_type=AssistedTaggingPrediction.DataType.TAG,
model_version=latest_model1_version,
draft_entry=draft_entry1,
category=category1,
tag=tag1,
prediction=0.1,
threshold=0.05,
is_selected=True,
)
prediction2 = AssistedTaggingPredictionFactory.create(
data_type=AssistedTaggingPrediction.DataType.RAW,
model_version=latest_geo_model_version,
draft_entry=draft_entry1,
value='Nepal',
is_selected=True,
)
prediction3 = AssistedTaggingPredictionFactory.create(
data_type=AssistedTaggingPrediction.DataType.RAW,
model_version=latest_geo_model_version,
draft_entry=draft_entry1,
value='Kathmandu',
is_selected=True,
)
draft_entry1.save_geo_data()

def _query_check(**kwargs):
Expand Down Expand Up @@ -257,44 +175,7 @@ def _query_check(**kwargs):
predictionReceivedAt=None,
predictionStatus=self.genum(draft_entry1.prediction_status),
predictionStatusDisplay=draft_entry1.get_prediction_status_display(),
predictionTags=[
dict(
id=str(prediction1.pk),
modelVersion=str(prediction1.model_version_id),
dataType=self.genum(prediction1.data_type),
dataTypeDisplay=prediction1.get_data_type_display(),
value='',
category=str(prediction1.category_id),
tag=str(prediction1.tag_id),
),
dict(
id=str(prediction2.id),
modelVersion=str(prediction2.model_version.id),
dataType=self.genum(prediction2.data_type),
dataTypeDisplay=prediction2.get_data_type_display(),
value=prediction2.value,
category=None,
tag=None,
),
dict(
id=str(prediction3.id),
modelVersion=str(prediction3.model_version.id),
dataType=self.genum(prediction3.data_type),
dataTypeDisplay=prediction3.get_data_type_display(),
value=prediction3.value,
category=None,
tag=None,
)
],
geoAreas=[
dict(
title='Nepal',
),
dict(
title='Kathmandu',
)

],
geoAreas=[]
))


Expand Down
Loading

0 comments on commit 130c5a0

Please sign in to comment.