diff --git a/comfy/float.py b/comfy/float.py index 1dbdafd593c..eb4a9b26e10 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -1,4 +1,15 @@ import torch +import math + +def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS): + mantissa_scaled = torch.where( + normal_mask, + (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), + (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) + ) + + mantissa_scaled += torch.rand_like(mantissa_scaled) + return mantissa_scaled.floor() / (2**MANTISSA_BITS) #Not 100% sure about this def manual_stochastic_round_to_float8(x, dtype): @@ -9,40 +20,30 @@ def manual_stochastic_round_to_float8(x, dtype): else: raise ValueError("Unsupported dtype") + x = x.half() sign = torch.sign(x) abs_x = x.abs() + sign = torch.where(abs_x == 0, 0, sign) # Combine exponent calculation and clamping exponent = torch.clamp( - torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS, + torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS, 0, 2**EXPONENT_BITS - 1 ) # Combine mantissa calculation and rounding - # min_normal = 2.0 ** (-EXPONENT_BIAS + 1) - # zero_mask = (abs_x == 0) - # subnormal_mask = (exponent == 0) & (abs_x != 0) normal_mask = ~(exponent == 0) - mantissa_scaled = torch.where( - normal_mask, - (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), - (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) - ) - mantissa_floor = mantissa_scaled.floor() - mantissa = torch.where( - torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor), - (mantissa_floor + 1) / (2**MANTISSA_BITS), - mantissa_floor / (2**MANTISSA_BITS) - ) - result = torch.where( + abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS) + + sign *= torch.where( normal_mask, - sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa), - sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa + (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), + (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x ) + del abs_x - result = torch.where(abs_x == 0, 0, result) - return result.to(dtype=dtype) + return sign.to(dtype=dtype)