Skip to content

Commit

Permalink
Refactor so that environment and engine (AI-Hypercomputer#65)
Browse files Browse the repository at this point in the history
* Refactor so that environment and engine

so that they dont depend on llama
specific stuff such as ModelArgs

* Fix lints
  • Loading branch information
qihqi authored May 2, 2024
1 parent a58051d commit d507086
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 105 deletions.
2 changes: 2 additions & 0 deletions install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

TORCHXLA_TAG=jetstream-pytorch
JETSTREAM_TAG=v0.2.0

# Uninstall existing jax
pip3 show jax && pip3 uninstall -y jax
Expand All @@ -34,6 +35,7 @@ git checkout $TORCHXLA_TAG
pip install .
popd # now at the folder deps
pushd JetStream
git checkout $JETSTREAM_TAG
pip install .
popd # now at the folder deps
popd # now at the folder current file
Expand Down
53 changes: 29 additions & 24 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
from jetstream.engine import engine_api, tokenizer_pb2, token_utils
import torch_xla2
from torch.utils import _pytree as pytree
from jetstream_pt.third_party.llama2 import model_exportable, model_args

from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.third_party.llama2 import model_exportable, model_args


Mesh = jax.sharding.Mesh
Expand Down Expand Up @@ -81,9 +81,6 @@ def __init__(
self.env = env
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32

# NOTE: this is llama2 specific now.
self.param = pt_model.params

self.y_sharding = env.sharding_by_axis(1)
self.x_sharding = env.sharding_by_axis(0)
self.replicated = env.sharding_by_axis(-1) # replicated
Expand Down Expand Up @@ -486,7 +483,7 @@ def generate(
mask,
decode_state.input_pos,
)
next_token = self._sampling(logits, self.param.max_batch_size)
next_token = self._sampling(logits, self.env.batch_size)
lens = decode_state.lens + 1
data = jnp.concatenate(
[
Expand Down Expand Up @@ -621,7 +618,7 @@ def get_prefix_sequence_ddim(self) -> Any:

@property
def max_concurrent_decodes(self) -> int:
return self.param.max_batch_size
return self.env.batch_size

@property
def samples_per_slot(self) -> int:
Expand All @@ -630,7 +627,7 @@ def samples_per_slot(self) -> int:

@property
def max_prefill_length(self) -> int:
return self.param.max_seq_len
return self.env.max_input_sequence_length

@property
def max_decode_length(self) -> int:
Expand Down Expand Up @@ -693,24 +690,11 @@ def create_pytorch_engine(
checkpoint_format = "safetensors"
checkpoint_path = paths[0]

env_data = JetEngineEnvironmentData(
tokenizer_path=tokenizer_path,
checkpoint_path=checkpoint_path,
checkpoint_format=checkpoint_format,
model_type="llama-2-" + param_size,
batch_size=batch_size,
max_decode_length=max_decode_length,
max_input_sequence_length=context_length,
enable_weight_quantization=quantize_weights,
enable_kv_quantization=quantize_kv,
cache_sequence_length=max_cache_length,
bf16_enable=bf16_enable,
)
env = JetEngineEnvironment(env_data)

tokenizer = token_utils.load_vocab(tokenizer_path)
pt_model = None
if model_name == "llama":

if model_name.startswith("llama"):

args = model_args.get_model_args(
param_size,
context_length,
Expand All @@ -720,13 +704,34 @@ def create_pytorch_engine(
)
args.device = "meta"
args.quantize = quantize_weights
env_data = JetEngineEnvironmentData(
tokenizer_path=tokenizer_path,
checkpoint_path=checkpoint_path,
checkpoint_format=checkpoint_format,
model_type="llama-2-" + param_size,
batch_size=batch_size,
max_decode_length=max_decode_length,
max_input_sequence_length=context_length,
enable_weight_quantization=quantize_weights,
enable_kv_quantization=quantize_kv,
cache_sequence_length=max_cache_length,
bf16_enable=bf16_enable,
num_layers=args.n_layers,
cache_shape=(
batch_size,
args.n_kv_heads,
max_cache_length,
args.dim // args.n_heads,
),
)
env = JetEngineEnvironment(env_data)
pt_model = model_exportable.Transformer(args, env)

num_params_size = 0
num_params = 0
for _, v in pt_model.state_dict().items():
num_params += 1
num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2)
num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2)
print("Number of param Gbytes:", num_params_size / (1 << 30))
print("Number of param: ", num_params)

Expand Down
34 changes: 7 additions & 27 deletions jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch_xla2


from jetstream_pt.third_party.llama2 import model_args
from jetstream_pt import cache_manager


Expand Down Expand Up @@ -52,6 +51,11 @@ class JetEngineEnvironmentData:
"head_dim",
)

# Shape of cache len(cache_shape) == len(attention_kv_axis_names)
cache_shape: Tuple[int, ...] = ()

num_layers: int = 0

# This is the axis to shard among the number of available devices
# This string must be one of the values of attention_kv_axis_names above
kv_cache_shard_axis: str = "num_attn_heads"
Expand All @@ -73,23 +77,8 @@ class JetEngineEnvironment:

def __init__(self, data: JetEngineEnvironmentData):
self._data = data
# Get 13b
self._model_arg = model_args.get_model_args(
data.model_type.replace("llama-2-", ""),
context_length=data.max_input_sequence_length,
batch_size=data.batch_size,
vocab_size=32000, # ?
bf16_enable=data.bf16_enable,
)

self.batch_size = self._data.batch_size
self.seq_len = self._data.max_input_sequence_length
self.num_layers = self._model_arg.n_layers
self.num_kv_heads = self._model_arg.n_kv_heads
self.num_heads = self._model_arg.n_heads
self.head_dim = self._model_arg.dim // self._model_arg.n_heads
self.cache_sequence_length = self._data.cache_sequence_length
self.bf16_enable = self._data.bf16_enable

P = jax.sharding.PartitionSpec

Expand All @@ -115,11 +104,6 @@ def __init__(self, data: JetEngineEnvironmentData):
def __getattr__(self, name):
return getattr(self._data, name)

@property
def tokenizer_path(self):
"""Return tokenizer path"""
return self._data.tokenizer_path

# This is used by model to add activation sharding.
def apply_sharding(self, tensor, *, axis: int | None):
"""Apply sharding for tensor"""
Expand Down Expand Up @@ -150,12 +134,8 @@ def make_caches_prefill(self):
def make_caches_generate(self):
"""Create kv caches for inference generation"""
caches = []
shape = (
self.batch_size,
self.num_kv_heads,
self._data.cache_sequence_length,
self.head_dim,
)
shape = self._data.cache_shape

for _ in range(self.num_layers):
if self.enable_kv_quantization:
caches.append(
Expand Down
9 changes: 3 additions & 6 deletions jetstream_pt/ray_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ def __init__(
self.env = env
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32

# NOTE: this is llama2 specific now.
self.param = pt_model.params

self.y_sharding = env.sharding_by_axis(1)
self.x_sharding = env.sharding_by_axis(0)
self.replicated = env.sharding_by_axis(-1) # replicated
Expand Down Expand Up @@ -682,7 +679,7 @@ def generate(
)

logits = multihost_utils.process_allgather(logits, tiled=True)
next_token = self._sampling(logits, self.param.max_batch_size)
next_token = self._sampling(logits, self.env.batch_size)

data = np.concatenate(
[
Expand Down Expand Up @@ -837,7 +834,7 @@ def get_prefix_sequence_ddim(self) -> Any:
@property
def max_concurrent_decodes(self) -> int:
"""Max batch size for decodes"""
return self.param.max_batch_size
return self.env.batch_size

@property
def samples_per_slot(self) -> int:
Expand All @@ -847,7 +844,7 @@ def samples_per_slot(self) -> int:
@property
def max_prefill_length(self) -> int:
"""Maximum prefill length"""
return self.param.max_seq_len
return self.env.max_input_sequence_length

@property
def max_decode_length(self) -> int:
Expand Down
29 changes: 29 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import jax
from jetstream_pt.third_party.llama2 import model_args
from jetstream_pt import environment


def make_env_tiny(bf16_enable=True):
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
torch.set_default_dtype(torch_dtype)
jax.config.update("jax_dynamic_shapes", False)
jax.config.update("jax_traceback_filtering", "off")
config = model_args.get_model_args("tiny", 128, 1, 32000, True)
environment_data = environment.JetEngineEnvironmentData()
environment_data.max_input_sequence_length = 128
environment_data.max_input_sequence_length = 128
environment_data.cache_sequence_length = 128
environment_data.bf16_enable = bf16_enable
environment_data.model_type = "llama-2-tiny"
environment_data.batch_size = 1
environment_data.num_layers = config.n_layers
environment_data.cache_shape = (
1,
config.n_kv_heads,
environment_data.cache_sequence_length,
config.dim // config.n_heads,
)
env = environment.JetEngineEnvironment(environment_data)
env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
return env, config
File renamed without changes.
Loading

0 comments on commit d507086

Please sign in to comment.