diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 349f13f312..5c705b3b8b 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -390,6 +390,14 @@ def test_repr(self): s = m.__repr__() assert "i:dyn,w:del,go:dyn" in s + @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") + def test_inference_mode(self): + x = torch.randn(32, 32, device='cuda') + m = nn.Sequential(nn.Linear(32, 32)).cuda() + m = convert_to_float8_training(m) + with torch.inference_mode(mode=True): + y = m(x) + class TestScaledMM: @unittest.skipIf( @@ -718,8 +726,6 @@ def test_fp8_tensor_statistics(self): (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (tensor_len, tensor_len)) -# ghstack test 1 -# ghstack test 2 if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index d3c3b405b3..f8115649b3 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -41,6 +41,7 @@ def decorator(func): aten.slice.Tensor, aten.transpose.int, aten.fill_.Scalar, + aten.reshape.default, ] ) def float8_desugar_op(aten_op, args, kwargs=None):