Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Dynamo (custom op) integration code #5805

Merged
merged 3 commits into from
Nov 27, 2023
Merged

Conversation

yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Nov 15, 2023

Refactor after #5712 cc @wonjoolee95 @JackCaoG

@yeounoh yeounoh marked this pull request as draft November 15, 2023 07:22
@yeounoh
Copy link
Contributor Author

yeounoh commented Nov 15, 2023

@JackCaoG I see this test failing, but other dynamo tests are still passing

======================================================================
FAIL: test_dynamo_input_sharding_threashold (__main__.DynamoSpmdInferenceTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_dynamo_spmd.py", line 174, in test_dynamo_input_sharding_threashold
    self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res))
AssertionError: False is not true

----------------------------------------------------------------------

not sure what's happening, this fails consistently with/without the dynamo custom op. cc @wonjoolee95

@yeounoh yeounoh self-assigned this Nov 15, 2023
@yeounoh yeounoh changed the title Refactor and clean SPMD+Dynamo integration code Refactor and clean Dynamo (custom op) integration code Nov 15, 2023
@yeounoh yeounoh changed the title Refactor and clean Dynamo (custom op) integration code Refactor Dynamo (custom op) integration code Nov 15, 2023
@JackCaoG
Copy link
Collaborator

Let me take a look later today

@yeounoh yeounoh force-pushed the refactor_spmd_dynamo branch from 105e4fe to b4e7a5a Compare November 15, 2023 19:33
@yeounoh
Copy link
Contributor Author

yeounoh commented Nov 15, 2023

Let me take a look later today

The test requires multiple devices & non-GPU. So it won't run on the CI -- needs to test locally.

@yeounoh yeounoh force-pushed the refactor_spmd_dynamo branch from d3a391f to 46a7eb8 Compare November 15, 2023 19:56
@JackCaoG
Copy link
Collaborator

OK.. I can tell you what's support to happen. XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD controls how many times dynamo will check if input sharding changes.

