diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 8093c271..72d30ef4 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -8,7 +8,7 @@ import torch from float8_experimental.float8_python_api import addmm_float8_unwrapped -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import Float8Tensor, re_construct_float8_weight from float8_experimental.float8_utils import is_row_major from torch.utils._pytree import tree_map @@ -192,7 +192,7 @@ def forward( # This should be set to True when using traditional fsdp to avoid # saving the unsharded weight for backwards ctx.save_for_backward( - x_fp8, original_weight, weight_scale, weight_amax_buffer + x_fp8, original_weight, weight_scale ) else: # Does this interact properly with activation checkpointing? @@ -211,19 +211,13 @@ def forward( @staticmethod def backward(ctx, go_fp8: torch.Tensor): if ctx.recompute_float8_weight: - x_fp8, original_weight, weight_scale, weight_amax_buffer = ctx.saved_tensors - w_fp8 = Float8Tensor.to_float8( - original_weight, - weight_scale, - torch.float8_e4m3fn, - weight_amax_buffer, - emulate=ctx.emulate, - ) + x_fp8, original_weight, weight_scale = ctx.saved_tensors + w_fp8 = re_construct_float8_weight(original_weight, weight_scale, torch.float8_e4m3fn, emulate=ctx.emulate) else: x_fp8, w_fp8 = ctx.saved_tensors # calculate dL/dX - go_fp8_reshaped = go_fp8.view(-1, go_fp8.size(-1)) + go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1)) w_fp8_t_c_t = w_fp8.t().contiguous().t() dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t) dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1)) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 4450fce8..4bfd72d4 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -42,6 +42,22 @@ def backward(ctx, g): return g, None, None, None, None +@torch._dynamo.allow_in_graph +def re_construct_float8_weight(tensor: torch.Tensor, scale: torch.Tensor, float8_dtype, emulate: bool = False): + """ In the backwards of float8_linear we don't need to fill the amax buffer + for the weight tensor since that was done during the forward and we just need to + recast the orignal precision tensor using the scale from the forward + + Args: + tensor: the tensor to convert + scale: the scale to use to convert the tensor, from the forward + float8_dtype: the float8 dtype to use + emulate: if true using fp32 emulation for the matmuls, helpful + if you don't have access to h100 hardware. + """ + tensor_scaled = tensor * scale + bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) + return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate) @torch._dynamo.allow_in_graph class FromFloat8ConstrFunc(torch.autograd.Function): """