Skip to content

Commit

Permalink
Merge pull request #1111 from AI4Bharat/cumulative_fix
Browse files Browse the repository at this point in the history
sup_cumulative_tasks_count
  • Loading branch information
ishvindersethi22 authored Sep 23, 2024
2 parents c6ed2fe + 1726670 commit 64a9b1b
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 31 deletions.
109 changes: 103 additions & 6 deletions backend/functions/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@

from shoonya_backend.locks import Lock
from utils.constants import LANG_CHOICES

from projects.tasks import filter_data_items
from projects.models import BATCH
from dataset import models as dataset_models
from projects.registry_helper import ProjectRegistry
import logging

logger = logging.getLogger(__name__)
Expand All @@ -73,6 +76,10 @@ def sentence_text_translate_and_save_translation_pairs(
input_dataset_instance_id,
output_dataset_instance_id,
batch_size,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
api_type="indic-trans-v2",
checks_for_particular_languages=False,
automate_missing_data_items=True,
Expand All @@ -88,6 +95,10 @@ def sentence_text_translate_and_save_translation_pairs(
Allowed - [indic-trans, google, indic-trans-v2, azure, blank]
checks_for_particular_languages (bool): If True, checks for the particular languages in the translations.
automate_missing_data_items (bool): If True, consider only those data items that are missing in the target dataset instance.
filter_string (str): string to filter input data.
sampling_mode (str): can be batch or full.
sampling_parameters (json): is a json that contains, batch number and batch size
"""
task_name = "sentence_text_translate_and_save_translation_pairs"
output_sentences = list(
Expand All @@ -114,6 +125,14 @@ def sentence_text_translate_and_save_translation_pairs(
"metadata_json",
)
)
if filter_string and sampling_mode and sampling_parameters:
input_sentences = get_filtered_items(
"SentenceText",
input_dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)

# Convert the input_sentences list into a dataframe
input_sentences_complete_df = pd.DataFrame(
Expand Down Expand Up @@ -404,7 +423,15 @@ def conversation_data_machine_translation(

@shared_task(bind=True)
def generate_ocr_prediction_json(
self, dataset_instance_id, user_id, api_type, automate_missing_data_items
self,
dataset_instance_id,
user_id,
api_type,
automate_missing_data_items,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
"""Function to generate OCR prediction data and to save to the same data item.
Args:
Expand Down Expand Up @@ -437,7 +464,14 @@ def generate_ocr_prediction_json(
)
except Exception as e:
ocr_data_items = []

if filter_string and sampling_mode and sampling_parameters:
ocr_data_items = get_filtered_items(
"OCRDocument",
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)
# converting the dataset_instance to pandas dataframe.
ocr_data_items_df = pd.DataFrame(
ocr_data_items,
Expand Down Expand Up @@ -556,7 +590,15 @@ def generate_ocr_prediction_json(

@shared_task(bind=True)
def generate_asr_prediction_json(
self, dataset_instance_id, user_id, api_type, automate_missing_data_items
self,
dataset_instance_id,
user_id,
api_type,
automate_missing_data_items,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
"""Function to generate ASR prediction data and to save to the same data item.
Args:
Expand Down Expand Up @@ -590,7 +632,14 @@ def generate_asr_prediction_json(
)
except Exception as e:
asr_data_items = []

if filter_string and sampling_mode and sampling_parameters:
asr_data_items = get_filtered_items(
"SpeechConversation",
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)
# converting the dataset_instance to pandas dataframe.
asr_data_items_df = pd.DataFrame(
asr_data_items,
Expand Down Expand Up @@ -704,7 +753,16 @@ def generate_asr_prediction_json(


@shared_task(bind=True)
def populate_draft_data_json(self, pk, user_id, fields_list):
def populate_draft_data_json(
self,
pk,
user_id,
fields_list,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
task_name = "populate_draft_data_json"
try:
dataset_instance = DatasetInstance.objects.get(pk=pk)
Expand All @@ -713,6 +771,10 @@ def populate_draft_data_json(self, pk, user_id, fields_list):
dataset_type = dataset_instance.dataset_type
dataset_model = apps.get_model("dataset", dataset_type)
dataset_items = dataset_model.objects.filter(instance_id=dataset_instance)
if filter_string and sampling_mode and sampling_parameters:
dataset_items = get_filtered_items(
dataset_type, pk, filter_string, sampling_mode, sampling_parameters
)
cnt = 0
for dataset_item in dataset_items:
new_draft_data_json = {}
Expand Down Expand Up @@ -1696,3 +1758,38 @@ def upload_all_projects_to_blob_and_get_url(csv_files_directory):
return "Error in generating url"
blob_url = f"https://{account_name}.blob.{endpoint_suffix}/{CONTAINER_NAME_FOR_DOWNLOAD_ALL_PROJECTS}/{blob_client.blob_name}?{sas_token}"
return blob_url


def get_filtered_items(
dataset_model,
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
):
registry_helper = ProjectRegistry.get_instance()
project_type = registry_helper.get_project_name_from_dataset(dataset_model)
if not isinstance(dataset_instance_id, list):
dataset_instance_id = [dataset_instance_id]
filtered_items = filter_data_items(
project_type=project_type,
dataset_instance_ids=dataset_instance_id,
filter_string=filter_string,
)
# Apply sampling
if sampling_mode == BATCH:
batch_size = sampling_parameters["batch_size"]
try:
batch_number = sampling_parameters["batch_number"]
if len(batch_number) == 0:
batch_number = [1]
except KeyError:
batch_number = [1]
sampled_items = []
for batch_num in batch_number:
sampled_items += filtered_items[
batch_size * (batch_num - 1) : batch_size * batch_num
]
else:
sampled_items = filtered_items
return sampled_items
40 changes: 37 additions & 3 deletions backend/functions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def schedule_sentence_text_translate_job(request):
automate_missing_data_items = request.data.get(
"automate_missing_data_items", "true"
)
filter_string = request.data.get("filter_string", None)
sampling_mode = request.data.get("sampling_mode", None)
sampling_parameters = request.data.get("sampling_parameters_json", None)
variable_parameters = request.data.get("variable_parameters", None)

# Convert checks for languages into boolean
checks_for_particular_languages = checks_for_particular_languages.lower() == "true"
Expand Down Expand Up @@ -311,6 +315,10 @@ def schedule_sentence_text_translate_job(request):
input_dataset_instance_id=input_dataset_instance_id,
output_dataset_instance_id=output_dataset_instance_id,
batch_size=batch_size,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
api_type=api_type,
checks_for_particular_languages=checks_for_particular_languages,
automate_missing_data_items=automate_missing_data_items,
Expand Down Expand Up @@ -537,7 +545,10 @@ def schedule_ocr_prediction_json_population(request):
except KeyError:
automate_missing_data_items = True

# Calling a function asynchronously to create ocr predictions.
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

uid = request.user.id

Expand All @@ -546,6 +557,10 @@ def schedule_ocr_prediction_json_population(request):
user_id=uid,
api_type=api_type,
automate_missing_data_items=automate_missing_data_items,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

# Returning response
Expand Down Expand Up @@ -574,8 +589,20 @@ def schedule_draft_data_json_population(request):
pk = request.data["dataset_instance_id"]

uid = request.user.id
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

populate_draft_data_json.delay(pk=pk, user_id=uid, fields_list=fields_list)
populate_draft_data_json(
pk=pk,
user_id=uid,
fields_list=fields_list,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

ret_dict = {"message": "draft_data_json population started"}
ret_status = status.HTTP_200_OK
Expand Down Expand Up @@ -624,7 +651,10 @@ def schedule_asr_prediction_json_population(request):
except KeyError:
automate_missing_data_items = True

# Calling a function asynchronously to create ocr predictions.
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

uid = request.user.id

Expand All @@ -633,6 +663,10 @@ def schedule_asr_prediction_json_population(request):
user_id=uid,
api_type=api_type,
automate_missing_data_items=automate_missing_data_items,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

ret_dict = {"message": "Generating ASR Predictions"}
Expand Down
52 changes: 41 additions & 11 deletions backend/organizations/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
send_project_analytics_mail_org,
send_user_analytics_mail_org,
)
from utils.filter_tasks_by_ann_type import filter_tasks_by_ann_type


def get_task_count(proj_ids, status, annotator, return_count=True):
Expand Down Expand Up @@ -2743,24 +2744,41 @@ def cumulative_tasks_count(self, request, pk=None):
other_lang = []
for lang in languages:
proj_lang_filter = proj_objs.filter(tgt_language=lang)
annotation_tasks_count = 0
reviewer_task_count = 0
annotation_tasks = Task.objects.filter(
project_id__in=proj_lang_filter,
task_status__in=[
"annotated",
"reviewed",
"super_checked",
],
)
reviewer_tasks = Task.objects.filter(
project_id__in=proj_lang_filter,
project_id__project_stage__in=[REVIEW_STAGE, SUPERCHECK_STAGE],
task_status__in=["reviewed", "exported", "super_checked"],
task_status__in=["reviewed", "super_checked"],
)

annotation_tasks = Task.objects.filter(
supercheck_tasks = Task.objects.filter(
project_id__in=proj_lang_filter,
project_id__project_stage__in=[SUPERCHECK_STAGE],
task_status__in=["super_checked"],
)
annotation_tasks_exported = Task.objects.filter(
project_id__in=proj_lang_filter,
project_id__project_stage__in=[ANNOTATION_STAGE],
task_status__in=[
"annotated",
"reviewed",
"exported",
"super_checked",
],
)

reviewer_tasks_exported = Task.objects.filter(
project_id__in=proj_lang_filter,
project_id__project_stage__in=[REVIEW_STAGE],
task_status__in=["exported"],
)
supercheck_tasks_exported = Task.objects.filter(
project_id__in=proj_lang_filter,
project_id__project_stage__in=[SUPERCHECK_STAGE],
task_status__in=["exported"],
)
if metainfo == True:
result = {}

Expand Down Expand Up @@ -2975,14 +2993,23 @@ def cumulative_tasks_count(self, request, pk=None):
}

else:
reviewer_task_count = reviewer_tasks.count()
reviewer_task_count = (
reviewer_tasks.count() + reviewer_tasks_exported.count()
)

annotation_tasks_count = annotation_tasks.count()
annotation_tasks_count = (
annotation_tasks.count() + annotation_tasks_exported.count()
)

supercheck_tasks_count = (
supercheck_tasks.count() + supercheck_tasks_exported.count()
)

result = {
"language": lang,
"ann_cumulative_tasks_count": annotation_tasks_count,
"rew_cumulative_tasks_count": reviewer_task_count,
"sup_cumulative_tasks_count": supercheck_tasks_count,
}

if lang == None or lang == "":
Expand All @@ -2992,6 +3019,7 @@ def cumulative_tasks_count(self, request, pk=None):

ann_task_count = 0
rew_task_count = 0
sup_task_count = 0
ann_word_count = 0
rew_word_count = 0
ann_aud_dur = 0
Expand All @@ -3006,6 +3034,7 @@ def cumulative_tasks_count(self, request, pk=None):
if metainfo != True:
ann_task_count += dat["ann_cumulative_tasks_count"]
rew_task_count += dat["rew_cumulative_tasks_count"]
sup_task_count += dat["sup_cumulative_tasks_count"]
else:
if project_type in get_audio_project_types():
ann_aud_dur += convert_hours_to_seconds(
Expand Down Expand Up @@ -3048,6 +3077,7 @@ def cumulative_tasks_count(self, request, pk=None):
"language": "Others",
"ann_cumulative_tasks_count": ann_task_count,
"rew_cumulative_tasks_count": rew_task_count,
"sup_cumulative_tasks_count": sup_task_count,
}
else:
if project_type in get_audio_project_types():
Expand Down
12 changes: 12 additions & 0 deletions backend/projects/registry_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,15 @@ def validate_registry(self):
)

return True

def get_project_name_from_dataset(self, dataset_name: str):
for project_key, project_type in self.project_types.items():
input_dataset = project_type.get("input_dataset", {})
output_dataset = project_type.get("output_dataset", {})

if (
input_dataset.get("class") == dataset_name
or output_dataset.get("class") == dataset_name
):
return project_key
return None
Loading

0 comments on commit 64a9b1b

Please sign in to comment.