-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Triton #6798
Conversation
test/test_operations.py
Outdated
# y = tl.load(y_ptr + offsets, mask=mask) | ||
# output = x + y | ||
# tl.store(output_ptr + offsets, output, mask=mask) | ||
payload = b'x\x9c\xcdW]o\xdb6\x14E\x8b=\xcc|\n\xb6\x87\xf5\x91\xebf \xce\\\x99\xa4>,\xb9iP\x0c\x18\xf6\xb0\xae\xd8CQ\x0c\x18\nA\xb6\x98D\xa8,i\x12\x9d%3\xfc?\xf6#\xf6\xbe_\xb1\xff\xb4KR\x92E\xc5v\xd2`\x05\n\x18\xb2uy\xee9\xf7\x1e\x92\x12\x8d\xfe9A\x7f\x9d\xa0/\xa38\x0e\xdf\xf32\xe3iHb\x1a3n\xc7G\x9f=\xfd\xf7\xab\xc9\x04M&\xf8G\x9e\xf12\x12<\xc6\xf3\x1b\xfc\xea\xd5\xdb\x9f\xf1\xeb\xb7\xbf\xbc\xf9\x15\x7f\x1f-\xde?\xfb!\x8b\x01\x83\x90u\xc5\xcb*\xc93\xec[\x0cY"*/\xb8\xc0\xd52\x9c\xba\xc8\x02\xfa\x92WUX%\x7fr\xec9\x08\r\x80\xd5\xbaH\xf3y:\xd8!-\xc9\x92*\x99\xa7\x1c[<\x13\xe5\r\xde\x01:F\x03\xab\x88\xcah\x89\xad\x95\xe7\xecB\x84j8$\xe3\xfb"i\x17i\xb3\x03HvoN\x1b\x8d\x90\xb5\x8c\xae3\x91\xc4\x982\x7f\x8c)|\xd0\x1a\xf2K~\x81\xad\xa2\x04[\x07\xc3\xe2\xd4={\xde\xc4\xe6\xa0=\x18\x96\xa7\x94vb\xa0\x02\xb1\xf8\xd4W\xb14_\x0c(\xf60A\xdf\xbe\n\xc3\xf3U\xb6\x08\xe7\xfc"\xc9\xc8\xcc\x18D\x834\xd6u\xaa2%\x813\xc6\xbf\x1d0\xeb\xdd\xf3\x1d9\xee\xc1\x1c\n9\xb2\n\xb1,:\xf2\xd4\xc3v\x80\x06\xcb\xfcJy\t4\xc02\x04\x17\xack\x90\x88\xb2X\xb5)\xe3\x1e\xc4\xe5\xe0\xd4\x90\xd69\xd3\x83\xcalg\xb5\xde\xc1\x1c\xfb\xdd\xd6?\xea`\xc6\xb65\x0eK\x98\x9b\xe1BDu\x91\r\xcaU\xa8\xea2mf\xc6\x1fk\xac\xfd\xbc\xdb.\xf3\xd0 /\xdb\xb6\x02\tRH\xaf\x03\x9bb\n\xaeT\\\x14V*\xacJ\xf1\x15RV\xe3\xa7\x1d\xa8\xaf\x8b[\xa5\xd6\x1fI\xcckl\x19Ok\xb0#m\x8cc\xab\xaa\xdbV$rz%\xc6\xa0\xa1N\xb7G6\xc6\xe4\x9a\x00\xe0%\x08cpO\xee\xc3H\xf7\xb6\x96\xe3x\x03\x06J\x16\x8a\xbf\xc3\x04w\xfd\nTI]U\xa6\xf4\xdc[\xaaAO\xd5\xbeC\xd5nUY_\x95\x11\x98\x82ZU{\xa0\x9a\xb4\xe5\x85u`\xf0q\xcc\xe2\x14$\xf6\xfa\xc5\x01\xd2\xb6\xcd5HI\xbd\x08)\x93E\xaa\xf9\xe1\xbf7\xf3\xa3\xf4$\x84\xd4KW\xedZ9"\x15\xd4pAuo6\xaeD\xb77\xd5\x92\xad[\x1a\xabN\x1d\xbc1\x8bF\x83\x92\x8bv\x0b\xd1\xd9vK\xf3,\x86\x1d\x856\x00?OR\x0e\xf8\xa7\x93\xcb|\xc9\'\xf3yt\x99N\x04\x07\xad\xe2\xe6)\x0cW|!\xe0\xe1;\xb0b>_]\x84\xd1|^\xf2+4X#k\x0e\xf3\xaf\xaf\xd3\xce\x8d\xado|\x1d\n\xd4\x97\xabG\xba\x03\x9e\xfa\xd2Wf\xa4\xf8D\x0f\xe91\xcanKP\xbfsC:W\x8du\xbcNh\x7f"\xbc-\xd4M\x8d\xb3\xddn\xd4\xbfU\xb1[\xe7\xead7\xe8\xdeyv\xb7Vb\\\x07\x9b\xdb&&\xd9y\xae-\x84y\xa4\x8cv\x8a\':h\xba\xddq\xa7\x96\xad\xcd\xa1M\x0fn/\xde\xe0n\xbbCL$i\x90&\x83\xd3\x10\xd5\xf63\xba\xa3\xba4\xc9\xb8F\xd7&\x13\xc7\x94\'\x81!R\xc3\x02\xdf\xf8\xea\xe5\x12\xbf[&R\xaf\xa7\xde\x9b\xa8\x17\x94k\xb9m\xf0\xfe\xe8\xbaF\xd73\xcb \xc6wP\xdbB\xa6\xc8\xb4\xab\xe9\x94\x98\xf1\xba\xfa:\xcd\xf1\r2\'0n]3\xd7\xa5\xc6(\xf9\x94\xcb\xa2\x9d\xdd\xdb\xac\x8c]\x0b\xbdX\xcd\xb3h\xc9\xabv\xb1\xcbi\x80\xe0k\x19TS\xf1\xcc\x88Tp\xba\x13\xfa\xe4\xd1\x8b\xcd\xf6\xee\x11\xb5\x99\xea\x9d\xe4\xea\x1f\x1e\xfbt\xcd\x83\xf2z\r\xeag\xf1\x1e\xfb\xc4Mq\xdb\xbe72h\xd8\xa7#\xa6}\xdd\xd8\x07\xd9\xd7c\xd8_\x9f|\xd3\xaca\xe0\xe4\xef/\xbe\x81\xdf\xf8\x05\x86\xeb\xf1\xce7\xc9\xcc\x9b\x91\x11Z\xe6\xf1\n\xce\xdek\x84\xb1\x10\x96\xdc\x8f\x18d\xd2d\x81_\xee:\x88\x0f\xe1\xacOf\xf8k\xc0\x16\xa2<Ml8\x13\xd03\xbc\x86\xfb8Q\x07\xf9$M\xc4\r\xe8\xc2!i\x86a|sG\x05\xf0>\x05N\xfa\x118\x99\xc2*\x1e8\x94\x87=.\xffC\xa8\xec\xff\xb5\xbc\x11\x8e\x84(\x93\xf9\n\xa2x\x9d\xe5I&\x9f\xdb@p\x1e\xa5\x15\xdf\xa8\xc9\xc0xH \x02"\xf0\xdf*,\xca\xfcB\x1eh\xe1\x7f\xc5\xb5\xd6P\x12r\x8a\xe9H\xa3\x17~(\xc3/pT&\xe2\xd2Z\xe4\x19\xac\xb3L4}\xb6xV\xe3i\x0b\x85Sg\x02j\xe3\x96cw\x02\xd3\xe5,\xa3\xf7<,\xa3\xec\x02\x16\r\xac\xc3\xad\x93c\xac\x166\x04Hc\xc3\x0c\x0b\x9eUyy\xea_\xc3\xfdYKi\xd7\x94\xb6\xa6\xac\x8a4\x12\xb2\xa2\x19>\x06\xdc\x08?;\xdb\x93\xe8\xd4\x89N[<,R(^\x1e\xcd\xd8^\xb9&\xcb5\xe4\xe4\n\xb9[\xd1\xads\xbd\xad\xb3\xcb"\xc1U*@R\x1e\x06\xdd\xbd\xb2M\xea\xb4/+\xfd9\xee\xad\'\xb3\x84\xde\xe0\x96\xd3\xab9}\xcd\t\xdd\x03\x0c$d1\xddB\xfa\x04\xe3=56|\x81\xe6K\xf3(\x06\xf2\xb1\xecw\xbd\x88\x16\x97rU\xd2f~\xf9U\xb2\x10\xdd@R\xbd\xcd\xa1)8\xaen\x17\xef>7\xa6\xcd\xba#};\xe8\xc3\xed\xf0\x1bRj\xfa\xa1\x8e\xf8\x0f2\xa4ed]G(\xfd\x18\x96\x04\x8d\x96\xdd[\xce\xf2\xef!\xdd\xbf\x9e)i\x12\x9d\xbe\x97\xf6\xc3\xbd\xa4\xcd\x93\x84\xba=3\x9d\x87\x9a\xd9P\xca\x12E^r\xc9-[\xb3\xef\xe7\xe6~\xe7(k\x89\xe1\xff\xd4\xaa\xcc\xb6#\xea\xd9\xb2i\xefG\xa8\xfb[!\x0e\xbe\x14\xa93cL\x03\xd9a\xa0\xdb\x02\xed\xc3@of\x07\x1a\xe8\xdc\x05d\x9e\x06\xba\x87\x81\xd3\x19\xad\x19\xbd\xc3@\xbf\xadqz\x17\x90:\x1a\xe8\x1f\x06\x06-cp\x17\xb0aT;~?\x92\x11\xf0\xb2F\x1e\x9e\x1bFg\xac\xe1<<9\x80\xb4\xed\x1ayxv$\xe7\x08\x91\x9f\x8e\x1e=y\x84\x1f\x9d|\x86\x1e\x1f\x1d\xd5\xd7\xc7O>\xd7\xbf\xfe\x03\x97\x03\xa4b' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expect we have more than one test kernel going fwd - wdyt we put all kernel payloads for Triton (and Pallas) in a separate yaml file?
cc @jiawenliu64
The GPU CI test failed for the CUDA version issue. I'll try if the CI tests pass with a newer CUDA version. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @bhavya01. It has been a long journey!
.circleci/triton.sh
Outdated
@@ -0,0 +1,33 @@ | |||
#!/bin/bash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment at the top to suggest why. @will-cromar can you review this part? Appreciate it.
.github/workflows/triton.yml
Outdated
@@ -0,0 +1,99 @@ | |||
on: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as the triton.sh. @will-cromar
Add Triton kernel support to PyTorch/XLA
The torch_triton.py is adapted from https://github.com/jax-ml/jax-triton/blob/main/jax_triton/triton_lib.py