Skip to content

Commit

Permalink
simplify engine building
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Nov 3, 2024
1 parent 75fd44b commit ddec64b
Show file tree
Hide file tree
Showing 17 changed files with 167 additions and 357 deletions.
7 changes: 4 additions & 3 deletions REAL-Video-Enhancer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import os

# patch for macos
if sys.platform == "darwin":
os.chdir(os.path.dirname(os.path.abspath(__file__)))
Expand All @@ -15,7 +16,7 @@
)
from PySide6.QtGui import QIcon
from src.Util import printAndLog
from mainwindow import Ui_MainWindow
from mainwindow import Ui_MainWindow
from PySide6 import QtSvg # Import the QtSvg module so svg icons can be used on windows
from src.version import version
from src.InputHandler import VideoInputHandler
Expand Down Expand Up @@ -464,11 +465,11 @@ def closeEvent(self, event):
app = QApplication(sys.argv)

# setting the pallette

app.setPalette(Palette())
window = MainWindow()
if len(sys.argv) > 1:
if sys.argv[1] == '--fullscreen':
if sys.argv[1] == "--fullscreen":
window.showFullScreen()
window.show()
sys.exit(app.exec())
2 changes: 0 additions & 2 deletions backend/rve-backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ def checkArguments(self):
raise ValueError(
"Interpolation factor must be 1 if no interpolation model is used.\nPlease use --interpolateFactor 1 for no interpolation!"
)




if __name__ == "__main__":
Expand Down
19 changes: 0 additions & 19 deletions backend/src/FFmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,22 +377,6 @@ def writeOutInformation(self, fcs):

time.sleep(0.1)

def openMPVProc(self):
self.mpv_process = subprocess.Popen(
[
"mpv",
"--no-correct-pts",
f"--fps={self.fps * self.ceilInterpolateFactor}",
"--demuxer-thread=no",
"--",
"-",
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=False,
)

def writeOutVideoFrames(self):
"""
Writes out frames either to ffmpeg or to pipe
Expand All @@ -401,9 +385,6 @@ def writeOutVideoFrames(self):
ffmpeg -f rawvideo -pix_fmt rgb24 -s 1920x1080 -framerate 24 -i - -c:v libx264 -crf 18 -pix_fmt yuv420p -c:a copy out.mp4
"""
log("Rendering")
#

# self.openMPVProc()
self.startTime = time.time()
self.framesRendered: int = 1
self.last_length: int = 0
Expand Down
1 change: 1 addition & 0 deletions backend/src/InterpolateArchs/DetectInterpolateArch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch


class RIFE46:
def __init__():
pass
Expand Down
1 change: 1 addition & 0 deletions backend/src/InterpolateArchs/GMFSS/FeatureNet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn
from .util import MyPReLU


class FeatureNet(nn.Module):
"""The quadratic model"""

Expand Down
23 changes: 16 additions & 7 deletions backend/src/InterpolateArchs/GMFSS/GMFSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,18 @@


class GMFSS(nn.Module):
def __init__(self, model_path, model_type:str="union", scale:int=1, ensemble:bool=False, width:int=1920, height:int=1080):
def __init__(
self,
model_path,
model_type: str = "union",
scale: int = 1,
ensemble: bool = False,
width: int = 1920,
height: int = 1080,
):
super(GMFSS, self).__init__()
from .FusionNet_u import GridNet

# get gmfss from here, as its a combination of all the models https://github.com/TNTwise/real-video-enhancer-models/releases/download/models/GMFSS.pkl
self.width = width
self.height = height
Expand All @@ -25,12 +34,12 @@ def __init__(self, model_path, model_type:str="union", scale:int=1, ensemble:boo
self.fusionnet = GridNet()
combined_state_dict = torch.load(model_path, map_location="cpu")
if model_type != "base":
self.ifnet.load_state_dict(combined_state_dict['rife'])
self.flownet.load_state_dict(combined_state_dict['flownet'])
self.metricnet.load_state_dict(combined_state_dict['metricnet'])
self.feat_ext.load_state_dict(combined_state_dict['feat_ext'])
self.fusionnet.load_state_dict(combined_state_dict['fusionnet'])
self.ifnet.load_state_dict(combined_state_dict["rife"])
self.flownet.load_state_dict(combined_state_dict["flownet"])
self.metricnet.load_state_dict(combined_state_dict["metricnet"])
self.feat_ext.load_state_dict(combined_state_dict["feat_ext"])
self.fusionnet.load_state_dict(combined_state_dict["fusionnet"])

self.model_type = model_type
self.scale = scale

Expand Down
2 changes: 1 addition & 1 deletion backend/src/InterpolateArchs/GMFSS/softsplat.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,4 +642,4 @@ def backward(self, tenOutgrad):
# end


# end
# end
6 changes: 4 additions & 2 deletions backend/src/InterpolateArchs/RIFE/warplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
tenInput = tenInput.to(torch.float)
tenFlow = tenFlow.to(torch.float)

tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1)
tenFlow = torch.cat(
[tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1
)
g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1)
return torch.ops.aten.grid_sampler_2d(tenInput, g, 0, 1, True).to(dtype)
return torch.ops.aten.grid_sampler_2d(tenInput, g, 0, 1, True).to(dtype)
2 changes: 1 addition & 1 deletion backend/src/InterpolateNCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def process_fast(
self.height, self.width, self.channels
)


class RIFE(Rife): ...


Expand Down
Loading

0 comments on commit ddec64b

Please sign in to comment.