-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
150 lines (100 loc) · 3.9 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
server starts:
load all models
user starts a bot:
send welcome message
user writes a text:
encode it
feed to models
create audio file
send it to that user
delete that audio file from server
use threading for updater -> done
make it git clone or docker container -> with secrets of bot_token, certificates -> done
push it to ecr/dockerhub if required -> done
use spot instances and fleet for atleast one alive with elastic ip -> done
improve models quality:
make rick data for synthesizer, vocoder
overfit on it
check and improve
use wsgi with multiprocess model
add status endpoint to show system health
add track endpoint to check status of current task -> queued, failed, success, uploading
"""
import os
import shutil
import json
from flask import Flask, request
from telegram import Bot, Update
from telegram.ext import Dispatcher, CommandHandler, MessageHandler, Filters
from queue import Queue
from threading import Thread
from synthesizer.inference import Synthesizer
from encoder import inference as encoder
from vocoder import inference as vocoder
from pathlib import Path
import numpy as np
import librosa
import pickle
import logging
logging.basicConfig(format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s', level = logging.INFO)
dispatcher_update_queue = None
bot = None
application = Flask(__name__)
from flask import Flask
app = Flask(__name__)
BOT_TOKEN = "1054401279:AAESaTfz5nNuI6pgn3zbsDoSkj7LROSV0Ec"
#encoder_path = "/home/mayur/projects/Real-Time-Voice-Cloning/encoder/saved_models/pretrained.pt"
synthesizer_path = "./synthesizer/saved_models/logs-pretrained/taco_pretrained"
vocoder_path = "./vocoder/saved_models/pretrained/pretrained.pt"
embedding_path = "./rick_90_morty_30.pkl"
output_path = "./"
def load_embed():
embedding = pickle.load(open(embedding_path, "rb"))
#take only ricks embedding
embedding = embedding[0]
return embedding
def load_models():
#encoder_weights = Path(encoder_path)
vocoder_weights = Path(vocoder_path)
syn_dir = Path(synthesizer_path)
#encoder.load_model(encoder_weights)
synthesizer = Synthesizer(syn_dir)
vocoder.load_model(vocoder_weights)
return encoder, synthesizer, vocoder
def setup_bot(token):
global bot
global embedding
global encoder, synthesizer, vocoder
bot = Bot(BOT_TOKEN)
update_queue = Queue()
dispatcher = Dispatcher(bot, update_queue, workers = 1)
#register handlers
start_handler = CommandHandler("start", start_callback)
text_handler = MessageHandler(Filters.text, text_callback)
dispatcher.add_handler(start_handler)
dispatcher.add_handler(text_handler)
thread = Thread(target=dispatcher.start, name='dispatcher')
thread.start()
embedding = load_embed()
encoder, synthesizer, vocoder = load_models()
return update_queue
def start_callback(bot, update):
bot.send_message(chat_id = update.effective_chat.id, text = "yo")
def text_callback(bot, update):
spectrogram = synthesizer.synthesize_spectrograms([update.message.text], [embedding])
wav = vocoder.infer_waveform(spectrogram[0])
wav = np.pad(wav, (0, synthesizer.sample_rate), mode = "constant")
librosa.output.write_wav(f"{output_path}{update.effective_chat.id}.wav", wav, sr = synthesizer.sample_rate)
bot.send_voice(chat_id = update.effective_chat.id, voice = open(f"{output_path}{update.effective_chat.id}.wav", "rb"), timeout = 100)
@app.route("/", methods = ["POST"])
def root_function():
global bot
msg = request.get_json()
print("msg is ", msg)
decoded_msg = Update.de_json(msg, bot)
dispatcher_update_queue.put(decoded_msg)
return json.dumps({"message" : "success", "statusCode" : 200})
if __name__ == "__main__":
dispatcher_update_queue = setup_bot(BOT_TOKEN)
app.run(host = "0.0.0.0", ssl_context=('certificates/public.pem', 'certificates/private.key'), debug=True)