Where's the second expert per token per layer? #2261
Closed
enn-nafnlaus
started this conversation in
General
Replies: 1 comment
-
Oh, duh. self.top_k is 2. So dims on selected_experts are (sequence_length, 2). Flattened out it becomes sequence_length * 2. It loses this shape with gather and then comes back to this shape with scatter, and we're left with dimensions (2 * sequence_length, model_dim), which could be reshaped as e.g. two full hidden state vectors (model_dim in length) for each token in the sequence. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm looking into the code right now:
b5f882c
I was under the impression that Mixtral was 8 experts with two chosen per token per layer, but I'm not seeing that in the code. In BlockSparseMoE's forward function, it runs weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1), then flattens the experts, to get what as far as I can tell is one expert per token (per layer). And there's only one BlockSparseMoE per MixtralDecoderLayer. Where is it supposedly running two separate experts for each token to then add and norm? Maybe I'm just missing it...
I was thinking about extending it to (per-layer) add a dotproduct between the experts to assess the similarity between their output vectors, on the hypothesis that this could serve as a proxy measurement for hallucination, on the grounds that when the model is good at a task, different experts should reach roughly the same conclusion, but when the model is bad at a task, different experts may reach wildly different conclusions. This dotproduct could then be used as a scalar to a crossproduct of the output vector and a new (learned) vector, followed by add + norm - then training with examples of proper responses to situations with varying levels of uncertainty.
But perhaps I'm misunderstanding how the model works and maybe this isn't possible...
Beta Was this translation helpful? Give feedback.
All reactions