Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[RUNTIME] TPU Backend Support #433

Open
1 of 2 tasks
merrymercy opened this issue May 7, 2022 · 11 comments
Open
1 of 2 tasks

[RUNTIME] TPU Backend Support #433

merrymercy opened this issue May 7, 2022 · 11 comments
Assignees
Labels
enhancement New feature

Comments

@merrymercy
Copy link
Member

merrymercy commented May 7, 2022

Background

Currently, Alpa only supports GPU. With XLA as Alpa's backend, it is easy to support TPU as well. We can call the auto-sharding pass to generate the sharding annotated HLO and then call the closed-source TPU compiler.

TODO

  • Support shard_parallel on TPU
  • Reproduce benchmark results on TPU
@OhadRubin
Copy link

What is the status of this issue?
It would be nice to run a larger models on multiple v3-8.

@merrymercy
Copy link
Member Author

@OhadRubin It should be supported very soon. See our tpu-support branch and the official tensorflow code

@merrymercy
Copy link
Member Author

#764

@zhisbug
Copy link
Member

zhisbug commented Nov 15, 2022

Is TPU+pipeshard still in our scope? @merrymercy @ZYHowell

@merrymercy
Copy link
Member Author

merrymercy commented Nov 17, 2022

No. I updated the second todo item as ”reproduce benchmark results on TPU”

@merrymercy merrymercy added the enhancement New feature label Dec 20, 2022
@jon-chuang
Copy link

Reproduce benchmark results on TPU

I guess free TPU is not good enough to run a proper benchmark...😅 wonder if Google would sponsor many TPUs for research purposes

@Lime-Cakes
Copy link

Pipeshard isn't going to be supported on TPU?

@ZYHowell
Copy link
Collaborator

ZYHowell commented May 6, 2023

We have no plan mainly because 1) the TPU backend of XLA is close sourced; 2) unlike NCCL for GPU, the TPU has no communication library exposed

@Lime-Cakes
Copy link

Lime-Cakes commented May 6, 2023

Thanks!

Is there going to be none-cuda jaxlib (the alpa fork)? And is there some guide for installing alpa for TPU? It seems all the guide is for CUDA machines.

Seems like the CUDA wheels worked fine

Edit: Actually had conflicts with tpulib. Might have done something wrong. Errors pops up while trying to run on tpu.

Do we have to build our own jaxlib with the enable_tpu flag on?

@ZYHowell
Copy link
Collaborator

ZYHowell commented May 9, 2023

Yes I think you need to compile it(we can't do so mainly because this needs a TPU backend with TPU lib to compile). But in case we only supports shard parallel in the TPU side, and the part is already in the upstream official jax/jaxlib support, I'd suppose you to try the pjit.auto

@Lime-Cakes
Copy link

It's upstream into jax? Wow, I didn't know that. Thank you so much!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature
Projects
None yet
Development

No branches or pull requests

6 participants