Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton #6798

Merged
merged 67 commits into from
Jun 7, 2024
Merged

Triton #6798

merged 67 commits into from
Jun 7, 2024

Conversation

bhavya01
Copy link
Collaborator

@bhavya01 bhavya01 commented Mar 21, 2024

Add Triton kernel support to PyTorch/XLA

The torch_triton.py is adapted from https://github.com/jax-ml/jax-triton/blob/main/jax_triton/triton_lib.py

# y = tl.load(y_ptr + offsets, mask=mask)
# output = x + y
# tl.store(output_ptr + offsets, output, mask=mask)
payload = b'x\x9c\xcdW]o\xdb6\x14E\x8b=\xcc|\n\xb6\x87\xf5\x91\xebf \xce\\\x99\xa4>,\xb9iP\x0c\x18\xf6\xb0\xae\xd8CQ\x0c\x18\nA\xb6\x98D\xa8,i\x12\x9d%3\xfc?\xf6#\xf6\xbe_\xb1\xff\xb4KR\x92E\xc5v\xd2`\x05\n\x18\xb2uy\xee9\xf7\x1e\x92\x12\x8d\xfe9A\x7f\x9d\xa0/\xa38\x0e\xdf\xf32\xe3iHb\x1a3n\xc7G\x9f=\xfd\xf7\xab\xc9\x04M&\xf8G\x9e\xf12\x12<\xc6\xf3\x1b\xfc\xea\xd5\xdb\x9f\xf1\xeb\xb7\xbf\xbc\xf9\x15\x7f\x1f-\xde?\xfb!\x8b\x01\x83\x90u\xc5\xcb*\xc93\xec[\x0cY"*/\xb8\xc0\xd52\x9c\xba\xc8\x02\xfa\x92WUX%\x7fr\xec9\x08\r\x80\xd5\xbaH\xf3y:\xd8!-\xc9\x92*\x99\xa7\x1c[<\x13\xe5\r\xde\x01:F\x03\xab\x88\xcah\x89\xad\x95\xe7\xecB\x84j8$\xe3\xfb"i\x17i\xb3\x03HvoN\x1b\x8d\x90\xb5\x8c\xae3\x91\xc4\x982\x7f\x8c)|\xd0\x1a\xf2K~\x81\xad\xa2\x04[\x07\xc3\xe2\xd4={\xde\xc4\xe6\xa0=\x18\x96\xa7\x94vb\xa0\x02\xb1\xf8\xd4W\xb14_\x0c(\xf60A\xdf\xbe\n\xc3\xf3U\xb6\x08\xe7\xfc"\xc9\xc8\xcc\x18D\x834\xd6u\xaa2%\x813\xc6\xbf\x1d0\xeb\xdd\xf3\x1d9\xee\xc1\x1c\n9\xb2\n\xb1,:\xf2\xd4\xc3v\x80\x06\xcb\xfcJy\t4\xc02\x04\x17\xack\x90\x88\xb2X\xb5)\xe3\x1e\xc4\xe5\xe0\xd4\x90\xd69\xd3\x83\xcalg\xb5\xde\xc1\x1c\xfb\xdd\xd6?\xea`\xc6\xb65\x0eK\x98\x9b\xe1BDu\x91\r\xcaU\xa8\xea2mf\xc6\x1fk\xac\xfd\xbc\xdb.\xf3\xd0 /\xdb\xb6\x02\tRH\xaf\x03\x9bb\n\xaeT\\\x14V*\xacJ\xf1\x15RV\xe3\xa7\x1d\xa8\xaf\x8b[\xa5\xd6\x1fI\xcckl\x19Ok\xb0#m\x8cc\xab\xaa\xdbV$rz%\xc6\xa0\xa1N\xb7G6\xc6\xe4\x9a\x00\xe0%\x08cpO\xee\xc3H\xf7\xb6\x96\xe3x\x03\x06J\x16\x8a\xbf\xc3\x04w\xfd\nTI]U\xa6\xf4\xdc[\xaaAO\xd5\xbeC\xd5nUY_\x95\x11\x98\x82ZU{\xa0\x9a\xb4\xe5\x85u`\xf0q\xcc\xe2\x14$\xf6\xfa\xc5\x01\xd2\xb6\xcd5HI\xbd\x08)\x93E\xaa\xf9\xe1\xbf7\xf3\xa3\xf4$\x84\xd4KW\xedZ9"\x15\xd4pAuo6\xaeD\xb77\xd5\x92\xad[\x1a\xabN\x1d\xbc1\x8bF\x83\x92\x8bv\x0b\xd1\xd9vK\xf3,\x86\x1d\x856\x00?OR\x0e\xf8\xa7\x93\xcb|\xc9\'\xf3yt\x99N\x04\x07\xad\xe2\xe6)\x0cW|!\xe0\xe1;\xb0b>_]\x84\xd1|^\xf2+4X#k\x0e\xf3\xaf\xaf\xd3\xce\x8d\xado|\x1d\n\xd4\x97\xabG\xba\x03\x9e\xfa\xd2Wf\xa4\xf8D\x0f\xe91\xcanKP\xbfsC:W\x8du\xbcNh\x7f"\xbc-\xd4M\x8d\xb3\xddn\xd4\xbfU\xb1[\xe7\xead7\xe8\xdeyv\xb7Vb\\\x07\x9b\xdb&&\xd9y\xae-\x84y\xa4\x8cv\x8a\':h\xba\xddq\xa7\x96\xad\xcd\xa1M\x0fn/\xde\xe0n\xbbCL$i\x90&\x83\xd3\x10\xd5\xf63\xba\xa3\xba4\xc9\xb8F\xd7&\x13\xc7\x94\'\x81!R\xc3\x02\xdf\xf8\xea\xe5\x12\xbf[&R\xaf\xa7\xde\x9b\xa8\x17\x94k\xb9m\xf0\xfe\xe8\xbaF\xd73\xcb \xc6wP\xdbB\xa6\xc8\xb4\xab\xe9\x94\x98\xf1\xba\xfa:\xcd\xf1\r2\'0n]3\xd7\xa5\xc6(\xf9\x94\xcb\xa2\x9d\xdd\xdb\xac\x8c]\x0b\xbdX\xcd\xb3h\xc9\xabv\xb1\xcbi\x80\xe0k\x19TS\xf1\xcc\x88Tp\xba\x13\xfa\xe4\xd1\x8b\xcd\xf6\xee\x11\xb5\x99\xea\x9d\xe4\xea\x1f\x1e\xfbt\xcd\x83\xf2z\r\xeag\xf1\x1e\xfb\xc4Mq\xdb\xbe72h\xd8\xa7#\xa6}\xdd\xd8\x07\xd9\xd7c\xd8_\x9f|\xd3\xaca\xe0\xe4\xef/\xbe\x81\xdf\xf8\x05\x86\xeb\xf1\xce7\xc9\xcc\x9b\x91\x11Z\xe6\xf1\n\xce\xdek\x84\xb1\x10\x96\xdc\x8f\x18d\xd2d\x81_\xee:\x88\x0f\xe1\xacOf\xf8k\xc0\x16\xa2<Ml8\x13\xd03\xbc\x86\xfb8Q\x07\xf9$M\xc4\r\xe8\xc2!i\x86a|sG\x05\xf0>\x05N\xfa\x118\x99\xc2*\x1e8\x94\x87=.\xffC\xa8\xec\xff\xb5\xbc\x11\x8e\x84(\x93\xf9\n\xa2x\x9d\xe5I&\x9f\xdb@p\x1e\xa5\x15\xdf\xa8\xc9\xc0xH \x02"\xf0\xdf*,\xca\xfcB\x1eh\xe1\x7f\xc5\xb5\xd6P\x12r\x8a\xe9H\xa3\x17~(\xc3/pT&\xe2\xd2Z\xe4\x19\xac\xb3L4}\xb6xV\xe3i\x0b\x85Sg\x02j\xe3\x96cw\x02\xd3\xe5,\xa3\xf7<,\xa3\xec\x02\x16\r\xac\xc3\xad\x93c\xac\x166\x04Hc\xc3\x0c\x0b\x9eUyy\xea_\xc3\xfdYKi\xd7\x94\xb6\xa6\xac\x8a4\x12\xb2\xa2\x19>\x06\xdc\x08?;\xdb\x93\xe8\xd4\x89N[<,R(^\x1e\xcd\xd8^\xb9&\xcb5\xe4\xe4\n\xb9[\xd1\xads\xbd\xad\xb3\xcb"\xc1U*@R\x1e\x06\xdd\xbd\xb2M\xea\xb4/+\xfd9\xee\xad\'\xb3\x84\xde\xe0\x96\xd3\xab9}\xcd\t\xdd\x03\x0c$d1\xddB\xfa\x04\xe3=56|\x81\xe6K\xf3(\x06\xf2\xb1\xecw\xbd\x88\x16\x97rU\xd2f~\xf9U\xb2\x10\xdd@R\xbd\xcd\xa1)8\xaen\x17\xef>7\xa6\xcd\xba#};\xe8\xc3\xed\xf0\x1bRj\xfa\xa1\x8e\xf8\x0f2\xa4ed]G(\xfd\x18\x96\x04\x8d\x96\xdd[\xce\xf2\xef!\xdd\xbf\x9e)i\x12\x9d\xbe\x97\xf6\xc3\xbd\xa4\xcd\x93\x84\xba=3\x9d\x87\x9a\xd9P\xca\x12E^r\xc9-[\xb3\xef\xe7\xe6~\xe7(k\x89\xe1\xff\xd4\xaa\xcc\xb6#\xea\xd9\xb2i\xefG\xa8\xfb[!\x0e\xbe\x14\xa93cL\x03\xd9a\xa0\xdb\x02\xed\xc3@of\x07\x1a\xe8\xdc\x05d\x9e\x06\xba\x87\x81\xd3\x19\xad\x19\xbd\xc3@\xbf\xadqz\x17\x90:\x1a\xe8\x1f\x06\x06-cp\x17\xb0aT;~?\x92\x11\xf0\xb2F\x1e\x9e\x1bFg\xac\xe1<<9\x80\xb4\xed\x1ayxv$\xe7\x08\x91\x9f\x8e\x1e=y\x84\x1f\x9d|\x86\x1e\x1f\x1d\xd5\xd7\xc7O>\xd7\xbf\xfe\x03\x97\x03\xa4b'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect we have more than one test kernel going fwd - wdyt we put all kernel payloads for Triton (and Pallas) in a separate yaml file?

cc @jiawenliu64

@bhavya01
Copy link
Collaborator Author

The GPU CI test failed for the CUDA version issue. I'll try if the CI tests pass with a newer CUDA version.

@alanwaketan alanwaketan self-requested a review March 28, 2024 18:34
@bhavya01 bhavya01 added the DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing label Mar 29, 2024
@bhavya01 bhavya01 requested a review from alanwaketan May 15, 2024 00:58
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @bhavya01. It has been a long journey!

@@ -0,0 +1,33 @@
#!/bin/bash
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment at the top to suggest why. @will-cromar can you review this part? Appreciate it.

@@ -0,0 +1,99 @@
on:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as the triton.sh. @will-cromar

torch_xla/csrc/init_python_bindings.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/ops/gpu_custom_call.h Outdated Show resolved Hide resolved
torch_xla/csrc/xla_lower_util.cpp Outdated Show resolved Hide resolved
torch_xla/experimental/torch_triton.py Outdated Show resolved Hide resolved
torch_xla/experimental/torch_triton.py Outdated Show resolved Hide resolved
test/test_triton.py Outdated Show resolved Hide resolved
@bhavya01 bhavya01 requested a review from will-cromar May 20, 2024 20:27
@bhavya01 bhavya01 requested a review from vanbasten23 June 5, 2024 21:42
test/test_operations.py Outdated Show resolved Hide resolved
@bhavya01 bhavya01 merged commit 4a30ea7 into master Jun 7, 2024
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants