-
Notifications
You must be signed in to change notification settings - Fork 360
[RUNTIME] TPU Backend Support #433
Comments
What is the status of this issue? |
@OhadRubin It should be supported very soon. See our tpu-support branch and the official tensorflow code |
Is TPU+pipeshard still in our scope? @merrymercy @ZYHowell |
No. I updated the second todo item as ”reproduce benchmark results on TPU” |
I guess free TPU is not good enough to run a proper benchmark...😅 wonder if Google would sponsor many TPUs for research purposes |
Pipeshard isn't going to be supported on TPU? |
We have no plan mainly because 1) the TPU backend of XLA is close sourced; 2) unlike NCCL for GPU, the TPU has no communication library exposed |
Thanks!
Edit: Actually had conflicts with tpulib. Might have done something wrong. Errors pops up while trying to run on tpu. Do we have to build our own jaxlib with the enable_tpu flag on? |
Yes I think you need to compile it(we can't do so mainly because this needs a TPU backend with TPU lib to compile). But in case we only supports shard parallel in the TPU side, and the part is already in the upstream official jax/jaxlib support, I'd suppose you to try the |
It's upstream into jax? Wow, I didn't know that. Thank you so much! |
Background
Currently, Alpa only supports GPU. With XLA as Alpa's backend, it is easy to support TPU as well. We can call the auto-sharding pass to generate the sharding annotated HLO and then call the closed-source TPU compiler.
TODO
The text was updated successfully, but these errors were encountered: