Skip to content

Commit

Permalink
enable generic dimentionality for input
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Dec 1, 2023
1 parent 059ae20 commit 40e4918
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 7 deletions.
3 changes: 1 addition & 2 deletions megablocks/layers/dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,9 @@ def sparse_forward_once(self, x, expert_weights, top_experts):
with torch.no_grad():
indices, bin_ids, bins, padded_bins, tokens_per_expert = (
self.indices_and_padded_bins(top_experts))
sl, bs, hs = x.size()

# Route the tokens for MoE computation.
x = x.view(sl * bs, hs)
x = x.view(-1, x.shape[-1])
x = ops.padded_gather(
x,
indices,
Expand Down
5 changes: 2 additions & 3 deletions megablocks/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,13 @@ def parallel_forward_once(self, x, expert_weights, top_experts):
return x, tokens_per_expert.flatten()

def forward(self, x, scores, expert_weights, top_experts):
sl, bs, hs = x.size()
in_shape = x.size()

# Compute the experts.
x, tokens_per_expert = self.forward_fn(
x, expert_weights, top_experts)
save_load_balancing_loss((tokens_per_expert, scores))
x = x.view(sl, bs, hs)
x = x.view(in_shape)
if self.bias is not None:
if self.args.return_bias:
return x, self.bias
Expand All @@ -448,7 +448,6 @@ def forward(self, x):
# NOTE: If we're going to cast the activations to lower precision
# do it before we permute the tokens to save bandwidth.
x = common.cast_if_autocast_enabled(x)
sl, bs, hs = x.size()

# Compute the expert scores and assignments.
scores, expert_weights, top_experts = self.router(x)
Expand Down
3 changes: 1 addition & 2 deletions megablocks/layers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def forward(self, x):
if self.training and self.args.moe_jitter_eps is not None:
x = x * self.jitter(x)

sl, bs, hs = x.size()
scores = self.layer(x.view(-1, hs)).softmax(dim=-1)
scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
expert_weights, expert_indices = self._top_k(scores)

expert_indices = (
Expand Down

0 comments on commit 40e4918

Please sign in to comment.