From d507086f0358535a4fd6ce235192e3407ed08101 Mon Sep 17 00:00:00 2001 From: qihqi Date: Thu, 2 May 2024 11:17:10 -0700 Subject: [PATCH] Refactor so that environment and engine (#65) * Refactor so that environment and engine so that they dont depend on llama specific stuff such as ModelArgs * Fix lints --- install_everything.sh | 2 + jetstream_pt/engine.py | 53 +++++++++++++---------- jetstream_pt/environment.py | 34 +++------------ jetstream_pt/ray_worker.py | 9 ++-- tests/helpers.py | 29 +++++++++++++ tests/{jax_test.py => jax_experiments.py} | 0 tests/test_llama_e2e.py | 42 +++++++++--------- tests/test_model_impl.py | 37 ++++------------ 8 files changed, 101 insertions(+), 105 deletions(-) create mode 100644 tests/helpers.py rename tests/{jax_test.py => jax_experiments.py} (100%) diff --git a/install_everything.sh b/install_everything.sh index 980edc5b..aca732b3 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -13,6 +13,7 @@ # limitations under the License. TORCHXLA_TAG=jetstream-pytorch +JETSTREAM_TAG=v0.2.0 # Uninstall existing jax pip3 show jax && pip3 uninstall -y jax @@ -34,6 +35,7 @@ git checkout $TORCHXLA_TAG pip install . popd # now at the folder deps pushd JetStream +git checkout $JETSTREAM_TAG pip install . popd # now at the folder deps popd # now at the folder current file diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 440fe6e2..16d0038a 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -29,11 +29,11 @@ from jetstream.engine import engine_api, tokenizer_pb2, token_utils import torch_xla2 from torch.utils import _pytree as pytree -from jetstream_pt.third_party.llama2 import model_exportable, model_args from jetstream_pt import cache_manager from jetstream_pt import quantize from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData +from jetstream_pt.third_party.llama2 import model_exportable, model_args Mesh = jax.sharding.Mesh @@ -81,9 +81,6 @@ def __init__( self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 - # NOTE: this is llama2 specific now. - self.param = pt_model.params - self.y_sharding = env.sharding_by_axis(1) self.x_sharding = env.sharding_by_axis(0) self.replicated = env.sharding_by_axis(-1) # replicated @@ -486,7 +483,7 @@ def generate( mask, decode_state.input_pos, ) - next_token = self._sampling(logits, self.param.max_batch_size) + next_token = self._sampling(logits, self.env.batch_size) lens = decode_state.lens + 1 data = jnp.concatenate( [ @@ -621,7 +618,7 @@ def get_prefix_sequence_ddim(self) -> Any: @property def max_concurrent_decodes(self) -> int: - return self.param.max_batch_size + return self.env.batch_size @property def samples_per_slot(self) -> int: @@ -630,7 +627,7 @@ def samples_per_slot(self) -> int: @property def max_prefill_length(self) -> int: - return self.param.max_seq_len + return self.env.max_input_sequence_length @property def max_decode_length(self) -> int: @@ -693,24 +690,11 @@ def create_pytorch_engine( checkpoint_format = "safetensors" checkpoint_path = paths[0] - env_data = JetEngineEnvironmentData( - tokenizer_path=tokenizer_path, - checkpoint_path=checkpoint_path, - checkpoint_format=checkpoint_format, - model_type="llama-2-" + param_size, - batch_size=batch_size, - max_decode_length=max_decode_length, - max_input_sequence_length=context_length, - enable_weight_quantization=quantize_weights, - enable_kv_quantization=quantize_kv, - cache_sequence_length=max_cache_length, - bf16_enable=bf16_enable, - ) - env = JetEngineEnvironment(env_data) - tokenizer = token_utils.load_vocab(tokenizer_path) pt_model = None - if model_name == "llama": + + if model_name.startswith("llama"): + args = model_args.get_model_args( param_size, context_length, @@ -720,13 +704,34 @@ def create_pytorch_engine( ) args.device = "meta" args.quantize = quantize_weights + env_data = JetEngineEnvironmentData( + tokenizer_path=tokenizer_path, + checkpoint_path=checkpoint_path, + checkpoint_format=checkpoint_format, + model_type="llama-2-" + param_size, + batch_size=batch_size, + max_decode_length=max_decode_length, + max_input_sequence_length=context_length, + enable_weight_quantization=quantize_weights, + enable_kv_quantization=quantize_kv, + cache_sequence_length=max_cache_length, + bf16_enable=bf16_enable, + num_layers=args.n_layers, + cache_shape=( + batch_size, + args.n_kv_heads, + max_cache_length, + args.dim // args.n_heads, + ), + ) + env = JetEngineEnvironment(env_data) pt_model = model_exportable.Transformer(args, env) num_params_size = 0 num_params = 0 for _, v in pt_model.state_dict().items(): num_params += 1 - num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) + num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2) print("Number of param Gbytes:", num_params_size / (1 << 30)) print("Number of param: ", num_params) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 4231193d..f223f837 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -21,7 +21,6 @@ import torch_xla2 -from jetstream_pt.third_party.llama2 import model_args from jetstream_pt import cache_manager @@ -52,6 +51,11 @@ class JetEngineEnvironmentData: "head_dim", ) + # Shape of cache len(cache_shape) == len(attention_kv_axis_names) + cache_shape: Tuple[int, ...] = () + + num_layers: int = 0 + # This is the axis to shard among the number of available devices # This string must be one of the values of attention_kv_axis_names above kv_cache_shard_axis: str = "num_attn_heads" @@ -73,23 +77,8 @@ class JetEngineEnvironment: def __init__(self, data: JetEngineEnvironmentData): self._data = data - # Get 13b - self._model_arg = model_args.get_model_args( - data.model_type.replace("llama-2-", ""), - context_length=data.max_input_sequence_length, - batch_size=data.batch_size, - vocab_size=32000, # ? - bf16_enable=data.bf16_enable, - ) - self.batch_size = self._data.batch_size self.seq_len = self._data.max_input_sequence_length - self.num_layers = self._model_arg.n_layers - self.num_kv_heads = self._model_arg.n_kv_heads - self.num_heads = self._model_arg.n_heads - self.head_dim = self._model_arg.dim // self._model_arg.n_heads - self.cache_sequence_length = self._data.cache_sequence_length - self.bf16_enable = self._data.bf16_enable P = jax.sharding.PartitionSpec @@ -115,11 +104,6 @@ def __init__(self, data: JetEngineEnvironmentData): def __getattr__(self, name): return getattr(self._data, name) - @property - def tokenizer_path(self): - """Return tokenizer path""" - return self._data.tokenizer_path - # This is used by model to add activation sharding. def apply_sharding(self, tensor, *, axis: int | None): """Apply sharding for tensor""" @@ -150,12 +134,8 @@ def make_caches_prefill(self): def make_caches_generate(self): """Create kv caches for inference generation""" caches = [] - shape = ( - self.batch_size, - self.num_kv_heads, - self._data.cache_sequence_length, - self.head_dim, - ) + shape = self._data.cache_shape + for _ in range(self.num_layers): if self.enable_kv_quantization: caches.append( diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index edd782e4..71ca873b 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -187,9 +187,6 @@ def __init__( self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 - # NOTE: this is llama2 specific now. - self.param = pt_model.params - self.y_sharding = env.sharding_by_axis(1) self.x_sharding = env.sharding_by_axis(0) self.replicated = env.sharding_by_axis(-1) # replicated @@ -682,7 +679,7 @@ def generate( ) logits = multihost_utils.process_allgather(logits, tiled=True) - next_token = self._sampling(logits, self.param.max_batch_size) + next_token = self._sampling(logits, self.env.batch_size) data = np.concatenate( [ @@ -837,7 +834,7 @@ def get_prefix_sequence_ddim(self) -> Any: @property def max_concurrent_decodes(self) -> int: """Max batch size for decodes""" - return self.param.max_batch_size + return self.env.batch_size @property def samples_per_slot(self) -> int: @@ -847,7 +844,7 @@ def samples_per_slot(self) -> int: @property def max_prefill_length(self) -> int: """Maximum prefill length""" - return self.param.max_seq_len + return self.env.max_input_sequence_length @property def max_decode_length(self) -> int: diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 00000000..764bc4b0 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,29 @@ +import torch +import jax +from jetstream_pt.third_party.llama2 import model_args +from jetstream_pt import environment + + +def make_env_tiny(bf16_enable=True): + torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 + torch.set_default_dtype(torch_dtype) + jax.config.update("jax_dynamic_shapes", False) + jax.config.update("jax_traceback_filtering", "off") + config = model_args.get_model_args("tiny", 128, 1, 32000, True) + environment_data = environment.JetEngineEnvironmentData() + environment_data.max_input_sequence_length = 128 + environment_data.max_input_sequence_length = 128 + environment_data.cache_sequence_length = 128 + environment_data.bf16_enable = bf16_enable + environment_data.model_type = "llama-2-tiny" + environment_data.batch_size = 1 + environment_data.num_layers = config.n_layers + environment_data.cache_shape = ( + 1, + config.n_kv_heads, + environment_data.cache_sequence_length, + config.dim // config.n_heads, + ) + env = environment.JetEngineEnvironment(environment_data) + env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu + return env, config diff --git a/tests/jax_test.py b/tests/jax_experiments.py similarity index 100% rename from tests/jax_test.py rename to tests/jax_experiments.py diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 312294fc..31a8c36f 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -22,10 +22,11 @@ import torch import torch_xla2 from torch.utils import _pytree as pytree +from . import helpers from jetstream_pt.engine import PyTorchEngine -from jetstream_pt.third_party.llama2 import model_exportable +from jetstream_pt.third_party.llama2 import model_exportable, model_args from jetstream_pt.third_party.llama2.generation_original import LlamaOriginal from jetstream_pt import environment @@ -45,6 +46,7 @@ def _make_env(self, bf16_enable=True): torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) jax.config.update("jax_traceback_filtering", "off") + config = model_args.get_model_args("tiny", 128, 1, 32000, True) environment_data = environment.JetEngineEnvironmentData() environment_data.max_input_sequence_length = 128 environment_data.max_input_sequence_length = 128 @@ -52,24 +54,30 @@ def _make_env(self, bf16_enable=True): environment_data.bf16_enable = bf16_enable environment_data.model_type = "llama-2-tiny" environment_data.batch_size = 1 + environment_data.num_layers = config.n_layers + environment_data.cache_shape = ( + 1, + config.n_kv_heads, + environment_data.cache_sequence_length, + config.dim // config.n_heads, + ) env = environment.JetEngineEnvironment(environment_data) env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu - return env + return env, config def test_original_llama2_seed(self): """test original llama2 output with different seed""" jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") torch.set_default_dtype(torch.bfloat16) - env = self._make_env() # pylint: disable-next=all - model_arg = env._model_arg tokens = np.arange(10, dtype=np.int32) file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" ) output_tokens_multiple = [] + model_arg = model_args.get_model_args("tiny", 128, 1, 32000, True) for i in [1, 999, 99999]: llama_original = LlamaOriginal.build(tokenizer_path, model_arg, i) prompt_tokens = [tokens] @@ -91,9 +99,8 @@ def test_jetstream_llama2_seed(self): torch.set_default_dtype(torch.bfloat16) # pylint: disable-next=all - env = self._make_env() + env, model_arg = helpers.make_env_tiny() # pylint: disable-next=all - model_arg = env._model_arg tokens = np.arange(10, dtype=np.int32) true_length = tokens.shape[-1] padded_tokens = np.pad(tokens, (0, 6)) @@ -155,9 +162,8 @@ def test_jetstream_llama2_seed(self): ) # pylint: disable-next=all - def _llama_e2e(self, env): + def _llama_e2e(self, env, model_arg): # pylint: disable-next=all - model_arg = env._model_arg tokens = np.arange(10, dtype=np.int32) true_length = tokens.shape[-1] padded_tokens = np.pad(tokens, (0, 6)) @@ -218,8 +224,8 @@ def test_llama_e2e_float32(self): jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") - env = self._make_env(bf16_enable=False) - out_tokens, expected_output_tokens = self._llama_e2e(env) + env, model_arg = helpers.make_env_tiny(bf16_enable=False) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertEqual(out_tokens, expected_output_tokens) def test_llama_e2e_bfloat16(self): @@ -228,8 +234,8 @@ def test_llama_e2e_bfloat16(self): jax.config.update("jax_default_matmul_precision", jax.lax.Precision.HIGHEST) print(f"---------> {jax.devices()}") - env = self._make_env(bf16_enable=True) - out_tokens, expected_output_tokens = self._llama_e2e(env) + env, model_arg = helpers.make_env_tiny(bf16_enable=True) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertNotEqual(out_tokens, expected_output_tokens) # pylint: disable-next=all @@ -240,9 +246,8 @@ def test_llama_e2e_two_addtional_tokens(self): torch.set_default_dtype(torch.bfloat16) # pylint: disable-next=all - env = self._make_env() + env, model_arg = helpers.make_env_tiny() # pylint: disable-next=all - model_arg = env._model_arg tokens = np.arange(10, dtype=np.int32) tokens = np.append(tokens, [15050, 3503], axis=-1) true_length = tokens.shape[-1] @@ -314,9 +319,8 @@ def test_llama_e2e_four_addtional_tokens(self): torch.set_default_dtype(torch.bfloat16) # pylint: disable-next=all - env = self._make_env() + env, model_arg = helpers.make_env_tiny() # pylint: disable-next=all - model_arg = env._model_arg tokens = np.arange(10, dtype=np.int32) tokens = np.append(tokens, [15050, 3503, 11833, 28551], axis=-1) true_length = tokens.shape[-1] @@ -385,9 +389,8 @@ def test_llama_with_original_prefill_decode_32(self): print(f"---------> {jax.devices()}") torch.set_default_dtype(torch.float32) - env = self._make_env(bf16_enable=False) + env, model_arg = helpers.make_env_tiny(bf16_enable=False) # pylint: disable-next=all - model_arg = env._model_arg tokens = np.arange(10, dtype=np.int32) true_length = tokens.shape[-1] padded_tokens = np.pad(tokens, (0, 6)) @@ -462,9 +465,8 @@ def test_llama_with_original_prefill_decode(self): print(f"---------> {jax.devices()}") torch.set_default_dtype(torch.float32) - env = self._make_env() + env, model_arg = helpers.make_env_tiny() # pylint: disable-next=all - model_arg = env._model_arg tokens = np.arange(10, dtype=np.int32) true_length = tokens.shape[-1] padded_tokens = np.pad(tokens, (0, 6)) diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 9620ae87..23b9d14d 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -18,6 +18,7 @@ import torch from torch.utils import _pytree as pytree import torch_xla2 +from . import helpers from jetstream_pt.third_party.llama2 import model_exportable from jetstream_pt.third_party.llama2 import model_original @@ -32,7 +33,8 @@ class ModelComponentTest(unittest.TestCase): def setup(self): """setup torch env""" - torch.set_default_dtype(torch.bfloat16) + jax.config.update("jax_platform_name", "cpu") + torch.set_default_dtype(torch.float32) def _prefill_mask(self, seqlen, start_pos): mask = torch.full((seqlen, seqlen), float("-inf")) @@ -63,19 +65,6 @@ def _to_xla_tensor(self, tree): torch.Tensor, torch_xla2.tensor.move_to_device, tree ) - def _make_env(self): - jax.config.update("jax_platform_name", "cpu") - torch.set_default_dtype(torch.float32) - env_data = environment.JetEngineEnvironmentData() - env_data.max_input_sequence_length = 128 - env_data.max_input_sequence_length = 128 - env_data.cache_sequence_length = 128 - env_data.model_type = "llama-2-tiny" - env_data.batch_size = 1 - env = environment.JetEngineEnvironment(env_data) - env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu - return env - def _call_xla_model(self, model, weights, args): with jax.default_device(jax.devices("cpu")[0]): xla_weights, xla_inputs = self._to_xla_tensor((weights, args)) @@ -96,12 +85,9 @@ def _compare_cache(self, cache_torch, cache_jax): print("diff ", (cache_torch[0, s] - cache_j[0, :, s]).norm()) def _make_one_cache_for_generate(self, env, pos): - cache_array_k = jnp.zeros( - (1, env.num_heads, env.cache_sequence_length, env.head_dim) - ) - cache_array_v = jnp.zeros( - (1, env.num_heads, env.cache_sequence_length, env.head_dim) - ) + cache_array_k = jnp.zeros(env.cache_shape) + + cache_array_v = jnp.zeros(env.cache_shape) cache_array_k, cache_array_v = torch_xla2.tensor.wrap( (cache_array_k, cache_array_v) ) @@ -112,8 +98,7 @@ def _make_one_cache_for_generate(self, env, pos): # pylint: disable-next=all def test_attention(self): - env = self._make_env() - model_arg = env._model_arg + env, model_arg = helpers.make_env_tiny(False) attention_orig = model_original.Attention(model_arg) attention_ours = layers.Attention(model_arg, env) @@ -186,9 +171,7 @@ def test_attention(self): # pylint: disable-next=all def test_transformer_block(self): - env = self._make_env() - # pylint: disable-next=all - model_arg = env._model_arg + env, model_arg = helpers.make_env_tiny(False) block_orig = model_original.TransformerBlock(0, model_arg) block_ours = model_exportable.TransformerBlock(0, model_arg, env) @@ -260,9 +243,7 @@ def test_transformer_block(self): # pylint: disable-next=all def test_transformer(self): """test transformer diff between original model vs xla_model""" - env = self._make_env() - # pylint: disable-next=all - model_arg = env._model_arg + env, model_arg = helpers.make_env_tiny(False) model_orig = model_original.Transformer(model_arg) state_dict = dict(model_orig.state_dict())