diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 0ba6b8372..14835dbfd 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -10,7 +10,6 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -import json import logging import os import re @@ -250,7 +249,8 @@ def initialize(self, properties: dict): self.load_stopping_criteria_list(properties["stop_sequence"]) self.initialized = True - def parse_stop_sequence_input(self, stop_sequence): + @staticmethod + def parse_stop_sequence_input(stop_sequence): """ Gets a list of stop sequences by parsing the string given in serving.properties. @@ -290,61 +290,46 @@ def parse_input(self, inputs): adapters = [] errors = {} batch = inputs.get_batches() - first = True for i, item in enumerate(batch): try: content_type = item.get_property("Content-Type") input_map = decode(item, content_type) - _inputs = input_map.pop("inputs", input_map) - adapters_per_item = self._fetch_adapters_from_input( - input_map, item) - if first or self.rolling_batch_type: - parameters.append(input_map.pop("parameters", {})) - first = False - else: - param = input_map.pop("parameters", {}) - if parameters[0] != param: - logging.warning( - f"expected param: {parameters}, actual: {param}") - raise ValueError( - "In order to enable dynamic batching, all input batches must have the same parameters" - ) - - if not isinstance(_inputs, list): - _inputs = [_inputs] - - if not isinstance(adapters_per_item, list): - adapters_per_item = [adapters_per_item] - - if not adapters_per_item: - ## inference with just base model. - adapters_per_item = [""] * len(_inputs) - else: - if len(_inputs) != len(adapters_per_item): - ## input_size list needs to be appended as it's used during output processing - input_size.append(0) - raise Exception( - "Number of adapters is not equal to the number of inputs" - ) - - input_data.extend(_inputs) - input_size.append(len(_inputs)) - adapters.extend(adapters_per_item) - - if "cached_prompt" in input_map: - parameters[i]["cached_prompt"] = input_map.pop( - "cached_prompt") - - seed_key = 'seed' if inputs.is_batch() else f'batch_{i}.seed' - if item.contains_key(seed_key): - seed = parameters[i].get("seed") - if not seed: - # set server provided seed if seed is not part of request - parameters[i]["seed"] = item.get_as_string( - key=seed_key) except Exception as e: # pylint: disable=broad-except - logging.exception(f"Parse input failed: {i}") + logging.warning(f"Parse input failed: {i}") + input_size.append(0) errors[i] = str(e) + continue + + _inputs = input_map.pop("inputs", input_map) + if not isinstance(_inputs, list): + _inputs = [_inputs] + input_data.extend(_inputs) + input_size.append(len(_inputs)) + + _param = input_map.pop("parameters", {}) + if "cached_prompt" in input_map: + _param["cached_prompt"] = input_map.pop("cached_prompt") + if not "seed" in _param: + # set server provided seed if seed is not part of request + if item.contains_key("seed"): + _param["seed"] = item.get_as_string(key="seed") + for _ in range(input_size[i]): + parameters.append(_param) + + adapters_per_item = self._fetch_adapters_from_input( + input_map, item) + if not isinstance(adapters_per_item, list): + adapters_per_item = [adapters_per_item] + + if not adapters_per_item: + ## inference with just base model. + adapters_per_item = [""] * len(_inputs) + adapters.extend(adapters_per_item) + if len(_inputs) != len(adapters_per_item): + logging.warning( + f"Number of adapters is not equal to the number of inputs") + errors[ + i] = "Number of adapters is not equal to the number of inputs" self.adapters = adapters return input_data, input_size, parameters, errors, batch @@ -354,15 +339,16 @@ def inference(self, inputs): input_data, input_size, parameters, errors, batch = self.parse_input( inputs) - adapters = self.adapters - if not adapters: - adapters = [""] * len(input_data) if len(input_data) == 0: for i in range(len(batch)): err = errors.get(i) if self.rolling_batch_type: err = {"data": "", "last": True, "code": 424, "error": err} - outputs.add(err, key="data", batch_index=i) + outputs.add(Output.binary_encode(err), + key="data", + batch_index=i) + else: + outputs.add(err, key="data", batch_index=i) return outputs if self.rolling_batch_type: @@ -405,8 +391,13 @@ def inference(self, inputs): self.device, **parameters[0])) return outputs + if not all(p == parameters[0] for p in parameters): + raise ValueError( + "In order to enable dynamic batching, all input batches must have the same parameters" + ) + if isinstance(self.model, PeftModelForCausalLM): - parameters[0]["adapters"] = adapters + parameters[0]["adapters"] = self.adapters prediction = self.hf_pipeline(input_data, **parameters[0]) @@ -624,17 +615,18 @@ def _read_model_config(self, model_config_path: str, revision=None): f"This is required for loading huggingface models") raise e - def _fetch_adapters_from_input(self, input_map: dict, input: Input): + @staticmethod + def _fetch_adapters_from_input(input_map: dict, inputs: Input): if "adapters" in input_map: return input_map.pop("adapters", []) # check content, possible in workflow approach - if input.contains_key("adapter"): - return input.get_as_string("adapter") + if inputs.contains_key("adapter"): + return inputs.get_as_string("adapter") # check properties, possible from header - if "adapter" in input.get_properties(): - return input.get_properties()["adapter"] + if "adapter" in inputs.get_properties(): + return inputs.get_properties()["adapter"] return [] diff --git a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java index 7480f5371..76d1aa64d 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java +++ b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java @@ -35,6 +35,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -246,6 +247,8 @@ String getSeed() { void addResponse(byte[] json) { ByteBuf buf = Unpooled.wrappedBuffer(json); int size = buf.readShort(); + String code = null; + String error = null; for (int i = 0; i < size; ++i) { String key = Objects.requireNonNull(CodecUtils.readUtf8(buf)); String value = Objects.requireNonNull(CodecUtils.readUtf8(buf)); @@ -257,16 +260,25 @@ void addResponse(byte[] json) { last = "true".equalsIgnoreCase(value); break; case "code": - output.setCode(Integer.parseInt(value)); + code = value; break; case "error": - output.setMessage(value); + error = value; break; default: break; } } - data.appendContent(BytesSupplier.wrap(nextToken), last); + if (code != null) { + Map map = new ConcurrentHashMap<>(2); + map.put("code", Integer.parseInt(code)); + if (error != null) { + map.put("error", error); + } + data.appendContent(BytesSupplier.wrapAsJson(map), true); + } else { + data.appendContent(BytesSupplier.wrap(nextToken), last); + } } } }