Skip to content

Commit

Permalink
e2e run fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sixiang-google committed Dec 13, 2024
1 parent 45d2849 commit c35fbd3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
56 changes: 28 additions & 28 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,30 @@ def init_decode_state(self):

def warmup(self, max_length, warmup_samples):
self.init_decode_state()
# interesting_buckets = [
# 64,
# 128,
# 256,
# 512,
# 1024,
# 2048,
# 4096,
# ]
# for length in interesting_buckets:
# if length > max_length:
# break
# log.info(f"Compiling prefill: {length}")
# input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32"))
# self._cached_pref[length] = (
# jax.jit(self._prefill_insert, donate_argnums=(4,))
# .lower(
# self.params,
# tokens=input_data,
# slot=0,
# true_length=length - 1,
# decode_state=self.decode_state)
# .compile()
# )
interesting_buckets = [
64,
128,
256,
512,
1024,
2048,
4096,
]
for length in interesting_buckets:
if length > max_length:
break
log.info(f"Compiling prefill: {length}")
input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32"))
self._cached_pref[length] = (
jax.jit(self._prefill_insert, donate_argnums=(4,))
.lower(
self.params,
tokens=input_data,
slot=0,
true_length=length - 1,
decode_state=self.decode_state)
.compile()
)
# input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32"))
# example_seq_len=16
# num_prompts = max_length//length
Expand Down Expand Up @@ -201,8 +201,8 @@ def prefill(prefill_bucket, prefill_len):
return prefill_result
else:
prefill_fn = self._prefill_insert_batch
# if (cached := self._cached_pref_batch.get(prefill_len)) is not None:
# prefill_fn = cached
if (cached := self._cached_pref_batch.get(prefill_len)) is not None:
prefill_fn = cached
positions = np.concatenate([np.arange(0, row.tokens.shape[0]) for (slot, row) in prefill_bucket])
positions = jnp.array(positions)

Expand Down Expand Up @@ -350,14 +350,14 @@ def detokenize():
if len(self.prefill_buckets[num_tokens]) * num_tokens == 1024:
prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens)
for (first_token, slot, row) in prefill_results:
log.info(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} to detokenize backlog")
log.info(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog")
self.detokenize_backlog.put((first_token, True, row.id, slot), block = True)
self.prefill_buckets[num_tokens] = []

for num_tokens in self.prefill_buckets.keys():
prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens)
for (first_token, slot, row) in prefill_results:
log.info(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} to detokenize backlog")
log.info(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog")
self.detokenize_backlog.put((first_token, True, row.id, slot), block = True)
self.prefill_buckets = defaultdict(list)
while slot_to_id:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def insert(
unboxed_prefix = max_utils.unbox_logicallypartioned(prefix)

unboxed_prefix["cache"] = self._maybe_unstack_prefill_result_cache(unboxed_prefix["cache"])

# jax.debug.print("Inserting cache slot {} start_idx {} seq_len {}", slot, start_idx, seq_len)
# example = unboxed_prefix["cache"]["decoder"]['layers_0']['self_attention']['AttentionOp_0']
# for key in example.keys():
# jax.debug.print("{} shape: {}", key, example[key].shape)
Expand Down Expand Up @@ -549,7 +549,7 @@ def copy(path, partial_cache, full_cache, annotations):
slice_size = list(partial_cache.shape)
slice_size[seqlen_index] = seq_len

slice_size = tuple(slice_size)
slice_size = tuple(slice_size)
partial_cache = jax.lax.dynamic_slice(
partial_cache,
start_indices,
Expand Down

0 comments on commit c35fbd3

Please sign in to comment.