-
I've been working on several projects to port HF's DPO implementation and the lm_evaluation_harness to native mlx (versus Pytorch via the 'mps' device, for example). A typical pattern in both is identifying the target sequences within a 3D logits structure and collecting the probability scores for target tokens from the vocabulary distribution at the target sequence locations (along the 2nd dimension of the 3D logits structure). I'm trying to find out if there are more efficient approaches to what I have taken, which mostly seeks to push down as much of the computation into array operations v.s. explicit, imperative code (via a for loop, for e.g. here). For example, consider a function that takes a 3D logit structure (could also be log_softmax scores) that results from feeding a sequence of prompt + continuation tokens (i.e., the 'full sequence') into a model, a list of prompt sequence lengths for each item in the batch fed to produce the logits, a list of the full sequence lengths (excluding any left padding) for each item, and returns the same logits structure with the rows for each non-target sequence position zeroed out and those for the target sequence positions left as they are along with a 2D mask for the target positions. I.e., intuitively, it clears all but the last n target tokens in the given logits from later calculations and returns a mask for the logits sequence positions (the 2nd dimension) that correspond to the tokens of the target. My current best thought on how to implement this is: import mlx.core as mx
import numpy as np
from typing import List, Tuple
batch_size, logits_seq_len, vocab_size = batch_scores.shape
assert all([length - prompt_length <= logits_seq_len for prompt_length, length in zip(prompt_lengths,
non_padding_lengths)])
indices = mx.stack([mx.arange(logits_seq_len)] * batch_size)
target_pos = list(map(lambda i: logits_seq_len - (i[0] - i[1]), zip(non_padding_lengths, prompt_lengths)))
target_mask = indices >= mx.array(target_pos)[..., None]
zeros = mx.zeros_like(batch_scores)
expanded_mask = mx.repeat(target_mask[..., None], vocab_size, axis=2)
result = mx.where(expanded_mask, batch_scores, zeros)
return result, target_mask Another common scenario is to collect the probability scores at each token sequence position for the corresponding token at the position from the vocabulary distribution. PyTorch's gather method makes this very trivial, but there is no equivalent in mlx for design reasons. However, I was able to do this declaratively, but it is unwieldy, and I don't know if it is the most straightforward or the most efficient approach. The function below takes the same logits structure and a list of target tokens: import mlx.core as mx
import numpy as np
from typing import List, Tuple
def get_3d_target_scores(batch_scores: mx.array, target_tokens: List) -> mx.array:
batch_size, full_seq_len, vocab_size = batch_scores.shape
flattened_shape = batch_size, full_seq_len
lengths = [len(i) for i in target_tokens]
target_seq_size = max(lengths)
arr = mx.take_along_axis(batch_scores,
mx.array([[0] * (batch_scores.shape[1] - len(i)) + i for i in target_tokens]
)[..., None],
-1)
return mx.where(mx.stack([mx.arange(full_seq_len)] * 2) >= full_seq_len - target_seq_size,
arr.reshape(*flattened_shape), mx.zeros(flattened_shape)).sum(1)[...,None] Both functions assume the sequence batches fed to the model are left-padded (which differs from how this is implemented in mlx-lm's LoRa training implementation but makes it easier to identify target sequence positions from the end of the 2nd dimension in the logits structure). Are there more efficient and straightforward ways to do this in mlx? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
We have fancy indexing in MLX which should be as expressive as gather. But maybe I'm misunderstanding what you are trying to do. Your second function ( def get_3d_target_scores(batch_scores, target_tokens):
full_seq_len = batch_scores.shape[1]
target_seq_size = max(len(i) for i in target_tokens)
arr = mx.take_along_axis(
batch_scores,
mx.array([[0] * (batch_scores.shape[1] - len(i)) + i for i in target_tokens])[..., None],
-1,
)
return mx.where(
mx.arange(full_seq_len)[:, None] >= full_seq_len - target_seq_size, arr, 0.0
).sum(1) Would it be much simpler to do the above in e.g. PyTorch? I don't really think so.. but if you share the parallel code that would be interesting to compare. For the first function it looks like you clipped the signature? If you add it in I can take a look and see if it's reasonable or not. |
Beta Was this translation helpful? Give feedback.
We have fancy indexing in MLX which should be as expressive as gather. But maybe I'm misunderstanding what you are trying to do.
Your second function (
get_3d_target_scores
) looks ok to me (I don't quite get what it's doing at the end where you sum over the sequence length.. you can simplify it a bit which will also make it a bit faster: