Skip to content

Commit

Permalink
Add DeepSpeed support for rolling batch
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaqib Ansari committed Nov 14, 2023
1 parent 5f4ca95 commit 7c839f5
Show file tree
Hide file tree
Showing 8 changed files with 403 additions and 33 deletions.
74 changes: 73 additions & 1 deletion .github/workflows/rolling_batch_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,82 @@ jobs:
name: vllm-logs
path: tests/integration/logs/

deepspeed-test:
runs-on: [ self-hosted, g5 ]
timeout-minutes: 60
needs: create-runners
steps:
- uses: actions/checkout@v3
- name: Clean env
run: |
yes | docker system prune -a --volumes
sudo rm -rf /home/ubuntu/actions-runner/_work/_tool/Java_Corretto_jdk/
echo "wait dpkg lock..."
while sudo fuser /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock >/dev/null 2>&1; do sleep 5; done
- name: Set up Python3
uses: actions/setup-python@v4
with:
python-version: '3.10.x'
- name: Install pip dependencies
run: pip3 install requests pillow numpy
- name: Build container name
run: ./serving/docker/scripts/docker_name_builder.sh deepspeed ${{ github.event.inputs.djl-version }}
- name: Download models and dockers
working-directory: tests/integration
run: |
docker pull deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG
- name: Test deepspeed_rolling_batch gpt-neox-20b
working-directory: tests/integration
run: |
rm -rf models
python3 llm/prepare.py deepspeed_rolling_batch gpt-neox-20b
./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models deepspeed \
serve
python3 llm/client.py deepspeed_rolling_batch gpt-neox-20b
docker rm -f $(docker ps -aq)
- name: Test deepspeed_rolling_batch open-llama-7b
working-directory: tests/integration
run: |
rm -rf models
python3 llm/prepare.py deepspeed_rolling_batch open-llama-7b
./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models deepspeed \
serve
python3 llm/client.py deepspeed_rolling_batch open-llama-7b
docker rm -f $(docker ps -aq)
- name: Test deepspeed_rolling_batch gpt2
working-directory: tests/integration
run: |
rm -rf models
python3 llm/prepare.py deepspeed_rolling_batch gpt2
./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models deepspeed \
serve
python3 llm/client.py deepspeed_rolling_batch gpt2
docker rm -f $(docker ps -aq)
- name: Test deepspeed_rolling_batch llama2-13b-smoothquant
working-directory: tests/integration
run: |
rm -rf models
python3 llm/prepare.py deepspeed_rolling_batch llama2-13b-smoothquant
./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models deepspeed \
serve
python3 llm/client.py deepspeed_rolling_batch llama2-13b-smoothquant
docker rm -f $(docker ps -aq)
- name: On fail step
if: ${{ failure() }}
working-directory: tests/integration
run: |
docker rm -f $(docker ps -aq) || true
cat logs/serving.log
- name: Upload test logs
uses: actions/upload-artifact@v3
with:
name: ds-rolling-batch-handler-logs
path: tests/integration/logs/

stop-runners:
if: always()
runs-on: [ self-hosted, scheduler ]
needs: [ create-runners, scheduler-single-gpu-test, scheduler-multi-gpu-test, lmi-dist-test-1, lmi-dist-test-2, vllm-test ]
needs: [ create-runners, scheduler-single-gpu-test, scheduler-multi-gpu-test, lmi-dist-test-1, lmi-dist-test-2, vllm-test, deepspeed-test ]
steps:
- name: Stop all instances
run: |
Expand Down
116 changes: 89 additions & 27 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from peft import PeftConfig, PeftModel

from djl_python.properties_manager.ds_properties import DeepSpeedProperties, DsQuantizeMethods
from djl_python.properties_manager.properties import StreamingEnum, is_streaming_enabled
from djl_python.properties_manager.properties import StreamingEnum, is_streaming_enabled, is_rolling_batch_enabled

SMOOTHQUANT_SUPPORTED_MODEL_TYPES = {
"gpt2",
Expand Down Expand Up @@ -106,19 +106,43 @@ def __init__(self):
self.peft_config = None
self.model = None
self.tokenizer = None
self.rolling_batch = None
self.enable_rolling_batch = False

def initialize(self, properties: dict):
self.properties = DeepSpeedProperties(**properties)
self.enable_rolling_batch = is_rolling_batch_enabled(
self.properties.rolling_batch)
self._read_model_config()
self._validate_model_type_and_task()
self.create_model_pipeline()
if self.enable_rolling_batch:
from djl_python.rolling_batch.deepspeed_rolling_batch import DeepSpeedRollingBatch
self.model = self.create_ds_module()
if not self.properties.ds_config.get("replace_with_kernel_inject",
False):
raise ValueError(
f"option.rolling_batch=deepspeed only works with kernel_injection models: {OPTIMIZED_MODEL_TYPES}"
)
kwargs = {
"max_batch_size":
int(properties.get("max_rolling_batch_size", 4)),
"max_seq_len": int(properties.get("max_seq_len", 1024)),
"tokenizer": self.tokenizer
}
if "output_formatter" in properties:
kwargs["output_formatter"] = properties.get("output_formatter")
self.rolling_batch = DeepSpeedRollingBatch(self.model,
self.properties.device,
properties, **kwargs)
else:
self.create_model_pipeline()
self.logger.info(
f"Initialized DeepSpeed model with the following configurations\n"
f"model: {self.properties.model_id_or_path}\n"
f"task: {self.properties.task}\n"
f"data_type: {self.properties.ds_config['dtype']}\n"
f"tensor_parallel_degree: {self.properties.tensor_parallel_degree}\n"
)
f"rolling_batch: {self.enable_rolling_batch}\n")
self.initialized = True

