-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shard on batch mode. Als update version of torchxla2 #80
Conversation
@@ -12,10 +12,10 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
import torch_xla2 | |||
import jax | |||
import jax.numpy as jnp | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! It's more meaningful than wrap and unwrap...
jetstream_pt/engine.py
Outdated
) | ||
if self.env.shard_on_batch: | ||
return Prefix( | ||
self.replicated, # cache is replicated because bs=1 for prefill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The multiple prefill isolated instance's performance is better than this replicated sharding (or we use samller vm to test instead of 4chips or 8 chips).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Fanhai, do you mean each TPU chip do certain length of the prefill?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean having 8 InterleavedEngine with one device each; vs. sharding on batch from jax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 8 InterleavedEngine vs sharding on batch from jax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course because there is no collective operations involved. IIUC then the best config might be enable disaggregated serving and use multiple prefill engine instances instead of sharding.
jetstream_pt/engine.py
Outdated
) | ||
else: | ||
return DecodeState( | ||
self.replicated, # shard on batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
scores = torch_xla2.extra.call_jax( | ||
jnp.einsum, "ikjl,ikml->ikjm", xq, keys | ||
) / math.sqrt(head_dim) | ||
self.env.apply_sharding(scores, axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is matrix transpose HLO issues been fixed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share the PR for the fix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And also update b/329899712?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix! Not directly related, but I see you specially treated einsum to fix the issue. Please also fix b/329899713 for matmul. But the conversation in this PR can be closed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for improving the Torch xla2, adding Gemma 2 support, and clean up the code! If we can have 1 PR for each task, that'll be great and won't block you from merging.
self.cache_sharding = self.env.cache_sharding | ||
|
||
jax.config.update("jax_enable_x64", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, we have config set up in both engine.py and the script, in the future we should place them together.
jetstream_pt/engine.py
Outdated
for (k, v), (ks, vs) in torch_xla2.tensor.wrap( | ||
list(zip(caches, cache_scales)) | ||
) | ||
for (k, v), (ks, vs) in from_jax(list(zip(caches, cache_scales))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, since Pytorch should be the focus, is it better to use "to_torch" and "from_torch" instead of "from_jax" and "to_jax"? Ideally we should not see any Jax related term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
jetstream_pt/engine.py
Outdated
with self._lock: | ||
with torch_xla2.tensor.XLADispatchMode(): | ||
with torch_xla2.default_env(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Pytorch user, DispatchMode is more descriptive than default env? And we also need some comments here why we need TorchDispatchMode here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
jetstream_pt/engine.py
Outdated
) | ||
if self.env.shard_on_batch: | ||
return Prefix( | ||
self.replicated, # cache is replicated because bs=1 for prefill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Fanhai, do you mean each TPU chip do certain length of the prefill?
jetstream_pt/engine.py
Outdated
self.replicated, | ||
) | ||
if self.env.shard_on_batch: | ||
return DecodeState( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, can we do
return DecodeState(
self.x_sharding if self.env.shard_on_batch else self.replicated, # shard on batch
self.cache_sharding,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.,
jetstream_pt/layers.py
Outdated
if self.env.shard_on_batch: | ||
self.env.apply_sharding(output, axis=0) | ||
else: | ||
self.env.apply_sharding(output, axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, can we use self.env.shard_on_batch to select the axis instead of duplicate the apply_sharding
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
jetstream_pt/layers.py
Outdated
self.env.apply_sharding(xq, axis=2) | ||
self.env.apply_sharding(xk, axis=2) | ||
self.env.apply_sharding(xv, axis=2) | ||
if self.env.shard_on_batch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
donee
@@ -148,15 +148,15 @@ def forward( | |||
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) | |||
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) | |||
|
|||
if self.num_kv_heads > 1: | |||
if self.env.shard_on_batch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -75,6 +75,9 @@ | |||
_SHARDING_CONFIG = flags.DEFINE_string( | |||
"sharding_config", "", "config file for sharding" | |||
) | |||
_SHARD_ON_BATCH = flags.DEFINE_bool( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, can you add when enabled, it overwrites the sharding_config
or something similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -89,6 +89,9 @@ | |||
_SHARDING_CONFIG = flags.DEFINE_string( | |||
"sharding_config", "", "config file for sharding" | |||
) | |||
_SHARD_ON_BATCH = flags.DEFINE_bool( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
No description provided.