Skip to content

Commit

Permalink
Use git submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed May 17, 2024
1 parent 0f4e7ae commit 0e5607c
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 31 deletions.
28 changes: 26 additions & 2 deletions .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
pylint --indent-string=' ' jetstream_pt/ benchmarks/
- name: Format check with pyink
run: |
pyink --pyink-indentation 2 --line-length 80 --check --verbose .
pyink --pyink-indentation 2 --line-length 80 --check --verbose --extend-exclude=deps .
cpu:
name: "jetstream_pt unit tests"
Expand All @@ -79,4 +79,28 @@ jobs:
JAX_PLATFORMS=cpu coverage run -m unittest -v
- name: Create test coverage report
run: |
coverage report -m
coverage report -m
interactive:
name: "jetstream_pt run interactive"
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ['3.10']
runs-on: ${{ matrix.os }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
source install_everything.sh
- name: Run interactive (bf16)
run: |
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0
- name: Run interactive (int8)
run: |
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_kv_cache=1
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# source dependencies
deps/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[submodule "deps/JetStream"]
path = deps/JetStream
url = https://github.com/google/JetStream.git
[submodule "deps/xla"]
path = deps/xla
url = https://github.com/pytorch/xla.git
1 change: 1 addition & 0 deletions deps/JetStream
Submodule JetStream added at 8128c8
1 change: 1 addition & 0 deletions deps/xla
Submodule xla added at f26c35
17 changes: 1 addition & 16 deletions install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024
JETSTREAM_TAG=e4952fbb12e0ab3c33bc7c1eef3839b7c2ad0dd4 # updated May 16, 2024

# Uninstall existing jax
pip show jax && pip uninstall -y jax
pip show jaxlib && pip uninstall -y jaxlib
Expand All @@ -26,17 +23,5 @@ pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
pip install safetensors colorama coverage ray[default] humanize

mkdir -p deps
pushd deps
git clone https://github.com/google/JetStream.git
git clone https://github.com/pytorch/xla.git
pushd xla/experimental/torch_xla2
git checkout $TORCHXLA_TAG
pip install .
popd # now at the folder deps
pushd JetStream
git checkout $JETSTREAM_TAG
pip install .
popd # now at the folder deps
popd # now at the folder current file
git submodule update --init --recursive
pip install -e .
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
version = "0.2.0"
version = "0.2.1"
name = "jetstream_pt"
dependencies = [
"absl-py",
Expand All @@ -14,7 +14,12 @@ dependencies = [
"google-jetstream",
"google-cloud-storage",
"safetensors",
"torch_xla2 @ {root:uri}/deps/xla/experimental/torch_xla2",
"google-jetstream @ {root:uri}/deps/JetStream",
]

requires-python = ">=3.10"
license = {file = "LICENSE"}

[tool.hatch.metadata]
allow-direct-references = true
13 changes: 4 additions & 9 deletions run_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,15 @@ def main(argv):
decode_state, result_tokens = engine.generate(params, decode_state)
result_tokens = result_tokens.convert_to_numpy()
res = result_tokens.get_result_at_slot(slot)
stop_tokens = set(tokenizer.tokenizer.stop_tokens)
stop_tokens = set(tokenizer.stop_tokens)
stop_tokens.add(tokenizer.pad_id)
token_id = res.tokens[0][0].item()
sampled_tokens_list.append(token_id)
if (
res.tokens[0][0] in stop_tokens
token_id in stop_tokens
or len(sampled_tokens_list) > max_output_length
):
break
token_id = res.tokens[0][0]
sampled_tokens_list.append(token_id)
# output_str = tokenizer.decode_str([token_id])
# print(Fore.GREEN + output_str, end="", flush=True)

# print(Style.RESET_ALL + "\n")
# print("---- Streaming decode finished.")

print("---- All output tokens.")
print(sampled_tokens_list)
Expand Down

0 comments on commit 0e5607c

Please sign in to comment.