Skip to content

Commit

Permalink
code review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ehartford committed Mar 15, 2024
1 parent 301cc4c commit 9c221a6
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions src/axolotl/monkeypatch/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,33 @@ def _post_training(self, model, name):
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)

# Recreate the MoE class with original weights
experts = []
for i in range(self.num_experts):
expert = nn.Sequential(
nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False),
self.experts.activation,
nn.Linear(self.ffn_dim, self.hidden_dim, bias=False),
)
expert[0].weight.data = torch.cat([w1s[i], w3s[i]], dim=0)
expert[2].weight.data = w2s[i]
experts.append(expert)
# Recreate the structure of the original MixtralSparseMoeBlock
original_moe = nn.Module()
original_moe.hidden_dim = self.hidden_dim
original_moe.ffn_dim = self.ffn_dim
original_moe.num_experts = self.num_experts
original_moe.top_k = self.top_k

# Create a new MoE module with the recreated experts
moe = nn.ModuleList(experts)
# Recreate the gating module
original_moe.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
original_moe.gate.weight.data = self.gate.weight.data

# Replace the fused experts with the recreated MoE module
setattr(model, name.replace("experts", "moe"), moe)
delattr(model, name)
# Recreate the experts as a ModuleList
original_moe.experts = nn.ModuleList()
for expert_idx in range(self.num_experts):
expert = nn.Module()
expert.w1 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
expert.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
expert.w3 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
expert.act_fn = self.experts.activation

expert.w1.weight.data = torch.cat([w1s[expert_idx], w3s[expert_idx]], dim=0)
expert.w2.weight.data = w2s[expert_idx]

original_moe.experts.append(expert)

# Replace the SparseMoeBlock with the recreated MixtralSparseMoeBlock structure
setattr(model, name, original_moe)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
Expand Down

0 comments on commit 9c221a6

Please sign in to comment.