Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move where clear pending IR is called to avoid crash #5552

Merged
merged 3 commits into from
Sep 14, 2023

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Sep 9, 2023

Without this patch

python benchmarks/dynamo/torchbench.py --randomize-input --performance --training --trace-on-xla --backend=openxla --only hf_Bert

will crash with

RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:168 : Check failed: HasValue() 

We expect FallBackNodeCollector to introduce in place operation to xla_args so we want to replaced it with the cloned arg and clearing the pending IR. What I found is that CapabilityBasedPartitioner (or something close to that region of the code) will also introduce IRs. On top of that after the Partitioner, the xla_args passed to the extract_internal is not the same xla_arg passed to extract_compiled_graph. If we called clearPendingIr after Partitioner we might remove the pending IR of the copied xla_arg and we have no way to restore those values then dynamo will crash.

I chose to move the clearpendingIr to a earlier region of the code and use mark_step to turn them into device data. I have some concern of the correctness of this approach but I need more time to debug why Partitioner will introduce IRs to begin with.

@wonjoolee95
Copy link
Collaborator

Thanks! Seems like the dynamo unit tests fail due to metric comparisons, probably due to moving the clear_pending_ir. Just throwing some thoughts but maybe because the IRs introduced by the CapabilityBasedPartitioner is not being cleared anymore?

@JackCaoG
Copy link
Collaborator Author

I am kind of glad that CI failed, so I can figure out what IR produced by CapabilityBasedPartitioner and why in a smaller scale.

@JackCaoG
Copy link
Collaborator Author

ok I am able to repo locally, will try to take a look today or tmr

@JackCaoG
Copy link
Collaborator Author

hmm, with my most recent fix, hf_bert crashed again.. looking into it

@JackCaoG
Copy link
Collaborator Author

Ah OK, the DynamoCpuFallbackTest.test_fallback_multiple_submodules was complaining that we execute one more graph, but it is

[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
  mark_step (/src/pytorch/xla/torch_xla/core/xla_model.py:816)
  extract_internal (/src/pytorch/xla/torch_xla/core/dynamo_bridge.py:328)
  extract_compiled_graph (/src/pytorch/xla/torch_xla/core/dynamo_bridge.py:525)
  fwd (/src/pytorch/torch/_dynamo/backends/torchxla.py:49)
  g (/src/pytorch/torch/_functorch/aot_autograd.py:1482)
  rng_functionalization_wrapper (/src/pytorch/torch/_functorch/aot_autograd.py:1594)
  call_func_with_args (/src/pytorch/torch/_functorch/aot_autograd.py:1506)
  runtime_wrapper (/src/pytorch/torch/_functorch/aot_autograd.py:2533)
  g (/src/pytorch/torch/_functorch/aot_autograd.py:1482)
  forward (/src/pytorch/torch/_functorch/aot_autograd.py:3905)
  inner (/src/pytorch/torch/_dynamo/external_utils.py:17)
  _fn (/src/pytorch/torch/_dynamo/eval_frame.py:338)
  fn_fallback (test_dynamo.py:242)
  _fn (/src/pytorch/torch/_dynamo/eval_frame.py:338)
  test_fallback_multiple_submodules (test_dynamo.py:259)
  _callTestMethod (/usr/local/lib/python3.8/unittest/case.py:633)
  run (/usr/local/lib/python3.8/unittest/case.py:676)
  __call__ (/usr/local/lib/python3.8/unittest/case.py:736)
  run (/usr/local/lib/python3.8/unittest/suite.py:122)
  __call__ (/usr/local/lib/python3.8/unittest/suite.py:84)
  run (/usr/local/lib/python3.8/unittest/suite.py:122)
  __call__ (/usr/local/lib/python3.8/unittest/suite.py:84)
  run (/usr/local/lib/python3.8/unittest/runner.py:176)
  runTests (/usr/local/lib/python3.8/unittest/main.py:271)
  __init__ (/usr/local/lib/python3.8/unittest/main.py:101)
  <module> (test_dynamo.py:535)

Hashes: (efa8de56ce4a47dfd1977caef0297d85)

## BEGIN_GRAPH
HloModule IrToHlo.3, entry_computation_layout={(f32[7]{0})->(f32[7]{0})}

ENTRY %IrToHlo.3 (p0.1: f32[7]) -> (f32[7]) {
  %p0.1 = f32[7]{0} parameter(0)
  ROOT %tuple.2 = (f32[7]{0}) tuple(f32[7]{0} %p0.1)
}

I think it is OK, I will bump up the counter

@JackCaoG
Copy link
Collaborator Author

@wonjoolee95 can I get a review for this one?

Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@JackCaoG JackCaoG merged commit 9b12009 into master Sep 14, 2023
7 checks passed
JackCaoG added a commit that referenced this pull request Sep 15, 2023
* Move where clear pending IR is called to avoid crash

* fix CI

* fix CI and add some debugging messages
will-cromar pushed a commit that referenced this pull request Sep 15, 2023
* Move where clear pending IR is called to avoid crash

* fix CI

* fix CI and add some debugging messages
will-cromar pushed a commit that referenced this pull request Sep 18, 2023
* Move where clear pending IR is called to avoid crash

* fix CI

* fix CI and add some debugging messages
will-cromar added a commit that referenced this pull request Sep 19, 2023
* Handle dynamo function without input (#5565) (#5577)

* Make cpu tensor on XLA dynamo backend a warning instead of error (#5549) (#5576)

* [author: jluntamazon] Adding more explicit HLO lowering control by exposing LoweringContext… (#5431) (#5580)

* Adding more explicit HLO lowering control by exposing LoweringContext (and utilities) to python for Neuron

* fixing linter issues

* fixing spacing

* apply comments and fix compilation errors

* add test for new apis

* fix linter

* update test

* update test

* modify test

* reverse back to GetIrValue()

* update test inputs with random numbers

* skip unittest because it only fails in CI

---------

Co-authored-by: aws-kingrj <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: seanlatias <[email protected]>

* fixing num_local_processes typo (#5573) (#5579)

Co-authored-by: aws-kingrj <[email protected]>

* Move where clear pending IR is called to avoid crash (#5552) (#5582)

* Move where clear pending IR is called to avoid crash

* fix CI

* fix CI and add some debugging messages

* Fix release branch and tag patterns for GitHub Actions (#5587) (#5590)

* Improve bernoulli rng-bit-generation memory footprint (#5581) (#5589)

* Allow downcasting RngUniform genenration for Bernoulli

Co-authored-by: Yeounoh Chung <[email protected]>

* Enable xla:gpu autocast for bfloat16 if not restricted (#5570) (#5591)

* Enable autocast for XLA:GPU

* linter fix

* XLA autocast test for GPU and TPU

* linter fix

* Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available.

* linter fix

* Add tests

* Support bf16

* linter fix

* exclude unsupported test cases

* increase GPU test timeout to 300

Co-authored-by: Yeounoh Chung <[email protected]>

* Cherry-pick: Don't trigger CI build on release tag push (#5595)

Copy of #5594 on release branch

* formatting

---------

Co-authored-by: JackCaoG <[email protected]>
Co-authored-by: Wonjoo Lee <[email protected]>
Co-authored-by: aws-kingrj <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: seanlatias <[email protected]>
Co-authored-by: Manfei <[email protected]>
Co-authored-by: Yeounoh Chung <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants