From da3bb7d8e7d37aa703ca06be0fcfbfe397ff8d69 Mon Sep 17 00:00:00 2001 From: Kunal Tiwary Date: Wed, 21 Aug 2024 03:35:35 +0000 Subject: [PATCH] added filtering for datasets --- backend/functions/tasks.py | 109 ++++++++++++++++++++++++++-- backend/functions/views.py | 40 +++++++++- backend/projects/registry_helper.py | 12 +++ 3 files changed, 152 insertions(+), 9 deletions(-) diff --git a/backend/functions/tasks.py b/backend/functions/tasks.py index e52b02a84..614641df1 100644 --- a/backend/functions/tasks.py +++ b/backend/functions/tasks.py @@ -57,7 +57,10 @@ from shoonya_backend.locks import Lock from utils.constants import LANG_CHOICES - +from projects.tasks import filter_data_items +from projects.models import BATCH +from dataset import models as dataset_models +from projects.registry_helper import ProjectRegistry import logging logger = logging.getLogger(__name__) @@ -73,6 +76,10 @@ def sentence_text_translate_and_save_translation_pairs( input_dataset_instance_id, output_dataset_instance_id, batch_size, + filter_string, + sampling_mode, + sampling_parameters, + variable_parameters, api_type="indic-trans-v2", checks_for_particular_languages=False, automate_missing_data_items=True, @@ -88,6 +95,10 @@ def sentence_text_translate_and_save_translation_pairs( Allowed - [indic-trans, google, indic-trans-v2, azure, blank] checks_for_particular_languages (bool): If True, checks for the particular languages in the translations. automate_missing_data_items (bool): If True, consider only those data items that are missing in the target dataset instance. + filter_string (str): string to filter input data. + sampling_mode (str): can be batch or full. + sampling_parameters (json): is a json that contains, batch number and batch size + """ task_name = "sentence_text_translate_and_save_translation_pairs" output_sentences = list( @@ -114,6 +125,14 @@ def sentence_text_translate_and_save_translation_pairs( "metadata_json", ) ) + if filter_string and sampling_mode and sampling_parameters: + input_sentences = get_filtered_items( + "SentenceText", + input_dataset_instance_id, + filter_string, + sampling_mode, + sampling_parameters, + ) # Convert the input_sentences list into a dataframe input_sentences_complete_df = pd.DataFrame( @@ -404,7 +423,15 @@ def conversation_data_machine_translation( @shared_task(bind=True) def generate_ocr_prediction_json( - self, dataset_instance_id, user_id, api_type, automate_missing_data_items + self, + dataset_instance_id, + user_id, + api_type, + automate_missing_data_items, + filter_string, + sampling_mode, + sampling_parameters, + variable_parameters, ): """Function to generate OCR prediction data and to save to the same data item. Args: @@ -437,7 +464,14 @@ def generate_ocr_prediction_json( ) except Exception as e: ocr_data_items = [] - + if filter_string and sampling_mode and sampling_parameters: + ocr_data_items = get_filtered_items( + "OCRDocument", + dataset_instance_id, + filter_string, + sampling_mode, + sampling_parameters, + ) # converting the dataset_instance to pandas dataframe. ocr_data_items_df = pd.DataFrame( ocr_data_items, @@ -556,7 +590,15 @@ def generate_ocr_prediction_json( @shared_task(bind=True) def generate_asr_prediction_json( - self, dataset_instance_id, user_id, api_type, automate_missing_data_items + self, + dataset_instance_id, + user_id, + api_type, + automate_missing_data_items, + filter_string, + sampling_mode, + sampling_parameters, + variable_parameters, ): """Function to generate ASR prediction data and to save to the same data item. Args: @@ -590,7 +632,14 @@ def generate_asr_prediction_json( ) except Exception as e: asr_data_items = [] - + if filter_string and sampling_mode and sampling_parameters: + asr_data_items = get_filtered_items( + "SpeechConversation", + dataset_instance_id, + filter_string, + sampling_mode, + sampling_parameters, + ) # converting the dataset_instance to pandas dataframe. asr_data_items_df = pd.DataFrame( asr_data_items, @@ -704,7 +753,16 @@ def generate_asr_prediction_json( @shared_task(bind=True) -def populate_draft_data_json(self, pk, user_id, fields_list): +def populate_draft_data_json( + self, + pk, + user_id, + fields_list, + filter_string, + sampling_mode, + sampling_parameters, + variable_parameters, +): task_name = "populate_draft_data_json" try: dataset_instance = DatasetInstance.objects.get(pk=pk) @@ -713,6 +771,10 @@ def populate_draft_data_json(self, pk, user_id, fields_list): dataset_type = dataset_instance.dataset_type dataset_model = apps.get_model("dataset", dataset_type) dataset_items = dataset_model.objects.filter(instance_id=dataset_instance) + if filter_string and sampling_mode and sampling_parameters: + dataset_items = get_filtered_items( + dataset_type, pk, filter_string, sampling_mode, sampling_parameters + ) cnt = 0 for dataset_item in dataset_items: new_draft_data_json = {} @@ -1696,3 +1758,38 @@ def upload_all_projects_to_blob_and_get_url(csv_files_directory): return "Error in generating url" blob_url = f"https://{account_name}.blob.{endpoint_suffix}/{CONTAINER_NAME_FOR_DOWNLOAD_ALL_PROJECTS}/{blob_client.blob_name}?{sas_token}" return blob_url + + +def get_filtered_items( + dataset_model, + dataset_instance_id, + filter_string, + sampling_mode, + sampling_parameters, +): + registry_helper = ProjectRegistry.get_instance() + project_type = registry_helper.get_project_name_from_dataset(dataset_model) + if not isinstance(dataset_instance_id, list): + dataset_instance_id = [dataset_instance_id] + filtered_items = filter_data_items( + project_type=project_type, + dataset_instance_ids=dataset_instance_id, + filter_string=filter_string, + ) + # Apply sampling + if sampling_mode == BATCH: + batch_size = sampling_parameters["batch_size"] + try: + batch_number = sampling_parameters["batch_number"] + if len(batch_number) == 0: + batch_number = [1] + except KeyError: + batch_number = [1] + sampled_items = [] + for batch_num in batch_number: + sampled_items += filtered_items[ + batch_size * (batch_num - 1) : batch_size * batch_num + ] + else: + sampled_items = filtered_items + return sampled_items diff --git a/backend/functions/views.py b/backend/functions/views.py index 09608665b..ccbc14434 100644 --- a/backend/functions/views.py +++ b/backend/functions/views.py @@ -274,6 +274,10 @@ def schedule_sentence_text_translate_job(request): automate_missing_data_items = request.data.get( "automate_missing_data_items", "true" ) + filter_string = request.data.get("filter_string", None) + sampling_mode = request.data.get("sampling_mode", None) + sampling_parameters = request.data.get("sampling_parameters_json", None) + variable_parameters = request.data.get("variable_parameters", None) # Convert checks for languages into boolean checks_for_particular_languages = checks_for_particular_languages.lower() == "true" @@ -311,6 +315,10 @@ def schedule_sentence_text_translate_job(request): input_dataset_instance_id=input_dataset_instance_id, output_dataset_instance_id=output_dataset_instance_id, batch_size=batch_size, + filter_string=filter_string, + sampling_mode=sampling_mode, + sampling_parameters=sampling_parameters, + variable_parameters=variable_parameters, api_type=api_type, checks_for_particular_languages=checks_for_particular_languages, automate_missing_data_items=automate_missing_data_items, @@ -537,7 +545,10 @@ def schedule_ocr_prediction_json_population(request): except KeyError: automate_missing_data_items = True - # Calling a function asynchronously to create ocr predictions. + filter_string = request.data.get("filter_string") + sampling_mode = request.data.get("sampling_mode") + sampling_parameters = request.data.get("sampling_parameters_json") + variable_parameters = request.data.get("variable_parameters") uid = request.user.id @@ -546,6 +557,10 @@ def schedule_ocr_prediction_json_population(request): user_id=uid, api_type=api_type, automate_missing_data_items=automate_missing_data_items, + filter_string=filter_string, + sampling_mode=sampling_mode, + sampling_parameters=sampling_parameters, + variable_parameters=variable_parameters, ) # Returning response @@ -574,8 +589,20 @@ def schedule_draft_data_json_population(request): pk = request.data["dataset_instance_id"] uid = request.user.id + filter_string = request.data.get("filter_string") + sampling_mode = request.data.get("sampling_mode") + sampling_parameters = request.data.get("sampling_parameters_json") + variable_parameters = request.data.get("variable_parameters") - populate_draft_data_json.delay(pk=pk, user_id=uid, fields_list=fields_list) + populate_draft_data_json( + pk=pk, + user_id=uid, + fields_list=fields_list, + filter_string=filter_string, + sampling_mode=sampling_mode, + sampling_parameters=sampling_parameters, + variable_parameters=variable_parameters, + ) ret_dict = {"message": "draft_data_json population started"} ret_status = status.HTTP_200_OK @@ -624,7 +651,10 @@ def schedule_asr_prediction_json_population(request): except KeyError: automate_missing_data_items = True - # Calling a function asynchronously to create ocr predictions. + filter_string = request.data.get("filter_string") + sampling_mode = request.data.get("sampling_mode") + sampling_parameters = request.data.get("sampling_parameters_json") + variable_parameters = request.data.get("variable_parameters") uid = request.user.id @@ -633,6 +663,10 @@ def schedule_asr_prediction_json_population(request): user_id=uid, api_type=api_type, automate_missing_data_items=automate_missing_data_items, + filter_string=filter_string, + sampling_mode=sampling_mode, + sampling_parameters=sampling_parameters, + variable_parameters=variable_parameters, ) ret_dict = {"message": "Generating ASR Predictions"} diff --git a/backend/projects/registry_helper.py b/backend/projects/registry_helper.py index ed1859e4c..3f8a5653a 100644 --- a/backend/projects/registry_helper.py +++ b/backend/projects/registry_helper.py @@ -253,3 +253,15 @@ def validate_registry(self): ) return True + + def get_project_name_from_dataset(self, dataset_name: str): + for project_key, project_type in self.project_types.items(): + input_dataset = project_type.get("input_dataset", {}) + output_dataset = project_type.get("output_dataset", {}) + + if ( + input_dataset.get("class") == dataset_name + or output_dataset.get("class") == dataset_name + ): + return project_key + return None