Skip to content

Commit

Permalink
a few ideas for tensorrt
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Dec 3, 2024
1 parent 6140ce7 commit 183421e
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 57 deletions.
60 changes: 5 additions & 55 deletions backend/src/FFmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,59 +142,6 @@ def __init__(
self.readQueue = queue.Queue(maxsize=50)
self.writeQueue = queue.Queue(maxsize=50)

def get_ffmpeg_streams(self, video_file):
"""Get a list of streams from the video file using FFmpeg."""
try:
result = subprocess.run(
[FFMPEG_PATH, "-i", video_file], stderr=subprocess.PIPE, text=True
)
return result.stderr
except Exception as e:
print(f"An error occurred while running FFmpeg: {e}")
return None

def extract_subtitles(self, video_file, stream_index, subtitle_file):
self.videoPropertiesLocation = os.path.join(CWD, self.inputFile + "_VIDEODATA")
if not os.path.exists(self.videoPropertiesLocation):
os.makedirs(self.videoPropertiesLocation)
"""Extract a specific subtitle stream from the video file."""
try:
subprocess.run(
[
FFMPEG_PATH,
"-i",
video_file,
"-map",
f"0:{stream_index}",
subtitle_file,
],
check=True,
)
print(f"Extracted subtitle stream {stream_index} to {subtitle_file}")
self.subtitleFiles.append(subtitle_file)
except subprocess.CalledProcessError as e:
print(f"An error occurred while extracting subtitles: {e}")

def getVideoSubs(self, video_file):
ffmpeg_output = self.get_ffmpeg_streams(video_file)
if not ffmpeg_output:
return

subtitle_stream_pattern = re.compile(
r"Stream #0:(\d+).*?Subtitle", re.MULTILINE | re.DOTALL
)
subtitle_streams = subtitle_stream_pattern.findall(ffmpeg_output)

if not subtitle_streams:
print("No subtitle streams found in the video.")
return

for stream_index in subtitle_streams:
subtitle_file = os.path.join(
self.videoPropertiesLocation, f"subtitle_{stream_index}.srt"
)
self.extract_subtitles(video_file, stream_index, subtitle_file)

def getVideoProperties(self, inputFile: str = None):
log("Getting Video Properties...")
if inputFile is None:
Expand Down Expand Up @@ -282,13 +229,13 @@ def getFFmpegWriteCommand(self):
command.append(
f"{self.outputFile}",
)

if self.overwrite:
command.append("-y")

else:
command = [
f"{FFMPEG_PATH}",
"-y",
"-hide_banner",
"-v",
"warning",
Expand Down Expand Up @@ -431,3 +378,6 @@ def writeOutVideoFrames(self):
printAndLog(f"\nTime to complete render: {round(renderTime, 2)}")
except Exception as e:
print(f"ERROR: {e}\nPlease remove everything related to the app, and reinstall it if the problem persists across multiple input videos.")
self.shm.close()
self.shm.unlink()
os._exit(1)
93 changes: 91 additions & 2 deletions backend/src/pytorch/TensorRTHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,19 @@
SOFTWARE.
"""

from math import e
import sys
import os
import tensorrt
from io import BytesIO
import torch
import torch_tensorrt
onnx_support = True
try:
import onnx
except ImportError:
onnx_support = False
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
from torch._decomp import get_decompositions


Expand All @@ -40,7 +48,7 @@ def __init__(
debug: bool = False,
static_shape: bool = True,
):
self.tensorrt_version = tensorrt.__version__ # can just grab version from here instead of importing trt and torch trt in all related files
self.tensorrt_version = trt.__version__ # can just grab version from here instead of importing trt and torch trt in all related files
self.torch_tensorrt_version = torch_tensorrt.__version__
self.export_format = export_format
self.trt_workspace_size = trt_workspace_size
Expand Down Expand Up @@ -152,3 +160,84 @@ def load_engine(self, trt_engine_path: str) -> torch.jit.ScriptModule:
"""Loads a TensorRT engine from the specified path."""
print(f"Loading TensorRT engine from {trt_engine_path}.", file=sys.stderr)
return torch.jit.load(trt_engine_path).eval()

class TensorRTHandler:
def __init__(
self,
trt_workspace_size: int = 0,
max_aux_streams: int | None = None,
trt_optimization_level: int = 3,
static_shape: bool = True,
):
self.tensorrt_version = trt.__version__
self.trt_workspace_size = trt_workspace_size
self.max_aux_streams = max_aux_streams
self.optimization_level = trt_optimization_level
self.static_shape = static_shape

def export_onnx(
self,
model,
dtype,
device,
example_inputs: list[torch.Tensor]
):
example_inputs = [input.to(device=device, dtype=dtype) for input in example_inputs]
model.to(device=device, dtype=dtype)
with BytesIO() as f:
torch.onnx.export(
model,
tuple(example_inputs),
f,
verbose=True,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
#dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, dealing with static for testing currently
)
f.seek(0)
return f.read()

def build_tensorrt_engine(self, onnx_model, trt_engine_path):
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(0)
parser = trt.OnnxParser(network, TRT_LOGGER)
parser.parse(onnx_model)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20)
serialized_engine = builder.build_serialized_network(network, config)
with open(trt_engine_path, "wb") as f:
f.write(serialized_engine)

def load_engine(self, trt_engine_path: str):
runtime = trt.Runtime(TRT_LOGGER)
with open(trt_engine_path, "rb") as f:
serialized_engine = f.read()
engine = runtime.deserialize_cuda_engine(serialized_engine)
self.context = engine.create_execution_context()

def build_engine(
self,
model: torch.nn.Module,
dtype: torch.dtype,
device: torch.device,
example_inputs: list[torch.Tensor],
trt_engine_path: str,
):
onnx_model = self.export_onnx(model=model, example_inputs=example_inputs, dtype=dtype, device=device)
engine = self.build_tensorrt_engine(onnx_model, trt_engine_path)

def __call__(self): # inference here
pass

if __name__ == '__main__':
model = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)
example_inputs = [torch.randn(10)]
trt_engine_path = "model.engine"
handler = TensorRTHandler()
handler.build_engine(model, torch.float32, torch.device("cuda"), example_inputs, trt_engine_path)
engine = handler.load_engine(trt_engine_path)

0 comments on commit 183421e

Please sign in to comment.