Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
do less in the backwards
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 19, 2024
1 parent 20da1c0 commit e7a6aa3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
16 changes: 5 additions & 11 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from float8_experimental.float8_python_api import addmm_float8_unwrapped
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, re_construct_float8_weight
from float8_experimental.float8_utils import is_row_major
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -192,7 +192,7 @@ def forward(
# This should be set to True when using traditional fsdp to avoid
# saving the unsharded weight for backwards
ctx.save_for_backward(
x_fp8, original_weight, weight_scale, weight_amax_buffer
x_fp8, original_weight, weight_scale
)
else:
# Does this interact properly with activation checkpointing?
Expand All @@ -211,19 +211,13 @@ def forward(
@staticmethod
def backward(ctx, go_fp8: torch.Tensor):
if ctx.recompute_float8_weight:
x_fp8, original_weight, weight_scale, weight_amax_buffer = ctx.saved_tensors
w_fp8 = Float8Tensor.to_float8(
original_weight,
weight_scale,
torch.float8_e4m3fn,
weight_amax_buffer,
emulate=ctx.emulate,
)
x_fp8, original_weight, weight_scale = ctx.saved_tensors
w_fp8 = re_construct_float8_weight(original_weight, weight_scale, torch.float8_e4m3fn, emulate=ctx.emulate)
else:
x_fp8, w_fp8 = ctx.saved_tensors

# calculate dL/dX
go_fp8_reshaped = go_fp8.view(-1, go_fp8.size(-1))
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1))
w_fp8_t_c_t = w_fp8.t().contiguous().t()
dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t)
dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1))
Expand Down
16 changes: 16 additions & 0 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ def backward(ctx, g):
return g, None, None, None, None


@torch._dynamo.allow_in_graph
def re_construct_float8_weight(tensor: torch.Tensor, scale: torch.Tensor, float8_dtype, emulate: bool = False):
""" In the backwards of float8_linear we don't need to fill the amax buffer
for the weight tensor since that was done during the forward and we just need to
recast the orignal precision tensor using the scale from the forward
Args:
tensor: the tensor to convert
scale: the scale to use to convert the tensor, from the forward
float8_dtype: the float8 dtype to use
emulate: if true using fp32 emulation for the matmuls, helpful
if you don't have access to h100 hardware.
"""
tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)
@torch._dynamo.allow_in_graph
class FromFloat8ConstrFunc(torch.autograd.Function):
"""
Expand Down

0 comments on commit e7a6aa3

Please sign in to comment.