Skip to content

Commit

Permalink
Merge pull request #47 from the-deep-nlp/fix/topic-modeling
Browse files Browse the repository at this point in the history
Fix/topic modeling
  • Loading branch information
sudan45 authored Feb 8, 2024
2 parents 47c7eac + 01ede78 commit 34012ba
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 33 deletions.
1 change: 1 addition & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@ TEXTEXTRACTION_ECS_ENDPOINT=
SUMMARIZATION_V3_ECS_ENDPOINT=
ENTRYEXTRACTION_ECS_ENDPOINT=
GEOLOCATION_ECS_ENDPOINT=
TOPICMODEL_ECS_ENDPOINT=

1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
SUMMARIZATION_V3_ECS_ENDPOINT: ''
ENTRYEXTRACTION_ECS_ENDPOINT: ''
GEOLOCATION_ECS_ENDPOINT: ''
TOPICMODEL_ECS_ENDPOINT: ''

# Celery
CELERY_BROKER_URL: ''
Expand Down
2 changes: 1 addition & 1 deletion analysis_module/mockserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def process_topicmodeling(body) -> Any:
for x in range(0, len(excerpt_ids), ceil(len(excerpt_ids) / clusters))
]

data = {key: val for key, val in enumerate(data)}
data = dict(enumerate(data))

