Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logits method to Sequence class #237

Closed
wants to merge 5 commits into from

Conversation

SamDuffield
Copy link

@SamDuffield SamDuffield commented Aug 16, 2023

Allows the user to access the next token logits for a given prompt (including with regex!)

prompt = "What is the IP address of the Google DNS servers? "
unguided_logits = generate.continuation(model, max_tokens=30).logits(prompt)
guided_logits = generate.regex(
    model,
    r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
    max_tokens=30,
).logits(prompt)

@SamDuffield SamDuffield changed the title Add logits method to Sequence clase Add logits method to Sequence class Aug 16, 2023
@rlouf
Copy link
Member

rlouf commented Aug 16, 2023

I think it would make more sense to return the logprob as part of the step method?

Then we could create a custom object that contains both the text output and the logprob. This might do for now, give me some time to think about it.

What are you using this for?

@SamDuffield
Copy link
Author

I think it would make more sense to return the logprob as part of the step method?

Then we could create a custom object that contains both the text output and the logprob. This might do for now, give me some time to think about it.

If you include it in the step are you thinking to return the logit only for the sampled token or all of the logits. I think it would be useful for the user (at least me 😄 ) to have access to all of the logits if needed.

What are you using this for?

It's just very useful to be able to analyse the next token logits to e.g. assess the confidence of the model on a True/False question for various questions (and this pairs very well with regex).

@SamDuffield
Copy link
Author

Ah I see what you mean, step already outputs the probs as an additional output.

It would be useful to have an additional method with a simple API like
generate.continuation(model).next_token_logits(prompt)
or
generate.continuation(model).next_token_probs(prompt)

I.e. returns the next token logits or probs without actually doing any sampling.

@SamDuffield
Copy link
Author

Ok I've thought about this a bit more.

IMO the best API would be to have a .next_token_logits(prompt) or .next_token_probs(prompt) function that allows the user to examine the distribution without sampling. I think it's good to have this separate from the step function which samples and appends the new token, which is a different use case.

With this new function it may also make sense to have step only return the token_ids and not probs. This could also avoid some confusion as the token_ids relate to the whole sequence and probs only the appended token.

Was there a specific reason to have step return probs as well?

WDYT? @rlouf @brandonwillard

@rlouf
Copy link
Member

rlouf commented Aug 17, 2023

Was there a specific reason to have step return probs as well?

Unnecessary anticipation of SMC sampling.

WDYT? @rlouf @brandonwillard

Sounds good to me.

@SamDuffield
Copy link
Author

Do you prefer .logits or .probs or .next_token_logits or .next_token_probs?

I can imagine use cases where the user wants both logits or probs so IMO best to go with .logits or next_token_logits as this has (slightly) less computation and the user can easily call softmax if they need, this is also more inline with the output from self.model.

@rlouf
Copy link
Member

rlouf commented Aug 17, 2023

next_token_logits sounds good

@SamDuffield SamDuffield requested a review from rlouf August 17, 2023 19:19
@SamDuffield
Copy link
Author

I'm now thinking that indeed the probs might be useful for SMC 😃

Let's leave step returning probs for now

@dpsimpson dpsimpson requested review from dpsimpson and removed request for rlouf September 5, 2023 15:33
@dpsimpson
Copy link

I've had a look at this and I'm not sure it's needed.

The pros:

  • It gives finer scale control for people inheriting sequence
  • It allows people to avoid superfluous sampling if they're just interested in the logits (eg speculative decoding)
  • It avoids needing to torch.log the probabilities, which might be useful for numerical stability.

The cons:

  • There's a real risk that people will use this and step, which would lead to an extra model call.
  • The information that is needed is already present in the output of step modulo a logarithm.

@rlouf - Is it possible to consistently work with logits rather than probs, just for numerical stability. I've had a look at the code and it's not clear to me why you're not using torch.multinomial

@dpsimpson
Copy link

(Also sorry @rlouf I must've mis-clicked. I don't think I meant to remove you as a reviewer)

@rlouf
Copy link
Member

rlouf commented Dec 6, 2023

Solved in #366

@rlouf rlouf closed this Dec 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants