Skip to content

Commit

Permalink
[fix] restrict per request streaming to rolling batch use-cases (#1670)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Mar 26, 2024
1 parent fc9d407 commit 191b084
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
4 changes: 3 additions & 1 deletion engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ def parse_input(self, inputs):
input_map = decode(item, content_type)
_inputs = input_map.pop("inputs", input_map)
_param = input_map.pop("parameters", {})
_param["stream"] = input_map.pop("stream", False)
if not self.enable_rolling_batch:
if first:
parameters.append(_param)
Expand All @@ -384,6 +383,9 @@ def parse_input(self, inputs):
raise ValueError(
"In order to enable dynamic batching, all input batches must have the same parameters"
)
else:
# Per request streaming is only supported by rolling batch
_param["stream"] = input_map.pop("stream", False)

if "seed" not in _param:
# set server provided seed if seed is not part of request
Expand Down
4 changes: 4 additions & 0 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ def parse_input(self, inputs):
else:
_inputs = input_map.pop("inputs", input_map)
_param = input_map.pop("parameters", {})

# Per request streaming is only supported by rolling batch
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
_param["stream"] = input_map.pop("stream", False)

if not isinstance(_inputs, list):
_inputs = [_inputs]
input_data.extend(_inputs)
Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def parse_input(self, inputs):
if "output_formatter" not in param:
param[
"output_formatter"] = self.config.output_formatter
param["stream"] = input_map.pop("stream", False)
if first or self.rolling_batch:
parameters.append(param)
first = False
Expand Down Expand Up @@ -189,6 +188,8 @@ def parse_input(self, inputs):

if not "output_formatter" in param:
param["output_formatter"] = self.config.output_formatter
if self.rolling_batch:
param["stream"] = input_map.pop("stream", False)

for _ in range(input_size[i]):
parameters.append(param)
Expand Down

0 comments on commit 191b084

Please sign in to comment.