diff --git a/comfy/float.py b/comfy/float.py index eb4a9b26e10..57fd070995e 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -1,18 +1,18 @@ import torch import math -def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS): +def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): mantissa_scaled = torch.where( normal_mask, (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) ) - mantissa_scaled += torch.rand_like(mantissa_scaled) + mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator) return mantissa_scaled.floor() / (2**MANTISSA_BITS) #Not 100% sure about this -def manual_stochastic_round_to_float8(x, dtype): +def manual_stochastic_round_to_float8(x, dtype, generator=None): if dtype == torch.float8_e4m3fn: EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 elif dtype == torch.float8_e5m2: @@ -34,7 +34,7 @@ def manual_stochastic_round_to_float8(x, dtype): # Combine mantissa calculation and rounding normal_mask = ~(exponent == 0) - abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS) + abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator) sign *= torch.where( normal_mask, @@ -47,7 +47,7 @@ def manual_stochastic_round_to_float8(x, dtype): -def stochastic_rounding(value, dtype): +def stochastic_rounding(value, dtype, seed=0): if dtype == torch.float32: return value.to(dtype=torch.float32) if dtype == torch.float16: @@ -55,6 +55,8 @@ def stochastic_rounding(value, dtype): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: - return manual_stochastic_round_to_float8(value, dtype) + generator = torch.Generator(device=value.device) + generator.manual_seed(seed) + return manual_stochastic_round_to_float8(value, dtype, generator=generator) return value.to(dtype=dtype) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 59c50541382..3f5d90273c4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -30,6 +30,18 @@ import comfy.lora from comfy.types import UnetWrapperFunction +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -309,7 +321,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False): else: temp_weight = weight.to(torch.float32, copy=True) out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: