Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shard on batch mode. Als update version of torchxla2 #80

Merged
merged 2 commits into from
May 14, 2024
Merged

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented May 13, 2024

No description provided.

@qihqi qihqi requested review from lsy323, wang2yn84 and FanhaiLu1 May 13, 2024 23:20
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch_xla2
import jax
import jax.numpy as jnp
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! It's more meaningful than wrap and unwrap...

)
if self.env.shard_on_batch:
return Prefix(
self.replicated, # cache is replicated because bs=1 for prefill
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiple prefill isolated instance's performance is better than this replicated sharding (or we use samller vm to test instead of 4chips or 8 chips).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Fanhai, do you mean each TPU chip do certain length of the prefill?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean having 8 InterleavedEngine with one device each; vs. sharding on batch from jax?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 8 InterleavedEngine vs sharding on batch from jax.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course because there is no collective operations involved. IIUC then the best config might be enable disaggregated serving and use multiple prefill engine instances instead of sharding.

)
else:
return DecodeState(
self.replicated, # shard on batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

scores = torch_xla2.extra.call_jax(
jnp.einsum, "ikjl,ikml->ikjm", xq, keys
) / math.sqrt(head_dim)
self.env.apply_sharding(scores, axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is matrix transpose HLO issues been fixed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share the PR for the fix?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also update b/329899712?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix! Not directly related, but I see you specially treated einsum to fix the issue. Please also fix b/329899713 for matmul. But the conversation in this PR can be closed now.

Copy link
Collaborator

@wang2yn84 wang2yn84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for improving the Torch xla2, adding Gemma 2 support, and clean up the code! If we can have 1 PR for each task, that'll be great and won't block you from merging.

self.cache_sharding = self.env.cache_sharding

jax.config.update("jax_enable_x64", False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, we have config set up in both engine.py and the script, in the future we should place them together.

for (k, v), (ks, vs) in torch_xla2.tensor.wrap(
list(zip(caches, cache_scales))
)
for (k, v), (ks, vs) in from_jax(list(zip(caches, cache_scales)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, since Pytorch should be the focus, is it better to use "to_torch" and "from_torch" instead of "from_jax" and "to_jax"? Ideally we should not see any Jax related term.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

with self._lock:
with torch_xla2.tensor.XLADispatchMode():
with torch_xla2.default_env():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Pytorch user, DispatchMode is more descriptive than default env? And we also need some comments here why we need TorchDispatchMode here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)
if self.env.shard_on_batch:
return Prefix(
self.replicated, # cache is replicated because bs=1 for prefill
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Fanhai, do you mean each TPU chip do certain length of the prefill?

self.replicated,
)
if self.env.shard_on_batch:
return DecodeState(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, can we do

return DecodeState(
self.x_sharding if self.env.shard_on_batch else self.replicated, # shard on batch
self.cache_sharding,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.,

if self.env.shard_on_batch:
self.env.apply_sharding(output, axis=0)
else:
self.env.apply_sharding(output, axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, can we use self.env.shard_on_batch to select the axis instead of duplicate the apply_sharding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

self.env.apply_sharding(xq, axis=2)
self.env.apply_sharding(xk, axis=2)
self.env.apply_sharding(xv, axis=2)
if self.env.shard_on_batch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

donee

@@ -148,15 +148,15 @@ def forward(
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

if self.num_kv_heads > 1:
if self.env.shard_on_batch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -75,6 +75,9 @@
_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "config file for sharding"
)
_SHARD_ON_BATCH = flags.DEFINE_bool(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, can you add when enabled, it overwrites the sharding_config or something similar?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -89,6 +89,9 @@
_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "config file for sharding"
)
_SHARD_ON_BATCH = flags.DEFINE_bool(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@FanhaiLu1 FanhaiLu1 merged commit 776c1c4 into main May 14, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants