From 50feaa437b4b0edab89229fbf327c8f096c6c8db Mon Sep 17 00:00:00 2001 From: Aidan San Date: Sat, 14 Sep 2024 07:40:31 +0000 Subject: [PATCH] Fix MultinomialSampler hyperparameter bug Uses all logit_processors instead of just the last logit_processor --- outlines/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/outlines/samplers.py b/outlines/samplers.py index b1421971f..bcb2bb628 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -148,7 +148,7 @@ def __call__( altered_next_token_logits = next_token_logits for logit_processor in self.logits_processors: - altered_next_token_logits = logit_processor(next_token_logits) + altered_next_token_logits = logit_processor(altered_next_token_logits) probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1) next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)