Skip to content

Commit

Permalink
[TRTLLM][UX] add trtllm changes to support stop reason and also log p…
Browse files Browse the repository at this point in the history
…rob (#1355)
  • Loading branch information
Qing Lan authored Dec 6, 2023
1 parent a39e1d1 commit 9cd5a7e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# the specific language governing permissions and limitations under the License.
import logging
import tensorrt_llm_toolkit
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, Token


class TRTLLMRollingBatch(RollingBatch):
Expand Down Expand Up @@ -66,17 +66,33 @@ def inference(self, input_data, parameters):
# step 0: register new active requests
for request in new_requests:
param = self.translate_triton_params(request.parameters)
output_len = param["request_output_len"]
response = self.model.generate(request.input_text, **param)
self.request_cache[request.id] = response
self.request_cache[request.id] = {
"response": response,
"out_length": output_len,
"cumulative_logprob": 0
}

# step 1: loop the active requests to send result
for request in self.active_requests:
trt_req = self.request_cache[request.id]
output_text, complete = trt_req.fetch()
request.set_next_token(output_text, self.output_formatter,
complete)
if complete:
trt_resp = self.request_cache[request.id]["response"]
generation = trt_resp.fetch()
log_prob = generation.cum_logprob - self.request_cache[
request.id]["cumulative_logprob"]
self.request_cache[
request.id]["cumulative_logprob"] = generation.cum_logprob
token = Token(generation.token_id, generation.token_text, log_prob,
None)
if generation.finished:
finish_reason = "eos_token" if generation.seq_length < self.request_cache[
request.id]["out_length"] else "length"
request.set_next_token(token, self.output_formatter,
generation.finished, finish_reason)
self.request_cache.pop(request.id)
else:
request.set_next_token(token, self.output_formatter,
generation.finished)

return self.postprocess_results()

Expand Down
16 changes: 9 additions & 7 deletions serving/docker/tensorrt-llm.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ ARG TORCH_VERSION=2.1.0
ARG djl_version=0.24.0~SNAPSHOT
ARG transformers_version=4.34.0
ARG accelerate_version=0.23.0
ARG tensorrtlibs_version=9.1.0.post12.dev4
ARG tensorrtlibs_version=9.2.0.post12.dev5
ARG trtllm_toolkit_version=nightly
ARG cuda_python_version=12.2.0
ARG peft_wheel="https://publish.djl.ai/peft/peft-0.5.0alpha-py3-none-any.whl"
ARG trtllm_toolkit_wheel="https://publish.djl.ai/tensorrt-llm/toolkit/tensorrt_llm_toolkit-${trtllm_toolkit_version}-py3-none-any.whl"
ARG trtllm_wheel="https://djl-ai.s3.amazonaws.com/publish/tensorrt-llm/0.5.0/tensorrt_llm-0.5.0-py3-none-any.whl"
ARG triton_toolkit_wheel="https://publish.djl.ai/tritonserver/r23.09/tritontoolkit-23.9-py310-none-any.whl"
ARG trtllm_wheel="https://djl-ai.s3.amazonaws.com/publish/tensorrt-llm/0.6.1/tensorrt_llm-0.6.1-py3-none-any.whl"
ARG triton_toolkit_wheel="https://publish.djl.ai/tritonserver/r23.11/tritontoolkit-23.11-py310-none-any.whl"
ARG pydantic_version=1.10.13
EXPOSE 8080

Expand Down Expand Up @@ -67,7 +67,7 @@ RUN apt-get update && apt-get install -y wget unzip openmpi-bin libopenmpi-dev l

# Install PyTorch
RUN pip install torch==${TORCH_VERSION} transformers==${transformers_version} accelerate==${accelerate_version} ${peft_wheel} sentencepiece \
mpi4py cuda-python==${cuda_python_version} onnx polygraphy datasets pydantic==${pydantic_version} && \
mpi4py cuda-python==${cuda_python_version} onnx polygraphy pynvml datasets pydantic==${pydantic_version} && \
pip3 cache purge

# Install TensorRT and TRT LLM
Expand All @@ -76,11 +76,13 @@ RUN pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com tensorr
pip3 cache purge

# download dependencies
# install manual-build boost fs library required by tritonserver 23.11
RUN pip install ${triton_toolkit_wheel} ${trtllm_toolkit_wheel} && \
mkdir -p /opt/tritonserver/lib && mkdir -p /opt/tritonserver/backends/tensorrtllm && \
curl -o /opt/tritonserver/lib/libtritonserver.so https://publish.djl.ai/tritonserver/r23.09/libtritonserver.so && \
curl -o /opt/tritonserver/backends/tensorrtllm/libtriton_tensorrtllm.so https://publish.djl.ai/tensorrt-llm/0.5.0/libtriton_tensorrtllm.so && \
curl -o /opt/tritonserver/lib/libnvinfer_plugin_tensorrt_llm.so.9 https://publish.djl.ai/tensorrt-llm/0.5.0/libnvinfer_plugin_tensorrt_llm.so.9 && \
curl -o /opt/tritonserver/lib/libtritonserver.so https://publish.djl.ai/tritonserver/r23.11/libtritonserver.so && \
curl -o /lib/x86_64-linux-gnu/libboost_filesystem.so.1.80.0 https://publish.djl.ai/tritonserver/r23.11/libboost_filesystem.so.1.80.0 && \
curl -o /opt/tritonserver/backends/tensorrtllm/libtriton_tensorrtllm.so https://publish.djl.ai/tensorrt-llm/0.6.1/libtriton_tensorrtllm.so && \
curl -o /opt/tritonserver/lib/libnvinfer_plugin_tensorrt_llm.so.9 https://publish.djl.ai/tensorrt-llm/0.6.1/libnvinfer_plugin_tensorrt_llm.so.9 && \
pip3 cache purge && \
apt-get clean -y && rm -rf /var/lib/apt/lists/*

Expand Down

0 comments on commit 9cd5a7e

Please sign in to comment.