Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logits processors: Update inplace, with batching #92

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

lapp0
Copy link
Owner

@lapp0 lapp0 commented Oct 5, 2024

Changes

For GuideLogitsProcessor,

Benchmarks for both changes

CI doesn't benchmark torch_cuda, so I've included it here.

Update logits inplace

Before [245c7fc] After [7514aff] Ratio Benchmark (Parameter)
159_0.4μs 181_1μs 1.14 time_structured_generation('numpy', 'Z*')
149_0.3μs 170_0.7μs 1.14 time_structured_generation('torch', 'Z*')
292_0.8μs 254_1μs 0.87 time_structured_generation('torch_cuda', 'Z*')
572_2μs 391_1μs 0.68 time_structured_generation('torch_cuda', '[^Z]*')

Batch update logits

Before [245c7fc] After [8aa0b0d] Ratio Benchmark (Parameter)
481_5μs 401_2μs 0.83 time_structured_generation('numpy', '[^Z]*')
466_3μs 386_1μs 0.83 time_structured_generation('torch', '[^Z]*')
159_0.8μs 106_0.5μs 0.67 time_structured_generation('numpy', 'Z*')
149_0.7μs 94.7_0.2μs 0.64 time_structured_generation('torch', 'Z*')
290_1μs 149_0.4μs 0.51 time_structured_generation('torch_cuda', 'Z*')
573_3μs 229_1μs 0.4 time_structured_generation('torch_cuda', '[^Z]*')

Testing

  • All tests pass, including test_generate.py with CUDA models
  • Did not test on Metal.

Further Work

  • We can cache the RegexGuide legal token mask on GPU to improve time_structured_generation('torch', '[^Z]*'). In this benchmark, allowed_tokens is all tokens except Z, ZZ, and ZZZ, resulting in a large LongTensor being sent to GPU each step.

@lapp0 lapp0 changed the title Logits processors inplace change logits Logits processors: Update inplace, with batching Oct 5, 2024
@lapp0 lapp0 force-pushed the logits-processors-inplace-change-logits branch from 8aa0b0d to 094af23 Compare October 7, 2024 14:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update logits array in-place
1 participant