Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Support Dynamo #6477

Merged
merged 3 commits into from
Mar 6, 2024
Merged

[Pallas] Support Dynamo #6477

merged 3 commits into from
Mar 6, 2024

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request enables dynamo support for custom tpu calls, e.g. ones written in Pallas.

Test Plan:
PJRT_DEVICE=TPU XLA_DISABLE_FUNCTIONALIZATION=1 python test/test_operations.py -v -k test_tpu_custom_call_pallas_add_one_dynamo

@alanwaketan alanwaketan self-assigned this Feb 6, 2024
@alanwaketan
Copy link
Collaborator Author

@bdhirsh Hi Brian, I have difficulties on registering custom ops with functionalization enabled. Here is the error log, do you have any insights? Maybe the aten schema should looks something different?

root@t1v-n-f0938a8f-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call_pallas_add_one_dynamo
/workspaces/work/transformers_pt/src/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/ptxla/.local/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'libc10_cuda.so: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
test_tpu_custom_call_pallas_add_one_dynamo (__main__.TestAtenXlaTensor) ... WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707181259.319496  950358 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/ptxla/.local/lib/python3.8/site-packages/libtpu/libtpu.so
I0000 00:00:1707181259.319562  950358 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1707181259.319572  950358 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
ERROR

======================================================================
ERROR: test_tpu_custom_call_pallas_add_one_dynamo (__main__.TestAtenXlaTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_operations.py", line 1943, in test_tpu_custom_call_pallas_add_one_dynamo
    compiled_add_one_pallas(output, [x], payload)
  File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "test/test_operations.py", line 1939, in add_one_pallas
    def add_one_pallas(output, inputs, payload):
  File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_functorch/aot_autograd.py", line 903, in forward
    return compiled_fn(full_args)
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 95, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/workspaces/work/pytorch/torch/_dynamo/backends/torchxla.py", line 49, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/workspaces/work/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 543, in extract_compiled_graph
    collector.run(*xla_args)
  File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 144, in run
    self.env[node] = self.run_node(node)
  File "/workspaces/work/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 431, in run_node
    result = super().run_node(n)
  File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 201, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 273, in call_function
    return target(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 62, in __call__
    return super().__call__(op, mutated_args_names, kwargs)
  File "/workspaces/work/pytorch/torch/_ops.py", line 364, in __call__
    return wrapper()
  File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_ops.py", line 360, in wrapper
    return self.dispatch(
  File "/workspaces/work/pytorch/torch/_ops.py", line 334, in dispatch
    raise NotImplementedError(
NotImplementedError: could not find kernel for HigherOrderOperator auto_functionalized at dispatch key DispatchKey.Functionalize (resolved from DispatchKey.Functionalize)

While executing %auto_functionalized : [num_users=1] = call_function[target=torch._higher_order_ops.auto_functionalize.auto_functionalized](args = (xla.tpu_custom_call_.default, [output], {output: %arg0_1, inputs: [%arg1_1], payload: {"custom_call_config": {"body": "TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==", "needs_layout_passes": true}}}), kwargs = {})
Original traceback:
  File "test/test_operations.py", line 1940, in add_one_pallas
    torch.ops.xla.tpu_custom_call_(output, inputs, payload)


----------------------------------------------------------------------
Ran 1 test in 3.185s

FAILED (errors=1)

@bdhirsh
Copy link
Collaborator

bdhirsh commented Feb 6, 2024

@alanwaketan - you have a custom op that mutates some of its inputs, and recently @zou3519 added an "auto-functionalize" higher-order-op that tries to automatically functionalize mutable custom ops.

I'm not sure what's causing that error. Although if you're worried about trace-time, you might be a bit better off with a hand-written C++ functionalization kernel (similar to the cod-generated ones we have for ATen).

You can find some examples to base it off of if you build pytorch locally, and inspect some of the kernels in build/aten/src/ATen/RegisterFunctionalizeEverything.cpp

@alanwaketan
Copy link
Collaborator Author

@alanwaketan - you have a custom op that mutates some of its inputs, and recently @zou3519 added an "auto-functionalize" higher-order-op that tries to automatically functionalize mutable custom ops.

I'm not sure what's causing that error. Although if you're worried about trace-time, you might be a bit better off with a hand-written C++ functionalization kernel (similar to the cod-generated ones we have for ATen).

You can find some examples to base it off of if you build pytorch locally, and inspect some of the kernels in build/aten/src/ATen/RegisterFunctionalizeEverything.cpp

Thanks, Brian. Will looks into this. On the other hand, I guess I can also change the semantics of my custom op to not in-place. Then all the problems should go away?

@alanwaketan alanwaketan force-pushed the alanwaketan/pallas_dynamo branch from bc6b14c to ebccfaa Compare March 6, 2024 01:02
@alanwaketan
Copy link
Collaborator Author

I will land this as it is and do a follow up to make the tpu_custom_call_ functional.

Thanks @qihqi for the approval.

@alanwaketan alanwaketan merged commit ce8ee38 into master Mar 6, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants