diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 837e31ed..3eebec0c 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -132,10 +132,9 @@ def sparse_forward_once(self, x, expert_weights, top_experts): with torch.no_grad(): indices, bin_ids, bins, padded_bins, tokens_per_expert = ( self.indices_and_padded_bins(top_experts)) - sl, bs, hs = x.size() # Route the tokens for MoE computation. - x = x.view(sl * bs, hs) + x = x.view(-1, x.shape[-1]) x = ops.padded_gather( x, indices, diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index a74036a0..01db9bf8 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -419,13 +419,13 @@ def parallel_forward_once(self, x, expert_weights, top_experts): return x, tokens_per_expert.flatten() def forward(self, x, scores, expert_weights, top_experts): - sl, bs, hs = x.size() + in_shape = x.size() # Compute the experts. x, tokens_per_expert = self.forward_fn( x, expert_weights, top_experts) save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(sl, bs, hs) + x = x.view(in_shape) if self.bias is not None: if self.args.return_bias: return x, self.bias @@ -448,7 +448,6 @@ def forward(self, x): # NOTE: If we're going to cast the activations to lower precision # do it before we permute the tokens to save bandwidth. x = common.cast_if_autocast_enabled(x) - sl, bs, hs = x.size() # Compute the expert scores and assignments. scores, expert_weights, top_experts = self.router(x) diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index b4720439..0714a573 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -53,8 +53,7 @@ def forward(self, x): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) - sl, bs, hs = x.size() - scores = self.layer(x.view(-1, hs)).softmax(dim=-1) + scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) expert_weights, expert_indices = self._top_k(scores) expert_indices = (