Skip to content

Commit

Permalink
Use less memory in float8 lora patching by doing calculations in fp16.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 26, 2024
1 parent c681294 commit 7985ff8
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions comfy/float.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
import torch
import math

def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)

mantissa_scaled += torch.rand_like(mantissa_scaled)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)

#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype):
Expand All @@ -9,40 +20,30 @@ def manual_stochastic_round_to_float8(x, dtype):
else:
raise ValueError("Unsupported dtype")

x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)

# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)

# Combine mantissa calculation and rounding
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
# zero_mask = (abs_x == 0)
# subnormal_mask = (exponent == 0) & (abs_x != 0)
normal_mask = ~(exponent == 0)

mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_floor = mantissa_scaled.floor()
mantissa = torch.where(
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
(mantissa_floor + 1) / (2**MANTISSA_BITS),
mantissa_floor / (2**MANTISSA_BITS)
)
result = torch.where(
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS)

sign *= torch.where(
normal_mask,
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
del abs_x

result = torch.where(abs_x == 0, 0, result)
return result.to(dtype=dtype)
return sign.to(dtype=dtype)



Expand Down

0 comments on commit 7985ff8

Please sign in to comment.