diff --git a/backend/src/pytorch/InterpolateTorch.py b/backend/src/pytorch/InterpolateTorch.py index 6fd2bee7..50d580d4 100644 --- a/backend/src/pytorch/InterpolateTorch.py +++ b/backend/src/pytorch/InterpolateTorch.py @@ -9,12 +9,15 @@ import os import logging import gc +import sys from ..utils.Util import ( printAndLog, errorAndLog, check_bfloat16_support, warnAndLog, + log ) +from ..constants import HAS_SYSTEM_CUDA from time import sleep torch.set_float32_matmul_precision("medium") @@ -166,9 +169,7 @@ def _load(self): dummyInput = torch.zeros([1, 3, self.ph, self.pw], dtype=self.dtype, device=self.device) dummyInput2 = torch.zeros([1, 3, self.ph, self.pw], dtype=self.dtype, device=self.device) - xs = torch.cat((dummyInput.unsqueeze(2), dummyInput2.unsqueeze(2)), dim=2).to( - self.device, non_blocking=True -) + xs = torch.cat((dummyInput.unsqueeze(2), dummyInput2.unsqueeze(2)), dim=2).to(self.device, non_blocking=True) s_shape = xs.shape[-2:] # caching the timestep tensor in a dict with the timestep as a float for the key @@ -189,7 +190,9 @@ def _load(self): ).to(non_blocking=True, dtype=self.dtype, device=self.device),None) self.coordDict[timestep] = coord - + log("GIMM loaded") + log("Scale: " + str(self.scale)) + log("Using System CUDA: " + str(HAS_SYSTEM_CUDA)) if self.backend == "tensorrt": warnAndLog( "TensorRT is not implemented for GIMM yet, falling back to PyTorch" @@ -281,6 +284,9 @@ def _load(self): height=self.height, ) self.flownet.eval().to(device=self.device, dtype=self.dtype) + log("GMFSS loaded") + log("Scale: " + str(self.scale)) + log("Using System CUDA: " + str(HAS_SYSTEM_CUDA)) if self.backend == "tensorrt": warnAndLog( "TensorRT is not implemented for GMFSS yet, falling back to PyTorch"