From 0b0b91b17ae314cf4710ccaaf2424b49dd3c5e07 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Tue, 17 Dec 2024 23:47:43 +0000 Subject: [PATCH] fix batch_norm amp autocast --- test/test_autocast.py | 11 ++++++++ torch_xla/csrc/batch_norm.cpp | 48 +++++++++++++++++++++-------------- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/test/test_autocast.py b/test/test_autocast.py index acbd0e03be3..c91b69c97dd 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -484,6 +484,17 @@ def test_autocast_tpu_check_dtype(self): assert not torch.is_autocast_xla_enabled() +class TestOtherOps(unittest.TestCase): + + def test_batch_norm(self): + device = xm.xla_device() + data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16) + with autocast(device, dtype=torch.bfloat16): + output = torch.nn.BatchNorm2d(16)(data) + xm.mark_step() + self.assertEqual(output.dtype, torch.bfloat16) + + if __name__ == "__main__": test = unittest.main(verbosity=FLAGS.verbosity, exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/batch_norm.cpp b/torch_xla/csrc/batch_norm.cpp index c33fc0c523c..fa9a365e774 100644 --- a/torch_xla/csrc/batch_norm.cpp +++ b/torch_xla/csrc/batch_norm.cpp @@ -8,10 +8,17 @@ namespace torch_xla { namespace { -bool IsF32BatchNormWithFP16Inputs(const xla::XlaOp& input, - const xla::XlaOp& weight) { - if (ShapeHelper::ShapeOfXlaOp(input).element_type() == - xla::PrimitiveType::F16 && +bool IsF32BatchNormWithLowerFPInputs(const xla::XlaOp& input, + const xla::XlaOp& weight) { + static constexpr std::array lowerPrecistionTypes = { + xla::PrimitiveType::F8E5M2, xla::PrimitiveType::F8E4M3, + xla::PrimitiveType::F8E4M3FN, xla::PrimitiveType::F8E4M3B11FNUZ, + xla::PrimitiveType::F8E3M4, xla::PrimitiveType::F8E5M2FNUZ, + xla::PrimitiveType::F8E4M3FNUZ, xla::PrimitiveType::F16, + xla::PrimitiveType::BF16}; + if (std::find(lowerPrecistionTypes.begin(), lowerPrecistionTypes.end(), + ShapeHelper::ShapeOfXlaOp(input).element_type()) != + lowerPrecistionTypes.end() && ShapeHelper::ShapeOfXlaOp(weight).element_type() == xla::PrimitiveType::F32) { return true; @@ -39,10 +46,10 @@ xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) { BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp bias, float eps_value) { - bool is_batchnorm_with_fp16_inputs = - IsF32BatchNormWithFP16Inputs(input, weight); + bool is_batchnorm_with_lower_fp_inputs = + IsF32BatchNormWithLowerFPInputs(input, weight); // Handle the mixed precision use case. - if (is_batchnorm_with_fp16_inputs) { + if (is_batchnorm_with_lower_fp_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); } xla::XlaOp outputs = xla::BatchNormTraining(input, weight, bias, eps_value, @@ -50,8 +57,9 @@ BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp output = xla::GetTupleElement(outputs, 0); xla::XlaOp batch_mean = xla::GetTupleElement(outputs, 1); xla::XlaOp batch_variance = xla::GetTupleElement(outputs, 2); - if (is_batchnorm_with_fp16_inputs) { - output = xla::ConvertElementType(output, xla::PrimitiveType::F16); + if (is_batchnorm_with_lower_fp_inputs) { + output = xla::ConvertElementType( + output, ShapeHelper::ShapeOfXlaOp(input).element_type()); } return {output, batch_mean, batch_variance}; } @@ -59,17 +67,18 @@ BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp bias, xla::XlaOp mean, xla::XlaOp variance, float eps_value) { - bool is_batchnorm_with_fp16_inputs = - IsF32BatchNormWithFP16Inputs(input, weight); + bool is_batchnorm_with_lower_fp_inputs = + IsF32BatchNormWithLowerFPInputs(input, weight); // Handle the mixed precision use case. - if (is_batchnorm_with_fp16_inputs) { + if (is_batchnorm_with_lower_fp_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); } xla::XlaOp output = xla::BatchNormInference(input, weight, bias, mean, variance, eps_value, /*feature_index=*/1); - if (is_batchnorm_with_fp16_inputs) { - output = xla::ConvertElementType(output, xla::PrimitiveType::F16); + if (is_batchnorm_with_lower_fp_inputs) { + output = xla::ConvertElementType( + output, ShapeHelper::ShapeOfXlaOp(input).element_type()); } return output; } @@ -78,10 +87,10 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input, xla::XlaOp weight, xla::XlaOp save_mean, xla::XlaOp save_invstd, bool training, float eps_value) { - bool is_batchnorm_with_fp16_inputs = - IsF32BatchNormWithFP16Inputs(input, weight); + bool is_batchnorm_with_lower_fp_inputs = + IsF32BatchNormWithLowerFPInputs(input, weight); // Handle the mixed precision use case. - if (is_batchnorm_with_fp16_inputs) { + if (is_batchnorm_with_lower_fp_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); grad = xla::ConvertElementType(grad, xla::PrimitiveType::F32); } @@ -91,8 +100,9 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input, xla::XlaOp grad_input = xla::GetTupleElement(grads, 0); xla::XlaOp grad_weight = xla::GetTupleElement(grads, 1); xla::XlaOp grad_bias = xla::GetTupleElement(grads, 2); - if (is_batchnorm_with_fp16_inputs) { - grad_input = xla::ConvertElementType(grad_input, xla::PrimitiveType::F16); + if (is_batchnorm_with_lower_fp_inputs) { + grad_input = xla::ConvertElementType( + grad_input, ShapeHelper::ShapeOfXlaOp(input).element_type()); } return {grad_input, grad_weight, grad_bias}; }