Skip to content

Commit

Permalink
fix batch_norm amp autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji committed Dec 17, 2024
1 parent d3ed982 commit 0b0b91b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
11 changes: 11 additions & 0 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,17 @@ def test_autocast_tpu_check_dtype(self):
assert not torch.is_autocast_xla_enabled()


class TestOtherOps(unittest.TestCase):

def test_batch_norm(self):
device = xm.xla_device()
data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16)
with autocast(device, dtype=torch.bfloat16):
output = torch.nn.BatchNorm2d(16)(data)
xm.mark_step()
self.assertEqual(output.dtype, torch.bfloat16)


if __name__ == "__main__":
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
sys.exit(0 if test.result.wasSuccessful() else 1)
48 changes: 29 additions & 19 deletions torch_xla/csrc/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
namespace torch_xla {
namespace {

bool IsF32BatchNormWithFP16Inputs(const xla::XlaOp& input,
const xla::XlaOp& weight) {
if (ShapeHelper::ShapeOfXlaOp(input).element_type() ==
xla::PrimitiveType::F16 &&
bool IsF32BatchNormWithLowerFPInputs(const xla::XlaOp& input,
const xla::XlaOp& weight) {
static constexpr std::array<xla::PrimitiveType, 9> lowerPrecistionTypes = {
xla::PrimitiveType::F8E5M2, xla::PrimitiveType::F8E4M3,
xla::PrimitiveType::F8E4M3FN, xla::PrimitiveType::F8E4M3B11FNUZ,
xla::PrimitiveType::F8E3M4, xla::PrimitiveType::F8E5M2FNUZ,
xla::PrimitiveType::F8E4M3FNUZ, xla::PrimitiveType::F16,
xla::PrimitiveType::BF16};
if (std::find(lowerPrecistionTypes.begin(), lowerPrecistionTypes.end(),
ShapeHelper::ShapeOfXlaOp(input).element_type()) !=
lowerPrecistionTypes.end() &&
ShapeHelper::ShapeOfXlaOp(weight).element_type() ==
xla::PrimitiveType::F32) {
return true;
Expand Down Expand Up @@ -39,37 +46,39 @@ xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) {

BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight,
xla::XlaOp bias, float eps_value) {
bool is_batchnorm_with_fp16_inputs =
IsF32BatchNormWithFP16Inputs(input, weight);
bool is_batchnorm_with_lower_fp_inputs =
IsF32BatchNormWithLowerFPInputs(input, weight);
// Handle the mixed precision use case.
if (is_batchnorm_with_fp16_inputs) {
if (is_batchnorm_with_lower_fp_inputs) {
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
}
xla::XlaOp outputs = xla::BatchNormTraining(input, weight, bias, eps_value,
/*feature_index=*/1);
xla::XlaOp output = xla::GetTupleElement(outputs, 0);
xla::XlaOp batch_mean = xla::GetTupleElement(outputs, 1);
xla::XlaOp batch_variance = xla::GetTupleElement(outputs, 2);
if (is_batchnorm_with_fp16_inputs) {
output = xla::ConvertElementType(output, xla::PrimitiveType::F16);
if (is_batchnorm_with_lower_fp_inputs) {
output = xla::ConvertElementType(
output, ShapeHelper::ShapeOfXlaOp(input).element_type());
}
return {output, batch_mean, batch_variance};
}

xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight,
xla::XlaOp bias, xla::XlaOp mean,
xla::XlaOp variance, float eps_value) {
bool is_batchnorm_with_fp16_inputs =
IsF32BatchNormWithFP16Inputs(input, weight);
bool is_batchnorm_with_lower_fp_inputs =
IsF32BatchNormWithLowerFPInputs(input, weight);
// Handle the mixed precision use case.
if (is_batchnorm_with_fp16_inputs) {
if (is_batchnorm_with_lower_fp_inputs) {
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
}
xla::XlaOp output =
xla::BatchNormInference(input, weight, bias, mean, variance, eps_value,
/*feature_index=*/1);
if (is_batchnorm_with_fp16_inputs) {
output = xla::ConvertElementType(output, xla::PrimitiveType::F16);
if (is_batchnorm_with_lower_fp_inputs) {
output = xla::ConvertElementType(
output, ShapeHelper::ShapeOfXlaOp(input).element_type());
}
return output;
}
Expand All @@ -78,10 +87,10 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
xla::XlaOp weight, xla::XlaOp save_mean,
xla::XlaOp save_invstd, bool training,
float eps_value) {
bool is_batchnorm_with_fp16_inputs =
IsF32BatchNormWithFP16Inputs(input, weight);
bool is_batchnorm_with_lower_fp_inputs =
IsF32BatchNormWithLowerFPInputs(input, weight);
// Handle the mixed precision use case.
if (is_batchnorm_with_fp16_inputs) {
if (is_batchnorm_with_lower_fp_inputs) {
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
grad = xla::ConvertElementType(grad, xla::PrimitiveType::F32);
}
Expand All @@ -91,8 +100,9 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
xla::XlaOp grad_input = xla::GetTupleElement(grads, 0);
xla::XlaOp grad_weight = xla::GetTupleElement(grads, 1);
xla::XlaOp grad_bias = xla::GetTupleElement(grads, 2);
if (is_batchnorm_with_fp16_inputs) {
grad_input = xla::ConvertElementType(grad_input, xla::PrimitiveType::F16);
if (is_batchnorm_with_lower_fp_inputs) {
grad_input = xla::ConvertElementType(
grad_input, ShapeHelper::ShapeOfXlaOp(input).element_type());
}
return {grad_input, grad_weight, grad_bias};
}
Expand Down

0 comments on commit 0b0b91b

Please sign in to comment.