Skip to content

Commit

Permalink
Add ray multiple host support (AI-Hypercomputer#63)
Browse files Browse the repository at this point in the history
* Add ray multiple host support

* add dependencies

* add dependencies

* add assertion check on pod_name and num_hosts

* Update ray engine and worker

* update interactive

* update ray worker

* add comments

* update comments
  • Loading branch information
FanhaiLu1 authored Apr 30, 2024
1 parent 2b1a527 commit a58051d
Show file tree
Hide file tree
Showing 5 changed files with 1,218 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[MESSAGES CONTROL]
disable=C0114,R0801,E1102,W0613
disable=C0114,R0801,E1102,W0613,R1711
3 changes: 2 additions & 1 deletion install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ pip3 show libtpu-nightly && pip3 uninstall -y libtpu-nightly
pip3 install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# torch cpu
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage safetensors colorama coverage
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
pip3 install safetensors colorama coverage ray[default] humanize

mkdir -p deps
pushd deps
Expand Down
194 changes: 194 additions & 0 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import Any, Iterable, Optional, Union

import numpy as np
import jax
import ray
from ray.util.accelerators import tpu

from jetstream.engine import engine_api, tokenizer_pb2
from jetstream_pt.ray_worker import PyTorchRayWorker

Params = Any
Prefix = Any
DecodeState = Any


class PyTorchRayEngine(engine_api.Engine):
"""Ray PyTorch Engine Implementation for Multi-Host Inference Serving.
Key Features:
1. Manages all Ray workers.
2. Initializes model parameters for each Ray worker.
3. Routes incoming inference requests to Ray workers.
4. Collects token responses from the Ray workers.
"""

def __init__(
self,
engine_workers: Iterable[PyTorchRayWorker],
tokenizer_path: str,
context_length: int,
batch_size: int,
):
self.engine_workers = engine_workers
self.tokenizer_path = tokenizer_path
self.context_length = context_length
self.batch_size = batch_size

# pylint: disable-next=all
def load_params(self) -> Params:
all_outputs = []
for worker in self.engine_workers:
output = worker.load_params_ray.remote()
all_outputs.append(output)
_ = ray.get(all_outputs)
return None

# pylint: disable-next=all
def init_decode_state(
self,
) -> DecodeState:
all_outputs = []
for worker in self.engine_workers:
output = worker.init_decode_state_ray.remote()
all_outputs.append(output)
_ = ray.get(all_outputs)
return None

def prefill(
self,
*,
params: Any, # Weights
existing_prefix: Optional[Prefix] = None,
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
true_length: int,
) -> Prefix:
all_outputs = []
for worker in self.engine_workers:
output = worker.prefill_ray.remote(
params=params,
existing_prefix=existing_prefix,
padded_tokens=padded_tokens,
true_length=true_length,
)
all_outputs.append(output)
_ = ray.get(all_outputs)
# The prefill function does not return any values;
# the worker itself manages and maintains the prefill states.
return None

def insert(
self,
prefix: Prefix,
decode_state: DecodeState,
slot: int,
) -> DecodeState:
all_outputs = []
for worker in self.engine_workers:
output = worker.insert_ray.remote(
prefix=prefix, decode_state=decode_state, slot=slot
)
all_outputs.append(output)
_ = ray.get(all_outputs)
# The insert function does not return any values;
# the worker itself manages and maintains the DecodeState.
return None

def generate(
self, params: Any, decode_state: DecodeState
) -> tuple[None, engine_api.ResultTokens]:
all_outputs = []
for worker in self.engine_workers:
output = worker.generate_ray.remote(
params=params, decode_state=decode_state
)
all_outputs.append(output)
# All workers performed an all_gather operation. Since the results are
# identical across all workers, the result from worker 0 is returned.
state, result_tokens = ray.get(all_outputs)[0]
return state, result_tokens

# pylint: disable-next=all
def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
# pylint: disable-next=all
return tokenizer_pb2.TokenizerParameters(path=self.tokenizer_path)

@property
def max_concurrent_decodes(self) -> int:
return self.batch_size

@property
def samples_per_slot(self) -> int:
return 1

@property
def max_prefill_length(self) -> int:
return self.context_length

@property
def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]:
return jax.devices("cpu")[0]

def get_prefix_destination_sharding(self) -> Prefix:
"No implementation"
return None

@property
def mesh(self):
"No implementation"
return None


# pylint: disable-next=all
def create_pytorch_ray_engine(
tokenizer_path: str,
ckpt_path: Optional[str] = None,
samples_per_slot: int = 1,
bf16_enable: bool = False,
param_size: str = "7b",
context_length: int = 1024,
batch_size: int = 1,
max_decode_length: int = 4096,
model_name="llama",
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
) -> PyTorchRayEngine:

ray.init(ignore_reinit_error=True)
pod_name = tpu.get_current_pod_name()
num_hosts = tpu.get_current_pod_worker_count()
print(f"pod_name:{pod_name}, number of host: {num_hosts}")
assert (
pod_name is not None
), f"TPU pod name (current value:{pod_name}) can not be None"
assert (
num_hosts > 0
), f"num_hosts (current value {num_hosts}) should be a positive number"
# pylint: disable-next=all
engine_worker_with_tpu_resource = PyTorchRayWorker.options(
resources={"TPU": 4}
)
engine_workers = []
for _ in range(num_hosts):
engine_worker = engine_worker_with_tpu_resource.remote(
tokenizer_path=tokenizer_path,
ckpt_path=ckpt_path,
samples_per_slot=samples_per_slot,
bf16_enable=bf16_enable,
param_size=param_size,
context_length=context_length,
batch_size=batch_size,
max_decode_length=max_decode_length,
model_name=model_name,
quantize_weights=quantize_weights,
quantize_kv=quantize_kv,
max_cache_length=max_cache_length,
)
engine_workers.append(engine_worker)
engine_master = PyTorchRayEngine(
engine_workers=engine_workers,
tokenizer_path=tokenizer_path,
context_length=context_length,
batch_size=batch_size,
)
return engine_master
Loading

0 comments on commit a58051d

Please sign in to comment.