Skip to content

Commit

Permalink
[python] Fixes batch error handling. (#1232)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Oct 28, 2023
1 parent 5e25d61 commit 248aa25
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 64 deletions.
114 changes: 53 additions & 61 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import json
import logging
import os
import re
Expand Down Expand Up @@ -250,7 +249,8 @@ def initialize(self, properties: dict):
self.load_stopping_criteria_list(properties["stop_sequence"])
self.initialized = True

def parse_stop_sequence_input(self, stop_sequence):
@staticmethod
def parse_stop_sequence_input(stop_sequence):
"""
Gets a list of stop sequences by parsing the string given in
serving.properties.
Expand Down Expand Up @@ -290,61 +290,46 @@ def parse_input(self, inputs):
adapters = []
errors = {}
batch = inputs.get_batches()
first = True
for i, item in enumerate(batch):
try:
content_type = item.get_property("Content-Type")
input_map = decode(item, content_type)
_inputs = input_map.pop("inputs", input_map)
adapters_per_item = self._fetch_adapters_from_input(
input_map, item)
if first or self.rolling_batch_type:
parameters.append(input_map.pop("parameters", {}))
first = False
else:
param = input_map.pop("parameters", {})
if parameters[0] != param:
logging.warning(
f"expected param: {parameters}, actual: {param}")
raise ValueError(
"In order to enable dynamic batching, all input batches must have the same parameters"
)

if not isinstance(_inputs, list):
_inputs = [_inputs]

if not isinstance(adapters_per_item, list):
adapters_per_item = [adapters_per_item]

if not adapters_per_item:
## inference with just base model.
adapters_per_item = [""] * len(_inputs)
else:
if len(_inputs) != len(adapters_per_item):
## input_size list needs to be appended as it's used during output processing
input_size.append(0)
raise Exception(
"Number of adapters is not equal to the number of inputs"
)

input_data.extend(_inputs)
input_size.append(len(_inputs))
adapters.extend(adapters_per_item)

if "cached_prompt" in input_map:
parameters[i]["cached_prompt"] = input_map.pop(
"cached_prompt")

seed_key = 'seed' if inputs.is_batch() else f'batch_{i}.seed'
if item.contains_key(seed_key):
seed = parameters[i].get("seed")
if not seed:
# set server provided seed if seed is not part of request
parameters[i]["seed"] = item.get_as_string(
key=seed_key)
except Exception as e: # pylint: disable=broad-except
logging.exception(f"Parse input failed: {i}")
logging.warning(f"Parse input failed: {i}")
input_size.append(0)
errors[i] = str(e)
continue

_inputs = input_map.pop("inputs", input_map)
if not isinstance(_inputs, list):
_inputs = [_inputs]
input_data.extend(_inputs)
input_size.append(len(_inputs))

_param = input_map.pop("parameters", {})
if "cached_prompt" in input_map:
_param["cached_prompt"] = input_map.pop("cached_prompt")
if not "seed" in _param:
# set server provided seed if seed is not part of request
if item.contains_key("seed"):
_param["seed"] = item.get_as_string(key="seed")
for _ in range(input_size[i]):
parameters.append(_param)

adapters_per_item = self._fetch_adapters_from_input(
input_map, item)
if not isinstance(adapters_per_item, list):
adapters_per_item = [adapters_per_item]

if not adapters_per_item:
## inference with just base model.
adapters_per_item = [""] * len(_inputs)
adapters.extend(adapters_per_item)
if len(_inputs) != len(adapters_per_item):
logging.warning(
f"Number of adapters is not equal to the number of inputs")
errors[
i] = "Number of adapters is not equal to the number of inputs"

self.adapters = adapters
return input_data, input_size, parameters, errors, batch
Expand All @@ -354,15 +339,16 @@ def inference(self, inputs):

input_data, input_size, parameters, errors, batch = self.parse_input(
inputs)
adapters = self.adapters
if not adapters:
adapters = [""] * len(input_data)
if len(input_data) == 0:
for i in range(len(batch)):
err = errors.get(i)
if self.rolling_batch_type:
err = {"data": "", "last": True, "code": 424, "error": err}
outputs.add(err, key="data", batch_index=i)
outputs.add(Output.binary_encode(err),
key="data",
batch_index=i)
else:
outputs.add(err, key="data", batch_index=i)
return outputs

if self.rolling_batch_type:
Expand Down Expand Up @@ -405,8 +391,13 @@ def inference(self, inputs):
self.device, **parameters[0]))
return outputs

if not all(p == parameters[0] for p in parameters):
raise ValueError(
"In order to enable dynamic batching, all input batches must have the same parameters"
)

if isinstance(self.model, PeftModelForCausalLM):
parameters[0]["adapters"] = adapters
parameters[0]["adapters"] = self.adapters

prediction = self.hf_pipeline(input_data, **parameters[0])

Expand Down Expand Up @@ -624,17 +615,18 @@ def _read_model_config(self, model_config_path: str, revision=None):
f"This is required for loading huggingface models")
raise e

def _fetch_adapters_from_input(self, input_map: dict, input: Input):
@staticmethod
def _fetch_adapters_from_input(input_map: dict, inputs: Input):
if "adapters" in input_map:
return input_map.pop("adapters", [])

# check content, possible in workflow approach
if input.contains_key("adapter"):
return input.get_as_string("adapter")
if inputs.contains_key("adapter"):
return inputs.get_as_string("adapter")

# check properties, possible from header
if "adapter" in input.get_properties():
return input.get_properties()["adapter"]
if "adapter" in inputs.get_properties():
return inputs.get_properties()["adapter"]

return []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -246,6 +247,8 @@ String getSeed() {
void addResponse(byte[] json) {
ByteBuf buf = Unpooled.wrappedBuffer(json);
int size = buf.readShort();
String code = null;
String error = null;
for (int i = 0; i < size; ++i) {
String key = Objects.requireNonNull(CodecUtils.readUtf8(buf));
String value = Objects.requireNonNull(CodecUtils.readUtf8(buf));
Expand All @@ -257,16 +260,25 @@ void addResponse(byte[] json) {
last = "true".equalsIgnoreCase(value);
break;
case "code":
output.setCode(Integer.parseInt(value));
code = value;
break;
case "error":
output.setMessage(value);
error = value;
break;
default:
break;
}
}
data.appendContent(BytesSupplier.wrap(nextToken), last);
if (code != null) {
Map<String, Object> map = new ConcurrentHashMap<>(2);
map.put("code", Integer.parseInt(code));
if (error != null) {
map.put("error", error);
}
data.appendContent(BytesSupplier.wrapAsJson(map), true);
} else {
data.appendContent(BytesSupplier.wrap(nextToken), last);
}
}
}
}

0 comments on commit 248aa25

Please sign in to comment.