Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TrtLLM] Python backend support for T5 model #1680

Merged
merged 4 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/llm_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,17 @@ jobs:
python3 llm/client.py trtllm qwen-7b
rm -rf docker_env
docker rm -f $(docker ps -aq)
- name: flan-t5-xl model with python backend
working-directory: tests/integration
run: |
rm -rf models
echo -en "CUDA_VISIBLE_DEVICES=0,1,2,3" > docker_env
python3 llm/prepare.py trtllm flan-t5-xl
./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models trtllm \
serve
python3 llm/client.py trtllm-python flan-t5-xl
rm -rf docker_env
docker rm -f $(docker ps -aq)
- name: On fail step
if: ${{ failure() }}
working-directory: tests/integration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def validate_batch_size(cls, batch_size, values):
batch_size = int(batch_size)
if batch_size > 1:
if not is_rolling_batch_enabled(
values['rolling_batch']) and is_streaming_enabled(
values['enable_streaming']):
values.get('rolling_batch', RollingBatchEnum.disable)
) and is_streaming_enabled(
values.get('enable_streaming', StreamingEnum.false)):
sindhuvahinis marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"We cannot enable streaming for dynamic batching")
return batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from djl_python.properties_manager.properties import Properties, RollingBatchEnum, StreamingEnum
from djl_python.properties_manager.properties import Properties, RollingBatchEnum
from pydantic.v1 import validator

TRT_SUPPORTED_ROLLING_BATCH_TYPES = [
RollingBatchEnum.auto.value, RollingBatchEnum.trtllm.value
RollingBatchEnum.auto.value, RollingBatchEnum.trtllm.value,
RollingBatchEnum.disable.value
]


Expand All @@ -24,11 +25,6 @@ class TensorRtLlmProperties(Properties):
def validate_rolling_batch(cls, rolling_batch: str) -> str:
rolling_batch = rolling_batch.lower()

if rolling_batch == RollingBatchEnum.disable.value:
sindhuvahinis marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"You cannot disable rolling batch for TensorRT LLM."
f"Kindly enable it with auto or tensorrt values to option.rolling_batch"
)
if rolling_batch not in TRT_SUPPORTED_ROLLING_BATCH_TYPES:
raise ValueError(
f"tensorrt llm only supports "
Expand Down
96 changes: 90 additions & 6 deletions engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import os
import logging
import tensorrt_llm_toolkit
from tensorrt_llm_toolkit.utils import utils as toolkit_utils

from djl_python.encode_decode import decode
from transformers import AutoConfig

from djl_python.encode_decode import encode, decode
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.rolling_batch.rolling_batch import get_content_type_from_output_formatter
from djl_python.rolling_batch.trtllm_rolling_batch import TRTLLMRollingBatch
from djl_python.properties_manager.trt_properties import TensorRtLlmProperties
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request

from djl_python.properties_manager.properties import is_rolling_batch_enabled


class TRTLLMService(object):
"""
Expand All @@ -29,16 +36,20 @@ class TRTLLMService(object):
calls TensorRT-LLM in the back-end.
"""

PYTHON_BACKEND_SUPPORTED_MODELS = {'t5'}

def __init__(self):
self.initialized = False
self.trt_configs = None
self.rolling_batch = None
self.model = None
self.is_rolling_batch_enabled = True

def initialize(self, properties: dict):
self.trt_configs = TensorRtLlmProperties(**properties)

self.rolling_batch = TRTLLMRollingBatch(
self.trt_configs.model_id_or_path, properties, **properties)
self.is_rolling_batch_enabled = is_rolling_batch_enabled(
self.trt_configs.rolling_batch)
self._load_model(properties)
self.initialized = True
return

Expand Down Expand Up @@ -97,7 +108,37 @@ def parse_input(

return input_data, input_size, parameters, errors, batch

def inference(self, inputs: Input) -> Output:
def _get_config(self, properties):
model_path = self.trt_configs.model_id_or_path
if not os.path.isfile(os.path.join(model_path, 'config.json')):
model_path = toolkit_utils.get_python_backend_engine_path(
model_path, properties)
if not os.path.isfile(os.path.join(model_path, 'config.json')):
raise ValueError(
f"Could not find config.json in {self.trt_configs.model_id_or_path} or"
f"{model_path} for TensorRT python backend")

return AutoConfig.from_pretrained(
model_path, trust_remote_code=self.trt_configs.trust_remote_code)

def _load_model(self, properties):
if self.is_rolling_batch_enabled:
self.rolling_batch = TRTLLMRollingBatch(
self.trt_configs.model_id_or_path, properties, **properties)
else:
model_config = self._get_config(properties)
if model_config.model_type in self.PYTHON_BACKEND_SUPPORTED_MODELS:
self.model = tensorrt_llm_toolkit.init_inference(
self.trt_configs.model_id_or_path,
**properties,
use_python_backend=True)
else:
raise ValueError(
f"You cannot disable rolling batch if its not any of these models"
f" {self.PYTHON_BACKEND_SUPPORTED_MODELS}. Please enable it with auto or trtllm "
f"values to option.rolling_batch")

def rolling_batch_inference(self, inputs: Input) -> Output:
"""
Does preprocessing and sends new requests to the rolling batch script for inference

Expand Down Expand Up @@ -143,6 +184,46 @@ def inference(self, inputs: Input) -> Output:

return outputs

# TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching.
def inference(self, inputs: Input) -> Output:
"""
Does preprocessing and sends new requests to the rolling batch script for inference

:param inputs (Input): a batch of inputs, each corresponding to a new request

:return outputs (Output): a batch of outputs that contain status code, output text, and other information
"""
outputs = Output()

input_data, input_size, parameters, errors, batch = self.parse_input(
inputs)
if len(input_data) == 0:
for i in range(len(batch)):
err = errors.get(i)
outputs.add(err, key="data", batch_index=i)
return outputs

params = parameters[0]
result = self.model.generate(input_data, **params)
result = [{"generated_text": s} for s in result.batch_generation()]
idx = 0
for i, item in enumerate(batch):
content_type = item.get_property("Content-Type")
accept = item.get_property("Accept")
if not accept:
content_type = content_type if content_type else "application/json"
accept = content_type if content_type.startswith(
"tensor/") else "application/json"
elif "*/*" in accept:
accept = "application/json"

encode(outputs,
result[idx:idx + input_size[i]],
accept,
key=inputs.get_content().key_at(i))
idx += input_size[i]
return outputs


_service = TRTLLMService()

Expand All @@ -163,4 +244,7 @@ def handle(inputs: Input) -> Output:
# initialization request
return None

return _service.inference(inputs)
if _service.is_rolling_batch_enabled:
return _service.rolling_batch_inference(inputs)
sindhuvahinis marked this conversation as resolved.
Show resolved Hide resolved
else:
return _service.inference(inputs)
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,12 @@ def test_trtllm_error_cases(self):
"model_dir": "model_dir",
}

def test_trtllm_rb_disable():
properties['rolling_batch'] = 'disable'
with self.assertRaises(ValueError):
TensorRtLlmProperties(**properties)

def test_trtllm_rb_invalid():
properties['rolling_batch'] = 'lmi-dist'
with self.assertRaises(ValueError):
TensorRtLlmProperties(**properties)

test_trtllm_rb_invalid()
test_trtllm_rb_disable()

def test_ds_properties(self):
ds_properties = {
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,11 @@ def get_model_name():
"batch_size": [1, 8],
"seq_length": [256],
"tokenizer": "mistralai/Mixtral-8x7B-v0.1"
},
"flan-t5-xl": {
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "google/flan-t5-xl"
}
}

Expand Down Expand Up @@ -1085,6 +1090,8 @@ def test_handler(model, model_spec):
assert len(result) == batch_size
if "max_memory_per_gpu" in spec:
validate_memory_usage(spec["max_memory_per_gpu"][i])
if "tokenizer" in spec:
awscurl_run(req, spec.get("tokenizer"), batch_size)


def test_ds_raw_model(model, model_spec):
Expand Down Expand Up @@ -1278,6 +1285,8 @@ def test_unmerged_lora_correctness():
test_handler_rolling_batch(args.model, lmi_dist_aiccl_model_spec)
elif args.handler == "trtllm":
test_handler_rolling_batch(args.model, trtllm_model_spec)
elif args.handler == "trtllm-python":
test_handler(args.model, trtllm_model_spec)
elif args.handler == "deepspeed_rolling_batch":
test_handler_rolling_batch(args.model,
deepspeed_rolling_batch_model_spec)
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,11 @@
"option.use_custom_all_reduce": False,
"option.max_rolling_batch_size": 32,
"option.output_formatter": "jsonlines"
},
"flan-t5-xl": {
"option.model_id": "s3://djl-llm/flan-t5-xl/",
"option.rolling_batch": "disable",
"option.entryPoint": "djl_python.tensorrt_llm"
}
}

Expand Down
Loading