filepath = save_data_local_and_get_url(
dir_name="topicmodel", client_id=client_id, data=data
Expand Down
13 changes: 1 addition & 12 deletions analysis_module/tests/test_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@ def test_topicmodel_incomplete_data(self):
errors = resp.json()["field_errors"]
assert param in errors

@patch('analysis_module.views.analysis_module.spin_ecs_container')
def test_topicmodel_valid_request(self, spin_ecs_mock):
def test_topicmodel_valid_request(self):
"""
This tests for a topicmodel api with valid data
"""
requests_count = NLPRequest.objects.count()
valid_data = {
"entries_url": "https://someurl.com/entries",
"cluster_size": 2,
Expand All @@ -52,15 +50,6 @@ def test_topicmodel_valid_request(self, spin_ecs_mock):
self.set_credentials()
resp = self.client.post(self.TOPICMODELING_URL, valid_data)
assert resp.status_code == 202
spin_ecs_mock.delay.assert_called_once()
new_requests_count = NLPRequest.objects.count()
assert \
new_requests_count == requests_count + 1, \
"One more NLPRequest object should be created"
assert NLPRequest.objects.filter(
type="topicmodel",
created_by=self.user,
).exists()

def test_ngrams_incomplete_data(self):
"""
Expand Down
37 changes: 18 additions & 19 deletions analysis_module/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
SUMMARIZATION_V3_ECS_ENDPOINT,
TEXT_EXTRACTION_ECS_ENDPOINT,
ENTRYEXTRACTION_ECS_ENDPOINT,
GEOLOCATION_ECS_ENDPOINT
GEOLOCATION_ECS_ENDPOINT,
TOPICMODEL_ECS_ENDPOINT
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -245,24 +246,22 @@ def send_ecs_http_request(nlp_request: NLPRequest):


def get_ecs_id_param_name(request_type: NLPRequest.FeaturesType):
if request_type == NLPRequest.FeaturesType.SUMMARIZATION_V3:
return "summarization_id"
if request_type == NLPRequest.FeaturesType.TEXT_EXTRACTION:
return "textextraction_id"
if request_type == NLPRequest.FeaturesType.ENTRY_EXTRACTION:
return "entryextraction_id" # not needed probably, just to be in line with the rest.
if request_type == NLPRequest.FeaturesType.GEOLOCATION:
return "geolocation_id"
return None
mapper = {
NLPRequest.FeaturesType.TOPICMODEL: "topicmodel_id",
NLPRequest.FeaturesType.GEOLOCATION: "geolocation_id",
NLPRequest.FeaturesType.ENTRY_EXTRACTION: "entryextraction_id",
NLPRequest.FeaturesType.TEXT_EXTRACTION: "textextraction_id",
NLPRequest.FeaturesType.SUMMARIZATION_V3: "summarization_id"
}
return mapper.get(request_type, None)


def get_ecs_url(request_type: NLPRequest.FeaturesType):
if request_type == NLPRequest.FeaturesType.SUMMARIZATION_V3:
return urljoin(SUMMARIZATION_V3_ECS_ENDPOINT, "/generate_report")
elif request_type == NLPRequest.FeaturesType.TEXT_EXTRACTION:
return urljoin(TEXT_EXTRACTION_ECS_ENDPOINT, "/extract_document")
elif request_type == NLPRequest.FeaturesType.ENTRY_EXTRACTION:
return urljoin(ENTRYEXTRACTION_ECS_ENDPOINT, "/extract_entries")
elif request_type == NLPRequest.FeaturesType.GEOLOCATION:
return urljoin(GEOLOCATION_ECS_ENDPOINT, "/get_geolocations")
return None
mapper = {
NLPRequest.FeaturesType.TOPICMODEL: urljoin(TOPICMODEL_ECS_ENDPOINT, "/get_excerpt_clusters"),
NLPRequest.FeaturesType.GEOLOCATION: urljoin(GEOLOCATION_ECS_ENDPOINT, "/get_geolocations"),
NLPRequest.FeaturesType.ENTRY_EXTRACTION: urljoin(ENTRYEXTRACTION_ECS_ENDPOINT, "/extract_entries"),
NLPRequest.FeaturesType.TEXT_EXTRACTION: urljoin(TEXT_EXTRACTION_ECS_ENDPOINT, "/extract_document"),
NLPRequest.FeaturesType.SUMMARIZATION_V3: urljoin(SUMMARIZATION_V3_ECS_ENDPOINT, "/generate_report")
}
return mapper.get(request_type, None)
27 changes: 26 additions & 1 deletion analysis_module/views/analysis_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,32 @@ def process_request(
@api_view(["POST"])
@permission_classes([IsAuthenticated])
def topic_modeling(request: Request):
return process_request(TopicModelDeepRequest, request, "topicmodel")
serializer = TopicModelDeepRequest(data=request.data)
serializer.is_valid(raise_exception=True)

if serializer.validated_data.get("mock") or IS_MOCKSERVER:
return process_mock_request(
request=serializer.validated_data,
request_type=NLPRequest.FeaturesType.TOPICMODEL
)

nlp_request = NLPRequest.objects.create(
client_id=serializer.validated_data["client_id"],
type=NLPRequest.FeaturesType.TOPICMODEL,
request_params=serializer.validated_data,
created_by=request.user
)
transaction.on_commit(lambda: send_ecs_http_request(nlp_request))
resp = {
"client_id": serializer.data.get("client_id"),
"type": NLPRequest.FeaturesType.TOPICMODEL,
"message": "Request has been successfully processed.",
"request_id": str(nlp_request.unique_id),
}
return Response(
resp,
status=status.HTTP_202_ACCEPTED,
)


@api_view(["POST"])
Expand Down
1 change: 1 addition & 0 deletions core_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
TEXT_EXTRACTION_ECS_ENDPOINT = env("TEXTEXTRACTION_ECS_ENDPOINT")
ENTRYEXTRACTION_ECS_ENDPOINT = env("ENTRYEXTRACTION_ECS_ENDPOINT")
GEOLOCATION_ECS_ENDPOINT = env("GEOLOCATION_ECS_ENDPOINT")
TOPICMODEL_ECS_ENDPOINT = env("TOPICMODEL_ECS_ENDPOINT")


CALLBACK_MAX_RETRIES_LIMIT = env("CALLBACK_MAX_RETRIES_LIMIT")
Expand Down
1 change: 1 addition & 0 deletions docker-compose-prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ x-server: &base-server-config
TEXTEXTRACTION_ECS_ENDPOINT: ${TEXTEXTRACTION_ECS_ENDPOINT:?Provide text extraction endpoint}
ENTRYEXTRACTION_ECS_ENDPOINT: ${ENTRYEXTRACTION_ECS_ENDPOINT:?Provide entry extraction endpoint}
GEOLOCATION_ECS_ENDPOINT: ${GEOLOCATION_ECS_ENDPOINT:?Provide geolocation endpoint}
TOPICMODEL_ECS_ENDPOINT: ${TOPICMODEL_ECS_ENDPOINT:?Provide topic model endpoint}

# MODEL_INFO
CLASSIFICATION_MODEL_ID: ${CLASSIFICATION_MODEL_ID:-classification-model}
Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ x-server: &base-server-config
TEXTEXTRACTION_ECS_ENDPOINT: ${TEXTEXTRACTION_ECS_ENDPOINT:?Provide text extraction endpoint}
ENTRYEXTRACTION_ECS_ENDPOINT: ${ENTRYEXTRACTION_ECS_ENDPOINT:?Provide entry extraction endpoint}
GEOLOCATION_ECS_ENDPOINT: ${GEOLOCATION_ECS_ENDPOINT:?Provide geolocation endpoint}
TOPICMODEL_ECS_ENDPOINT: ${TOPICMODEL_ECS_ENDPOINT:?Provide topic model endpoint}

# SENTRY
SENTRY_DSN: ${SENTRY_DSN:-}
Expand Down

0 comments on commit 34012ba

Please sign in to comment.