Skip to content

Commit

Permalink
format a bit more
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Dec 14, 2024
1 parent 21deca8 commit b4f6ad0
Showing 1 changed file with 58 additions and 53 deletions.
111 changes: 58 additions & 53 deletions backend/rve-backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,70 @@
check_bfloat16_support,
checkForDirectML,
checkForDirectMLHalfPrecisionSupport,
checkForGMFSS,
get_pytorch_vram,

)


class HandleApplication:
def __init__(self):
self.args = self.handleArguments()
if not self.args.list_backends:
self.checkArguments()
Render(
self.renderVideo()
else:
self.listBackends()

def listBackends(self):
half_prec_supp = False
availableBackends = []
printMSG = ""

if checkForTensorRT():
"""
checks for tensorrt availability, and the current gpu works with it (if half precision is supported)
Trt 10 only supports RTX 20 series and up.
Half precision is only availaible on RTX 20 series and up
"""
import torch
half_prec_supp = check_bfloat16_support()
if half_prec_supp:
import tensorrt

availableBackends.append("tensorrt")
printMSG += f"TensorRT Version: {tensorrt.__version__}\n"
else:
printMSG += "ERROR: Cannot use tensorrt backend, as it is not supported on your current GPU"

if checkForPytorchCUDA():
import torch
availableBackends.append("pytorch (cuda)")
printMSG += f"PyTorch Version: {torch.__version__}\n"
half_prec_supp = check_bfloat16_support()

if checkForPytorchROCM():
availableBackends.append("pytorch (rocm)")
import torch
printMSG += f"PyTorch Version: {torch.__version__}\n"
half_prec_supp = check_bfloat16_support()

if checkForNCNN():
availableBackends.append("ncnn")
printMSG += f"NCNN Version: 20220729\n"
from rife_ncnn_vulkan_python import Rife

if checkForDirectML():
availableBackends.append("directml")
import onnxruntime as ort

printMSG += f"ONNXruntime Version: {ort.__version__}\n"
half_prec_supp = checkForDirectMLHalfPrecisionSupport()

printMSG += f"Half precision support: {half_prec_supp}\n"
print("Available Backends: " + str(availableBackends))
print(printMSG)

def renderVideo(self):
self.checkArguments()
Render(
# model settings
inputFile=self.args.input,
outputFile=self.args.output,
Expand Down Expand Up @@ -50,55 +103,7 @@ def __init__(self):
dynamic_scaled_optical_flow=self.args.dynamic_scaled_optical_flow,
ensemble=self.args.ensemble,
)
else:
half_prec_supp = False
availableBackends = []
printMSG = ""

if checkForTensorRT():
"""
checks for tensorrt availability, and the current gpu works with it (if half precision is supported)
Trt 10 only supports RTX 20 series and up.
Half precision is only availaible on RTX 20 series and up
"""
import torch

half_prec_supp = check_bfloat16_support()
if half_prec_supp:
import tensorrt

availableBackends.append("tensorrt")
printMSG += f"TensorRT Version: {tensorrt.__version__}\n"
else:
printMSG += "ERROR: Cannot use tensorrt backend, as it is not supported on your current GPU"
if checkForPytorchCUDA():
import torch

availableBackends.append("pytorch (cuda)")
printMSG += f"PyTorch Version: {torch.__version__}\n"
half_prec_supp = check_bfloat16_support()
if checkForPytorchROCM():
availableBackends.append("pytorch (rocm)")
import torch

printMSG += f"PyTorch Version: {torch.__version__}\n"
half_prec_supp = check_bfloat16_support()

if checkForNCNN():
availableBackends.append("ncnn")
printMSG += f"NCNN Version: 20220729\n"
from rife_ncnn_vulkan_python import Rife
if checkForDirectML():
availableBackends.append("directml")
import onnxruntime as ort

printMSG += f"ONNXruntime Version: {ort.__version__}\n"
half_prec_supp = checkForDirectMLHalfPrecisionSupport()
printMSG += f"Half precision support: {half_prec_supp}\n"

print("Available Backends: " + str(availableBackends))
print(printMSG)


def handleArguments(self) -> argparse.ArgumentParser:
"""_summary_
Expand Down

0 comments on commit b4f6ad0

Please sign in to comment.