Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampler Throughput Optimization #192

Merged
merged 30 commits into from
Feb 8, 2024

Conversation

sunggg
Copy link
Member

@sunggg sunggg commented Feb 2, 2024

Overview

This PR aims to land the followings:

  • Better throughput when various sampling params are enabled. By further vectorizing sampling computations and async prepare necessary metadata, this PR can achieve the significant throughput improvement. Performance numbers are appended below.
  • Restructure and clean-up to make it easier to understand and extend. Mainly, this PR creates new sampler.py that manages data structures and key functions regarding sampling computation. For example, SamplingMetadata manages top_p, top_k, logit_bias etc., in torch tensors so that they can be performed in the vectorized fashion. Also, adjust_logits function has introduced to have a single point of various logit manipulations, such as temperature, presence_penalty, etc. and the upcoming logit_processor for JSON mode.
  • Introduce test_sampler.py to verify logit manipulation directly.
  • Simplify the logprob logits and use the detokenize_incrementally for tokens in TopLogprobs for the contextual detokenization.

Performance

By targeting Mistral fp16 on 2xH100, the following scenarios are measured and compared with the current HEAD.

  • vanilla random sampling (D0)
  • D0 + --apply-all-sampling-params: it includes the overhead for top_p/k, presence/frequency/repetition penalties, logit bias
  • D0 + --logprob: this includes the overhead for logprob

Updates

PR v0: initial version
PR v1: vectorize further by reducing index shuffling

Engine Throughput (req/s)

command for D0: /opt/bin/cuda-reserve.py --num-gpu 2 python3 serve/benchmarks/benchmark_throughput.py --local-id mixtral-8x7b-instruct-v0.1-q0f16-presharded-2gpu --max-num-batched-tokens 16000 --dataset /opt/models/dataset/ShareGPT_V3_unfiltered_cleaned_split.json --greedy-sampling-ratio 0.0 --num-prompts 1000

vanilla random sampling (D0) D0 + --apply-all-sampling-params D0 + --logprob
HEAD 24.49 12.11 11.07
PR v0 25.11 (1.02x) 23.80 (1.96x) 10.82 (0.977x)
PR v1 24.54 (1.00x) 23.69 (1.95x) 11.45 (1.034x)

Latency (s)

command for D0: /opt/bin/cuda-reserve.py --num-gpu 2 python3 serve/benchmarks/benchmark_latency.py --local-id mixtral-8x7b-instruct-v0.1-q0f16-presharded-2gpu --max-num-batched-tokens 16000

vanilla random sampling (D0) D0 + --apply-all-sampling-params D0 + --logprob
HEAD 1.381 1.451 1.790
PR v0 1.358 1.425 1.504
PR v1 1.386 1.471 1.714

TODOs

  • Further throughput optimization. Currently, there are many sequential loops to post-process the result from the model layer and prepare the output for users. This overhead is especially expensive for logprob and this is main source of the significant throughput degradation.
  • Since we don't have any ordering in the requests, the current sampler needs to extract requests that can be vectorized together and recover the original order in the batch. We may be able to remove this.
  • Add more sampler tests.
  • Add JSON mode support.

Note

Although there is no obvious quality impact on my local testing, please review carefully since this changes the way of computation.
cc. @yelite @masahi @vvchernov @zxybazh

@sunggg sunggg changed the title Sampler Throughput Optimization [WIP] Sampler Throughput Optimization Feb 2, 2024
@sunggg sunggg changed the title [WIP] Sampler Throughput Optimization Sampler Throughput Optimization Feb 5, 2024
)
_test_max_tokens(staging_engine)
_test_ignore_eos(staging_engine)
# TODO (@sunggg): There is something stateful.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, _test_logrpobs does not work when _test_stop is enabled, seems like _test_stop has some statefulness that affects the next tests. I see the following test starts after generating one extra token at the beginning which I don't understand where it comes from atm. @yelite, can you take a look?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask if the order of the unit tests impact those results? Like if we move test_stop after test_logprobs would that work?

Copy link
Member Author

@sunggg sunggg Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if I change the order, it works. So I suspect there is something strange going on with stop.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that the _test_stop will result extra tokens after the sequence is finished due to stop word. The stop word detection happens in the main process, and the main process will send cancellation to worker. Somehow the worker and main process become out of sync. I am looking into the fix.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that it's not really a bug in the InferenceEngine, although such out-of-sync is undesirable from the perspective of implementation. What really causes trouble is that different tests send requests with same request ids. Here _test_stop sends request id 2, and _test_logprobs also has a request id 2. A quick fix for this test failure is to make sure request ids are unique. I will fix the underlying out-of-sync in the engine refactoring.



def _test_penalties():
# TODO(vvchernov): Add test for repetition penalty
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vvchernov, would you add one after this PR? I couldn't find much info about repetition penaltiy in OpenAI spec.

@masahi
Copy link
Member

masahi commented Feb 5, 2024

Simplify the logprob logits and use the detokenize_incrementally for tokens in TopLogprobs for the contextual detokenization.

Probably better to make that a separate PR.

Copy link

@vvchernov vvchernov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that so much refactor of logprobs is correct.
Due to performance with logprobs is still bad, could we do async calculation of it when next token is found and transferred to the next decode step?

serve/mlc_serve/engine/staging_engine_worker.py Outdated Show resolved Hide resolved
serve/mlc_serve/engine/sampling_params.py Show resolved Hide resolved
serve/mlc_serve/engine/engine_common.py Outdated Show resolved Hide resolved
serve/mlc_serve/engine/engine_common.py Outdated Show resolved Hide resolved
@sunggg
Copy link
Member Author

sunggg commented Feb 5, 2024

Simplify the logprob logits and use the detokenize_incrementally for tokens in TopLogprobs for the contextual detokenization.

Probably better to make that a separate PR.

@masahi, yeah, I thought about it but it was tightly coupled with this refactoring so it was tricky to find the clean cut 😢
detokenize_incrementally was a few lines of changes, but if this makes the review difficult, I will try to find a way. So, please let me know.

@sunggg
Copy link
Member Author

sunggg commented Feb 5, 2024

I'm not sure that so much refactor of logprobs is correct.
@vvchernov, is this about #192 (comment)?

I think mostly what I did is a pure refactoring without any functional change. This detokenziation is the only exception.

Due to performance with logprobs is still bad, could we do async calculation of it when next token is found and transferred to the next decode step?

Yeah, I also thought about that and would be an interesting optimization. I think there could be more opportunity for CPUs as well. Actually, one interesting observation is that most of the degradation is from post-processing when preparing the final response. If I change this line to None while keeping the logprob computations, I see the significant throughput boost.

@sunggg sunggg marked this pull request as ready for review February 6, 2024 06:24

@dataclass
class SamplingTensors:
mask_random: torch.Tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi we might want to look into https://github.com/patrick-kidger/jaxtyping or another project to also get more checking on the tensors. I think its relatively value-able especially as we have more and more people working on this code.

Copy link
Member

@jroesch jroesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of small nits and requests from me. Overall LGTM thanks for taking the first stab at this @sunggg, looking forward to landing it.

serve/mlc_serve/engine/engine_common.py Show resolved Hide resolved
serve/mlc_serve/engine/engine_common.py Show resolved Hide resolved
serve/mlc_serve/model/sampler.py Outdated Show resolved Hide resolved
@masahi masahi merged commit eae6ac4 into octoml:batch-serving Feb 8, 2024
1 check passed
Lunderberg pushed a commit to Lunderberg/mlc-llm that referenced this pull request Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants