Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed May 14, 2024
1 parent b6bb1fa commit b0d861a
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 109 deletions.
2 changes: 1 addition & 1 deletion install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

TORCHXLA_TAG=6dccf0a02d7828516bdb589f2ae0dc79b64488fa # updated May 10, 2024
TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024
JETSTREAM_TAG=v0.2.1

# Uninstall existing jax
Expand Down
20 changes: 10 additions & 10 deletions jetstream_pt/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import jax
import jax.numpy as jnp
import torch
from jetstream_pt.torchjax import from_jax, to_jax
from jetstream_pt import torchjax


# pylint: disable-next=all
Expand Down Expand Up @@ -49,7 +49,7 @@ def update(self, key, value):
self.cache_v = value
if self.kv_quantize: # pretend to be quantized
bsz, _, seq, _ = key.shape
ones = from_jax(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16))
ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16))
return key, value, ones, ones

return key, value
Expand All @@ -62,15 +62,15 @@ def state(self):
# pylint: disable-next=all
def KVCachePrefill_flatten(cache):
return (
to_jax((cache.cache_k, cache.cache_v)),
torchjax.from_torch((cache.cache_k, cache.cache_v)),
cache.kv_quantize,
)


# pylint: disable-next=all
def KVCachePrefill_unflatten(auxdata, data):
cache = KVCachePrefill(auxdata)
cache_k, cache_v = to_jax(data)
cache_k, cache_v = torchjax.from_torch(data)
cache.cache_k = cache_k
cache.cache_v = cache_v

Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(

def update(self, key, value):
"""Update kv cache"""
keyj, valuej = from_jax((key, value))
keyj, valuej = torchjax.to_torch((key, value))
# pylint: disable-next=all
self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj)
# pylint: disable-next=all
Expand All @@ -118,7 +118,7 @@ def empty(cls, shape, device, bf16_enable):
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
k = jnp.zeros(shape, device=device, dtype=default_dtype)
v = jnp.zeros(shape, device=device, dtype=default_dtype)
k, v = from_jax((k, v))
k, v = torchjax.to_torch((k, v))
return cls(k, v, 0, device)


Expand All @@ -133,7 +133,7 @@ def KVCacheGenerate_flatten(cache):
# pylint: disable-next=all
def KVCacheGenerate_unflatten(auxdata, data):
position, sharding = auxdata
cache_k, cache_v = from_jax(data)
cache_k, cache_v = torchjax.to_torch(data)
cache = KVCacheGenerate(cache_k, cache_v, position, sharding)
return cache

Expand Down Expand Up @@ -166,11 +166,11 @@ def __init__(

def state(self):
"""Get kv cache state"""
return to_jax((self.cache_k, self.cache_v))
return torchjax.from_torch((self.cache_k, self.cache_v))

def scalers(self):
"""Get kv cache scalers"""
return to_jax((self.k_scaler, self.v_scaler))
return torchjax.from_torch((self.k_scaler, self.v_scaler))

@classmethod
# pylint: disable-next=all
Expand All @@ -182,7 +182,7 @@ def empty(cls, shape, device, bf16_enable):
kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16)
vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16)

