From 0faa5ab3d3f20dbe3b5be21f78fad4b99d32ed98 Mon Sep 17 00:00:00 2001 From: Heiru Wu Date: Tue, 14 Nov 2023 17:58:28 +0800 Subject: [PATCH] chore: fix const import --- instill/helpers/__init__.py | 0 instill/helpers/const.py | 4 ++-- instill/helpers/ray_io.py | 29 +++++++++++++++++++---------- 3 files changed, 21 insertions(+), 12 deletions(-) create mode 100644 instill/helpers/__init__.py diff --git a/instill/helpers/__init__.py b/instill/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/instill/helpers/const.py b/instill/helpers/const.py index 4294238..28fcec8 100644 --- a/instill/helpers/const.py +++ b/instill/helpers/const.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict +from typing import Any, Dict class DataType(Enum): @@ -24,5 +24,5 @@ class TextGenerationInput: top_k = 1 temperature = 0.8 random_seed = 0 - stop_words = "" + stop_words: Any = "" extra_params: Dict[str, str] = {} diff --git a/instill/helpers/ray_io.py b/instill/helpers/ray_io.py index 50cd801..52a3a80 100644 --- a/instill/helpers/ray_io.py +++ b/instill/helpers/ray_io.py @@ -3,7 +3,8 @@ from typing import List import numpy as np -from const import TextGenerationInput + +from instill.helpers.const import TextGenerationInput def serialize_byte_tensor(input_tensor): @@ -77,7 +78,7 @@ def deserialize_bytes_tensor(encoded_tensor): while offset < len(val_buf): l = struct.unpack_from(" TextGenerationInput: input_tensor = deserialize_bytes_tensor(b_input_tensor) text_generation_input.prompt = str(input_tensor[0].decode("utf-8")) print( - f"[DEBUG] input `prompt` type({type(text_generation_input.prompt)}): {text_generation_input.prompt}" + f"[DEBUG] input `prompt` type\ + ({type(text_generation_input.prompt)}): {text_generation_input.prompt}" ) if input_name == "max_new_tokens": @@ -103,13 +105,15 @@ def parse_task_text_generation_input(request) -> TextGenerationInput: b_input_tensor, "little" ) print( - f"[DEBUG] input `max_new_tokens` type({type(text_generation_inputmax_new_tokens)}): {text_generation_inputmax_new_tokens}" + f"[DEBUG] input `max_new_tokens` type\ + ({type(text_generation_inputmax_new_tokens)}): {text_generation_inputmax_new_tokens}" ) if input_name == "top_k": text_generation_input.top_k = int.from_bytes(b_input_tensor, "little") print( - f"[DEBUG] input `top_k` type({type(text_generation_input.top_k)}): {text_generation_input.top_k}" + f"[DEBUG] input `top_k` type\ + ({type(text_generation_input.top_k)}): {text_generation_input.top_k}" ) if input_name == "temperature": @@ -117,7 +121,8 @@ def parse_task_text_generation_input(request) -> TextGenerationInput: 0 ] print( - f"[DEBUG] input `temperature` type({type(text_generation_input.temperature)}): {text_generation_input.temperature}" + f"[DEBUG] input `temperature` type\ + ({type(text_generation_input.temperature)}): {text_generation_input.temperature}" ) temperature: float = round(temperature, 2) @@ -126,14 +131,16 @@ def parse_task_text_generation_input(request) -> TextGenerationInput: b_input_tensor, "little" ) print( - f"[DEBUG] input `random_seed` type({type(text_generation_input.random_seed)}): {text_generation_input.random_seed}" + f"[DEBUG] input `random_seed` type\ + ({type(text_generation_input.random_seed)}): {text_generation_input.random_seed}" ) if input_name == "stop_words": input_tensor = deserialize_bytes_tensor(b_input_tensor) text_generation_input.stop_words = input_tensor[0] print( - f"[DEBUG] input `stop_words` type({type(text_generation_input.stop_words)}): {text_generation_input.stop_words}" + f"[DEBUG] input `stop_words` type\ + ({type(text_generation_input.stop_words)}): {text_generation_input.stop_words}" ) if len(text_generation_input.stop_words) == 0: text_generation_input.stop_words = None @@ -147,14 +154,16 @@ def parse_task_text_generation_input(request) -> TextGenerationInput: str(text_generation_input.stop_words[0]) ] print( - f"[DEBUG] parsed input `stop_words` type({type(text_generation_input.stop_words)}): {text_generation_input.stop_words}" + f"[DEBUG] parsed input `stop_words` type\ + ({type(text_generation_input.stop_words)}): {text_generation_input.stop_words}" ) if input_name == "extra_params": input_tensor = deserialize_bytes_tensor(b_input_tensor) extra_params_str = str(input_tensor[0].decode("utf-8")) print( - f"[DEBUG] input `extra_params` type({type(extra_params_str)}): {extra_params_str}" + f"[DEBUG] input `extra_params` type\ + ({type(extra_params_str)}): {extra_params_str}" ) try: