Skip to content

Commit

Permalink
Add shard on batch mode. Als update version of torchxla2 (#80)
Browse files Browse the repository at this point in the history
* Add shard on batch mode. Als update version of torchxla2

* Address comments
  • Loading branch information
qihqi authored May 14, 2024
1 parent 648bf48 commit 776c1c4
Show file tree
Hide file tree
Showing 15 changed files with 144 additions and 91 deletions.
1 change: 1 addition & 0 deletions benchmarks/prefill_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import functools
import humanize

# pylint: disable-next=all
from absl import app
from absl import flags
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
import time
# pylint: disable-next=all
from absl import app
from absl import flags

Expand Down
2 changes: 1 addition & 1 deletion install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

TORCHXLA_TAG=jetstream-pytorch
TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024
JETSTREAM_TAG=v0.2.1

# Uninstall existing jax
Expand Down
30 changes: 14 additions & 16 deletions jetstream_pt/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch_xla2
import jax
import jax.numpy as jnp
import torch
from jetstream_pt import torchjax


# pylint: disable-next=all
Expand Down Expand Up @@ -49,9 +49,7 @@ def update(self, key, value):
self.cache_v = value
if self.kv_quantize: # pretend to be quantized
bsz, _, seq, _ = key.shape
ones = torch_xla2.tensor.wrap(
jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)
)
ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16))
return key, value, ones, ones

return key, value
Expand All @@ -64,15 +62,15 @@ def state(self):
# pylint: disable-next=all
def KVCachePrefill_flatten(cache):
return (
torch_xla2.tensor.unwrap((cache.cache_k, cache.cache_v)),
torchjax.from_torch((cache.cache_k, cache.cache_v)),
cache.kv_quantize,
)


# pylint: disable-next=all
def KVCachePrefill_unflatten(auxdata, data):
cache = KVCachePrefill(auxdata)
cache_k, cache_v = torch_xla2.tensor.wrap(data)
cache_k, cache_v = torchjax.from_torch(data)
cache.cache_k = cache_k
cache.cache_v = cache_v

Expand Down Expand Up @@ -102,7 +100,7 @@ def __init__(

def update(self, key, value):
"""Update kv cache"""
keyj, valuej = torch_xla2.tensor.unwrap((key, value))
keyj, valuej = torchjax.to_torch((key, value))
# pylint: disable-next=all
self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj)
# pylint: disable-next=all
Expand All @@ -112,30 +110,30 @@ def update(self, key, value):
def state(self):
"""Get kv cache state"""
# pylint: disable-next=all
return self.cache_k._elem, self.cache_v._elem
return self.cache_k.jax(), self.cache_v.jax()

@classmethod
def empty(cls, shape, device, bf16_enable):
"""Create empty kv caches"""
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
k = jnp.zeros(shape, device=device, dtype=default_dtype)
v = jnp.zeros(shape, device=device, dtype=default_dtype)
k, v = torch_xla2.tensor.wrap((k, v))
k, v = torchjax.to_torch((k, v))
return cls(k, v, 0, device)


# pylint: disable-next=all
def KVCacheGenerate_flatten(cache):
return torch_xla2.tensor.unwrap((cache.cache_k, cache.cache_v)), (
cache.pos,
cache.sharding,
return ((cache.cache_k.jax(), cache.cache_v.jax())), (
cache.pos.jax(),
cache.sharding.jax(),
)


# pylint: disable-next=all
def KVCacheGenerate_unflatten(auxdata, data):
position, sharding = auxdata
cache_k, cache_v = torch_xla2.tensor.wrap(data)
cache_k, cache_v = torchjax.to_torch(data)
cache = KVCacheGenerate(cache_k, cache_v, position, sharding)
return cache

Expand Down Expand Up @@ -168,11 +166,11 @@ def __init__(

def state(self):
"""Get kv cache state"""
return torch_xla2.tensor.unwrap((self.cache_k, self.cache_v))
return torchjax.from_torch((self.cache_k, self.cache_v))

def scalers(self):
"""Get kv cache scalers"""
return torch_xla2.tensor.unwrap((self.k_scaler, self.v_scaler))
return torchjax.from_torch((self.k_scaler, self.v_scaler))

@classmethod
# pylint: disable-next=all
Expand All @@ -184,7 +182,7 @@ def empty(cls, shape, device, bf16_enable):
kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16)
vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16)

