diff --git a/benchmarks/prefill_offline.py b/benchmarks/prefill_offline.py index a6db15ab..03bf4180 100644 --- a/benchmarks/prefill_offline.py +++ b/benchmarks/prefill_offline.py @@ -17,6 +17,7 @@ import functools import humanize +# pylint: disable-next=all from absl import app from absl import flags import numpy as np diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 2f480930..df788591 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -15,6 +15,7 @@ import logging import os import time +# pylint: disable-next=all from absl import app from absl import flags diff --git a/install_everything.sh b/install_everything.sh index 46482302..82687ed9 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -TORCHXLA_TAG=jetstream-pytorch +TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024 JETSTREAM_TAG=v0.2.1 # Uninstall existing jax diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index d3f2d284..75bd75c0 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch_xla2 import jax import jax.numpy as jnp import torch +from jetstream_pt import torchjax # pylint: disable-next=all @@ -49,9 +49,7 @@ def update(self, key, value): self.cache_v = value if self.kv_quantize: # pretend to be quantized bsz, _, seq, _ = key.shape - ones = torch_xla2.tensor.wrap( - jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16) - ) + ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)) return key, value, ones, ones return key, value @@ -64,7 +62,7 @@ def state(self): # pylint: disable-next=all def KVCachePrefill_flatten(cache): return ( - torch_xla2.tensor.unwrap((cache.cache_k, cache.cache_v)), + torchjax.from_torch((cache.cache_k, cache.cache_v)), cache.kv_quantize, ) @@ -72,7 +70,7 @@ def KVCachePrefill_flatten(cache): # pylint: disable-next=all def KVCachePrefill_unflatten(auxdata, data): cache = KVCachePrefill(auxdata) - cache_k, cache_v = torch_xla2.tensor.wrap(data) + cache_k, cache_v = torchjax.from_torch(data) cache.cache_k = cache_k cache.cache_v = cache_v @@ -102,7 +100,7 @@ def __init__( def update(self, key, value): """Update kv cache""" - keyj, valuej = torch_xla2.tensor.unwrap((key, value)) + keyj, valuej = torchjax.to_torch((key, value)) # pylint: disable-next=all self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) # pylint: disable-next=all @@ -112,7 +110,7 @@ def update(self, key, value): def state(self): """Get kv cache state""" # pylint: disable-next=all - return self.cache_k._elem, self.cache_v._elem + return self.cache_k.jax(), self.cache_v.jax() @classmethod def empty(cls, shape, device, bf16_enable): @@ -120,22 +118,22 @@ def empty(cls, shape, device, bf16_enable): default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 k = jnp.zeros(shape, device=device, dtype=default_dtype) v = jnp.zeros(shape, device=device, dtype=default_dtype) - k, v = torch_xla2.tensor.wrap((k, v)) + k, v = torchjax.to_torch((k, v)) return cls(k, v, 0, device) # pylint: disable-next=all def KVCacheGenerate_flatten(cache): - return torch_xla2.tensor.unwrap((cache.cache_k, cache.cache_v)), ( - cache.pos, - cache.sharding, + return ((cache.cache_k.jax(), cache.cache_v.jax())), ( + cache.pos.jax(), + cache.sharding.jax(), ) # pylint: disable-next=all def KVCacheGenerate_unflatten(auxdata, data): position, sharding = auxdata - cache_k, cache_v = torch_xla2.tensor.wrap(data) + cache_k, cache_v = torchjax.to_torch(data) cache = KVCacheGenerate(cache_k, cache_v, position, sharding) return cache @@ -168,11 +166,11 @@ def __init__( def state(self): """Get kv cache state""" - return torch_xla2.tensor.unwrap((self.cache_k, self.cache_v)) + return torchjax.from_torch((self.cache_k, self.cache_v)) def scalers(self): """Get kv cache scalers""" - return torch_xla2.tensor.unwrap((self.k_scaler, self.v_scaler)) + return torchjax.from_torch((self.k_scaler, self.v_scaler)) @classmethod # pylint: disable-next=all @@ -184,7 +182,7 @@ def empty(cls, shape, device, bf16_enable): kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) - cache_k, cache_v, kscaler, vscaler = torch_xla2.tensor.wrap( + cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) ) return cls(cache_k, cache_v, kscaler, vscaler, 0, device) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 895c8e27..792d2d5b 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -33,6 +33,7 @@ from jetstream_pt import cache_manager from jetstream_pt import quantize +from jetstream_pt import torchjax from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model @@ -86,8 +87,11 @@ def __init__( self.y_sharding = env.sharding_by_axis(1) self.x_sharding = env.sharding_by_axis(0) self.replicated = env.sharding_by_axis(-1) # replicated + self.cache_sharding = self.env.cache_sharding + jax.config.update("jax_enable_x64", False) + self.prefill = jax.jit( self.prefill, out_shardings=self.get_prefix_destination_sharding() ) @@ -147,7 +151,7 @@ def _call_model_generate( if self.env.enable_kv_quantization: caches_obj = [ cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes) - for (k, v), (ks, vs) in torch_xla2.tensor.wrap( + for (k, v), (ks, vs) in torchjax.to_torch( list(zip(caches, cache_scales)) ) ] @@ -156,20 +160,22 @@ def _call_model_generate( cache_manager.KVCacheGenerate( k, v, input_indexes, self.cache_sharding ) - for k, v in torch_xla2.tensor.wrap(caches) + for k, v in torchjax.to_torch(caches) ] mask = jnp.expand_dims(mask, (1, 2)) args = (tokens, input_pos, caches_obj, mask) - paramst, argst = torch_xla2.tensor.wrap((weights, args)) + paramst, argst = torchjax.to_torch((weights, args)) with self._lock: - with torch_xla2.tensor.XLADispatchMode(): + with torchjax.jax_mode: + # The mode is needed so that tensors created inside of + # the model (such as via torch.ones etc) also have the right type res = torch.func.functional_call(self.pt_model, paramst, argst) updated_caches = [c.state() for c in caches_obj] scales = [] if self.env.enable_kv_quantization: scales = [c.scalers() for c in caches_obj] - return torch_xla2.tensor.unwrap((res, updated_caches, scales)) + return torchjax.from_torch((res, updated_caches, scales)) @functools.partial( jax.jit, @@ -188,12 +194,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes): mask = jnp.triu(mask, k=1) args = (tokens, input_indexes, caches, mask) - paramst, argst = torch_xla2.tensor.wrap((weights, args)) + paramst, argst = torchjax.to_torch((weights, args)) with self._lock: - with torch_xla2.tensor.XLADispatchMode(): + with torchjax.jax_mode: res = torch.func.functional_call(self.pt_model, paramst, argst)[0] caches_res = [c.state() for c in caches] - return torch_xla2.tensor.unwrap((res, caches_res)) + return torchjax.from_torch((res, caches_res)) def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: if len(logits.shape) == 2: @@ -287,12 +293,12 @@ def insert(cache, new_entry): @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, scaler, new_entry): reduce_axis = (1, 3) - vals, scales = torch_xla2.extra.call_torch( + vals, scales = torch_xla2.interop.call_torch( quantize.quantize_torch_int8, new_entry, reduce_axis ) new_scaler = jax.lax.dynamic_update_slice( scaler, - scales, + scales.jax(), [slot, 0, pos, 0], ) new_scaler = jax.lax.with_sharding_constraint( @@ -300,7 +306,7 @@ def insert(cache, scaler, new_entry): ) res = jax.lax.dynamic_update_slice( cache, - vals, + vals.jax(), [slot, 0, pos, 0], ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) @@ -386,7 +392,7 @@ def insert(cache, new_entry): def insert(cache, scaler, new_entry): new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2)) reduce_axis = (1, 2) - vals, scales = torch_xla2.extra.call_torch( + vals, scales = torch_xla2.interop.call_torch( quantize.quantize_torch_int8, new_entry, reduce_axis ) new_scaler = scaler.at[slot, :, update_indexes, :].set(scales) @@ -559,7 +565,7 @@ def _load_from_state_dict(self, path): for key, model_weights in self.pt_model.state_dict().items(): assert key in state_dict, f"key: {key} not found" arr = jax.device_put( - torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key) + torchjax.from_torch(state_dict[key]), self.env.sharding_by_name(key) ) assert tuple(model_weights.shape) == tuple( arr.shape @@ -602,14 +608,14 @@ def get_prefix_destination_sharding(self) -> Prefix: """Returns the shardings necessary to transfer data between engines.""" return Prefix( self.replicated, - self.cache_sharding, + self.replicated if self.env.shard_on_batch else self.cache_sharding, self.replicated, ) def get_decode_state_sharding(self) -> DecodeState: """Gets the shardings corresponding to the decode state.""" return DecodeState( - self.replicated, + self.x_sharding if self.env.shard_on_batch else self.replicated, self.cache_sharding, self.replicated, self.replicated, @@ -663,6 +669,7 @@ def create_pytorch_engine( quantize_kv=False, max_cache_length=1024, sharding_config=None, + shard_on_batch=False, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -718,8 +725,12 @@ def create_pytorch_engine( cache_sequence_length=max_cache_length, bf16_enable=bf16_enable, sharding_config_path=sharding_config, + shard_on_batch=shard_on_batch, ) + if shard_on_batch and sharding_config: + print("WARNING: with sharding_on_batch sharding config is ignored.") + if model_name.startswith("llama"): args = model_args.get_model_args( diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 7458636e..5d128990 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -75,6 +75,9 @@ class JetEngineEnvironmentData: sharding_config_path: str = "" + # Whether to shard on batch dimension. i.e. data parallel. + shard_on_batch: bool = False + # pylint: disable-next=all class JetEngineEnvironment: @@ -97,9 +100,12 @@ def __init__(self, data: JetEngineEnvironmentData): self.x_sharding = jsharding.NamedSharding(self._mesh, P("x")) self.replicated = jsharding.NamedSharding(self._mesh, P()) - cache_sharding_axis = self.attention_kv_axis_names.index( - self.kv_cache_shard_axis - ) + if data.shard_on_batch: + cache_sharding_axis = 0 + else: + cache_sharding_axis = self.attention_kv_axis_names.index( + self.kv_cache_shard_axis + ) if self.cache_shape[cache_sharding_axis] == 1: # cannot shard on an axis that is 1 @@ -169,6 +175,9 @@ def make_caches_generate(self): def sharding_by_name(self, name): """Create sharding specified in the config.""" + if self.shard_on_batch: + return self.shading_by_axis(0) # batch dimension + if name in self._sharding_config: return self.sharding_by_axis(self._sharding_config[name]) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index a3d5260b..f4dea094 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -23,7 +23,6 @@ import torch.nn.functional as F import jax import jax.numpy as jnp -import torch_xla2 class Int8Embedding(torch.nn.Module): @@ -156,10 +155,7 @@ def __call__(self, xq, xk, xv, mask, cache): with jax.named_scope("attn_mat1"): ## Attention start # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch_xla2.extra.call_jax( - jnp.einsum, "ikjl,ikml->ikjm", xq, keys - ) / math.sqrt(head_dim) - self.env.apply_sharding(scores, axis=1) + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) if mask is not None: # if mask.shape != (1,1,16,16): # breakpoint() @@ -171,14 +167,13 @@ def __call__(self, xq, xk, xv, mask, cache): # output = torch.einsum( # "ikjm,ikml->ikjl", scores, values # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch_xla2.extra.call_jax( - jnp.einsum, "ikjm,ikml->ikjl", scores, values - ) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) if seqlen == 1: output = output[:, :, 0:1, :] # For XLA matmul performance boost # output = torch.matmul(scores, values) - self.env.apply_sharding(output, axis=1) + shard_axis = 0 if self.env.shard_on_batch else 1 + self.env.apply_sharding(output, axis=shard_axis) return output @@ -210,29 +205,26 @@ def __call__(self, xq, xk, xv, mask, cache): ## Attention start # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) scores = ( - torch_xla2.extra.call_jax(jnp.einsum, "ikjl,ikml->ikjm", xq, keys) + torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) ) - self.env.apply_sharding(scores, axis=1) if mask is not None: scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) with jax.named_scope("attn_soft"): scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - self.env.apply_sharding(scores, axis=1) with jax.named_scope("attn_mat2"): # output = torch.einsum( # "ikjm,ikml->ikjl", scores, values # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch_xla2.extra.call_jax( - jnp.einsum, "ikjm,ikml->ikjl", scores, values - ) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) if seqlen == 1: output = output[:, :, 0:1, :] # output = torch.matmul(scores, values) - self.env.apply_sharding(output, axis=1) + shard_axis = 0 if self.env.shard_on_batch else 1 + self.env.apply_sharding(output, axis=shard_axis) return output @@ -323,9 +315,10 @@ def forward( xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) - self.env.apply_sharding(xq, axis=2) - self.env.apply_sharding(xk, axis=2) - self.env.apply_sharding(xv, axis=2) + shard_axis = 0 if self.env.shard_on_batch else 2 + self.env.apply_sharding(xq, axis=shard_axis) + self.env.apply_sharding(xk, axis=shard_axis) + self.env.apply_sharding(xv, axis=shard_axis) with jax.named_scope("attn_rope"): xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index ff49e624..947acf59 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -39,6 +39,7 @@ from jetstream_pt import cache_manager from jetstream_pt import quantize +from jetstream_pt import torchjax from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model @@ -335,7 +336,7 @@ def _call_model_generate( if self.env.enable_kv_quantization: caches_obj = [ cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes) - for (k, v), (ks, vs) in torch_xla2.tensor.wrap( + for (k, v), (ks, vs) in torchjax.to_torch( list(zip(caches, cache_scales)) ) ] @@ -344,14 +345,14 @@ def _call_model_generate( cache_manager.KVCacheGenerate( k, v, input_indexes, self.cache_sharding ) - for k, v in torch_xla2.tensor.wrap(caches) + for k, v in torchjax.to_torch(caches) ] mask = jnp.expand_dims(mask, (1, 2)) args = (tokens, input_pos, caches_obj, mask) - paramst, argst = torch_xla2.tensor.wrap((weights, args)) + paramst, argst = torchjax.to_torch((weights, args)) with self._lock: - with torch_xla2.tensor.XLADispatchMode(): + with torchjax.jax_mode(): res = torch.func.functional_call(self.pt_model, paramst, argst) updated_caches = [c.state() for c in caches_obj] scales = [] @@ -361,7 +362,7 @@ def _call_model_generate( current_position + 1 ) % self.env.cache_sequence_length - return torch_xla2.tensor.unwrap( + return torchjax.from_torch( ( res, updated_caches, @@ -390,12 +391,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes): mask = jnp.triu(mask, k=1) args = (tokens, input_indexes, caches, mask) - paramst, argst = torch_xla2.tensor.wrap((weights, args)) + paramst, argst = torchjax.to_torch((weights, args)) with self._lock: - with torch_xla2.tensor.XLADispatchMode(): + with torchjax.jax_mode: res = torch.func.functional_call(self.pt_model, paramst, argst)[0] caches_res = [c.state() for c in caches] - return torch_xla2.tensor.unwrap((res, caches_res)) + return torchjax.from_torch((res, caches_res)) def _sampling(self, logits: Any, batch_size: int) -> np.ndarray: if len(logits.shape) == 2: @@ -504,7 +505,7 @@ def insert(cache, new_entry): @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, scaler, new_entry): reduce_axis = (1, 3) - vals, scales = torch_xla2.extra.call_torch( + vals, scales = torchjax.call_torch( quantize.quantize_torch_int8, new_entry, reduce_axis ) new_scaler = jax.lax.dynamic_update_slice( @@ -603,7 +604,7 @@ def insert(cache, new_entry): def insert(cache, scaler, new_entry): new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2)) reduce_axis = (1, 2) - vals, scales = torch_xla2.extra.call_torch( + vals, scales = torchjax.call_torch( quantize.quantize_torch_int8, new_entry, reduce_axis ) new_scaler = scaler.at[slot, :, update_indexes, :].set(scales) @@ -778,6 +779,7 @@ def _weight_sharding(self, weight, sharding): def _load_from_safetensors(self, path): weights = {} + # pylint: disable-next=all with safetensors.safe_open(path, framework="flax", device="cpu") as f: for key, model_weights in self.pt_model.state_dict().items(): if key == "freqs_cis": diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 32647949..fa7806c3 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -148,15 +148,10 @@ def forward( xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) - if self.num_kv_heads > 1: - self.env.apply_sharding(xq, axis=2) - self.env.apply_sharding(xk, axis=2) - self.env.apply_sharding(xv, axis=2) - else: - # Gemma 2B - self.env.apply_sharding(xq, axis=3) - self.env.apply_sharding(xk, axis=3) - self.env.apply_sharding(xv, axis=3) + shard_axis = 0 if self.env.shard_on_batch else 2 + self.env.apply_sharding(xq, axis=shard_axis) + self.env.apply_sharding(xk, axis=shard_axis) + self.env.apply_sharding(xv, axis=shard_axis) # Positional embedding. xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) diff --git a/jetstream_pt/torchjax.py b/jetstream_pt/torchjax.py new file mode 100644 index 00000000..b0afd56e --- /dev/null +++ b/jetstream_pt/torchjax.py @@ -0,0 +1,30 @@ +"""This file will serve as proxy APIs for torch_xla2 API. + +It serves 2 purposes: + +1. torch_xla2 APIs are not + stable yet, and changes of it means lots of code edits throughout + this repo. So future changes of torch_xla2 API we only need to edit + this one file. + +2. We can iterate API look and feel in this file and the influence + how it looks like in torch_xla2. +""" + +import torch_xla2 +import torch_xla2.interop + +jax_mode = torch_xla2.default_env() + +call_jax = torch_xla2.interop.call_jax +call_torch = torch_xla2.interop.call_torch + + +def to_torch(tensors): + """Wrap a jax Array into XLATensor.""" + return jax_mode.j2t_iso(tensors) + + +def from_torch(tensors): + """Unwrap a XLATensor into jax Array.""" + return jax_mode.t2j_iso(tensors) diff --git a/run_interactive.py b/run_interactive.py index d0d3b21f..1e81dce5 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -75,6 +75,12 @@ _SHARDING_CONFIG = flags.DEFINE_string( "sharding_config", "", "config file for sharding" ) +_SHARD_ON_BATCH = flags.DEFINE_bool( + "shard_on_batch", + False, + "whether to shard on batch dimension." + "If set true, sharding_config will be ignored.", +) def create_engine(): @@ -97,6 +103,7 @@ def create_engine(): quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, sharding_config=_SHARDING_CONFIG.value, + shard_on_batch=_SHARD_ON_BATCH.value, ) print("Initialize engine", time.perf_counter() - start) diff --git a/run_server.py b/run_server.py index 161af9bd..e10f802c 100644 --- a/run_server.py +++ b/run_server.py @@ -89,6 +89,12 @@ _SHARDING_CONFIG = flags.DEFINE_string( "sharding_config", "", "config file for sharding" ) +_SHARD_ON_BATCH = flags.DEFINE_bool( + "shard_on_batch", + False, + "whether to shard on batch dimension" + "If set true, sharding_config will be ignored.", +) # pylint: disable-next=all @@ -116,6 +122,7 @@ def main(argv: Sequence[str]): quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, sharding_config=sharding_config_path, + shard_on_batch=_SHARD_ON_BATCH.value, ) server_config = ServerConfig( interleaved_slices=(_PLATFORM.value,), diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 0b27b28b..f538f818 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -34,7 +34,7 @@ class LlamaE2ETest(unittest.TestCase): """This test class includes all E2E test for llama2""" - def _to_jax(self, tree): + def _from_torch(self, tree): return pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, tree) def _make_env(self, bf16_enable=True): @@ -116,7 +116,7 @@ def test_jetstream_llama2_seed(self): state_dict = dict(model_orig.state_dict()) state_dict["freqs_cis"] = model_orig.freqs_cis - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) output_tokens_multiple = [] for i in [1, 2, 3]: @@ -189,7 +189,7 @@ def _llama_e2e(self, env, model_arg): engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) decode_state = engine.init_decode_state() slot = 0 # pylint: disable-next=all @@ -273,7 +273,7 @@ def test_llama_e2e_two_addtional_tokens(self): engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) decode_state = engine.init_decode_state() slot = 0 @@ -345,7 +345,7 @@ def test_llama_e2e_four_addtional_tokens(self): engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) decode_state = engine.init_decode_state() slot = 0 @@ -404,7 +404,7 @@ def test_llama_with_original_prefill_decode_32(self): state_dict["freqs_cis"] = model_orig.freqs_cis model_ours = model_exportable.Transformer(model_arg, env) engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) slot = 0 out_tokens = [] @@ -479,7 +479,7 @@ def test_llama_with_original_prefill_decode(self): state_dict["freqs_cis"] = model_orig.freqs_cis model_ours = model_exportable.Transformer(model_arg, env) engine = PyTorchEngine(pt_model=model_ours, env=env) - params = self._to_jax(state_dict) + params = self._from_torch(state_dict) slot = 0 out_tokens = [] diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 775c780e..4a3d87b5 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -13,11 +13,9 @@ # limitations under the License. import unittest -import random import jax import jax.numpy as jnp import torch -from torch.utils import _pytree as pytree import torch_xla2 from . import helpers @@ -25,6 +23,7 @@ from jetstream_pt.third_party.llama import model_original from jetstream_pt.third_party.gemma import model_original as gemma_orig from jetstream_pt.third_party.gemma import model as gemma +from jetstream_pt import torchjax from jetstream_pt import layers from jetstream_pt import cache_manager @@ -36,6 +35,7 @@ class ModelComponentTest(unittest.TestCase): def setUp(self): """setup torch env""" jax.config.update("jax_platform_name", "cpu") + jax.config.update("jax_enable_x64", False) torch.set_default_dtype(torch.float32) def _prefill_mask(self, seqlen, start_pos): @@ -63,9 +63,7 @@ def _make_freqs_cis(self, model_arg, seqlen, start_pos): return freqs_cis def _to_xla_tensor(self, tree): - return pytree.tree_map_only( - torch.Tensor, torch_xla2.tensor.move_to_device, tree - ) + return torch_xla2.default_env().to_xla(tree) def _call_xla_model(self, model, weights, args): with jax.default_device(jax.devices("cpu")[0]): @@ -78,7 +76,7 @@ def _generate_mask(self, cache_length, pos, seqlen): x = jnp.arange(0, cache_length) cond = jnp.logical_and(x <= pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) - return torch_xla2.tensor.wrap(res) + return torchjax.to_torch(res) def _compare_cache(self, cache_torch, cache_jax): _, seq, _, _ = cache_torch.shape @@ -90,7 +88,7 @@ def _make_one_cache_for_generate(self, env, pos): cache_array_k = jnp.zeros(env.cache_shape) cache_array_v = jnp.zeros(env.cache_shape) - cache_array_k, cache_array_v = torch_xla2.tensor.wrap( + cache_array_k, cache_array_v = torchjax.to_torch( (cache_array_k, cache_array_v) ) cache_decode = cache_manager.KVCacheGenerate( diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 6b59ba18..a2fee0bd 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -20,6 +20,7 @@ import torch_xla2 from jetstream_pt import cache_manager, layers, quantize +from jetstream_pt import torchjax class QuantizationTest(unittest.TestCase): @@ -27,7 +28,7 @@ class QuantizationTest(unittest.TestCase): def _xla_tensor(self, shape): res = torch.randn(shape, dtype=torch.bfloat16) - return torch_xla2.tensor.move_to_device(res) + return torch_xla2.default_env().to_xla(res) def test_kv_cache(self): """test kv cache quantization""" @@ -60,7 +61,7 @@ def test_kv_kernel(self): cache_k_jax = jax.random.normal(key, cache_shape) cache_v_jax = jax.random.normal(key2, cache_shape) - cache_k, cache_v = torch_xla2.tensor.wrap((cache_k_jax, cache_v_jax)) + cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None) @@ -69,14 +70,14 @@ def test_kv_kernel(self): xk = jax.random.normal(key, (3, 2, 1, 2)) xv = jax.random.normal(key, (3, 2, 1, 2)) - xq, xk, xv = torch_xla2.tensor.wrap((xq, xk, xv)) + xq, xk, xv = torchjax.to_torch((xq, xk, xv)) attention_float = layers.AttentionKernel(env) float_res = attention_float(xq, xk, xv, None, cache) # == - cache_k, cache_v = torch_xla2.tensor.wrap((cache_k_jax, cache_v_jax)) + cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) cache_k_int, cache_k_scaler = quantize.quantize_torch_int8( cache_k, (1, 3) )