Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow configurable sampling steps #318

Merged
merged 2 commits into from
Oct 12, 2023

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Oct 12, 2023

This PR allows one to specify sampler functions that decide exactly how generated token sequences are sampled/produced from the logit values. In doing so, this PR also adds a greedy "sampler" function (i.e. one that chooses the max logit values).

The dimensions of the mock decoder output and the test input token ids did not
match and were causing PyTorch to issue a warning/error in CI.
@brandonwillard brandonwillard merged commit f6e33dd into dottxt-ai:main Oct 12, 2023
5 checks passed
@brandonwillard brandonwillard deleted the configurable-samplers branch October 12, 2023 19:14

"""
probs = torch.nn.functional.softmax(logits, dim=-1)
# next_token_ids = torch.multinomial(probs, num_samples=samples, generator=rng)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.multinomial accepts matrices for probs: https://pytorch.org/docs/stable/generated/torch.multinomial.html

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few follow-up updates to add shortly; I'll put that in place alongside those.

Copy link
Member Author

@brandonwillard brandonwillard Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While torch.multinomial does work for matrices, we still need to check its results against the higher dimensional tensor inputs in the tests (and update the tests that rely on RNG seeds), so the change should come in a follow-up.

@namin
Copy link

namin commented Oct 12, 2023

👍 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Possibility to generate with the modal token rather than a sample
3 participants