def _validate_model_type_and_task(self):
Expand Down Expand Up @@ -239,7 +263,7 @@ def load_model_with_mmap(model_id_or_path, loading_method):

return model, tokenizer, state_dict_mmap

def create_model_pipeline(self):
def create_ds_module(self):
# If a ds checkpoint is provided, we instantiate model with meta tensors. weights loaded when DS engine invoked
# Workaround on int8. fp16 fp32 bf16 init supported
dtype = torch.float16 if self.properties.dtype == torch.int8 else self.properties.dtype
Expand Down Expand Up @@ -303,8 +327,10 @@ def create_model_pipeline(self):
if smoothing_config.get("calibrate", False):
smoothing_config["tokenizer"] = self.tokenizer

self.model = deepspeed.init_inference(model, self.properties.ds_config)
return deepspeed.init_inference(model, self.properties.ds_config)

def create_model_pipeline(self):
self.model = self.create_ds_module()
# Don't create a "pipeline" if we're streaming or text-generation task, since those don't use a pipeline
if is_streaming_enabled(self.properties.enable_streaming
) or self.properties.task == "text-generation":
Expand Down Expand Up @@ -348,36 +374,72 @@ def parse_input(self, inputs):
content_type = item.get_property("Content-Type")
input_map = decode(item, content_type)
_inputs = input_map.pop("inputs", input_map)
if first:
parameters.append(input_map.pop("parameters", {}))
first = False
else:
param = input_map.pop("parameters", {})
if parameters[0] != param:
logging.warning(
f"expected param: {parameters}, actual: {param}")
raise ValueError(
"In order to enable dynamic batching, all input batches must have the same parameters"
)
if isinstance(_inputs, list):
input_data.extend(_inputs)
input_size.append(len(_inputs))
else:
input_data.append(_inputs)
input_size.append(1)
_param = input_map.pop("parameters", {})
if not self.enable_rolling_batch:
if first:
parameters.append(_param)
first = False
else:
if parameters[0] != _param:
logging.warning(
f"expected param: {parameters}, actual: {_param}"
)
raise ValueError(
"In order to enable dynamic batching, all input batches must have the same parameters"
)
if not isinstance(_inputs, list):
_inputs = [_inputs]
input_data.extend(_inputs)
input_size.append(len(_inputs))
if self.enable_rolling_batch:
for _ in range(input_size[i]):
parameters.append(_param)

except Exception as e: # pylint: disable=broad-except
logging.exception(f"Parse input failed: {i}")
errors[i] = str(e)

return input_data, input_size, parameters, errors, batch

def inference(self, inputs: Input):

input_data, input_size, parameters, errors, batch = self.parse_input(
inputs)
parameters = parameters[0]

outputs = Output()
input_data, input_size, params, errors, batch = self.parse_input(
inputs)
if len(input_data) == 0:
for i in range(len(batch)):
err = errors.get(i)
if self.enable_rolling_batch:
err = {"data": "", "last": True, "code": 424, "error": err}
outputs.add(Output.binary_encode(err),
key="data",
batch_index=i)
else:
outputs.add(err, key="data", batch_index=i)
return outputs
parameters = params[0]

if self.enable_rolling_batch:
if inputs.get_property("reset_rollingbatch"):
self.rolling_batch.reset()
result = self.rolling_batch.inference(input_data, params)
idx = 0
for i in range(len(batch)):
err = errors.get(i)
if err:
err = {"data": "", "last": True, "code": 424, "error": err}
outputs.add(Output.binary_encode(err),
key="data",
batch_index=i)
else:
outputs.add(Output.binary_encode(result[idx]),
key="data",
batch_index=i)
idx += 1

content_type = self.rolling_batch.get_content_type()
if content_type:
outputs.add_property("content-type", content_type)
return outputs
if is_streaming_enabled(self.properties.enable_streaming):
if len(batch) > 1:
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pydantic import root_validator, validator, Field

from djl_python.properties_manager.properties import Properties
from djl_python.properties_manager.properties import Properties, RollingBatchEnum


class DsQuantizeMethods(str, Enum):
Expand All @@ -19,6 +19,9 @@ class DsQuantizeMethods(str, Enum):
SUPPORTED_QUANTIZATION_MODE = [
DsQuantizeMethods.smoothquant.value, DsQuantizeMethods.dynamicint8.value
]
DS_SUPPORTED_ROLLING_BATCH_TYPES = [
RollingBatchEnum.auto.value, RollingBatchEnum.deepspeed.value
]


class DeepSpeedProperties(Properties):
Expand Down Expand Up @@ -75,6 +78,16 @@ def set_ds_config(cls, deepspeed_config_path):
with open(deepspeed_config_path, "r") as f:
return json.load(f)

@validator('rolling_batch', pre=True)
def validate_rolling_batch(cls, rolling_batch) -> bool:
if rolling_batch == RollingBatchEnum.disable.value:
return rolling_batch
if rolling_batch not in DS_SUPPORTED_ROLLING_BATCH_TYPES:
raise ValueError(
f"deepspeed engine only supports "
f"rolling batch type {DS_SUPPORTED_ROLLING_BATCH_TYPES}.")
return rolling_batch

@root_validator()
def set_dtype(cls, properties):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class RollingBatchEnum(str, Enum):
auto = "auto"
disable = "disable"
trtllm = "trtllm"
deepspeed = "deepspeed"


class StreamingEnum(str, Enum):
Expand Down
Loading

0 comments on commit 7c839f5

Please sign in to comment.