diff --git a/.github/workflows/rolling_batch_integration.yml b/.github/workflows/rolling_batch_integration.yml index f063ba021d..365bc24562 100644 --- a/.github/workflows/rolling_batch_integration.yml +++ b/.github/workflows/rolling_batch_integration.yml @@ -402,10 +402,82 @@ jobs: name: vllm-logs path: tests/integration/logs/ + deepspeed-test: + runs-on: [ self-hosted, g5 ] + timeout-minutes: 60 + needs: create-runners + steps: + - uses: actions/checkout@v3 + - name: Clean env + run: | + yes | docker system prune -a --volumes + sudo rm -rf /home/ubuntu/actions-runner/_work/_tool/Java_Corretto_jdk/ + echo "wait dpkg lock..." + while sudo fuser /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock >/dev/null 2>&1; do sleep 5; done + - name: Set up Python3 + uses: actions/setup-python@v4 + with: + python-version: '3.10.x' + - name: Install pip dependencies + run: pip3 install requests pillow numpy + - name: Build container name + run: ./serving/docker/scripts/docker_name_builder.sh deepspeed ${{ github.event.inputs.djl-version }} + - name: Download models and dockers + working-directory: tests/integration + run: | + docker pull deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG + - name: Test deepspeed_rolling_batch gpt-neox-20b + working-directory: tests/integration + run: | + rm -rf models + python3 llm/prepare.py deepspeed_rolling_batch gpt-neox-20b + ./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models deepspeed \ + serve + python3 llm/client.py deepspeed_rolling_batch gpt-neox-20b + docker rm -f $(docker ps -aq) + - name: Test deepspeed_rolling_batch open-llama-7b + working-directory: tests/integration + run: | + rm -rf models + python3 llm/prepare.py deepspeed_rolling_batch open-llama-7b + ./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models deepspeed \ + serve + python3 llm/client.py deepspeed_rolling_batch open-llama-7b + docker rm -f $(docker ps -aq) + - name: Test deepspeed_rolling_batch gpt2 + working-directory: tests/integration + run: | + rm -rf models + python3 llm/prepare.py deepspeed_rolling_batch gpt2 + ./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models deepspeed \ + serve + python3 llm/client.py deepspeed_rolling_batch gpt2 + docker rm -f $(docker ps -aq) + - name: Test deepspeed_rolling_batch llama2-13b-smoothquant + working-directory: tests/integration + run: | + rm -rf models + python3 llm/prepare.py deepspeed_rolling_batch llama2-13b-smoothquant + ./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models deepspeed \ + serve + python3 llm/client.py deepspeed_rolling_batch llama2-13b-smoothquant + docker rm -f $(docker ps -aq) + - name: On fail step + if: ${{ failure() }} + working-directory: tests/integration + run: | + docker rm -f $(docker ps -aq) || true + cat logs/serving.log + - name: Upload test logs + uses: actions/upload-artifact@v3 + with: + name: ds-rolling-batch-handler-logs + path: tests/integration/logs/ + stop-runners: if: always() runs-on: [ self-hosted, scheduler ] - needs: [ create-runners, scheduler-single-gpu-test, scheduler-multi-gpu-test, lmi-dist-test-1, lmi-dist-test-2, vllm-test ] + needs: [ create-runners, scheduler-single-gpu-test, scheduler-multi-gpu-test, lmi-dist-test-1, lmi-dist-test-2, vllm-test, deepspeed-test ] steps: - name: Stop all instances run: | diff --git a/engines/python/setup/djl_python/deepspeed.py b/engines/python/setup/djl_python/deepspeed.py index 5af659ded3..aef21fb449 100644 --- a/engines/python/setup/djl_python/deepspeed.py +++ b/engines/python/setup/djl_python/deepspeed.py @@ -39,7 +39,7 @@ from peft import PeftConfig, PeftModel from djl_python.properties_manager.ds_properties import DeepSpeedProperties, DsQuantizeMethods -from djl_python.properties_manager.properties import StreamingEnum, is_streaming_enabled +from djl_python.properties_manager.properties import StreamingEnum, is_streaming_enabled, is_rolling_batch_enabled SMOOTHQUANT_SUPPORTED_MODEL_TYPES = { "gpt2", @@ -106,19 +106,43 @@ def __init__(self): self.peft_config = None self.model = None self.tokenizer = None + self.rolling_batch = None + self.enable_rolling_batch = False def initialize(self, properties: dict): self.properties = DeepSpeedProperties(**properties) + self.enable_rolling_batch = is_rolling_batch_enabled( + self.properties.rolling_batch) self._read_model_config() self._validate_model_type_and_task() - self.create_model_pipeline() + if self.enable_rolling_batch: + from djl_python.rolling_batch.deepspeed_rolling_batch import DeepSpeedRollingBatch + self.model = self.create_ds_module() + if not self.properties.ds_config.get("replace_with_kernel_inject", + False): + raise ValueError( + f"option.rolling_batch=deepspeed only works with kernel_injection models: {OPTIMIZED_MODEL_TYPES}" + ) + kwargs = { + "max_batch_size": + int(properties.get("max_rolling_batch_size", 4)), + "max_seq_len": int(properties.get("max_seq_len", 1024)), + "tokenizer": self.tokenizer + } + if "output_formatter" in properties: + kwargs["output_formatter"] = properties.get("output_formatter") + self.rolling_batch = DeepSpeedRollingBatch(self.model, + self.properties.device, + properties, **kwargs) + else: + self.create_model_pipeline() self.logger.info( f"Initialized DeepSpeed model with the following configurations\n" f"model: {self.properties.model_id_or_path}\n" f"task: {self.properties.task}\n" f"data_type: {self.properties.ds_config['dtype']}\n" f"tensor_parallel_degree: {self.properties.tensor_parallel_degree}\n" - ) + f"rolling_batch: {self.enable_rolling_batch}\n") self.initialized = True def _validate_model_type_and_task(self): @@ -239,7 +263,7 @@ def load_model_with_mmap(model_id_or_path, loading_method): return model, tokenizer, state_dict_mmap - def create_model_pipeline(self): + def create_ds_module(self): # If a ds checkpoint is provided, we instantiate model with meta tensors. weights loaded when DS engine invoked # Workaround on int8. fp16 fp32 bf16 init supported dtype = torch.float16 if self.properties.dtype == torch.int8 else self.properties.dtype @@ -303,8 +327,10 @@ def create_model_pipeline(self): if smoothing_config.get("calibrate", False): smoothing_config["tokenizer"] = self.tokenizer - self.model = deepspeed.init_inference(model, self.properties.ds_config) + return deepspeed.init_inference(model, self.properties.ds_config) + def create_model_pipeline(self): + self.model = self.create_ds_module() # Don't create a "pipeline" if we're streaming or text-generation task, since those don't use a pipeline if is_streaming_enabled(self.properties.enable_streaming ) or self.properties.task == "text-generation": @@ -348,23 +374,27 @@ def parse_input(self, inputs): content_type = item.get_property("Content-Type") input_map = decode(item, content_type) _inputs = input_map.pop("inputs", input_map) - if first: - parameters.append(input_map.pop("parameters", {})) - first = False - else: - param = input_map.pop("parameters", {}) - if parameters[0] != param: - logging.warning( - f"expected param: {parameters}, actual: {param}") - raise ValueError( - "In order to enable dynamic batching, all input batches must have the same parameters" - ) - if isinstance(_inputs, list): - input_data.extend(_inputs) - input_size.append(len(_inputs)) - else: - input_data.append(_inputs) - input_size.append(1) + _param = input_map.pop("parameters", {}) + if not self.enable_rolling_batch: + if first: + parameters.append(_param) + first = False + else: + if parameters[0] != _param: + logging.warning( + f"expected param: {parameters}, actual: {_param}" + ) + raise ValueError( + "In order to enable dynamic batching, all input batches must have the same parameters" + ) + if not isinstance(_inputs, list): + _inputs = [_inputs] + input_data.extend(_inputs) + input_size.append(len(_inputs)) + if self.enable_rolling_batch: + for _ in range(input_size[i]): + parameters.append(_param) + except Exception as e: # pylint: disable=broad-except logging.exception(f"Parse input failed: {i}") errors[i] = str(e) @@ -372,12 +402,44 @@ def parse_input(self, inputs): return input_data, input_size, parameters, errors, batch def inference(self, inputs: Input): - - input_data, input_size, parameters, errors, batch = self.parse_input( - inputs) - parameters = parameters[0] - outputs = Output() + input_data, input_size, params, errors, batch = self.parse_input( + inputs) + if len(input_data) == 0: + for i in range(len(batch)): + err = errors.get(i) + if self.enable_rolling_batch: + err = {"data": "", "last": True, "code": 424, "error": err} + outputs.add(Output.binary_encode(err), + key="data", + batch_index=i) + else: + outputs.add(err, key="data", batch_index=i) + return outputs + parameters = params[0] + + if self.enable_rolling_batch: + if inputs.get_property("reset_rollingbatch"): + self.rolling_batch.reset() + result = self.rolling_batch.inference(input_data, params) + idx = 0 + for i in range(len(batch)): + err = errors.get(i) + if err: + err = {"data": "", "last": True, "code": 424, "error": err} + outputs.add(Output.binary_encode(err), + key="data", + batch_index=i) + else: + outputs.add(Output.binary_encode(result[idx]), + key="data", + batch_index=i) + idx += 1 + + content_type = self.rolling_batch.get_content_type() + if content_type: + outputs.add_property("content-type", content_type) + return outputs if is_streaming_enabled(self.properties.enable_streaming): if len(batch) > 1: raise NotImplementedError( diff --git a/engines/python/setup/djl_python/properties_manager/ds_properties.py b/engines/python/setup/djl_python/properties_manager/ds_properties.py index 47487bf233..b5cf51466d 100644 --- a/engines/python/setup/djl_python/properties_manager/ds_properties.py +++ b/engines/python/setup/djl_python/properties_manager/ds_properties.py @@ -8,7 +8,7 @@ from pydantic import root_validator, validator, Field -from djl_python.properties_manager.properties import Properties +from djl_python.properties_manager.properties import Properties, RollingBatchEnum class DsQuantizeMethods(str, Enum): @@ -19,6 +19,9 @@ class DsQuantizeMethods(str, Enum): SUPPORTED_QUANTIZATION_MODE = [ DsQuantizeMethods.smoothquant.value, DsQuantizeMethods.dynamicint8.value ] +DS_SUPPORTED_ROLLING_BATCH_TYPES = [ + RollingBatchEnum.auto.value, RollingBatchEnum.deepspeed.value +] class DeepSpeedProperties(Properties): @@ -75,6 +78,16 @@ def set_ds_config(cls, deepspeed_config_path): with open(deepspeed_config_path, "r") as f: return json.load(f) + @validator('rolling_batch', pre=True) + def validate_rolling_batch(cls, rolling_batch) -> bool: + if rolling_batch == RollingBatchEnum.disable.value: + return rolling_batch + if rolling_batch not in DS_SUPPORTED_ROLLING_BATCH_TYPES: + raise ValueError( + f"deepspeed engine only supports " + f"rolling batch type {DS_SUPPORTED_ROLLING_BATCH_TYPES}.") + return rolling_batch + @root_validator() def set_dtype(cls, properties): diff --git a/engines/python/setup/djl_python/properties_manager/properties.py b/engines/python/setup/djl_python/properties_manager/properties.py index 89d12ff1b9..de05d4f95d 100644 --- a/engines/python/setup/djl_python/properties_manager/properties.py +++ b/engines/python/setup/djl_python/properties_manager/properties.py @@ -24,6 +24,7 @@ class RollingBatchEnum(str, Enum): auto = "auto" disable = "disable" trtllm = "trtllm" + deepspeed = "deepspeed" class StreamingEnum(str, Enum): diff --git a/engines/python/setup/djl_python/rolling_batch/deepspeed_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/deepspeed_rolling_batch.py new file mode 100644 index 0000000000..588294b381 --- /dev/null +++ b/engines/python/setup/djl_python/rolling_batch/deepspeed_rolling_batch.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception +from deepspeed.external.lmi_dist.utils.parameters import ( + NextTokenChooserParameters, + StoppingCriteriaParameters, +) +from deepspeed.external.lmi_dist.utils.types import (Batch, Request) +from deepspeed.inference.engine import InferenceEngine +from deepspeed.inference.rolling_batch import DeepSpeedRollingBatchGeneration + + +class DeepSpeedRollingBatch(RollingBatch): + + def __init__(self, model: InferenceEngine, device, properties, **kwargs): + """ + Initializes the LmiDistRollingBatch. + + :param model_id_or_path: model id or path + :param device: model loaded device + :param properties: other properties of the model, such as decoder strategy + :param kwargs passed while loading the model + """ + + super().__init__(device, **kwargs) + self.properties = properties + self.batch_cls = None + self.batch_id_counter = 0 + self.rolling_batch = DeepSpeedRollingBatchGeneration( + model=model, + tokenizer=kwargs.get("tokenizer"), + max_batch_size=kwargs.get("max_batch_size"), + max_seq_len=kwargs.get("max_seq_len")) + + def reset(self): + # self.rolling_batch.rolling_batch.clear() + self.batch_id_counter = 0 + super().reset() + + def _warmup(self, **kwargs): + pass + + @stop_on_any_exception + def inference(self, input_data, parameters): + """ + Performs prefill and decode operations for the batch. + + :param input_data: List of input texts for each request in a batch + :param parameters: List of kwargs for each request in a batch + :return: generated batch decoded tokens + """ + new_requests = self.get_new_requests(input_data, parameters, + len(input_data)) + new_batch = self.preprocess_requests(new_requests) + if new_batch or len(self.active_requests) > 0: + self._prefill_and_decode(new_batch) + return self.postprocess_results() + + def _prefill_and_decode(self, new_batch): + if new_batch: + batch = new_batch + generations, error_requests = self.rolling_batch.prefill_batch( + batch) + self.error_requests = error_requests + else: + generations = self.rolling_batch.generate_token() + for request in self.active_requests: + generation = None + # TODO(mohaan): Change generations to a Dict with request id index + filtered_gens = list( + filter(lambda g: g.request_id == request.id, generations)) + if len(filtered_gens) > 0: + generation = filtered_gens[0] + if generation: + is_last_token = generation.generated_text is not None + + request.set_next_token("" if generation.token_is_special else + generation.token_text, + self.output_formatter, + last_token=is_last_token) + else: + request.set_next_token("", + self.output_formatter, + last_token=False) + + def preprocess_requests(self, requests, **kwargs): + preprocessed_requests = [] + for r in requests: + param = r.parameters + parameters = NextTokenChooserParameters( + temperature=param.get("temperature", 1.0), + repetition_penalty=param.get("repetition_penalty", 1.0), + top_k=param.get("top_k", 0), + top_p=param.get("top_p", 1.0), + typical_p=param.get("typical_p", 1.0), + do_sample=param.get("do_sample", False), + seed=int(param.get("seed", 0))) + stop_parameters = StoppingCriteriaParameters( + stop_sequences=param.get("stop_sequences", []), + max_new_tokens=param.get("max_new_tokens", 30)) + + request = Request(id=r.id, + inputs=r.input_text, + parameters=parameters, + stopping_parameters=stop_parameters) + truncate = param.get("truncate", None) + if truncate is not None: + request.truncate = truncate + preprocessed_requests.append(request) + + if preprocessed_requests: + batch = Batch(id=self.batch_id_counter, + requests=preprocessed_requests, + size=len(preprocessed_requests)) + self.batch_id_counter += 1 + + return batch + else: + return None diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index 04d9cd9ca0..4ea0951928 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -144,6 +144,7 @@ def __init__(self, device, **kwargs): self.device = device self.pending_requests = [] self.active_requests = [] + self.error_requests = [] self.req_id_counter = 0 self.output_formatter = None self.waiting_steps = kwargs.get("waiting_steps", None) @@ -163,6 +164,7 @@ def __init__(self, device, **kwargs): def reset(self): self.pending_requests = [] self.active_requests = [] + self.error_requests = [] self.req_id_counter = 0 @abstractmethod @@ -203,9 +205,24 @@ def preprocess_requests(self, requests): def postprocess_results(self): results = [] - for i in range(len(self.active_requests)): - req = self.active_requests[i] - res = {"data": req.get_next_token(), "last": req.is_last_token()} + err_reqs = dict((r.id, err) for r, err in self.error_requests) + for req in self.active_requests: + if req.id in err_reqs.keys(): + res = { + "data": + "", + "last": + True, + "code": + 424, + "error": + f"Request: `{req.input_text}` failed due to: {err_reqs.get(req.id)}" + } + else: + res = { + "data": req.get_next_token(), + "last": req.is_last_token() + } results.append(res) # add empty tokens to pending requests @@ -215,7 +232,12 @@ def postprocess_results(self): results.append(res) self.active_requests = [ - req for req in self.active_requests if not req.is_last_token() + req for req in self.active_requests + if not req.is_last_token() and req.id not in err_reqs.keys() + ] + self.pending_requests = [ + req for req in self.pending_requests + if req.id not in err_reqs.keys() ] if len(self.active_requests) + len(self.pending_requests) == 0: diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index 31e31dd7d9..67dde8c63b 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -421,6 +421,33 @@ def get_model_name(): }, } +deepspeed_rolling_batch_model_spec = { + "gpt-neox-20b": { + "max_memory_per_gpu": [25.0], + "batch_size": [1], + "seq_length": [64, 128, 256], + "stream_output": True + }, + "open-llama-7b": { + "max_memory_per_gpu": [25.0], + "batch_size": [1], + "seq_length": [64, 128, 256], + "stream_output": True + }, + "gpt2": { + "max_memory_per_gpu": [25.0], + "batch_size": [1], + "seq_length": [64, 128, 256], + "stream_output": True + }, + "llama2-13b-smoothquant": { + "max_memory_per_gpu": [21.0], + "batch_size": [1], + "seq_length": [64, 128, 256], + "stream_output": True, + }, +} + def check_worker_number(desired): model_name = get_model_name() @@ -835,6 +862,8 @@ def test_unmerged_lora_correctness(): test_handler(args.model, lmi_dist_aiccl_model_spec) elif args.handler == "trtllm": test_handler(args.model, trtllm_model_spec) + elif args.handler == "deepspeed_rolling_batch": + test_handler(args.model, deepspeed_rolling_batch_model_spec) else: raise ValueError( f"{args.handler} is not one of the supporting handler") diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index f814fb80dd..79ded474c5 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -589,6 +589,34 @@ }, } +deepspeed_rolling_batch_model_list = { + "gpt-neox-20b": { + "option.model_id": "s3://djl-llm/gpt-neox-20b", + "option.task": "text-generation", + "option.tensor_parallel_degree": 4, + "option.max_rolling_batch_size": 4 + }, + "open-llama-7b": { + "option.model_id": "s3://djl-llm/open-llama-7b", + "option.task": "text-generation", + "option.tensor_parallel_degree": 4, + "option.max_rolling_batch_size": 4 + }, + "gpt2": { + "option.model_id": "gpt2", + "option.task": "text-generation", + "option.tensor_parallel_degree": 1, + "option.max_rolling_batch_size": 2 + }, + "llama2-13b-smoothquant": { + "option.model_id": "TheBloke/Llama-2-13B-fp16", + "option.task": "text-generation", + "option.tensor_parallel_degree": 4, + "option.max_rolling_batch_size": 4, + "option.quantize": "smoothquant", + }, +} + def write_model_artifacts(properties, requirements=None, @@ -809,6 +837,18 @@ def build_trtllm_handler_model(model): write_model_artifacts(options) +def build_deepspeed_rolling_batch_model(model): + if model not in deepspeed_rolling_batch_model_list.keys(): + raise ValueError( + f"{model} is not one of the supporting handler {list(deepspeed_rolling_batch_model_list.keys())}" + ) + options = deepspeed_rolling_batch_model_list[model] + options["engine"] = "DeepSpeed" + options["option.rolling_batch"] = "deepspeed" + options["option.output_formatter"] = "jsonlines" + write_model_artifacts(options) + + supported_handler = { 'deepspeed': build_ds_handler_model, 'huggingface': build_hf_handler_model, @@ -825,6 +865,7 @@ def build_trtllm_handler_model(model): 'deepspeed_smoothquant': build_ds_smoothquant_model, 'lmi_dist_aiccl': build_lmi_dist_aiccl_model, 'trtllm': build_trtllm_handler_model, + 'deepspeed_rolling_batch': build_deepspeed_rolling_batch_model, } if __name__ == '__main__':