diff --git a/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py b/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py index 6cf083e0..67eeacdc 100644 --- a/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py +++ b/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py @@ -6,9 +6,8 @@ from .gmflow.gmflow import GMFlow from .IFNet_HDv3 import IFNet from .MetricNet import MetricNet -from .softsplat import softsplat as warp +from ..util.softsplat_torch import softsplat as warp -torch.fx.wrap("warp") class GMFSS(nn.Module): diff --git a/backend/src/pytorch/InterpolateArchs/GMFSS/softsplat.py b/backend/src/pytorch/InterpolateArchs/util/softsplat_cupy.py similarity index 100% rename from backend/src/pytorch/InterpolateArchs/GMFSS/softsplat.py rename to backend/src/pytorch/InterpolateArchs/util/softsplat_cupy.py diff --git a/backend/src/pytorch/InterpolateArchs/GIMM/softsplat.py b/backend/src/pytorch/InterpolateArchs/util/softsplat_torch.py similarity index 60% rename from backend/src/pytorch/InterpolateArchs/GIMM/softsplat.py rename to backend/src/pytorch/InterpolateArchs/util/softsplat_torch.py index 136ead8a..4ff66a19 100644 --- a/backend/src/pytorch/InterpolateArchs/GIMM/softsplat.py +++ b/backend/src/pytorch/InterpolateArchs/util/softsplat_torch.py @@ -1,11 +1,3 @@ -#!/usr/bin/env python - -import collections -import os -import re -import typing - -import cupy import torch ########################################################## @@ -14,264 +6,6 @@ objCudacache = {} -def cuda_int32(intIn: int): - return cupy.int32(intIn) - - -# end - - -def cuda_float32(fltIn: float): - return cupy.float32(fltIn) - - -# end - - -def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): - if "device" not in objCudacache: - objCudacache["device"] = torch.cuda.get_device_name() - # end - - strKey = strFunction - - for strVariable in objVariables: - objValue = objVariables[strVariable] - - strKey += strVariable - - if objValue is None: - continue - - elif isinstance(objValue, int): - strKey += str(objValue) - - elif isinstance(objValue, float): - strKey += str(objValue) - - elif isinstance(objValue, bool): - strKey += str(objValue) - - elif isinstance(objValue, str): - strKey += objValue - - elif type(objValue) == torch.Tensor: - strKey += str(objValue.dtype) - strKey += str(objValue.shape) - strKey += str(objValue.stride()) - - elif True: - print(strVariable, type(objValue)) - assert False - - # end - # end - - strKey += objCudacache["device"] - - if strKey not in objCudacache: - for strVariable in objVariables: - objValue = objVariables[strVariable] - - if objValue is None: - continue - - elif isinstance(objValue, int): - strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) - - elif isinstance(objValue, float): - strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) - - elif isinstance(objValue, bool): - strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) - - elif isinstance(objValue, str): - strKernel = strKernel.replace("{{" + strVariable + "}}", objValue) - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: - strKernel = strKernel.replace("{{type}}", "unsigned char") - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: - strKernel = strKernel.replace("{{type}}", "half") - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: - strKernel = strKernel.replace("{{type}}", "float") - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: - strKernel = strKernel.replace("{{type}}", "double") - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: - strKernel = strKernel.replace("{{type}}", "int") - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: - strKernel = strKernel.replace("{{type}}", "long") - - elif type(objValue) == torch.Tensor: - print(strVariable, objValue.dtype) - assert False - - elif True: - print(strVariable, type(objValue)) - assert False - - # end - # end - - while True: - objMatch = re.search("(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel) - - if objMatch is None: - break - # end - - intArg = int(objMatch.group(2)) - - strTensor = objMatch.group(4) - intSizes = objVariables[strTensor].size() - - strKernel = strKernel.replace( - objMatch.group(), - str( - intSizes[intArg] - if torch.is_tensor(intSizes[intArg]) == False - else intSizes[intArg].item() - ), - ) - # end - - while True: - objMatch = re.search("(OFFSET_)([0-4])(\()", strKernel) - - if objMatch is None: - break - # end - - intStart = objMatch.span()[1] - intStop = objMatch.span()[1] - intParentheses = 1 - - while True: - intParentheses += 1 if strKernel[intStop] == "(" else 0 - intParentheses -= 1 if strKernel[intStop] == ")" else 0 - - if intParentheses == 0: - break - # end - - intStop += 1 - # end - - intArgs = int(objMatch.group(2)) - strArgs = strKernel[intStart:intStop].split(",") - - assert intArgs == len(strArgs) - 1 - - strTensor = strArgs[0] - intStrides = objVariables[strTensor].stride() - - strIndex = [] - - for intArg in range(intArgs): - strIndex.append( - "((" - + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() - + ")*" - + str( - intStrides[intArg] - if torch.is_tensor(intStrides[intArg]) == False - else intStrides[intArg].item() - ) - + ")" - ) - # end - - strKernel = strKernel.replace( - "OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", - "(" + str.join("+", strIndex) + ")", - ) - # end - - while True: - objMatch = re.search("(VALUE_)([0-4])(\()", strKernel) - - if objMatch is None: - break - # end - - intStart = objMatch.span()[1] - intStop = objMatch.span()[1] - intParentheses = 1 - - while True: - intParentheses += 1 if strKernel[intStop] == "(" else 0 - intParentheses -= 1 if strKernel[intStop] == ")" else 0 - - if intParentheses == 0: - break - # end - - intStop += 1 - # end - - intArgs = int(objMatch.group(2)) - strArgs = strKernel[intStart:intStop].split(",") - - assert intArgs == len(strArgs) - 1 - - strTensor = strArgs[0] - intStrides = objVariables[strTensor].stride() - - strIndex = [] - - for intArg in range(intArgs): - strIndex.append( - "((" - + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() - + ")*" - + str( - intStrides[intArg] - if torch.is_tensor(intStrides[intArg]) == False - else intStrides[intArg].item() - ) - + ")" - ) - # end - - strKernel = strKernel.replace( - "VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", - strTensor + "[" + str.join("+", strIndex) + "]", - ) - # end - - objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel} - # end - - return strKey - - -# end - - -@cupy.memoize(for_each_device=True) -def cuda_launch(strKey: str): - if "CUDA_HOME" not in os.environ: - os.environ["CUDA_HOME"] = cupy.cuda.get_cuda_path() - # end - - return cupy.RawKernel( - objCudacache[strKey]["strKernel"], - objCudacache[strKey]["strFunction"], - tuple( - [ - "-I " + os.environ["CUDA_HOME"], - "-I " + os.environ["CUDA_HOME"] + "/include", - ] - ), - ) - - -# end - ########################################################## diff --git a/backend/src/pytorch/InterpolateTorch.py b/backend/src/pytorch/InterpolateTorch.py index d2deaa0f..fab8b55a 100644 --- a/backend/src/pytorch/InterpolateTorch.py +++ b/backend/src/pytorch/InterpolateTorch.py @@ -2,7 +2,7 @@ import torch.nn.functional as F from abc import ABCMeta, abstractmethod -from backend.src.pytorch.InterpolateArchs.GIMM import GIMM +#from backend.src.pytorch.InterpolateArchs.GIMM import GIMM from .InterpolateArchs.DetectInterpolateArch import ArchDetect import math import os