diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 121009e2..551e5a0f 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -55,7 +55,7 @@ jobs: pylint --indent-string=' ' jetstream_pt/ benchmarks/ - name: Format check with pyink run: | - pyink --pyink-indentation 2 --line-length 80 --check --verbose . + pyink --pyink-indentation 2 --line-length 80 --check --verbose --extend-exclude=deps . cpu: name: "jetstream_pt unit tests" @@ -79,4 +79,28 @@ jobs: JAX_PLATFORMS=cpu coverage run -m unittest -v - name: Create test coverage report run: | - coverage report -m \ No newline at end of file + coverage report -m + + interactive: + name: "jetstream_pt run interactive" + strategy: + matrix: + os: [ubuntu-20.04] + python-version: ['3.10'] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + source install_everything.sh + - name: Run interactive (bf16) + run: | + JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0 + - name: Run interactive (int8) + run: | + JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_kv_cache=1 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 71d7399e..d3c23fe0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# source dependencies -deps/ - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..e7cfe4dd --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "deps/JetStream"] + path = deps/JetStream + url = https://github.com/google/JetStream.git +[submodule "deps/xla"] + path = deps/xla + url = https://github.com/pytorch/xla.git diff --git a/deps/JetStream b/deps/JetStream new file mode 160000 index 00000000..8128c8a5 --- /dev/null +++ b/deps/JetStream @@ -0,0 +1 @@ +Subproject commit 8128c8a59f859e7691726180b7953e92c6ad4b2b diff --git a/deps/xla b/deps/xla new file mode 160000 index 00000000..f26c35c2 --- /dev/null +++ b/deps/xla @@ -0,0 +1 @@ +Subproject commit f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 diff --git a/install_everything.sh b/install_everything.sh index aa59011c..e81cd16f 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024 -JETSTREAM_TAG=e4952fbb12e0ab3c33bc7c1eef3839b7c2ad0dd4 # updated May 16, 2024 - # Uninstall existing jax pip show jax && pip uninstall -y jax pip show jaxlib && pip uninstall -y jaxlib @@ -26,17 +23,5 @@ pip install torch --index-url https://download.pytorch.org/whl/cpu pip install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage pip install safetensors colorama coverage ray[default] humanize -mkdir -p deps -pushd deps -git clone https://github.com/google/JetStream.git -git clone https://github.com/pytorch/xla.git -pushd xla/experimental/torch_xla2 -git checkout $TORCHXLA_TAG -pip install . -popd # now at the folder deps -pushd JetStream -git checkout $JETSTREAM_TAG -pip install . -popd # now at the folder deps -popd # now at the folder current file +git submodule update --init --recursive pip install -e . diff --git a/pyproject.toml b/pyproject.toml index 10f895fe..c68c1500 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -version = "0.2.0" +version = "0.2.1" name = "jetstream_pt" dependencies = [ "absl-py", @@ -14,7 +14,12 @@ dependencies = [ "google-jetstream", "google-cloud-storage", "safetensors", + "torch_xla2 @ {root:uri}/deps/xla/experimental/torch_xla2", + "google-jetstream @ {root:uri}/deps/JetStream", ] requires-python = ">=3.10" license = {file = "LICENSE"} + +[tool.hatch.metadata] +allow-direct-references = true \ No newline at end of file diff --git a/run_interactive.py b/run_interactive.py index 740ab00d..3ab7e59a 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -158,20 +158,15 @@ def main(argv): decode_state, result_tokens = engine.generate(params, decode_state) result_tokens = result_tokens.convert_to_numpy() res = result_tokens.get_result_at_slot(slot) - stop_tokens = set(tokenizer.tokenizer.stop_tokens) + stop_tokens = set(tokenizer.stop_tokens) stop_tokens.add(tokenizer.pad_id) + token_id = res.tokens[0][0].item() + sampled_tokens_list.append(token_id) if ( - res.tokens[0][0] in stop_tokens + token_id in stop_tokens or len(sampled_tokens_list) > max_output_length ): break - token_id = res.tokens[0][0] - sampled_tokens_list.append(token_id) - # output_str = tokenizer.decode_str([token_id]) - # print(Fore.GREEN + output_str, end="", flush=True) - - # print(Style.RESET_ALL + "\n") - # print("---- Streaming decode finished.") print("---- All output tokens.") print(sampled_tokens_list)