Skip to content

Logit processing in mlx #1649

Answered by awni
chimezie asked this question in Q&A
Dec 5, 2024 · 1 comments · 3 replies
Discussion options

You must be logged in to vote

PyTorch's gather method makes this very trivial, but there is no equivalent in mlx for design reasons.

We have fancy indexing in MLX which should be as expressive as gather. But maybe I'm misunderstanding what you are trying to do.

Your second function (get_3d_target_scores) looks ok to me (I don't quite get what it's doing at the end where you sum over the sequence length.. you can simplify it a bit which will also make it a bit faster:

def get_3d_target_scores(batch_scores, target_tokens):
    full_seq_len = batch_scores.shape[1]
    target_seq_size = max(len(i) for i in target_tokens)
    arr = mx.take_along_axis(
        batch_scores,
        mx.array([[0] * (batch_scores.shape[1] - l…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@chimezie
Comment options

@awni
Comment options

awni Dec 7, 2024
Maintainer

@chimezie
Comment options

Answer selected by chimezie
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants