diff --git a/src/axolotl/monkeypatch/moe/moe.py b/src/axolotl/monkeypatch/moe/moe.py index ee3d9bb283..7c651d0a27 100644 --- a/src/axolotl/monkeypatch/moe/moe.py +++ b/src/axolotl/monkeypatch/moe/moe.py @@ -25,24 +25,33 @@ def _post_training(self, model, name): w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1) w2s = torch.unbind(self.experts.output_experts.weight, dim=0) - # Recreate the MoE class with original weights - experts = [] - for i in range(self.num_experts): - expert = nn.Sequential( - nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False), - self.experts.activation, - nn.Linear(self.ffn_dim, self.hidden_dim, bias=False), - ) - expert[0].weight.data = torch.cat([w1s[i], w3s[i]], dim=0) - expert[2].weight.data = w2s[i] - experts.append(expert) + # Recreate the structure of the original MixtralSparseMoeBlock + original_moe = nn.Module() + original_moe.hidden_dim = self.hidden_dim + original_moe.ffn_dim = self.ffn_dim + original_moe.num_experts = self.num_experts + original_moe.top_k = self.top_k - # Create a new MoE module with the recreated experts - moe = nn.ModuleList(experts) + # Recreate the gating module + original_moe.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + original_moe.gate.weight.data = self.gate.weight.data - # Replace the fused experts with the recreated MoE module - setattr(model, name.replace("experts", "moe"), moe) - delattr(model, name) + # Recreate the experts as a ModuleList + original_moe.experts = nn.ModuleList() + for expert_idx in range(self.num_experts): + expert = nn.Module() + expert.w1 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False) + expert.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + expert.w3 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False) + expert.act_fn = self.experts.activation + + expert.w1.weight.data = torch.cat([w1s[expert_idx], w3s[expert_idx]], dim=0) + expert.w2.weight.data = w2s[expert_idx] + + original_moe.experts.append(expert) + + # Replace the SparseMoeBlock with the recreated MixtralSparseMoeBlock structure + setattr(model, name, original_moe) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape