-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
78 lines (58 loc) · 2.37 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import json
import httpx
import base64
import requests
from fastapi import FastAPI, Request, BackgroundTasks
from os.path import os, join, dirname
from dotenv import load_dotenv, find_dotenv
from utils.utils import bytes_to_wav, reduce_noise_mfcc_up
from utils.model_utils import (
get_audio_classification_class,
get_keyword_similarity,
get_audio_direction,
get_oscvm_result,
)
dotenv_path = join(dirname(__file__), ".env")
load_dotenv(find_dotenv())
app = FastAPI()
async def get_model_inference(req):
top_channel, bottom_channel, uid = (
req["top_channel"],
req["bottom_channel"],
req["uid"],
)
filtered_class = req["filtered_class"]
websocket_idx = req["websocket_idx"]
data = {"keyword": "unknown", "websocket_idx": websocket_idx}
top_channel_data, bottom_channel_data = bytes(
base64.b64decode(top_channel.encode("utf-8"))
), bytes(base64.b64decode(bottom_channel.encode("utf-8")))
top_channel_audio = bytes_to_wav(uid, top_channel_data, "top_channel.wav")
bottom_channel_audio = bytes_to_wav(uid, bottom_channel_data, "bottom_channel.wav")
top_channel_audio = reduce_noise_mfcc_up(uid + "/top_channel.wav")
bottom_channel_audio = reduce_noise_mfcc_up(uid + "/bottom_channel.wav")
oscvm_prediction = get_oscvm_result(top_channel_audio)
if oscvm_prediction == -1:
return
class_prediction = get_audio_classification_class(top_channel_audio)
if class_prediction != -1 and (not filtered_class[class_prediction]):
return
if class_prediction == -1 or class_prediction == 5:
return
elif class_prediction == 4:
keyword_prediction, flag = get_keyword_similarity(uid, top_channel_audio)
if flag == 0:
return
data["keyword"] = keyword_prediction
data["prediction_class"] = str(class_prediction)
direction = get_audio_direction(bottom_channel_audio, top_channel_audio)
data["direction"] = direction
async with httpx.AsyncClient() as client:
response = await client.post(
os.getenv("CENTRAL_SERVER_URL") + "/get_model_prediction", json=data
)
@app.post("/prediction")
async def return_prediction(audio_data: Request, background_tasks: BackgroundTasks):
req = await audio_data.json()
background_tasks.add_task(get_model_inference, req)
return {"message": "Prediction request received."}