Skip to content

Commit

Permalink
Merge pull request #964 from AI4Bharat/master
Browse files Browse the repository at this point in the history
Prod changes to develop
  • Loading branch information
aparna-aa authored Dec 27, 2024
2 parents 1e2e4c2 + 7fd44ce commit 405cea2
Show file tree
Hide file tree
Showing 31 changed files with 1,011 additions and 170 deletions.
4 changes: 2 additions & 2 deletions ai-services/align-api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ tensorboardX
rich==12.6.0
srt==3.5.2
Cython==0.29.32
urduhack==1.1.1
fastapi['all']
urduhack
fastapi
indic-nlp-library
11 changes: 8 additions & 3 deletions ai-services/align-api/src/wav2vec2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from rich.console import Console
from rich.traceback import install

import os
install()
console = Console()

Expand Down Expand Up @@ -35,8 +35,13 @@ def length(self):

class Wav2vec2:
def __init__(self, model_path, language_code, mode, device):
self.asr_path = glob(model_path + "/" + language_code + "/*.pt")[0]
self.dict_path = glob(model_path + "/" + language_code + "/*.txt")[0]
current_dir = os.path.dirname(os.path.abspath(__file__))
two_levels_up = os.path.abspath(os.path.join(current_dir, "../../"))
model_loc = os.path.join(two_levels_up, os.path.join(model_path, language_code))

self.asr_path = glob(os.path.join( model_loc ,"*.pt"))[0]
self.dict_path = glob(os.path.join(model_loc, "*.txt"))[0]

self.device = device
self.encoder = self.load_model_encoder()
self.labels = self.get_labels()
Expand Down
3 changes: 2 additions & 1 deletion backend/backend/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@

DOMAIN = os.getenv("DOMAIN")
SITE_NAME = os.getenv("DOMAIN")
PROTOCOL = "https"
DEFAULT_HTTP_PROTOCOL = 'https'

DJOSER = {
"PASSWORD_RESET_CONFIRM_URL": "forget-password/confirm/{uid}/{token}",
Expand Down Expand Up @@ -127,6 +127,7 @@

CSRF_TRUSTED_ORIGINS = [
"http://localhost:*", # for localhost (Developlemt)
"https://*.ai4bharat.org",
]
CUSTOM_CSRF_TRUSTED_ORIGINS = os.getenv("CORS_TRUSTED_ORIGINS", "")
if CUSTOM_CSRF_TRUSTED_ORIGINS:
Expand Down
2 changes: 1 addition & 1 deletion backend/backend/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_schema(self, request=None, public=False):
path("voiceover/", include("voiceover.urls")),
path("youtube/", include("youtube.urls")),
path(
"api/generic/transliteration/<str:target_language>/<str:data>/",
"xlit-api/generic/transliteration/<str:target_language>/<str:data>",
TransliterationAPIView.as_view(),
name="transliteration-api",
),
Expand Down
1 change: 1 addition & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
english_asr_url = os.getenv("ENGLISH_ASR_API_URL")
indic_asr_url = os.getenv("INDIC_ASR_API_URL")
service_id_hindi = os.getenv("SERVICE_ID_HINDI")
service_id_nepali = os.getenv("SERVICE_ID_NEPALI")
service_id_indo_aryan = os.getenv("SERVICE_ID_INDO_ARYAN")
service_id_dravidian = os.getenv("SERVICE_ID_DRAVIDIAN")
misc_tts_url = os.getenv("MISC_TTS_API_URL")
Expand Down
3 changes: 3 additions & 0 deletions backend/organization/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ def list_org_tasks(self, request, pk=None, *args, **kwargs):
task["updated_at"]
).replace(tzinfo=None):
buttons["Reopen"] = False
if "TRANSLATION_VOICEOVER" in task["task_type"]:
if task["status"] in ["SELECTED_SOURCE", "FAILED"] and task["is_active"] is False:
buttons["Regenerate"] = True
if task["status"] == "POST_PROCESS":
buttons["Update"] = True
if task["status"] == "FAILED":
Expand Down
1 change: 1 addition & 0 deletions backend/project/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ProjectAdmin(admin.ModelAdmin):

list_display = (
"id",
"title",
"organization_id",
"default_task_types",
"default_target_languages",
Expand Down
3 changes: 3 additions & 0 deletions backend/project/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,9 @@ def list_project_tasks(self, request, pk=None, *args, **kwargs):
data["updated_at"]
).replace(tzinfo=None):
buttons["Reopen"] = False
if "TRANSLATION_VOICEOVER" in data["task_type"]:
if data["status"] in ["SELECTED_SOURCE", "FAILED"] and data["is_active"] is False:
buttons["Regenerate"] = True
if data["status"] == "POST_PROCESS":
buttons["Update"] = True
if data["status"] == "FAILED":
Expand Down
117 changes: 58 additions & 59 deletions backend/task/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def celery_nmt_tts_call(task_id):
task_obj.status = "FAILED"
task_obj.is_active = False
task_obj.save()
logging.info("Generating translation payload failed for %s", str(task_id))
return
else:
if (
type(translation_obj.payload) == dict
Expand All @@ -308,69 +310,66 @@ def celery_nmt_tts_call(task_id):
task_obj.status = "SELECTED_SOURCE"
# task_obj.is_active = True
task_obj.save()
tts_payload = process_translation_payload(
translation_obj, task_obj.target_language
)
if type(tts_payload) == dict and "message" in tts_payload.keys():
message = tts_payload["message"]
logging.info("Error from TTS API")
voice_over_task.status = "FAILED"
voice_over_task.save()
# set_fail_for_translation_task(task)
return message

(
tts_input,
target_language,
translation,
translation_id,
empty_sentences,
) = tts_payload
tts_payload = process_translation_payload(
translation_obj, task_obj.target_language
)
if type(tts_payload) == dict and "message" in tts_payload.keys():
message = tts_payload["message"]
logging.info("Error from TTS API")
voice_over_task.status = "FAILED"
voice_over_task.save()
# set_fail_for_translation_task(task)
return message

generate_audio = task_obj.video.project_id.pre_generate_audio
tts_payload = generate_tts_output(
tts_input,
target_language,
translation,
translation_obj,
empty_sentences,
generate_audio,
)
payloads = tts_payload
(
tts_input,
target_language,
translation,
translation_id,
empty_sentences,
) = tts_payload

existing_voiceover = VoiceOver.objects.filter(task=task_obj).first()
generate_audio = task_obj.video.project_id.pre_generate_audio
tts_payload = generate_tts_output(
tts_input,
target_language,
translation,
translation_obj,
empty_sentences,
generate_audio,
)
payloads = tts_payload

print("Fetched voiceover", existing_voiceover)
existing_voiceover = VoiceOver.objects.filter(task=task_obj).first()

if existing_voiceover == None:
voiceover_obj = VoiceOver(
video=task_obj.video,
user=task_obj.user,
translation=translation_obj,
payload=tts_payload,
target_language=task_obj.target_language,
task=task_obj,
voice_over_type="MACHINE_GENERATED",
status="VOICEOVER_SELECT_SOURCE",
)
voiceover_obj.save()
else:
existing_voiceover.payload = tts_payload
existing_voiceover.translation = translation_obj
existing_voiceover.save()
task_obj.is_active = True
task_obj.status = "SELECTED_SOURCE"
task_obj.save()
logging.info("Payload generated for TTS API for %s", str(task_id))
if "message" in tts_payload:
task_obj.is_active = False
task_obj.status = "FAILED"
task_obj.save()
try:
send_mail_to_user(task_obj)
except:
logging.info("Error in sending mail")
print("Fetched voiceover", existing_voiceover)

# send_mail_to_user(task_obj)
if existing_voiceover == None:
voiceover_obj = VoiceOver(
video=task_obj.video,
user=task_obj.user,
translation=translation_obj,
payload=tts_payload,
target_language=task_obj.target_language,
task=task_obj,
voice_over_type="MACHINE_GENERATED",
status="VOICEOVER_SELECT_SOURCE",
)
voiceover_obj.save()
else:
logging.info("Translation already exists")
existing_voiceover.payload = tts_payload
existing_voiceover.translation = translation_obj
existing_voiceover.save()
task_obj.is_active = True
task_obj.status = "SELECTED_SOURCE"
task_obj.save()
logging.info("Payload generated for TTS API for %s", str(task_id))
if "message" in tts_payload:
task_obj.is_active = False
task_obj.status = "FAILED"
task_obj.save()
try:
send_mail_to_user(task_obj)
except:
logging.info("Error in sending mail")
66 changes: 34 additions & 32 deletions backend/task/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from rest_framework.decorators import parser_classes
from rest_framework.parsers import MultiPartParser, FormParser
import regex

from translation.views import regenerate_translation_voiceover

def get_export_translation(request, task_id, export_type):
new_request = HttpRequest()
Expand Down Expand Up @@ -3193,6 +3193,15 @@ def inspect_queue(self, request):
)
elif elem["name"] == "task.tasks.celery_nmt_call":
task_obj["task_id"] = eval(elem["kwargs"])["task_id"]
elif elem["name"] == "task.tasks.celery_nmt_tts_call":
try:
task_obj["task_id"] = eval(elem["kwargs"])["task_id"]
except:
task_obj["task_id"] = eval(elem["args"].split(",")[0].split("(")[1])
elif elem["name"] == "voiceover.tasks.celery_integration":
task_obj["task_id"] = eval(elem["args"].split(",")[2])
elif elem["name"] == "voiceover.tasks.export_voiceover_async":
task_obj["task_id"] = eval(elem["args"].split(",")[0].split("(")[1])
else:
task_obj["task_id"] = ""

Expand Down Expand Up @@ -3229,19 +3238,17 @@ def inspect_queue(self, request):
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
else:
if queue == "nmt":
if queue == "nmt" or queue == "nmt_tts":
queue_type = "celery@nmt_worker"
elif queue == "tts":
queue_type = "celery@asr_tts_worker"
else:
queue_type = "celery@asr_tts_worker"

try:
task_list = []
status_list = []
url = f"{flower_url}/api/tasks"
params = {
"state": "STARTED",
"sort_by": "received",
"sort_by": "-received",
"workername": queue_type,
}
if flower_username and flower_password:
Expand All @@ -3255,42 +3262,27 @@ def inspect_queue(self, request):
for elem in task_data:
if queue == "asr" and elem["name"] == "task.tasks.celery_asr_call":
task_list.append(eval(elem["kwargs"])["task_id"])
status_list.append(elem["state"])
elif (
queue == "tts" and elem["name"] == "task.tasks.celery_tts_call"
):
# task_list.append(eval(elem["kwargs"])["task_id"])
task_list.append(eval(elem["args"].split(",")[0].split("(")[1]))
status_list.append(elem["state"])
elif (
queue == "nmt" and elem["name"] == "task.tasks.celery_nmt_call"
):
task_list.append(eval(elem["kwargs"])["task_id"])
else:
pass
params = {
"state": "RECEIVED",
"sort_by": "received",
"workername": queue_type,
}
if flower_username and flower_password:
res = requests.get(
url, params=params, auth=(flower_username, flower_password)
)
else:
res = requests.get(url, params=params)
data = res.json()
task_data = list(data.values())
for elem in task_data:
if queue == "asr" and elem["name"] == "task.tasks.celery_asr_call":
task_list.append(eval(elem["kwargs"])["task_id"])
elif (
queue == "tts" and elem["name"] == "task.tasks.celery_tts_call"
):
# task_list.append(eval(elem["kwargs"])["task_id"])
task_list.append(eval(elem["args"].split(",")[0].split("(")[1]))
status_list.append(elem["state"])
elif (
queue == "nmt" and elem["name"] == "task.tasks.celery_nmt_call"
queue == "nmt_tts" and elem["name"] == "task.tasks.celery_nmt_tts_call"
):
task_list.append(eval(elem["kwargs"])["task_id"])
try:
task_list.append(eval(elem["kwargs"])["task_id"])
status_list.append(elem["state"])
except:
task_list.append(eval(elem["args"].split(",")[0].split("(")[1]))
status_list.append(elem["state"])
else:
pass
if task_list:
Expand All @@ -3314,8 +3306,12 @@ def inspect_queue(self, request):
"video_duration": str(elem["video__duration"]),
}
i = task_list.index(elem["id"])
task_dict["status"] = status_list[i]
task_list[i] = task_dict

for i in task_list:
if type(i) == int:
j = task_list.index(i)
task_list[j] = {"task_id": i, "status": "Not Found"}
return Response(
{"message": "successful", "data": task_list},
status=status.HTTP_200_OK,
Expand Down Expand Up @@ -3392,6 +3388,12 @@ def regenerate_response(self, request, pk, *args, **kwargs):
elif task.task_type == "VOICEOVER_EDIT":
celery_tts_call.delay(task_id=task.id)
api = "TTS"
elif task.task_type == "TRANSLATION_VOICEOVER_EDIT":
if regenerate_translation_voiceover(task.id) is False:
return Response(
{"message": "Transcription task is not complete yet"}, status=status.HTTP_400_BAD_REQUEST
)
api = "NMT-TTS"
else:
return Response(
{"message": "Invalid task"}, status=status.HTTP_400_BAD_REQUEST
Expand Down
2 changes: 1 addition & 1 deletion backend/transcript/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Show particular fields in the admin panel
class TranscriptAdmin(admin.ModelAdmin):
list_display = ("task", "video", "language", "transcript_type", "updated_at", "id")
list_display = ("task", "video", "language", "transcript_type", "updated_at", "id", "status")
list_filter = ("video", "language", "transcript_type")
search_fields = ("video", "language", "transcript_type")
ordering = ("-updated_at",)
Expand Down
2 changes: 2 additions & 0 deletions backend/transcript/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
("te", "Telugu"),
("sa", "Sanskrit"),
("ur", "Urdu"),
("ne", "Nepali"),
]

TRANSCRIPTION_SUPPORTED_LANGUAGES = {
Expand All @@ -29,4 +30,5 @@
"Tamil": "ta",
"Telugu": "te",
"Urdu": "ur",
"Nepali": "ne",
}
3 changes: 3 additions & 0 deletions backend/transcript/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
name="generate_original_transcript",
),
path("save/", views.save_transcription, name="save_transcript"),
path("reopen_completed_transcription_task/", views.reopen_completed_transcription_task, name="reopen_completed_transcription_task"),
path("get_transcription_status/", views.fetch_transcript_status, name="get_transcription_status"),
path("set_transcription_status/", views.update_transcript_status, name="set_transcription_status"),
path(
"save_full_transcript/",
views.save_full_transcription,
Expand Down
Loading

0 comments on commit 405cea2

Please sign in to comment.