Skip to content

Commit

Permalink
cache timestep tensort
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Aug 24, 2024
1 parent a2568fb commit b13c5b1
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
3 changes: 2 additions & 1 deletion backend/src/FFmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
# upsacletimes will be set to the scale of the loaded model with spandrel
self.upscaleTimes = upscaleTimes
self.interpolateFactor = interpolateFactor
self.ceilInterpolateFactor = math.ceil(self.interpolateFactor)
self.encoder = encoder
self.pixelFormat = pixelFormat
self.benchmark = benchmark
Expand All @@ -106,7 +107,7 @@ def __init__(
self.shm = shm
self.inputFrameChunkSize = inputFrameChunkSize
self.outputFrameChunkSize = outputFrameChunkSize
self.ceilInterpolateFactor = math.ceil(self.interpolateFactor)

self.totalOutputFrames = self.totalInputFrames * self.ceilInterpolateFactor

self.writeOutPipe = self.outputFile == "PIPE"
Expand Down
7 changes: 5 additions & 2 deletions backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def __init__(
height=1080,
backwarp_tenGrid=None,
tenFlow_div=None,
pw=1920,
ph=1088,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 8, c=192)
Expand All @@ -159,9 +161,10 @@ def __init__(
self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div

self.pw = pw
self.ph = ph
def forward(self, img0, img1, timestep):
# cant be cached


h, w = img0.shape[2], img0.shape[3]
imgs = torch.cat([img0, img1], dim=1)
imgs_2 = torch.reshape(imgs, (2, 3, h, w))
Expand Down
16 changes: 12 additions & 4 deletions backend/src/InterpolateTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def frame_to_tensor(self, frame):
def __init__(
self,
interpolateModelPath: str,
ceilInterpolateFactor: int = 2,
width: int = 1920,
height: int = 1080,
device: str = "default",
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
self.device = device
self.dtype = self.handlePrecision(dtype)
self.backend = backend
self.ceilInterpolateFactor = ceilInterpolateFactor
# set up streams for async processing
self.stream = torch.cuda.Stream()
self.prepareStream = torch.cuda.Stream()
Expand All @@ -130,6 +132,14 @@ def __init__(
self.padding = (0, self.pw - self.width, 0, self.ph - self.height)
ad = ArchDetect(interpolateModelPath)
interpolateArch = ad.getArch()
# caching the timestep tensor in a dict with the timestep as a float for the key
self.timestepDict = {}
for n in range(self.ceilInterpolateFactor):
timestep = 1 / (self.ceilInterpolateFactor - n)
timestep_tens = torch.full(
(1, 1, self.ph, self.pw), timestep, dtype=self.dtype, device=self.device
)
self.timestepDict[timestep] = timestep_tens
# detect what rife arch to use
match interpolateArch.lower():
case "rife46":
Expand Down Expand Up @@ -290,10 +300,8 @@ def handlePrecision(self, precision):
@torch.inference_mode()
def process(self, img0, img1, timestep):
with torch.cuda.stream(self.stream):
timestep = torch.full(
(1, 1, self.ph, self.pw), timestep, dtype=self.dtype, device=self.device
)


timestep = self.timestepDict[timestep]
output = self.flownet(img0, img1, timestep)
output = self.tensor_to_frame(output)
self.stream.synchronize()
Expand Down
3 changes: 3 additions & 0 deletions backend/src/RenderVideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from queue import Queue, Empty
from multiprocessing import shared_memory
import os
import math

from .FFmpeg import FFMpegRender
from .SceneDetect import SceneDetect
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self.precision = precision
self.upscaleTimes = 1 # if no upscaling, it will default to 1
self.interpolateFactor = interpolateFactor
self.ceilInterpolateFactor = math.ceil(self.interpolateFactor)
self.setupRender = self.returnFrame # set it to not convert the bytes to array by default, and just pass chunk through
self.frame0 = None
self.sceneDetectMethod = sceneDetectMethod
Expand Down Expand Up @@ -265,6 +267,7 @@ def setupInterpolate(self):
if self.backend == "pytorch" or self.backend == "tensorrt":
interpolateRifePytorch = InterpolateRifeTorch(
interpolateModelPath=self.interpolateModel,
ceilInterpolateFactor=self.ceilInterpolateFactor,
width=self.width,
height=self.height,
device=self.device,
Expand Down

0 comments on commit b13c5b1

Please sign in to comment.