Skip to content

Commit

Permalink
add different softsplat
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Dec 7, 2024
1 parent 7c532f8 commit 8fbd3b5
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 269 deletions.
3 changes: 1 addition & 2 deletions backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from .gmflow.gmflow import GMFlow
from .IFNet_HDv3 import IFNet
from .MetricNet import MetricNet
from .softsplat import softsplat as warp
from ..util.softsplat_torch import softsplat as warp

torch.fx.wrap("warp")


class GMFSS(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
#!/usr/bin/env python

import collections
import os
import re
import typing

import cupy
import torch

##########################################################
Expand All @@ -14,264 +6,6 @@
objCudacache = {}


def cuda_int32(intIn: int):
return cupy.int32(intIn)


# end


def cuda_float32(fltIn: float):
return cupy.float32(fltIn)


# end


def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
if "device" not in objCudacache:
objCudacache["device"] = torch.cuda.get_device_name()
# end

strKey = strFunction

for strVariable in objVariables:
objValue = objVariables[strVariable]

strKey += strVariable

if objValue is None:
continue

elif isinstance(objValue, int):
strKey += str(objValue)

elif isinstance(objValue, float):
strKey += str(objValue)

elif isinstance(objValue, bool):
strKey += str(objValue)

elif isinstance(objValue, str):
strKey += objValue

elif type(objValue) == torch.Tensor:
strKey += str(objValue.dtype)
strKey += str(objValue.shape)
strKey += str(objValue.stride())

elif True:
print(strVariable, type(objValue))
assert False

# end
# end

strKey += objCudacache["device"]

if strKey not in objCudacache:
for strVariable in objVariables:
objValue = objVariables[strVariable]

if objValue is None:
continue

elif isinstance(objValue, int):
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))

elif isinstance(objValue, float):
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))

elif isinstance(objValue, bool):
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))

elif isinstance(objValue, str):
strKernel = strKernel.replace("{{" + strVariable + "}}", objValue)

elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
strKernel = strKernel.replace("{{type}}", "unsigned char")

elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
strKernel = strKernel.replace("{{type}}", "half")

elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
strKernel = strKernel.replace("{{type}}", "float")

elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
strKernel = strKernel.replace("{{type}}", "double")

elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
strKernel = strKernel.replace("{{type}}", "int")

elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
strKernel = strKernel.replace("{{type}}", "long")

elif type(objValue) == torch.Tensor:
print(strVariable, objValue.dtype)
assert False

elif True:
print(strVariable, type(objValue))
assert False

# end
# end

while True:
objMatch = re.search("(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)

if objMatch is None:
break
# end

intArg = int(objMatch.group(2))

strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()

strKernel = strKernel.replace(
objMatch.group(),
str(
intSizes[intArg]
if torch.is_tensor(intSizes[intArg]) == False
else intSizes[intArg].item()
),
)
# end

while True:
objMatch = re.search("(OFFSET_)([0-4])(\()", strKernel)

if objMatch is None:
break
# end

intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1

while True:
intParentheses += 1 if strKernel[intStop] == "(" else 0
intParentheses -= 1 if strKernel[intStop] == ")" else 0

if intParentheses == 0:
break
# end

intStop += 1
# end

intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(",")

assert intArgs == len(strArgs) - 1

strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()

strIndex = []

for intArg in range(intArgs):
strIndex.append(
"(("
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ ")*"
+ str(
intStrides[intArg]
if torch.is_tensor(intStrides[intArg]) == False
else intStrides[intArg].item()
)
+ ")"
)
# end

strKernel = strKernel.replace(
"OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
"(" + str.join("+", strIndex) + ")",
)
# end

while True:
objMatch = re.search("(VALUE_)([0-4])(\()", strKernel)

if objMatch is None:
break
# end

intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1

while True:
intParentheses += 1 if strKernel[intStop] == "(" else 0
intParentheses -= 1 if strKernel[intStop] == ")" else 0

if intParentheses == 0:
break
# end

intStop += 1
# end

intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(",")

assert intArgs == len(strArgs) - 1

strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()

strIndex = []

for intArg in range(intArgs):
strIndex.append(
"(("
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ ")*"
+ str(
intStrides[intArg]
if torch.is_tensor(intStrides[intArg]) == False
else intStrides[intArg].item()
)
+ ")"
)
# end

strKernel = strKernel.replace(
"VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
strTensor + "[" + str.join("+", strIndex) + "]",
)
# end

objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel}
# end

return strKey


# end


@cupy.memoize(for_each_device=True)
def cuda_launch(strKey: str):
if "CUDA_HOME" not in os.environ:
os.environ["CUDA_HOME"] = cupy.cuda.get_cuda_path()
# end

return cupy.RawKernel(
objCudacache[strKey]["strKernel"],
objCudacache[strKey]["strFunction"],
tuple(
[
"-I " + os.environ["CUDA_HOME"],
"-I " + os.environ["CUDA_HOME"] + "/include",
]
),
)


# end


##########################################################

Expand Down
2 changes: 1 addition & 1 deletion backend/src/pytorch/InterpolateTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F
from abc import ABCMeta, abstractmethod

from backend.src.pytorch.InterpolateArchs.GIMM import GIMM
#from backend.src.pytorch.InterpolateArchs.GIMM import GIMM
from .InterpolateArchs.DetectInterpolateArch import ArchDetect
import math
import os
Expand Down

0 comments on commit 8fbd3b5

Please sign in to comment.