# if the input sharding was the same for skip_checking_input_sharding_threashold times
# we will skip checking the input sharding since it can be expensive.
if skip_checking_input_sharding_threashold > 0:
if torch_xla._XLAC._get_xla_sharding_specs(
args) != xla_args_sharding_spec:
# update the xla_args with the input with new sharding and retrace
xla_model.xla_args = args
(xla_args_sharding_spec, args_and_ou_copy, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
else:
skip_checking_input_sharding_threashold -= 1

After THRESHOLD is reached, the dynamo won't check about input sharding. If we changed the input sharding, we will try to execute a compiled program with input that has different sharding, so it will crash. The check in the test was to make sure crash actually happened. I think you can just step through the test, you should see some C++ exception long during try catch.

@yeounoh yeounoh force-pushed the refactor_spmd_dynamo branch 2 times, most recently from 1dfd000 to 4e506b7 Compare November 22, 2023 01:11
@yeounoh yeounoh force-pushed the refactor_spmd_dynamo branch from 4e506b7 to a6dac44 Compare November 22, 2023 01:16
@yeounoh yeounoh marked this pull request as ready for review November 22, 2023 01:16
@yeounoh
Copy link
Contributor Author

yeounoh commented Nov 22, 2023

OK.. I can tell you what's support to happen. XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD controls how many times dynamo will check if input sharding changes.

# if the input sharding was the same for skip_checking_input_sharding_threashold times
# we will skip checking the input sharding since it can be expensive.
if skip_checking_input_sharding_threashold > 0:
if torch_xla._XLAC._get_xla_sharding_specs(
args) != xla_args_sharding_spec:
# update the xla_args with the input with new sharding and retrace
xla_model.xla_args = args
(xla_args_sharding_spec, args_and_ou_copy, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
else:
skip_checking_input_sharding_threashold -= 1

After THRESHOLD is reached, the dynamo won't check about input sharding. If we changed the input sharding, we will try to execute a compiled program with input that has different sharding, so it will crash. The check in the test was to make sure crash actually happened. I think you can just step through the test, you should see some C++ exception long during try catch.

Synced offline with @JackCaoG he would help follow up. TLDR, it's recompiling, and we need to check if the threashold is being enforced.

# TODO(yeounoh) - this actually returns False, which means that the program was recompiled
    # with the new sharding change. We expect it to be True after a crash without
    # recompilation. Disabling the test until we debug.
    #self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res))
    ```

@yeounoh
Copy link
Contributor Author

yeounoh commented Nov 22, 2023

@JackCaoG @wonjoolee95 we can review & land this refacotring PR, I won't address the test issue here (also, the test is not being run in the CPU/GPU CI).

@yeounoh yeounoh force-pushed the refactor_spmd_dynamo branch 3 times, most recently from c998142 to 0664a30 Compare November 22, 2023 21:42
@yeounoh
Copy link
Contributor Author

yeounoh commented Nov 22, 2023

Ok, found another test regression -- test_mark_sharding_inside_compile works with torch nightly from 11/14/2023 but started failng with the latest (11/22/2023):

======================================================================
FAIL: test_mark_sharding_inside_compile (__main__.DynamoSpmdInferenceTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_dynamo_spmd.py", line 232, in test_mark_sharding_inside_compile
    dynamo_res = dynamo_linear(xla_x)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "test/spmd/test_dynamo_spmd.py", line 32, in forward
    xs.mark_sharding(
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/runtime.py", line 78, in wrapper
    if not using_pjrt():
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/runtime.py", line 82, in resume_in_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 499, in mark_sharding
    num_devices = xr.global_runtime_device_count()
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 499, in resume_in_mark_sharding
    num_devices = xr.global_runtime_device_count()
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 510, in resume_in_mark_sharding
    tile_assignment, group_assignment, replication_groups, sharding_type = _extract_op_sharding_specs(
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 413, in _extract_op_sharding_specs
    def _extract_op_sharding_specs(mesh: Mesh, partition_spec: Tuple):
  File "/usr/local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 4960, in forward
    return compiled_fn(full_args)
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2017, in g
    return f(*args)
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 3164, in runtime_wrapper
    all_outs = call_func_with_args(
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2041, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2145, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2017, in g
    return f(*args)
  File "/usr/local/lib/python3.8/site-packages/torch/_dynamo/backends/torchxla.py", line 49, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/core/dynamo_bridge.py", line 540, in extract_compiled_graph
    extract_internal(fused_module), node.args, None)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/core/dynamo_bridge.py", line 338, in extract_internal
    dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/core/dynamo_bridge.py", line 212, in extract_graph_helper
    assert all(
AssertionError: All tensors should be on xla

cc @JackCaoG @wonjoolee95

@yeounoh yeounoh force-pushed the refactor_spmd_dynamo branch 3 times, most recently from 7cbefd4 to acdd21b Compare November 22, 2023 23:29
@yeounoh yeounoh force-pushed the refactor_spmd_dynamo branch from acdd21b to 1009476 Compare November 23, 2023 00:39
@yeounoh
Copy link
Contributor Author

yeounoh commented Nov 27, 2023

Ok, found another test regression -- test_mark_sharding_inside_compile works with torch nightly from 11/14/2023 but started failng with the latest (11/22/2023):

======================================================================
FAIL: test_mark_sharding_inside_compile (__main__.DynamoSpmdInferenceTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_dynamo_spmd.py", line 232, in test_mark_sharding_inside_compile
    dynamo_res = dynamo_linear(xla_x)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "test/spmd/test_dynamo_spmd.py", line 32, in forward
    xs.mark_sharding(
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/runtime.py", line 78, in wrapper
    if not using_pjrt():
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/runtime.py", line 82, in resume_in_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 499, in mark_sharding
    num_devices = xr.global_runtime_device_count()
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 499, in resume_in_mark_sharding
    num_devices = xr.global_runtime_device_count()
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 510, in resume_in_mark_sharding
    tile_assignment, group_assignment, replication_groups, sharding_type = _extract_op_sharding_specs(
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 413, in _extract_op_sharding_specs
    def _extract_op_sharding_specs(mesh: Mesh, partition_spec: Tuple):
  File "/usr/local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 4960, in forward
    return compiled_fn(full_args)
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2017, in g
    return f(*args)
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 3164, in runtime_wrapper
    all_outs = call_func_with_args(
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2041, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2145, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/usr/local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2017, in g
    return f(*args)
  File "/usr/local/lib/python3.8/site-packages/torch/_dynamo/backends/torchxla.py", line 49, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/core/dynamo_bridge.py", line 540, in extract_compiled_graph
    extract_internal(fused_module), node.args, None)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/core/dynamo_bridge.py", line 338, in extract_internal
    dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git4e506b7-py3.8-linux-x86_64.egg/torch_xla/core/dynamo_bridge.py", line 212, in extract_graph_helper
    assert all(
AssertionError: All tensors should be on xla

cc @JackCaoG @wonjoolee95

This works now.

Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yeounoh yeounoh merged commit 3385bd6 into master Nov 27, 2023
18 checks passed
lsy323 pushed a commit to lsy323/xla that referenced this pull request Nov 28, 2023
* Refactor and clean SPMD+Dynamo integration code
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Refactor and clean SPMD+Dynamo integration code
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Refactor and clean SPMD+Dynamo integration code
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Refactor and clean SPMD+Dynamo integration code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants