forked from sdonoso/podcast-transcription
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
117 lines (87 loc) · 3.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import shutil
import argparse
import yaml
from glob import glob
from multiprocessing import Process, set_start_method, Manager
import subprocess
from src.scraper import download_audio_from_channel
from src.whisper import process_files, chunk_list, save_json
def parse_arguments():
parser = argparse.ArgumentParser(
description="Download audio from a YouTube channel."
)
parser.add_argument("config_file", type=str, help="The YAML configuration file.")
return parser.parse_args()
def load_config(config_file):
with open(config_file, "r") as file:
return yaml.safe_load(file)
def validate_config(config):
required_keys = ["channel_url", "output_folder", "output_json"]
missing_keys = [key for key in required_keys if not config.get(key)]
if missing_keys:
raise ValueError(
f"Error: The YAML file must contain {', '.join(missing_keys)}."
)
return config["channel_url"], config["output_folder"], config["output_json"]
def get_available_gpus(max_gpus=4):
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-gpu=index,memory.used,memory.total",
"--format=csv,noheader,nounits",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
raise RuntimeError(result.stderr)
gpu_info = result.stdout.strip().split("\n")
available_gpus = []
for line in gpu_info:
index, memory_used, memory_total = map(int, line.split(","))
if memory_used < 10:
available_gpus.append(index)
if len(available_gpus) >= max_gpus:
break
return available_gpus
except Exception as e:
print(f"Error detecting GPUs: {e}")
return []
def download_and_transcribe(channel_url, output_folder, output_json):
download_audio_from_channel(channel_url, output_folder)
manager = Manager()
shared_list = manager.list()
try:
set_start_method("spawn", force=True)
except RuntimeError as e:
print(f"Error setting start method: {e}")
audio_list = glob(f"{output_folder}/*.mp3")
gpus = get_available_gpus()
audio_chunks = chunk_list(audio_list, len(gpus))
processes = []
for gpu, audio_chunk in zip(gpus, audio_chunks):
p = Process(target=process_files, args=(gpu, audio_chunk, shared_list))
processes.append(p)
p.start()
for p in processes:
p.join()
dict_list = [
{"name": transcript[0], "transcription": transcript[1]}
for transcript in shared_list
]
save_json({"transcripts": dict_list}, output_json)
shutil.rmtree(output_folder)
def main():
args = parse_arguments()
config = load_config(args.config_file)
try:
channel_url, output_folder, output_json = validate_config(config)
except ValueError as e:
print(e)
return
for url, folder_path, json_path in zip(channel_url, output_folder, output_json):
download_and_transcribe(url, folder_path, json_path)
if __name__ == "__main__":
main()