cache_k, cache_v, kscaler, vscaler = from_jax(
cache_k, cache_v, kscaler, vscaler = torchjax.to_torch(
(cache_k, cache_v, kscaler, vscaler)
)
return cls(cache_k, cache_v, kscaler, vscaler, 0, device)
Expand Down
68 changes: 27 additions & 41 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt.torchjax import from_jax, to_jax
from jetstream_pt import torchjax
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.third_party.llama import model_exportable, model_args
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
Expand Down Expand Up @@ -151,27 +151,31 @@ def _call_model_generate(
if self.env.enable_kv_quantization:
caches_obj = [
cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes)
for (k, v), (ks, vs) in from_jax(list(zip(caches, cache_scales)))
for (k, v), (ks, vs) in torchjax.to_torch(
list(zip(caches, cache_scales))
)
]
else:
caches_obj = [
cache_manager.KVCacheGenerate(
k, v, input_indexes, self.cache_sharding
)
for k, v in from_jax(caches)
for k, v in torchjax.to_torch(caches)
]
mask = jnp.expand_dims(mask, (1, 2))

args = (tokens, input_pos, caches_obj, mask)
paramst, argst = from_jax((weights, args))
paramst, argst = torchjax.to_torch((weights, args))
with self._lock:
with torch_xla2.default_env():
with torchjax.jax_mode:
# The mode is needed so that tensors created inside of
# the model (such as via torch.ones etc) also have the right type
res = torch.func.functional_call(self.pt_model, paramst, argst)
updated_caches = [c.state() for c in caches_obj]
scales = []
if self.env.enable_kv_quantization:
scales = [c.scalers() for c in caches_obj]
return to_jax((res, updated_caches, scales))
return torchjax.from_torch((res, updated_caches, scales))

@functools.partial(
jax.jit,
Expand All @@ -190,12 +194,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
mask = jnp.triu(mask, k=1)
args = (tokens, input_indexes, caches, mask)

paramst, argst = from_jax((weights, args))
paramst, argst = torchjax.to_torch((weights, args))
with self._lock:
with torch_xla2.default_env():
res = torch.func.functional_call(self.pt_model, paramst, argst)[0]
caches_res = [c.state() for c in caches]
return to_jax((res, caches_res))
return torchjax.from_torch((res, caches_res))

def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
if len(logits.shape) == 2:
Expand Down Expand Up @@ -561,7 +565,7 @@ def _load_from_state_dict(self, path):
for key, model_weights in self.pt_model.state_dict().items():
assert key in state_dict, f"key: {key} not found"
arr = jax.device_put(
to_jax(state_dict[key]), self.env.sharding_by_name(key)
torchjax.from_torch(state_dict[key]), self.env.sharding_by_name(key)
)
assert tuple(model_weights.shape) == tuple(
arr.shape
Expand Down Expand Up @@ -602,41 +606,23 @@ def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]:

def get_prefix_destination_sharding(self) -> Prefix:
"""Returns the shardings necessary to transfer data between engines."""
if self.env.shard_on_batch:
return Prefix(
self.replicated, # cache is replicated because bs=1 for prefill
self.replicated,
self.replicated,
)
else:
return Prefix(
self.replicated,
self.cache_sharding,
self.replicated,
)
return Prefix(
self.replicated,
self.replicated if self.env.shard_on_batch else self.cache_sharding,
self.replicated,
)

def get_decode_state_sharding(self) -> DecodeState:
"""Gets the shardings corresponding to the decode state."""
if self.env.shard_on_batch:
return DecodeState(
self.x_sharding, # shard on batch
self.cache_sharding,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
)
else:
return DecodeState(
self.replicated, # shard on batch
self.cache_sharding,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
)
return DecodeState(
self.x_sharding if self.env.shard_on_batch else self.replicated,
self.cache_sharding,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
)

def get_prefix_sequence_ddim(self) -> Any:
"""Returns the index of the sequence dim in the prefix type."""
Expand Down
24 changes: 8 additions & 16 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,8 @@ def __call__(self, xq, xk, xv, mask, cache):
output = output[:, :, 0:1, :]
# For XLA matmul performance boost
# output = torch.matmul(scores, values)
if self.env.shard_on_batch:
self.env.apply_sharding(output, axis=0)
else:
self.env.apply_sharding(output, axis=1)
shard_axis = 0 if self.env.shard_on_batch else 1
self.env.apply_sharding(output, axis=shard_axis)
return output


Expand Down Expand Up @@ -225,10 +223,8 @@ def __call__(self, xq, xk, xv, mask, cache):
if seqlen == 1:
output = output[:, :, 0:1, :]
# output = torch.matmul(scores, values)
if self.env.shard_on_batch:
self.env.apply_sharding(output, axis=0)
else:
self.env.apply_sharding(output, axis=1)
shard_axis = 0 if self.env.shard_on_batch else 1
self.env.apply_sharding(output, axis=shard_axis)
return output


Expand Down Expand Up @@ -319,14 +315,10 @@ def forward(
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)

if self.env.shard_on_batch:
self.env.apply_sharding(xq, axis=0)
self.env.apply_sharding(xk, axis=0)
self.env.apply_sharding(xv, axis=0)
else:
self.env.apply_sharding(xq, axis=2)
self.env.apply_sharding(xk, axis=2)
self.env.apply_sharding(xv, axis=2)
shard_axis = 0 if self.env.shard_on_batch else 2
self.env.apply_sharding(xq, axis=shard_axis)
self.env.apply_sharding(xk, axis=shard_axis)
self.env.apply_sharding(xv, axis=shard_axis)

with jax.named_scope("attn_rope"):
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
Expand Down
24 changes: 13 additions & 11 deletions jetstream_pt/ray_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt.torchjax import from_jax, to_jax
from jetstream_pt import torchjax
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model

Expand Down Expand Up @@ -336,21 +336,23 @@ def _call_model_generate(
if self.env.enable_kv_quantization:
caches_obj = [
cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes)
for (k, v), (ks, vs) in from_jax(list(zip(caches, cache_scales)))
for (k, v), (ks, vs) in torchjax.to_torch(
list(zip(caches, cache_scales))
)
]
else:
caches_obj = [
cache_manager.KVCacheGenerate(
k, v, input_indexes, self.cache_sharding
)
for k, v in from_jax(caches)
for k, v in torchjax.to_torch(caches)
]
mask = jnp.expand_dims(mask, (1, 2))

args = (tokens, input_pos, caches_obj, mask)
paramst, argst = from_jax((weights, args))
paramst, argst = torchjax.to_torch((weights, args))
with self._lock:
with torch_xla2.default_env():
with torchjax.jax_mode():
res = torch.func.functional_call(self.pt_model, paramst, argst)
updated_caches = [c.state() for c in caches_obj]
scales = []
Expand All @@ -360,7 +362,7 @@ def _call_model_generate(
current_position + 1
) % self.env.cache_sequence_length

return to_jax(
return torchjax.from_torch(
(
res,
updated_caches,
Expand Down Expand Up @@ -389,12 +391,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
mask = jnp.triu(mask, k=1)
args = (tokens, input_indexes, caches, mask)

paramst, argst = from_jax((weights, args))
paramst, argst = torchjax.to_torch((weights, args))
with self._lock:
with torch_xla2.default_env():
with torchjax.jax_mode:
res = torch.func.functional_call(self.pt_model, paramst, argst)[0]
caches_res = [c.state() for c in caches]
return to_jax((res, caches_res))
return torchjax.from_torch((res, caches_res))

def _sampling(self, logits: Any, batch_size: int) -> np.ndarray:
if len(logits.shape) == 2:
Expand Down Expand Up @@ -503,7 +505,7 @@ def insert(cache, new_entry):
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
def insert(cache, scaler, new_entry):
reduce_axis = (1, 3)
vals, scales = torch_xla2.interop.call_torch(
vals, scales = torchjax.call_torch(
quantize.quantize_torch_int8, new_entry, reduce_axis
)
new_scaler = jax.lax.dynamic_update_slice(
Expand Down Expand Up @@ -602,7 +604,7 @@ def insert(cache, new_entry):
def insert(cache, scaler, new_entry):
new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2))
reduce_axis = (1, 2)
vals, scales = torch_xla2.interop.call_torch(
vals, scales = torchjax.call_torch(
quantize.quantize_torch_int8, new_entry, reduce_axis
)
new_scaler = scaler.at[slot, :, update_indexes, :].set(scales)
Expand Down
13 changes: 4 additions & 9 deletions jetstream_pt/third_party/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,10 @@ def forward(
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

if self.env.shard_on_batch:
# Gemma 2B
self.env.apply_sharding(xq, axis=0)
self.env.apply_sharding(xk, axis=0)
self.env.apply_sharding(xv, axis=0)
else:
self.env.apply_sharding(xq, axis=2)
self.env.apply_sharding(xk, axis=2)
self.env.apply_sharding(xv, axis=2)
shard_axis = 0 if self.env.shard_on_batch else 2
self.env.apply_sharding(xq, axis=shard_axis)
self.env.apply_sharding(xk, axis=shard_axis)
self.env.apply_sharding(xv, axis=shard_axis)

# Positional embedding.
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
Expand Down
Loading

0 comments on commit b0d861a

Please sign in to comment.