Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autocast support for XlaPatchedLinear #8421

Merged
merged 9 commits into from
Dec 10, 2024

Conversation

aws-nm9
Copy link
Contributor

@aws-nm9 aws-nm9 commented Nov 27, 2024

This PR adds autocast support for XlaPatchedLinear. This layer currently is ignored by autocast because of custom forward and backward functions. Adding decorators as per https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops brings this within autocast context (when autocast is enabled). A test for the same is also added.

Related to #8420 PR for einsum autocast support, and an extension of that for XlaPatchedLinear, since it uses einsum.

Relevant issue: #8405

test/test_patchedlinear_autocast.py Outdated Show resolved Hide resolved
test/test_patchedlinear_autocast.py Outdated Show resolved Hide resolved
test/test_patchedlinear_autocast.py Outdated Show resolved Hide resolved
test/test_patchedlinear_autocast.py Outdated Show resolved Hide resolved
@jeffhataws
Copy link
Collaborator

Need linter fix.

@aws-nm9
Copy link
Contributor Author

aws-nm9 commented Dec 6, 2024

Need linter fix

@jeffhataws @rpsilva-aws Thanks for comments, made the changes and did YAPF formatting (was using RUFF earlier), linter should be fixed. Also consolidated all tests in a single file.

@aws-nm9
Copy link
Contributor Author

aws-nm9 commented Dec 6, 2024

Why can't we assert the actual count instead? Can you also include the entire hlo in the commit message / as a comment for reference?

I think the actual counts might be a good check to have because if in the future due to some bug we still have a single f32->bf16 conversion but the rest are missing, it will help identify the issue. You can add the actual HLO to this PR, it will help for future reference.

Attaching the HLO for PatchedLinear test here:

HloModule IrToHlo.97, entry_computation_layout={(f32[5,10]{1,0}, f32[15]{0}, f32[15,10]{1,0}, f32[])->(f32[15,10]{1,0})}

ENTRY %IrToHlo.97 (p0.1: f32[5,10], p1.5: f32[15], p2.7: f32[15,10], p3.42: f32[]) -> (f32[15,10]) {
  %p0.1 = f32[5,10]{1,0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/shared_new/nmokashi/workspace/opensource_repos/pytorch_xla/xla/test/test_bf16_autocast.py" source_line=46}
  %convert.2 = bf16[5,10]{1,0} convert(f32[5,10]{1,0} %p0.1), metadata={op_type="xla__cast" op_name="xla__cast" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/amp/autocast_mode.py" source_line=443}
  %p2.7 = f32[15,10]{1,0} parameter(2), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/nn/modules/module.py" source_line=1329}
  %convert.8 = bf16[15,10]{1,0} convert(f32[15,10]{1,0} %p2.7), metadata={op_type="xla__cast" op_name="xla__cast" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/amp/autocast_mode.py" source_line=443}
  %dot.9 = bf16[5,15]{1,0} dot(bf16[5,10]{1,0} %convert.2, bf16[15,10]{1,0} %convert.8), lhs_contracting_dims={1}, rhs_contracting_dims={1}, frontend_attributes={grad_x="false",grad_y="false"}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/functional.py" source_line=407}
  %transpose.10 = bf16[5,15]{1,0} transpose(bf16[5,15]{1,0} %dot.9), dimensions={0,1}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/functional.py" source_line=407}
  %p1.5 = f32[15]{0} parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/nn/modules/module.py" source_line=1329}
  %convert.6 = bf16[15]{0} convert(f32[15]{0} %p1.5), metadata={op_type="xla__cast" op_name="xla__cast" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/amp/autocast_mode.py" source_line=443}
  %constant.4 = bf16[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py" source_line=751}
  %broadcast.11 = bf16[15]{0} broadcast(bf16[] %constant.4), dimensions={}, metadata={op_type="aten__add" op_name="aten__add.1/aten__add" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py" source_line=751}
  %multiply.12 = bf16[15]{0} multiply(bf16[15]{0} %convert.6, bf16[15]{0} %broadcast.11), metadata={op_type="aten__add" op_name="aten__add.1/aten__add" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py" source_line=751}
  %broadcast.13 = bf16[5,15]{1,0} broadcast(bf16[15]{0} %multiply.12), dimensions={1}, metadata={op_type="aten__add" op_name="aten__add.1/aten__add" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py" source_line=751}
  %add.14 = bf16[5,15]{1,0} add(bf16[5,15]{1,0} %transpose.10, bf16[5,15]{1,0} %broadcast.13), metadata={op_type="aten__add" op_name="aten__add.1/aten__add" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py" source_line=751}
  %convert.85 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %add.14), metadata={op_type="aten__eq" op_name="aten__eq"}
  %constant.84 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %broadcast.86 = f32[5,15]{1,0} broadcast(f32[] %constant.84), dimensions={}, metadata={op_type="aten__eq" op_name="aten__eq"}
  %compare.87 = pred[5,15]{1,0} compare(f32[5,15]{1,0} %convert.85, f32[5,15]{1,0} %broadcast.86), direction=EQ, metadata={op_type="aten__eq" op_name="aten__eq"}
  %broadcast.88 = pred[5,15]{1,0} broadcast(pred[5,15]{1,0} %compare.87), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %constant.78 = bf16[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %broadcast.79 = bf16[] broadcast(bf16[] %constant.78), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %reshape.80 = bf16[1,1]{1,0} reshape(bf16[] %broadcast.79), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.81 = bf16[1,1]{1,0} broadcast(bf16[1,1]{1,0} %reshape.80), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %reshape.82 = bf16[] reshape(bf16[1,1]{1,0} %broadcast.81), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.83 = bf16[5,15]{1,0} broadcast(bf16[] %reshape.82), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %constant.43 = bf16[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/autograd/__init__.py" source_line=220}
  %broadcast.44 = bf16[] broadcast(bf16[] %constant.43), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/autograd/__init__.py" source_line=220}
  %reshape.45 = bf16[1,1]{1,0} reshape(bf16[] %broadcast.44), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.46 = bf16[1,1]{1,0} broadcast(bf16[1,1]{1,0} %reshape.45), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %reshape.47 = bf16[] reshape(bf16[1,1]{1,0} %broadcast.46), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.48 = bf16[5,15]{1,0} broadcast(bf16[] %reshape.47), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %convert.49 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %broadcast.48), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.42 = f32[] parameter(3), metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %broadcast.50 = f32[5,15]{1,0} broadcast(f32[] %p3.42), dimensions={}, metadata={op_type="aten__div" op_name="aten__div"}
  %divide.51 = f32[5,15]{1,0} divide(f32[5,15]{1,0} %convert.49, f32[5,15]{1,0} %broadcast.50), metadata={op_type="aten__div" op_name="aten__div"}
  %convert.52 = bf16[5,15]{1,0} convert(f32[5,15]{1,0} %divide.51), metadata={op_type="xla__cast" op_name="xla__cast"}
  %convert.74 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %convert.52), metadata={op_type="xla__cast" op_name="xla__cast"}
  %convert.70 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %add.14), metadata={op_type="xla__cast" op_name="xla__cast"}
  %constant.61 = bf16[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %reshape.62 = bf16[1,1]{1,0} reshape(bf16[] %constant.61), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.63 = bf16[1,1]{1,0} broadcast(bf16[1,1]{1,0} %reshape.62), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %reshape.64 = bf16[] reshape(bf16[1,1]{1,0} %broadcast.63), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.65 = bf16[5,15]{1,0} broadcast(bf16[] %reshape.64), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %constant.56 = bf16[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %reshape.57 = bf16[1,1]{1,0} reshape(bf16[] %constant.56), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.58 = bf16[1,1]{1,0} broadcast(bf16[1,1]{1,0} %reshape.57), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %reshape.59 = bf16[] reshape(bf16[1,1]{1,0} %broadcast.58), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.60 = bf16[5,15]{1,0} broadcast(bf16[] %reshape.59), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %multiply.66 = bf16[5,15]{1,0} multiply(bf16[5,15]{1,0} %broadcast.65, bf16[5,15]{1,0} %broadcast.60), metadata={op_type="aten__sub" op_name="aten__sub.3/aten__sub"}
  %subtract.67 = bf16[5,15]{1,0} subtract(bf16[5,15]{1,0} %add.14, bf16[5,15]{1,0} %multiply.66), metadata={op_type="aten__sub" op_name="aten__sub.3/aten__sub"}
  %power.68 = bf16[5,15]{1,0} power(bf16[5,15]{1,0} %add.14, bf16[5,15]{1,0} %subtract.67), metadata={op_type="aten__pow" op_name="aten__pow"}
  %convert.69 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %power.68), metadata={op_type="xla__cast" op_name="xla__cast"}
  %multiply.71 = f32[5,15]{1,0} multiply(f32[5,15]{1,0} %convert.70, f32[5,15]{1,0} %convert.69), metadata={op_type="aten__mul" op_name="aten__mul.4/aten__mul"}
  %convert.72 = bf16[5,15]{1,0} convert(f32[5,15]{1,0} %multiply.71), metadata={op_type="xla__cast" op_name="xla__cast"}
  %convert.73 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %convert.72), metadata={op_type="xla__cast" op_name="xla__cast"}
  %multiply.75 = f32[5,15]{1,0} multiply(f32[5,15]{1,0} %convert.74, f32[5,15]{1,0} %convert.73), metadata={op_type="aten__mul" op_name="aten__mul.5/aten__mul"}
  %convert.76 = bf16[5,15]{1,0} convert(f32[5,15]{1,0} %multiply.75), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.77 = bf16[5,15]{1,0} broadcast(bf16[5,15]{1,0} %convert.76), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %select.89 = bf16[5,15]{1,0} select(pred[5,15]{1,0} %broadcast.88, bf16[5,15]{1,0} %broadcast.83, bf16[5,15]{1,0} %broadcast.77), metadata={op_type="aten__where" op_name="aten__where"}
  %convert.53 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %convert.52), metadata={op_type="xla__cast" op_name="xla__cast"}
  %constant.32 = s64[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %convert.33 = bf16[] convert(s64[] %constant.32), metadata={op_type="aten__eq" op_name="aten__eq"}
  %broadcast.34 = bf16[5,15]{1,0} broadcast(bf16[] %convert.33), dimensions={}, metadata={op_type="aten__eq" op_name="aten__eq"}
  %compare.35 = pred[5,15]{1,0} compare(bf16[5,15]{1,0} %add.14, bf16[5,15]{1,0} %broadcast.34), direction=EQ, metadata={op_type="aten__eq" op_name="aten__eq"}
  %convert.36 = pred[5,15]{1,0} convert(pred[5,15]{1,0} %compare.35), metadata={op_type="aten__logical_and" op_name="aten__logical_and"}
  %constant.28 = s64[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %convert.29 = bf16[] convert(s64[] %constant.28), metadata={op_type="aten__ge" op_name="aten__ge"}
  %broadcast.30 = bf16[5,15]{1,0} broadcast(bf16[] %convert.29), dimensions={}, metadata={op_type="aten__ge" op_name="aten__ge"}
  %compare.31 = pred[5,15]{1,0} compare(bf16[5,15]{1,0} %add.14, bf16[5,15]{1,0} %broadcast.30), direction=GE, metadata={op_type="aten__ge" op_name="aten__ge"}
  %convert.37 = pred[5,15]{1,0} convert(pred[5,15]{1,0} %compare.31), metadata={op_type="aten__logical_and" op_name="aten__logical_and"}
  %and.38 = pred[5,15]{1,0} and(pred[5,15]{1,0} %convert.36, pred[5,15]{1,0} %convert.37), metadata={op_type="aten__logical_and" op_name="aten__logical_and"}
  %broadcast.39 = pred[5,15]{1,0} broadcast(pred[5,15]{1,0} %and.38), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %constant.22 = bf16[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %broadcast.23 = bf16[] broadcast(bf16[] %constant.22), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %reshape.24 = bf16[1,1]{1,0} reshape(bf16[] %broadcast.23), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.25 = bf16[1,1]{1,0} broadcast(bf16[1,1]{1,0} %reshape.24), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %reshape.26 = bf16[] reshape(bf16[1,1]{1,0} %broadcast.25), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.27 = bf16[5,15]{1,0} broadcast(bf16[] %reshape.26), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %power.17 = bf16[5,15]{1,0} power(bf16[5,15]{1,0} %add.14, bf16[5,15]{1,0} %add.14), metadata={op_type="aten__pow" op_name="aten__pow" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/_tensor.py" source_line=39}
  %convert.18 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %power.17), metadata={op_type="xla__cast" op_name="xla__cast"}
  %log.15 = bf16[5,15]{1,0} log(bf16[5,15]{1,0} %add.14), metadata={op_type="aten__log" op_name="aten__log"}
  %convert.16 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %log.15), metadata={op_type="xla__cast" op_name="xla__cast"}
  %multiply.19 = f32[5,15]{1,0} multiply(f32[5,15]{1,0} %convert.18, f32[5,15]{1,0} %convert.16), metadata={op_type="aten__mul" op_name="aten__mul.1/aten__mul"}
  %convert.20 = bf16[5,15]{1,0} convert(f32[5,15]{1,0} %multiply.19), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.21 = bf16[5,15]{1,0} broadcast(bf16[5,15]{1,0} %convert.20), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %select.40 = bf16[5,15]{1,0} select(pred[5,15]{1,0} %broadcast.39, bf16[5,15]{1,0} %broadcast.27, bf16[5,15]{1,0} %broadcast.21), metadata={op_type="aten__where" op_name="aten__where"}
  %convert.41 = f32[5,15]{1,0} convert(bf16[5,15]{1,0} %select.40), metadata={op_type="xla__cast" op_name="xla__cast"}
  %multiply.54 = f32[5,15]{1,0} multiply(f32[5,15]{1,0} %convert.53, f32[5,15]{1,0} %convert.41), metadata={op_type="aten__mul" op_name="aten__mul.2/aten__mul"}
  %convert.55 = bf16[5,15]{1,0} convert(f32[5,15]{1,0} %multiply.54), metadata={op_type="xla__cast" op_name="xla__cast"}
  %constant.3 = bf16[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %broadcast.90 = bf16[5,15]{1,0} broadcast(bf16[] %constant.3), dimensions={}, metadata={op_type="aten__add" op_name="aten__add.6/aten__add"}
  %multiply.91 = bf16[5,15]{1,0} multiply(bf16[5,15]{1,0} %convert.55, bf16[5,15]{1,0} %broadcast.90), metadata={op_type="aten__add" op_name="aten__add.6/aten__add"}
  %add.92 = bf16[5,15]{1,0} add(bf16[5,15]{1,0} %select.89, bf16[5,15]{1,0} %multiply.91), metadata={op_type="aten__add" op_name="aten__add.6/aten__add"}
  %dot.93 = bf16[15,10]{1,0} dot(bf16[5,15]{1,0} %add.92, bf16[5,10]{1,0} %convert.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/functional.py" source_line=407}
  %transpose.94 = bf16[15,10]{1,0} transpose(bf16[15,10]{1,0} %dot.93), dimensions={0,1}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="/shared_new/nmokashi/testenv_ptGSPMD_6271032615/lib/python3.10/site-packages/torch/functional.py" source_line=407}
  %convert.95 = f32[15,10]{1,0} convert(bf16[15,10]{1,0} %transpose.94), metadata={op_type="xla__cast" op_name="xla__cast"}
  ROOT %tuple.96 = (f32[15,10]{1,0}) tuple(f32[15,10]{1,0} %convert.95)
}

For reference, the code is:

def test_patchedlinear_autocast(self):
    hidden_size = 10
    intermediate_size = 15
    input_tensor = torch.randn(
        5, hidden_size, requires_grad=True)  # batch_size=5 as example
    linear = torch.nn.Linear(hidden_size, intermediate_size, bias=True)

    with torch.autocast("xla", enabled=True):
      linear = linear.to(device)
      input_tensor = input_tensor.to(device)
      output = xs.xla_patched_nn_linear_forward(linear, input_tensor)
      result = (output**output).mean()
      # a simple sum would suffice, but then hlo for linear.weight.grad would be only the backward pass
      #  since grad of sum is constant. Hence this is done, to get whole execution graph in HLO.

    result.backward()
    hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad])
    # Verify that matrix multiplication is performed in bfloat16 precision and not f32.
    # XLAPatchedLinear uses einsum instead of matmul, hence this is checking if einsum op in this layer is being autocast.
    self.assertRegex(hlo, r".*dot.*bf16")
    self.assertNotRegex(hlo, r".*dot.*f32")

    bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo))
    f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo))

    # Verify that precision conversions are happening
    self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0)
    # Verify more bf16->f32 conversions than f32->bf16, since this is expected during backward pass for grad computation
    self.assertTrue(bf16_to_f32 == 11)
    self.assertTrue(f32_to_bf16 == 8)
    self.assertTrue(bf16_to_f32 > f32_to_bf16)  #redundant given the above two, but this is what we actually want to verify

Also adding exact count assertions in new commit below per @rpsilva-aws's and @avizon-aws's suggestion.

@bhavya01 bhavya01 self-requested a review December 6, 2024 21:28
@aws-nm9 aws-nm9 requested a review from bhavya01 December 9, 2024 05:27
@bhavya01
Copy link
Collaborator

bhavya01 commented Dec 9, 2024

Waiting for lint fixes and tests to pass before merging

@jeffhataws jeffhataws merged commit 2021c0f into pytorch:master Dec 10, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants