Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
checkpiont to reduce memory usage, only do dynamic for now
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 23, 2024
1 parent 713d2db commit e9cc745
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ def backward(ctx, gradY):
)


def cast_weight_linear(
x_fp8: Float8Tensor, weight: torch.Tensor, bias, emulate: bool
) -> torch.Tensor:
"""Cast weight to fp8_e4m3fn and do linear
Why a new function for something that can be inlined?
Because we want to call torch utils checkpoint on this function.
We always want to recompute the cast of the weight to fp8 since we can, trivially
fuse this into the transpose/contiguous of the weight during the backwards.
Args:
x_fp8 (Float8Tensor): input activation in fp8
weight: weight tensor in higher precision
bias: bias tensor in higher precision
emulate (bool): whether to emulate fp8 matmul logic in float32
"""
scale = tensor_to_scale(weight, torch.float8_e4m3fn)
w_fp8 = Float8Tensor.to_float8(weight, scale, torch.float8_e4m3fn, emulate=emulate)
y = torch.nn.functional.linear(x_fp8, w_fp8, bias)
return y


class Float8DynamicLinear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
Expand All @@ -48,9 +69,14 @@ class Float8DynamicLinear(torch.nn.Linear):

def forward(self, x):
x_fp8 = self.cast_to_float8(x)
w_fp8 = self.cast_to_float8(self.weight)

y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
y = torch.utils.checkpoint.checkpoint(
cast_weight_linear,
x_fp8,
self.weight,
self.bias,
self.emulate,
use_reentrant=False,
)

# Cast gradY to float8_e5m2 during backward
y = self.cast_to_float8e5m2_bw(y)
Expand Down

0 comments on commit e9cc745

Please sign in to comment.