diff --git a/benchmarks/prefill_offline.py b/benchmarks/prefill_offline.py index a6db15ab..03bf4180 100644 --- a/benchmarks/prefill_offline.py +++ b/benchmarks/prefill_offline.py @@ -17,6 +17,7 @@ import functools import humanize +# pylint: disable-next=all from absl import app from absl import flags import numpy as np diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 2f480930..df788591 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -15,6 +15,7 @@ import logging import os import time +# pylint: disable-next=all from absl import app from absl import flags diff --git a/install_everything.sh b/install_everything.sh index c4ae1d03..82687ed9 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -TORCHXLA_TAG=6dccf0a02d7828516bdb589f2ae0dc79b64488fa # updated May 10, 2024 +TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024 JETSTREAM_TAG=v0.2.1 # Uninstall existing jax diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 0f11f79f..75bd75c0 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -15,7 +15,7 @@ import jax import jax.numpy as jnp import torch -from jetstream_pt.torchjax import from_jax, to_jax +from jetstream_pt import torchjax # pylint: disable-next=all @@ -49,7 +49,7 @@ def update(self, key, value): self.cache_v = value if self.kv_quantize: # pretend to be quantized bsz, _, seq, _ = key.shape - ones = from_jax(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)) + ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)) return key, value, ones, ones return key, value @@ -62,7 +62,7 @@ def state(self): # pylint: disable-next=all def KVCachePrefill_flatten(cache): return ( - to_jax((cache.cache_k, cache.cache_v)), + torchjax.from_torch((cache.cache_k, cache.cache_v)), cache.kv_quantize, ) @@ -70,7 +70,7 @@ def KVCachePrefill_flatten(cache): # pylint: disable-next=all def KVCachePrefill_unflatten(auxdata, data): cache = KVCachePrefill(auxdata) - cache_k, cache_v = to_jax(data) + cache_k, cache_v = torchjax.from_torch(data) cache.cache_k = cache_k cache.cache_v = cache_v @@ -100,7 +100,7 @@ def __init__( def update(self, key, value): """Update kv cache""" - keyj, valuej = from_jax((key, value)) + keyj, valuej = torchjax.to_torch((key, value)) # pylint: disable-next=all self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) # pylint: disable-next=all @@ -118,7 +118,7 @@ def empty(cls, shape, device, bf16_enable): default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 k = jnp.zeros(shape, device=device, dtype=default_dtype) v = jnp.zeros(shape, device=device, dtype=default_dtype) - k, v = from_jax((k, v)) + k, v = torchjax.to_torch((k, v)) return cls(k, v, 0, device) @@ -133,7 +133,7 @@ def KVCacheGenerate_flatten(cache): # pylint: disable-next=all def KVCacheGenerate_unflatten(auxdata, data): position, sharding = auxdata - cache_k, cache_v = from_jax(data) + cache_k, cache_v = torchjax.to_torch(data) cache = KVCacheGenerate(cache_k, cache_v, position, sharding) return cache @@ -166,11 +166,11 @@ def __init__( def state(self): """Get kv cache state""" - return to_jax((self.cache_k, self.cache_v)) + return torchjax.from_torch((self.cache_k, self.cache_v)) def scalers(self): """Get kv cache scalers""" - return to_jax((self.k_scaler, self.v_scaler)) + return torchjax.from_torch((self.k_scaler, self.v_scaler)) @classmethod # pylint: disable-next=all @@ -182,7 +182,7 @@ def empty(cls, shape, device, bf16_enable): kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) - cache_k, cache_v, kscaler, vscaler = from_jax( + cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) ) return cls(cache_k, cache_v, kscaler, vscaler, 0, device) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 3d18f737..4f9c8256 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -33,7 +33,7 @@ from jetstream_pt import cache_manager from jetstream_pt import quantize -from jetstream_pt.torchjax import from_jax, to_jax +from jetstream_pt import torchjax from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model @@ -151,27 +151,31 @@ def _call_model_generate( if self.env.enable_kv_quantization: caches_obj = [ cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes) - for (k, v), (ks, vs) in from_jax(list(zip(caches, cache_scales))) + for (k, v), (ks, vs) in torchjax.to_torch( + list(zip(caches, cache_scales)) + ) ] else: caches_obj = [ cache_manager.KVCacheGenerate( k, v, input_indexes, self.cache_sharding ) - for k, v in from_jax(caches) + for k, v in torchjax.to_torch(caches) ] mask = jnp.expand_dims(mask, (1, 2)) args = (tokens, input_pos, caches_obj, mask) - paramst, argst = from_jax((weights, args)) + paramst, argst = torchjax.to_torch((weights, args)) with self._lock: - with torch_xla2.default_env(): + with torchjax.jax_mode: + # The mode is needed so that tensors created inside of + # the model (such as via torch.ones etc) also have the right type res = torch.func.functional_call(self.pt_model, paramst, argst) updated_caches = [c.state() for c in caches_obj] scales = [] if self.env.enable_kv_quantization: scales = [c.scalers() for c in caches_obj] - return to_jax((res, updated_caches, scales)) + return torchjax.from_torch((res, updated_caches, scales)) @functools.partial( jax.jit, @@ -190,12 +194,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes): mask = jnp.triu(mask, k=1) args = (tokens, input_indexes, caches, mask) - paramst, argst = from_jax((weights, args)) + paramst, argst = torchjax.to_torch((weights, args)) with self._lock: - with torch_xla2.default_env(): + with torchjax.jax_mode: res = torch.func.functional_call(self.pt_model, paramst, argst)[0] caches_res = [c.state() for c in caches] - return to_jax((res, caches_res)) + return torchjax.from_torch((res, caches_res)) def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: if len(logits.shape) == 2: @@ -561,7 +565,7 @@ def _load_from_state_dict(self, path): for key, model_weights in self.pt_model.state_dict().items(): assert key in state_dict, f"key: {key} not found" arr = jax.device_put( - to_jax(state_dict[key]), self.env.sharding_by_name(key) + torchjax.from_torch(state_dict[key]), self.env.sharding_by_name(key) ) assert tuple(model_weights.shape) == tuple( arr.shape @@ -602,41 +606,23 @@ def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]: def get_prefix_destination_sharding(self) -> Prefix: """Returns the shardings necessary to transfer data between engines.""" - if self.env.shard_on_batch: - return Prefix( - self.replicated, # cache is replicated because bs=1 for prefill - self.replicated, - self.replicated, - ) - else: - return Prefix( - self.replicated, - self.cache_sharding, - self.replicated, - ) + return Prefix( + self.replicated, + self.replicated if self.env.shard_on_batch else self.cache_sharding, + self.replicated, + ) def get_decode_state_sharding(self) -> DecodeState: """Gets the shardings corresponding to the decode state.""" - if self.env.shard_on_batch: - return DecodeState( - self.x_sharding, # shard on batch - self.cache_sharding, - self.replicated, - self.replicated, - self.replicated, - self.replicated, - self.replicated, - ) - else: - return DecodeState( - self.replicated, # shard on batch - self.cache_sharding, - self.replicated, - self.replicated, - self.replicated, - self.replicated, - self.replicated, - ) + return DecodeState( + self.x_sharding if self.env.shard_on_batch else self.replicated, + self.cache_sharding, + self.replicated, + self.replicated, + self.replicated, + self.replicated, + self.replicated, + ) def get_prefix_sequence_ddim(self) -> Any: """Returns the index of the sequence dim in the prefix type.""" diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 5ef9dfa6..f4dea094 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -172,10 +172,8 @@ def __call__(self, xq, xk, xv, mask, cache): output = output[:, :, 0:1, :] # For XLA matmul performance boost # output = torch.matmul(scores, values) - if self.env.shard_on_batch: - self.env.apply_sharding(output, axis=0) - else: - self.env.apply_sharding(output, axis=1) + shard_axis = 0 if self.env.shard_on_batch else 1 + self.env.apply_sharding(output, axis=shard_axis) return output @@ -225,10 +223,8 @@ def __call__(self, xq, xk, xv, mask, cache): if seqlen == 1: output = output[:, :, 0:1, :] # output = torch.matmul(scores, values) - if self.env.shard_on_batch: - self.env.apply_sharding(output, axis=0) - else: - self.env.apply_sharding(output, axis=1) + shard_axis = 0 if self.env.shard_on_batch else 1 + self.env.apply_sharding(output, axis=shard_axis) return output @@ -319,14 +315,10 @@ def forward( xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) - if self.env.shard_on_batch: - self.env.apply_sharding(xq, axis=0) - self.env.apply_sharding(xk, axis=0) - self.env.apply_sharding(xv, axis=0) - else: - self.env.apply_sharding(xq, axis=2) - self.env.apply_sharding(xk, axis=2) - self.env.apply_sharding(xv, axis=2) + shard_axis = 0 if self.env.shard_on_batch else 2 + self.env.apply_sharding(xq, axis=shard_axis) + self.env.apply_sharding(xk, axis=shard_axis) + self.env.apply_sharding(xv, axis=shard_axis) with jax.named_scope("attn_rope"): xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 64a67532..947acf59 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -39,7 +39,7 @@ from jetstream_pt import cache_manager from jetstream_pt import quantize -from jetstream_pt.torchjax import from_jax, to_jax +from jetstream_pt import torchjax from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model @@ -336,21 +336,23 @@ def _call_model_generate( if self.env.enable_kv_quantization: caches_obj = [ cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes) - for (k, v), (ks, vs) in from_jax(list(zip(caches, cache_scales))) + for (k, v), (ks, vs) in torchjax.to_torch( + list(zip(caches, cache_scales)) + ) ] else: caches_obj = [ cache_manager.KVCacheGenerate( k, v, input_indexes, self.cache_sharding ) - for k, v in from_jax(caches) + for k, v in torchjax.to_torch(caches) ] mask = jnp.expand_dims(mask, (1, 2)) args = (tokens, input_pos, caches_obj, mask) - paramst, argst = from_jax((weights, args)) + paramst, argst = torchjax.to_torch((weights, args)) with self._lock: - with torch_xla2.default_env(): + with torchjax.jax_mode(): res = torch.func.functional_call(self.pt_model, paramst, argst) updated_caches = [c.state() for c in caches_obj] scales = [] @@ -360,7 +362,7 @@ def _call_model_generate( current_position + 1 ) % self.env.cache_sequence_length - return to_jax( + return torchjax.from_torch( ( res, updated_caches, @@ -389,12 +391,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes): mask = jnp.triu(mask, k=1) args = (tokens, input_indexes, caches, mask) - paramst, argst = from_jax((weights, args)) + paramst, argst = torchjax.to_torch((weights, args)) with self._lock: - with torch_xla2.default_env(): + with torchjax.jax_mode: res = torch.func.functional_call(self.pt_model, paramst, argst)[0] caches_res = [c.state() for c in caches] - return to_jax((res, caches_res)) + return torchjax.from_torch((res, caches_res)) def _sampling(self, logits: Any, batch_size: int) -> np.ndarray: if len(logits.shape) == 2: @@ -503,7 +505,7 @@ def insert(cache, new_entry): @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, scaler, new_entry): reduce_axis = (1, 3) - vals, scales = torch_xla2.interop.call_torch( + vals, scales = torchjax.call_torch( quantize.quantize_torch_int8, new_entry, reduce_axis ) new_scaler = jax.lax.dynamic_update_slice( @@ -602,7 +604,7 @@ def insert(cache, new_entry): def insert(cache, scaler, new_entry): new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2)) reduce_axis = (1, 2) - vals, scales = torch_xla2.interop.call_torch( + vals, scales = torchjax.call_torch( quantize.quantize_torch_int8, new_entry, reduce_axis ) new_scaler = scaler.at[slot, :, update_indexes, :].set(scales) @@ -777,6 +779,7 @@ def _weight_sharding(self, weight, sharding): def _load_from_safetensors(self, path): weights = {} + # pylint: disable-next=all with safetensors.safe_open(path, framework="flax", device="cpu") as f: for key, model_weights in self.pt_model.state_dict().items(): if key == "freqs_cis": diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 2f0ebb2d..fa7806c3 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -148,15 +148,10 @@ def forward( xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) - if self.env.shard_on_batch: - # Gemma 2B - self.env.apply_sharding(xq, axis=0) - self.env.apply_sharding(xk, axis=0) - self.env.apply_sharding(xv, axis=0) - else: - self.env.apply_sharding(xq, axis=2) - self.env.apply_sharding(xk, axis=2) - self.env.apply_sharding(xv, axis=2) + shard_axis = 0 if self.env.shard_on_batch else 2 + self.env.apply_sharding(xq, axis=shard_axis) + self.env.apply_sharding(xk, axis=shard_axis) + self.env.apply_sharding(xv, axis=shard_axis) # Positional embedding. xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) diff --git a/jetstream_pt/torchjax.py b/jetstream_pt/torchjax.py index 85e1135c..b0afd56e 100644 --- a/jetstream_pt/torchjax.py +++ b/jetstream_pt/torchjax.py @@ -1,13 +1,30 @@ +"""This file will serve as proxy APIs for torch_xla2 API. + +It serves 2 purposes: + +1. torch_xla2 APIs are not + stable yet, and changes of it means lots of code edits throughout + this repo. So future changes of torch_xla2 API we only need to edit + this one file. + +2. We can iterate API look and feel in this file and the influence + how it looks like in torch_xla2. +""" + import torch_xla2 +import torch_xla2.interop + +jax_mode = torch_xla2.default_env() -env = torch_xla2.default_env() +call_jax = torch_xla2.interop.call_jax +call_torch = torch_xla2.interop.call_torch -def from_jax(tensors): +def to_torch(tensors): """Wrap a jax Array into XLATensor.""" - return env.j2t_iso(tensors) + return jax_mode.j2t_iso(tensors) -def to_jax(tensors): +def from_torch(tensors): """Unwrap a XLATensor into jax Array.""" - return env.t2j_iso(tensors) + return jax_mode.t2j_iso(tensors) diff --git a/run_interactive.py b/run_interactive.py index e4b9a775..1e81dce5 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -76,7 +76,10 @@ "sharding_config", "", "config file for sharding" ) _SHARD_ON_BATCH = flags.DEFINE_bool( - "shard_on_batch", False, "whether to shard on batch dimension" + "shard_on_batch", + False, + "whether to shard on batch dimension." + "If set true, sharding_config will be ignored.", ) diff --git a/run_server.py b/run_server.py index 60b8e6af..e10f802c 100644 --- a/run_server.py +++ b/run_server.py @@ -90,7 +90,10 @@ "sharding_config", "", "config file for sharding" ) _SHARD_ON_BATCH = flags.DEFINE_bool( - "shard_on_batch", False, "whether to shard on batch dimension" + "shard_on_batch", + False, + "whether to shard on batch dimension" + "If set true, sharding_config will be ignored.", ) diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 0b27b28b..f538f818 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -34,7 +34,7 @@ class LlamaE2ETest(unittest.TestCase): """This test class includes all E2E test for llama2""" - def _to_jax(self, tree): + def _from_torch(self, tree): return pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, tree) def _make_env(self, bf16_enable=True): @@ -116,7 +116,7 @@ def test_jetstream_llama2_seed(self): state_dict = dict(model_orig.state_dict()) state_dict["freqs_cis"] = model_orig.freqs_cis - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) output_tokens_multiple = [] for i in [1, 2, 3]: @@ -189,7 +189,7 @@ def _llama_e2e(self, env, model_arg): engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) decode_state = engine.init_decode_state() slot = 0 # pylint: disable-next=all @@ -273,7 +273,7 @@ def test_llama_e2e_two_addtional_tokens(self): engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) decode_state = engine.init_decode_state() slot = 0 @@ -345,7 +345,7 @@ def test_llama_e2e_four_addtional_tokens(self): engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) decode_state = engine.init_decode_state() slot = 0 @@ -404,7 +404,7 @@ def test_llama_with_original_prefill_decode_32(self): state_dict["freqs_cis"] = model_orig.freqs_cis model_ours = model_exportable.Transformer(model_arg, env) engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) slot = 0 out_tokens = [] @@ -479,7 +479,7 @@ def test_llama_with_original_prefill_decode(self): state_dict["freqs_cis"] = model_orig.freqs_cis model_ours = model_exportable.Transformer(model_arg, env) engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) slot = 0 out_tokens = [] diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 2b4450a8..4a3d87b5 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -23,7 +23,7 @@ from jetstream_pt.third_party.llama import model_original from jetstream_pt.third_party.gemma import model_original as gemma_orig from jetstream_pt.third_party.gemma import model as gemma -from jetstream_pt.torchjax import from_jax, to_jax +from jetstream_pt import torchjax from jetstream_pt import layers from jetstream_pt import cache_manager @@ -76,7 +76,7 @@ def _generate_mask(self, cache_length, pos, seqlen): x = jnp.arange(0, cache_length) cond = jnp.logical_and(x <= pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) - return from_jax(res) + return torchjax.to_torch(res) def _compare_cache(self, cache_torch, cache_jax): _, seq, _, _ = cache_torch.shape @@ -88,7 +88,9 @@ def _make_one_cache_for_generate(self, env, pos): cache_array_k = jnp.zeros(env.cache_shape) cache_array_v = jnp.zeros(env.cache_shape) - cache_array_k, cache_array_v = from_jax((cache_array_k, cache_array_v)) + cache_array_k, cache_array_v = torchjax.to_torch( + (cache_array_k, cache_array_v) + ) cache_decode = cache_manager.KVCacheGenerate( cache_array_k, cache_array_v, pos, None ) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index fbd64003..a2fee0bd 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -20,7 +20,7 @@ import torch_xla2 from jetstream_pt import cache_manager, layers, quantize -from jetstream_pt.torchjax import from_jax, to_jax +from jetstream_pt import torchjax class QuantizationTest(unittest.TestCase): @@ -61,7 +61,7 @@ def test_kv_kernel(self): cache_k_jax = jax.random.normal(key, cache_shape) cache_v_jax = jax.random.normal(key2, cache_shape) - cache_k, cache_v = from_jax((cache_k_jax, cache_v_jax)) + cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None) @@ -70,14 +70,14 @@ def test_kv_kernel(self): xk = jax.random.normal(key, (3, 2, 1, 2)) xv = jax.random.normal(key, (3, 2, 1, 2)) - xq, xk, xv = from_jax((xq, xk, xv)) + xq, xk, xv = torchjax.to_torch((xq, xk, xv)) attention_float = layers.AttentionKernel(env) float_res = attention_float(xq, xk, xv, None, cache) # == - cache_k, cache_v = from_jax((cache_k_jax, cache_v_jax)) + cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) cache_k_int, cache_k_scaler = quantize.quantize_torch_int8( cache_k, (1, 3) )