Skip to content

Commit

Permalink
make sure GPU works (#130)
Browse files Browse the repository at this point in the history
* make sure GPU works
  • Loading branch information
qihqi authored Jun 19, 2024
1 parent aa90b05 commit fa1f120
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 6 deletions.
2 changes: 1 addition & 1 deletion deps/xla
Submodule xla updated 52 files
+9 −7 .github/workflows/_build_plugin.yml
+11 −10 .github/workflows/_build_torch_with_cuda.yml
+10 −19 .github/workflows/_build_torch_xla.yml
+1 −1 .github/workflows/_docs.yml
+0 −25 .github/workflows/_get_torch_commit.yml
+13 −33 .github/workflows/_test.yml
+12 −30 .github/workflows/_test_requiring_torch_cuda.yml
+8 −10 .github/workflows/_tpu_ci.yml
+8 −2 .github/workflows/build_and_test.yml
+86 −0 .github/workflows/setup/action.yml
+16 −2 README.md
+6 −2 benchmarks/benchmark_model.py
+19 −6 benchmarks/torchbench_model.py
+37 −79 docs/fori_loop.md
+82 −0 docs/plugins.md
+1 −1 docs/source/index.rst
+0 −554 docs/spmd.md
+150 −0 docs/spmd_advanced.md
+83 −0 docs/spmd_basic.md
+125 −0 docs/spmd_distributed_checkpoint.md
+3 −2 examples/decoder_only_model.py
+157 −0 experimental/torch_xla2/docs/support_a_new_model.md
+7 −0 experimental/torch_xla2/examples/eager_mode.py
+49 −0 experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py
+0 −9 experimental/torch_xla2/test/test_ops.py
+19 −8 experimental/torch_xla2/torch_xla2/__init__.py
+1 −0 experimental/torch_xla2/torch_xla2/config.py
+49 −0 experimental/torch_xla2/torch_xla2/ops/jaten.py
+2 −2 experimental/torch_xla2/torch_xla2/ops/jtorch.py
+25 −11 experimental/torch_xla2/torch_xla2/tensor.py
+1 −0 infra/ansible/config/apt.yaml
+1 −1 infra/ansible/config/env.yaml
+3 −0 plugins/cpu/README.md
+2 −2 plugins/cpu/pyproject.toml
+3 −0 plugins/cuda/README.md
+9 −10 test/debug_tool/test_pt_xla_debug.py
+35 −5 test/dynamo/test_dynamo.py
+3 −2 test/run_tests.sh
+1 −48 test/spmd/test_dynamo_spmd.py
+2 −2 test/spmd/test_sharding_strategies.py
+6 −4 test/spmd/test_xla_sharding.py
+0 −106 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
+16 −17 test/test_metrics.py
+116 −0 test/test_while_loop.py
+43 −0 test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py
+1 −1 test/tpu/run_tests.sh
+2 −0 torch_xla/_internal/tpu.py
+56 −7 torch_xla/core/xla_model.py
+20 −15 torch_xla/csrc/init_python_bindings.cpp
+3 −0 torch_xla/csrc/runtime/pjrt_registry.cc
+7 −14 torch_xla/distributed/spmd/xla_sharding.py
+112 −45 torch_xla/experimental/fori_loop.py
3 changes: 2 additions & 1 deletion install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ git submodule update --init --recursive
pip show google-jetstream && pip uninstall -y google-jetstream
pip show torch_xla2 && pip uninstall -y torch_xla2
pip install -e .
pip install -U jax[tpu]==0.4.29 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -U jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
6 changes: 3 additions & 3 deletions install_everything_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ pip show tensorboard && pip uninstall -y tensorboard
pip show tensorflow-text && pip uninstall -y tensorflow-text
pip show torch_xla2 && pip uninstall -y torch_xla2

pip install flax==0.8.3
pip install -U "jax[cuda12]==0.4.28"
pip install flax==0.8.4
pip install tensorflow-text
pip install tensorflow

pip install ray[default]==2.22.0
# torch cpu
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu
pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
pip install safetensors colorama coverage humanize

git submodule update --init --recursive
pip show google-jetstream && pip uninstall -y google-jetstream
pip show torch_xla2 && pip uninstall -y torch_xla2
pip install -e .
pip install -U jax[cuda12]==0.4.30
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
3 changes: 2 additions & 1 deletion run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import os
from typing import Sequence

# import torch_xla2 first!
import torch_xla2 # pylint: disable
import jax
import jetstream_pt
from absl import app, flags
from jetstream.core import server_lib
from jetstream.core.config_lib import ServerConfig, MetricsServerConfig
Expand Down
24 changes: 24 additions & 0 deletions tests/test_model_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from jetstream_pt.third_party.llama import model_original
from jetstream_pt.third_party.gemma import model_original as gemma_orig
from jetstream_pt.third_party.gemma import model as gemma
from jetstream_pt.third_party.mixtral import model as mixtral
from jetstream_pt.third_party.mixtral import config as mixtral_config
from jetstream_pt import torchjax
from jetstream_pt import layers
from jetstream_pt import cache_manager
Expand Down Expand Up @@ -360,6 +362,28 @@ def test_transformer(self):
print("Transformer: Diff norm", (result_torch - expected_out).norm())
self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4))

def test_mixtral_moe(self):
config = mixtral_config.ModelArgs()
config.intermediate_size = 16
config.dim = 16
m = mixtral.ConditionalFeedForward(config)
# random init
states = m.state_dict()
for k, v in states.items():
states[k].normal_()
m.load_state_dict(states, assign=True)

seqlen = 3
num_expert = 8
num_active_expert = 2
x = torch.randn(seqlen, config.dim)
exp_index = torch.randint(0, num_expert, (seqlen, num_active_expert))

res1 = m.forward_for_short_seq_len(x, exp_index)
res2 = m.forward_for_long_seq_len(x, exp_index)

torch.testing.assert_close(res1, res2)


if __name__ == "__main__":
unittest.main()

0 comments on commit fa1f120

Please sign in to comment.