From b13c5b1cc792c688bcdb745ab6d02fc9c27ed9bf Mon Sep 17 00:00:00 2001 From: tntwise Date: Sat, 24 Aug 2024 22:08:39 +0000 Subject: [PATCH] cache timestep tensort --- backend/src/FFmpeg.py | 3 ++- .../InterpolateArchs/RIFE/rife422_liteIFNET.py | 7 +++++-- backend/src/InterpolateTorch.py | 16 ++++++++++++---- backend/src/RenderVideo.py | 3 +++ 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/backend/src/FFmpeg.py b/backend/src/FFmpeg.py index 4be61700..395fac26 100644 --- a/backend/src/FFmpeg.py +++ b/backend/src/FFmpeg.py @@ -92,6 +92,7 @@ def __init__( # upsacletimes will be set to the scale of the loaded model with spandrel self.upscaleTimes = upscaleTimes self.interpolateFactor = interpolateFactor + self.ceilInterpolateFactor = math.ceil(self.interpolateFactor) self.encoder = encoder self.pixelFormat = pixelFormat self.benchmark = benchmark @@ -106,7 +107,7 @@ def __init__( self.shm = shm self.inputFrameChunkSize = inputFrameChunkSize self.outputFrameChunkSize = outputFrameChunkSize - self.ceilInterpolateFactor = math.ceil(self.interpolateFactor) + self.totalOutputFrames = self.totalInputFrames * self.ceilInterpolateFactor self.writeOutPipe = self.outputFile == "PIPE" diff --git a/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py b/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py index 0fc47883..1185e067 100644 --- a/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py @@ -142,6 +142,8 @@ def __init__( height=1080, backwarp_tenGrid=None, tenFlow_div=None, + pw=1920, + ph=1088, ): super(IFNet, self).__init__() self.block0 = IFBlock(7 + 8, c=192) @@ -159,9 +161,10 @@ def __init__( self.backwarp_tenGrid = backwarp_tenGrid self.tenFlow_div = tenFlow_div + self.pw = pw + self.ph = ph def forward(self, img0, img1, timestep): - # cant be cached - + h, w = img0.shape[2], img0.shape[3] imgs = torch.cat([img0, img1], dim=1) imgs_2 = torch.reshape(imgs, (2, 3, h, w)) diff --git a/backend/src/InterpolateTorch.py b/backend/src/InterpolateTorch.py index 5e9790b0..f1e07ee1 100644 --- a/backend/src/InterpolateTorch.py +++ b/backend/src/InterpolateTorch.py @@ -81,6 +81,7 @@ def frame_to_tensor(self, frame): def __init__( self, interpolateModelPath: str, + ceilInterpolateFactor: int = 2, width: int = 1920, height: int = 1080, device: str = "default", @@ -113,6 +114,7 @@ def __init__( self.device = device self.dtype = self.handlePrecision(dtype) self.backend = backend + self.ceilInterpolateFactor = ceilInterpolateFactor # set up streams for async processing self.stream = torch.cuda.Stream() self.prepareStream = torch.cuda.Stream() @@ -130,6 +132,14 @@ def __init__( self.padding = (0, self.pw - self.width, 0, self.ph - self.height) ad = ArchDetect(interpolateModelPath) interpolateArch = ad.getArch() + # caching the timestep tensor in a dict with the timestep as a float for the key + self.timestepDict = {} + for n in range(self.ceilInterpolateFactor): + timestep = 1 / (self.ceilInterpolateFactor - n) + timestep_tens = torch.full( + (1, 1, self.ph, self.pw), timestep, dtype=self.dtype, device=self.device + ) + self.timestepDict[timestep] = timestep_tens # detect what rife arch to use match interpolateArch.lower(): case "rife46": @@ -290,10 +300,8 @@ def handlePrecision(self, precision): @torch.inference_mode() def process(self, img0, img1, timestep): with torch.cuda.stream(self.stream): - timestep = torch.full( - (1, 1, self.ph, self.pw), timestep, dtype=self.dtype, device=self.device - ) - + + timestep = self.timestepDict[timestep] output = self.flownet(img0, img1, timestep) output = self.tensor_to_frame(output) self.stream.synchronize() diff --git a/backend/src/RenderVideo.py b/backend/src/RenderVideo.py index 01db8ceb..67312521 100644 --- a/backend/src/RenderVideo.py +++ b/backend/src/RenderVideo.py @@ -2,6 +2,7 @@ from queue import Queue, Empty from multiprocessing import shared_memory import os +import math from .FFmpeg import FFMpegRender from .SceneDetect import SceneDetect @@ -78,6 +79,7 @@ def __init__( self.precision = precision self.upscaleTimes = 1 # if no upscaling, it will default to 1 self.interpolateFactor = interpolateFactor + self.ceilInterpolateFactor = math.ceil(self.interpolateFactor) self.setupRender = self.returnFrame # set it to not convert the bytes to array by default, and just pass chunk through self.frame0 = None self.sceneDetectMethod = sceneDetectMethod @@ -265,6 +267,7 @@ def setupInterpolate(self): if self.backend == "pytorch" or self.backend == "tensorrt": interpolateRifePytorch = InterpolateRifeTorch( interpolateModelPath=self.interpolateModel, + ceilInterpolateFactor=self.ceilInterpolateFactor, width=self.width, height=self.height, device=self.device,