diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 58e352da..42009ca0 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -40,6 +40,27 @@ def backward(ctx, gradY): ) +def cast_weight_linear( + x_fp8: Float8Tensor, weight: torch.Tensor, bias, emulate: bool +) -> torch.Tensor: + """Cast weight to fp8_e4m3fn and do linear + Why a new function for something that can be inlined? + Because we want to call torch utils checkpoint on this function. + We always want to recompute the cast of the weight to fp8 since we can, trivially + fuse this into the transpose/contiguous of the weight during the backwards. + + Args: + x_fp8 (Float8Tensor): input activation in fp8 + weight: weight tensor in higher precision + bias: bias tensor in higher precision + emulate (bool): whether to emulate fp8 matmul logic in float32 + """ + scale = tensor_to_scale(weight, torch.float8_e4m3fn) + w_fp8 = Float8Tensor.to_float8(weight, scale, torch.float8_e4m3fn, emulate=emulate) + y = torch.nn.functional.linear(x_fp8, w_fp8, bias) + return y + + class Float8DynamicLinear(torch.nn.Linear): """ A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly @@ -48,9 +69,14 @@ class Float8DynamicLinear(torch.nn.Linear): def forward(self, x): x_fp8 = self.cast_to_float8(x) - w_fp8 = self.cast_to_float8(self.weight) - - y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) + y = torch.utils.checkpoint.checkpoint( + cast_weight_linear, + x_fp8, + self.weight, + self.bias, + self.emulate, + use_reentrant=False, + ) # Cast gradY to float8_e5m2 during backward y = self.cast_to_float8e5m2_bw(y)