diff --git a/REAL-Video-Enhancer.py b/REAL-Video-Enhancer.py index e5636e50..b16fbcb2 100644 --- a/REAL-Video-Enhancer.py +++ b/REAL-Video-Enhancer.py @@ -339,6 +339,8 @@ def startRender(self): videoWidth=self.videoWidth, videoHeight=self.videoHeight, videoFps=self.videoFps, + tilingEnabled=self.tilingCheckBox.isChecked(), + tilesize=self.tileSizeComboBox.currentText(), videoFrameCount=self.videoFrameCount, method=method, backend=self.backendComboBox.currentText(), diff --git a/backend/src/SceneDetect.py b/backend/src/SceneDetect.py index 04b302b8..445b8b0d 100644 --- a/backend/src/SceneDetect.py +++ b/backend/src/SceneDetect.py @@ -77,7 +77,7 @@ def getPySceneDetectTransitions(self) -> Queue: for frame_num in tqdm(range(self.totalInputFrames - 1)): - frame = self.getFrameTo100x100img(self.readQueue.get()) + frame = bytesTo100x100img(self.readQueue.get(), width=self.width, height=self.height ) detectedFrameList = adaptiveDetector.process_frame( frame_num=frame_num, frame_img=frame @@ -96,7 +96,7 @@ def getMeanTransitions(self): sceneChangeQueue = Queue() detector = NPMeanSequential() for frame_num in tqdm(range(self.totalInputFrames - 1)): - frame = bytesTo100x100img(self.readQueue.get()) + frame = bytesTo100x100img(self.readQueue.get(), width=self.width, height=self.height ) if detector.sceneDetect(frame): sceneChangeQueue.put(frame_num-1) return sceneChangeQueue diff --git a/backend/src/UpscaleTorch.py b/backend/src/UpscaleTorch.py index 486fd978..d84880c5 100644 --- a/backend/src/UpscaleTorch.py +++ b/backend/src/UpscaleTorch.py @@ -3,6 +3,7 @@ import numpy as np import cv2 import torch as torch +import torch.nn.functional as F from src.Util import ( @@ -55,7 +56,7 @@ def __init__( self, modelPath: str, device="default", - tile_pad: int = 10, + tile_pad: int = 0, precision: str = "auto", width: int = 1920, height: int = 1080, @@ -81,9 +82,23 @@ def __init__( self.device = device model = self.loadModel(modelPath=modelPath, device=device, dtype=self.dtype) - self.width = width - self.height = height + self.videoWidth = width + self.videoHeight = height self.tilesize = tilesize + self.tile = [self.tilesize, self.tilesize] + match self.scale: + case 1: + modulo = 4 + case 2: + modulo = 2 + case _: + modulo = 1 + if all(t > 0 for t in self.tile): + self.pad_w = math.ceil(min(self.tile[0] + 2 * tile_pad, width) / modulo) * modulo + self.pad_h = math.ceil(min(self.tile[1] + 2 * tile_pad, height) / modulo) * modulo + else: + self.pad_w = width + self.pad_h = height if backend == "tensorrt": import tensorrt as trt @@ -93,7 +108,7 @@ def __init__( os.path.realpath(trt_cache_dir), ( f"{os.path.basename(modelPath)}" - + f"_{width}x{height}" + + f"_{self.pad_w}x{self.pad_h}" + f"_{'fp16' if self.dtype == torch.float16 else 'fp32'}" + f"_{torch.cuda.get_device_name(device)}" + f"_trt-{trt.__version__}" @@ -109,7 +124,7 @@ def __init__( if not os.path.isfile(trt_engine_path): inputs = [ torch.zeros( - (1, 3, self.height, self.width), + (1, 3, self.pad_h, self.pad_w), dtype=self.dtype, device=device, ) @@ -170,7 +185,7 @@ def loadModel( def bytesToFrame(self, frame): return ( torch.frombuffer(frame, dtype=torch.uint8) - .reshape(self.height, self.width, 3) + .reshape(self.videoHeight, self.videoWidth, 3) .to(self.device, dtype=self.dtype) .permute(2, 0, 1) .unsqueeze(0) @@ -212,76 +227,69 @@ def getScale(self): @torch.inference_mode() def renderTiledImage( self, - image: torch.Tensor, + img: torch.Tensor, ) -> torch.Tensor: - """It will first crop input images to tiles, and then process each tile. - Finally, all the processed tiles are merged into one images. + scale = self.scale + tile = self.tile + tile_pad = self.tile_pad - Modified from: https://github.com/ata4/esrgan-launcher - """ - batch, channel, height, width = image.shape - output_height = height * self.scale - output_width = width * self.scale - output_shape = (batch, channel, output_height, output_width) + batch, channel, height, width = img.shape + output_shape = (batch, channel, height * scale, width * scale) # start with black image - output = image.new_zeros(output_shape) - tiles_x = math.ceil(width / self.tilesize) - tiles_y = math.ceil(height / self.tilesize) + output = img.new_zeros(output_shape) + + tiles_x = math.ceil(width / tile[0]) + tiles_y = math.ceil(height / tile[1]) # loop over all tiles for y in range(tiles_y): for x in range(tiles_x): # extract tile from input image - ofs_x = x * self.tilesize - ofs_y = y * self.tilesize + ofs_x = x * tile[0] + ofs_y = y * tile[1] + # input tile area on total image input_start_x = ofs_x - input_end_x = min(ofs_x + self.tilesize, width) + input_end_x = min(ofs_x + tile[0], width) input_start_y = ofs_y - input_end_y = min(ofs_y + self.tilesize, height) + input_end_y = min(ofs_y + tile[1], height) # input tile area on total image with padding - input_start_x_pad = max(input_start_x - self.tile_pad, 0) - input_end_x_pad = min(input_end_x + self.tile_pad, width) - input_start_y_pad = max(input_start_y - self.tile_pad, 0) - input_end_y_pad = min(input_end_y + self.tile_pad, height) + input_start_x_pad = max(input_start_x - tile_pad, 0) + input_end_x_pad = min(input_end_x + tile_pad, width) + input_start_y_pad = max(input_start_y - tile_pad, 0) + input_end_y_pad = min(input_end_y + tile_pad, height) # input tile dimensions input_tile_width = input_end_x - input_start_x input_tile_height = input_end_y - input_start_y - tile_idx = y * tiles_x + x + 1 - input_tile = image[ - :, - :, - input_start_y_pad:input_end_y_pad, - input_start_x_pad:input_end_x_pad, - ] - # upscale tile - with torch.no_grad(): - output_tile = self.renderImage(input_tile) + input_tile = img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + h, w = input_tile.shape[2:] + input_tile = F.pad(input_tile, (0, self.pad_w - w, 0, self.pad_h - h), "replicate") + + # process tile + output_tile = self.model(input_tile) + + output_tile = output_tile[:, :, : h * scale, : w * scale] # output tile area on total image - output_start_x = input_start_x * self.scale - output_end_x = input_end_x * self.scale - output_start_y = input_start_y * self.scale - output_end_y = input_end_y * self.scale + output_start_x = input_start_x * scale + output_end_x = input_end_x * scale + output_start_y = input_start_y * scale + output_end_y = input_end_y * scale # output tile area without padding - output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale - output_end_x_tile = output_start_x_tile + input_tile_width * self.scale - output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale - output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + output_start_x_tile = (input_start_x - input_start_x_pad) * scale + output_end_x_tile = output_start_x_tile + input_tile_width * scale + output_start_y_tile = (input_start_y - input_start_y_pad) * scale + output_end_y_tile = output_start_y_tile + input_tile_height * scale # put tile into output image - output[ - :, :, output_start_y:output_end_y, output_start_x:output_end_x - ] = output_tile[ - :, - :, - output_start_y_tile:output_end_y_tile, - output_start_x_tile:output_end_x_tile, + output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[ + :, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile ] - return output + + return output \ No newline at end of file diff --git a/backend/src/Util.py b/backend/src/Util.py index 140b76db..8e092e28 100644 --- a/backend/src/Util.py +++ b/backend/src/Util.py @@ -50,8 +50,8 @@ def log(message: str): with open(os.path.join(cwd, "backend_log.txt"), "a") as f: f.write(message + "\n") -def bytesTo100x100img(self, image: bytes) -> np.ndarray: - frame = np.frombuffer(image,dtype=np.uint8).reshape(self.height, self.width, 3) +def bytesTo100x100img(image: bytes, width, height) -> np.ndarray: + frame = np.frombuffer(image,dtype=np.uint8).reshape(height, width, 3) frame = cv2.resize( frame, dsize=(100, 100) ) diff --git a/src/ui/ProcessTab.py b/src/ui/ProcessTab.py index 18496732..819e28a9 100644 --- a/src/ui/ProcessTab.py +++ b/src/ui/ProcessTab.py @@ -82,7 +82,12 @@ def QConnect(self): combobox.currentIndexChanged.connect( self.switchInterpolationAndUpscale ) - + # set tile size visible to false by default + self.parent.tileSizeContainer.setVisible(False) + # connect up tilesize container visiable + self.parent.tilingCheckBox.stateChanged.connect( + lambda: self.parent.tileSizeContainer.setVisible(self.parent.tilingCheckBox.isChecked()) + ) self.parent.inputFileText.textChanged.connect(self.parent.updateVideoGUIDetails) self.parent.interpolationMultiplierSpinBox.valueChanged.connect( @@ -90,6 +95,8 @@ def QConnect(self): ) self.parent.modelComboBox.currentIndexChanged.connect(self.parent.updateVideoGUIDetails) + + def killRenderProcess(self): try: # kills render process if necessary self.renderProcess.terminate() @@ -114,8 +121,10 @@ def switchInterpolationAndUpscale(self): if method.lower() == "interpolate": self.parent.interpolationContainer.setVisible(True) + self.parent.upscaleContainer.setVisible(False) else: self.parent.interpolationContainer.setVisible(False) + self.parent.upscaleContainer.setVisible(True) self.parent.updateVideoGUIDetails() @@ -127,6 +136,8 @@ def run( videoHeight: int, videoFps: float, videoFrameCount: int, + tilesize: int, + tilingEnabled: bool, method: str, backend: str, interpolationTimes: int, @@ -138,6 +149,8 @@ def run( self.videoWidth = videoWidth self.videoHeight = videoHeight self.videoFps = videoFps + self.tilingEnabled = tilingEnabled + self.tilesize = tilesize self.videoFrameCount = videoFrameCount models = self.getTotalModels(method=method, backend=backend) @@ -230,6 +243,11 @@ def renderToPipeThread(self, method: str, backend: str, interpolateTimes: int): "--interpolateFactor", "1", ] + if self.tilingEnabled: + command += [ + "--tilesize", + f"{self.tilesize}", + ] if method == "Interpolate": command += [ "--interpolateModel", diff --git a/testRVEInterface.ui b/testRVEInterface.ui index 970b525c..ab0ddd94 100644 --- a/testRVEInterface.ui +++ b/testRVEInterface.ui @@ -692,7 +692,7 @@ li.checked::marker { content: "\2612"; } - 0 + 6 0 @@ -747,7 +747,7 @@ li.checked::marker { content: "\2612"; } - <html><head/><body><p>Perform processing without outputing new video. This tests the raw performance of the inference.</p></body></html> + <html><head/><body><p>Split up processing upscaled frames into chunks.</p><p>Lowers VRAM usage, but also slows down render. </p><p>Only use when render failes due to VRAM limits.</p></body></html>