From 9c0d2ac730505f85d76eb0a404b1841c1154f7db Mon Sep 17 00:00:00 2001 From: Fanhai Lu <154379058+FanhaiLu1@users.noreply.github.com> Date: Thu, 9 May 2024 15:52:47 -0700 Subject: [PATCH] Add gemma and update recent changes to multiple host (#74) add gemma and update recent changes to multiple host --- jetstream_pt/ray_engine.py | 7 +++++ jetstream_pt/ray_worker.py | 50 ++++++++++++++++++++++++-------- run_interactive_multiple_host.py | 10 +++++++ 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 03f6b830..1b394829 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -152,8 +152,14 @@ def create_pytorch_ray_engine( quantize_weights=False, quantize_kv=False, max_cache_length=1024, + sharding_config=None, ) -> PyTorchRayEngine: + supported_models = ["llama-2", "llama-3", "gemma"] + if model_name not in supported_models: + raise NotImplementedError( + f"Model name should be one of{','.join(supported_models)}" + ) ray.init(ignore_reinit_error=True) pod_name = tpu.get_current_pod_name() num_hosts = tpu.get_current_pod_worker_count() @@ -183,6 +189,7 @@ def create_pytorch_ray_engine( quantize_weights=quantize_weights, quantize_kv=quantize_kv, max_cache_length=max_cache_length, + sharding_config=sharding_config, ) engine_workers.append(engine_worker) engine_master = PyTorchRayEngine( diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index f289dd57..ff49e624 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -17,6 +17,7 @@ from typing import Any, List, Optional, Tuple, Union import threading import functools +import os import humanize @@ -39,6 +40,7 @@ from jetstream_pt import cache_manager from jetstream_pt import quantize from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData +from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model Mesh = jax.sharding.Mesh @@ -103,6 +105,7 @@ def __init__( quantize_weights=False, quantize_kv=False, max_cache_length=1024, + sharding_config=None, ): jax.config.update("jax_default_prng_impl", "unsafe_rbg") @@ -144,11 +147,13 @@ def __init__( checkpoint_format = "safetensors" checkpoint_path = paths[0] + if not sharding_config: + sharding_config = os.path.join("default_shardings", model_name + ".yaml") + env_data = JetEngineEnvironmentData( tokenizer_path=tokenizer_path, checkpoint_path=checkpoint_path, checkpoint_format=checkpoint_format, - model_type="llama-2-" + param_size, batch_size=batch_size, max_decode_length=max_decode_length, max_input_sequence_length=context_length, @@ -156,26 +161,47 @@ def __init__( enable_kv_quantization=quantize_kv, cache_sequence_length=max_cache_length, bf16_enable=bf16_enable, + sharding_config_path=sharding_config, ) env = JetEngineEnvironment(env_data) - pt_model = None - if "llama" in model_name: + if model_name.startswith("llama"): + args = model_args.get_model_args( - model_name + "-" + param_size, - context_length, - batch_size, - bf16_enable, + model_name + "-" + param_size, context_length, batch_size, bf16_enable ) args.device = "meta" args.quantize = quantize_weights + env_data.cache_shape = ( + batch_size, + args.n_kv_heads, + max_cache_length, + args.dim // args.n_heads, + ) + env_data.model_type = "llama-2-" + param_size + env_data.num_layers = args.n_layers + env = JetEngineEnvironment(env_data) pt_model = model_exportable.Transformer(args, env) + elif model_name == "gemma": + args = gemma_config.get_model_config(param_size) + env_data.cache_shape = ( + batch_size, + args.num_key_value_heads, + max_cache_length, + args.head_dim, + ) + env_data.model_type = "gemma-" + param_size + env_data.num_layers = args.num_hidden_layers + env = JetEngineEnvironment(env_data) + pt_model = gemma_model.GemmaModel(args, env) + else: + raise RuntimeError(f"Model with name {model_name} not found") - num_params_size = 0 - num_params = 0 - for _, v in pt_model.state_dict().items(): - num_params += 1 - num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) + num_params_size = 0 + num_params = 0 + for _, v in pt_model.state_dict().items(): + num_params += 1 + num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) print("Number of param Gbytes:", num_params_size / (1 << 30)) print("Number of param: ", num_params) diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index fadf3d41..9de0c492 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -65,6 +65,14 @@ "max_cache_length", 1024, "kv_cache_quantize" ) +_MODEL_NAME = flags.DEFINE_string( + "model_name", None, "model type", required=False +) + +_SHARDING_CONFIG = flags.DEFINE_string( + "sharding_config", "", "config file for sharding" +) + def create_engine(): """create a pytorch engine""" @@ -73,6 +81,7 @@ def create_engine(): start = time.perf_counter() engine = ray_engine.create_pytorch_ray_engine( + model_name=_MODEL_NAME.value, tokenizer_path=_TOKENIZER_PATH.value, ckpt_path=_CKPT_PATH.value, bf16_enable=True, @@ -82,6 +91,7 @@ def create_engine(): quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, + sharding_config=_SHARDING_CONFIG.value, ) print("Initialize engine", time.perf_counter() - start)