diff --git a/backend/src/RenderVideo.py b/backend/src/RenderVideo.py index a7fb68a0..0ec699a5 100644 --- a/backend/src/RenderVideo.py +++ b/backend/src/RenderVideo.py @@ -102,15 +102,15 @@ def __init__( self.getVideoProperties(inputFile) printAndLog("Using backend: " + self.backend) - - if interpolateModel: - self.setupInterpolate() - printAndLog("Using Interpolation Model: " + self.interpolateModel) - + # upscale has to be called first to get the scale of the upscale model if upscaleModel: self.setupUpscale() + printAndLog("Using Upscaling Model: " + self.upscaleModel) - + if interpolateModel: + self.setupInterpolate() + + printAndLog("Using Interpolation Model: " + self.interpolateModel) super().__init__( inputFile=inputFile, diff --git a/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py b/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py index f286ea76..0f539216 100644 --- a/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py +++ b/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py @@ -25,7 +25,7 @@ def __init__( ensemble: bool = False, width: int = 1920, height: int = 1080, - trt=False, + trt=True, dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): @@ -71,14 +71,14 @@ def __init__( if trt: from ...TensorRTHandler import TorchTensorRTHandler trtHandler = TorchTensorRTHandler(multi_precision_engine=False,trt_optimization_level=5) - trtHandler.build_engine(self.ifnet, dtype=dtype, device=device, example_inputs=self.rife_example_input(), trt_engine_path="IFNet.engine") - trtHandler.build_engine(self.feat_ext, dtype=dtype, device=device, example_inputs=self.img0_example_input(), trt_engine_path="Feat.engine") trtHandler.build_engine(self.flownet, dtype=dtype, device=device, example_inputs=self.flownet_example_input(), trt_engine_path="Flownet.engine") - trtHandler.build_engine(self.fusionnet, dtype=dtype, device=device, example_inputs=self.flownet_example_input(), trt_engine_path="Flownet.engine") + #trtHandler.build_engine(self.ifnet, dtype=dtype, device=device, example_inputs=self.rife_example_input(), trt_engine_path="IFNet.engine") + #trtHandler.build_engine(self.feat_ext, dtype=dtype, device=device, example_inputs=self.img0_example_input(), trt_engine_path="Feat.engine") + #trtHandler.build_engine(self.fusionnet, dtype=dtype, device=device, example_inputs=self.flownet_example_input(), trt_engine_path="Flownet.engine") self.ifnet = None self.feat_ext = None #self.fusionnet = None - #self.flownet = None + self.flownet = None import gc gc.collect() torch.cuda.empty_cache() @@ -86,7 +86,7 @@ def __init__( torch.cuda.reset_max_memory_cached() self.ifnet = trtHandler.load_engine("IFNet.engine") self.feat_ext = trtHandler.load_engine("Feat.engine") - #self.flownet = trtHandler.load_engine("Flownet.engine") + self.flownet = trtHandler.load_engine("Flownet.engine")