cache_k, cache_v, kscaler, vscaler = torch_xla2.tensor.wrap(
cache_k, cache_v, kscaler, vscaler = torchjax.to_torch(
(cache_k, cache_v, kscaler, vscaler)
)
return cls(cache_k, cache_v, kscaler, vscaler, 0, device)
Expand Down
41 changes: 26 additions & 15 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt import torchjax
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.third_party.llama import model_exportable, model_args
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
Expand Down Expand Up @@ -86,8 +87,11 @@ def __init__(
self.y_sharding = env.sharding_by_axis(1)
self.x_sharding = env.sharding_by_axis(0)
self.replicated = env.sharding_by_axis(-1) # replicated

self.cache_sharding = self.env.cache_sharding

jax.config.update("jax_enable_x64", False)

self.prefill = jax.jit(
self.prefill, out_shardings=self.get_prefix_destination_sharding()
)
Expand Down Expand Up @@ -147,7 +151,7 @@ def _call_model_generate(
if self.env.enable_kv_quantization:
caches_obj = [
cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes)
for (k, v), (ks, vs) in torch_xla2.tensor.wrap(
for (k, v), (ks, vs) in torchjax.to_torch(
list(zip(caches, cache_scales))
)
]
Expand All @@ -156,20 +160,22 @@ def _call_model_generate(
cache_manager.KVCacheGenerate(
k, v, input_indexes, self.cache_sharding
)
for k, v in torch_xla2.tensor.wrap(caches)
for k, v in torchjax.to_torch(caches)
]
mask = jnp.expand_dims(mask, (1, 2))

args = (tokens, input_pos, caches_obj, mask)
paramst, argst = torch_xla2.tensor.wrap((weights, args))
paramst, argst = torchjax.to_torch((weights, args))
with self._lock:
with torch_xla2.tensor.XLADispatchMode():
with torchjax.jax_mode:
# The mode is needed so that tensors created inside of
# the model (such as via torch.ones etc) also have the right type
res = torch.func.functional_call(self.pt_model, paramst, argst)
updated_caches = [c.state() for c in caches_obj]
scales = []
if self.env.enable_kv_quantization:
scales = [c.scalers() for c in caches_obj]
return torch_xla2.tensor.unwrap((res, updated_caches, scales))
return torchjax.from_torch((res, updated_caches, scales))

@functools.partial(
jax.jit,
Expand All @@ -188,12 +194,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
mask = jnp.triu(mask, k=1)
args = (tokens, input_indexes, caches, mask)

paramst, argst = torch_xla2.tensor.wrap((weights, args))
paramst, argst = torchjax.to_torch((weights, args))
with self._lock:
with torch_xla2.tensor.XLADispatchMode():
with torchjax.jax_mode:
res = torch.func.functional_call(self.pt_model, paramst, argst)[0]
caches_res = [c.state() for c in caches]
return torch_xla2.tensor.unwrap((res, caches_res))
return torchjax.from_torch((res, caches_res))

def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
if len(logits.shape) == 2:
Expand Down Expand Up @@ -287,20 +293,20 @@ def insert(cache, new_entry):
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
def insert(cache, scaler, new_entry):
reduce_axis = (1, 3)
vals, scales = torch_xla2.extra.call_torch(
vals, scales = torch_xla2.interop.call_torch(
quantize.quantize_torch_int8, new_entry, reduce_axis
)
new_scaler = jax.lax.dynamic_update_slice(
scaler,
scales,
scales.jax(),
[slot, 0, pos, 0],
)
new_scaler = jax.lax.with_sharding_constraint(
new_scaler, self.replicated
)
res = jax.lax.dynamic_update_slice(
cache,
vals,
vals.jax(),
[slot, 0, pos, 0],
)
res = jax.lax.with_sharding_constraint(res, self.cache_sharding)
Expand Down Expand Up @@ -386,7 +392,7 @@ def insert(cache, new_entry):
def insert(cache, scaler, new_entry):
new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2))
reduce_axis = (1, 2)
vals, scales = torch_xla2.extra.call_torch(
vals, scales = torch_xla2.interop.call_torch(
quantize.quantize_torch_int8, new_entry, reduce_axis
)
new_scaler = scaler.at[slot, :, update_indexes, :].set(scales)
Expand Down Expand Up @@ -559,7 +565,7 @@ def _load_from_state_dict(self, path):
for key, model_weights in self.pt_model.state_dict().items():
assert key in state_dict, f"key: {key} not found"
arr = jax.device_put(
torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key)
torchjax.from_torch(state_dict[key]), self.env.sharding_by_name(key)
)
assert tuple(model_weights.shape) == tuple(
arr.shape
Expand Down Expand Up @@ -602,14 +608,14 @@ def get_prefix_destination_sharding(self) -> Prefix:
"""Returns the shardings necessary to transfer data between engines."""
return Prefix(
self.replicated,
self.cache_sharding,
self.replicated if self.env.shard_on_batch else self.cache_sharding,
self.replicated,
)

def get_decode_state_sharding(self) -> DecodeState:
"""Gets the shardings corresponding to the decode state."""
return DecodeState(
self.replicated,
self.x_sharding if self.env.shard_on_batch else self.replicated,
self.cache_sharding,
self.replicated,
self.replicated,
Expand Down Expand Up @@ -663,6 +669,7 @@ def create_pytorch_engine(
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
shard_on_batch=False,
) -> PyTorchEngine:
"""Returns: The pytorch engine."""

Expand Down Expand Up @@ -718,8 +725,12 @@ def create_pytorch_engine(
cache_sequence_length=max_cache_length,
bf16_enable=bf16_enable,
sharding_config_path=sharding_config,
shard_on_batch=shard_on_batch,
)

if shard_on_batch and sharding_config:
print("WARNING: with sharding_on_batch sharding config is ignored.")

if model_name.startswith("llama"):

args = model_args.get_model_args(
Expand Down
15 changes: 12 additions & 3 deletions jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class JetEngineEnvironmentData:

sharding_config_path: str = ""

# Whether to shard on batch dimension. i.e. data parallel.
shard_on_batch: bool = False


# pylint: disable-next=all
class JetEngineEnvironment:
Expand All @@ -97,9 +100,12 @@ def __init__(self, data: JetEngineEnvironmentData):
self.x_sharding = jsharding.NamedSharding(self._mesh, P("x"))
self.replicated = jsharding.NamedSharding(self._mesh, P())

cache_sharding_axis = self.attention_kv_axis_names.index(
self.kv_cache_shard_axis
)
if data.shard_on_batch:
cache_sharding_axis = 0
else:
cache_sharding_axis = self.attention_kv_axis_names.index(
self.kv_cache_shard_axis
)

if self.cache_shape[cache_sharding_axis] == 1:
# cannot shard on an axis that is 1
Expand Down Expand Up @@ -169,6 +175,9 @@ def make_caches_generate(self):

def sharding_by_name(self, name):
"""Create sharding specified in the config."""
if self.shard_on_batch:
return self.shading_by_axis(0) # batch dimension

if name in self._sharding_config:
return self.sharding_by_axis(self._sharding_config[name])

Expand Down
Loading

0 comments on commit 776c1c4

Please sign in to comment.