Skip to content

Commit

Permalink
chore: fix const import
Browse files Browse the repository at this point in the history
  • Loading branch information
heiruwu committed Nov 14, 2023
1 parent 5127b2b commit 0faa5ab
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
Empty file added instill/helpers/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions instill/helpers/const.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict
from typing import Any, Dict


class DataType(Enum):
Expand All @@ -24,5 +24,5 @@ class TextGenerationInput:
top_k = 1
temperature = 0.8
random_seed = 0
stop_words = ""
stop_words: Any = ""
extra_params: Dict[str, str] = {}
29 changes: 19 additions & 10 deletions instill/helpers/ray_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import List

import numpy as np
from const import TextGenerationInput

from instill.helpers.const import TextGenerationInput


def serialize_byte_tensor(input_tensor):
Expand Down Expand Up @@ -77,7 +78,7 @@ def deserialize_bytes_tensor(encoded_tensor):
while offset < len(val_buf):
l = struct.unpack_from("<I", val_buf, offset)[0]
offset += 4
sb = struct.unpack_from("<{}s".format(l), val_buf, offset)[0]
sb = struct.unpack_from(f"<{l}s", val_buf, offset)[0]
offset += l
strs.append(sb)
return np.array(strs, dtype=bytes)
Expand All @@ -95,29 +96,33 @@ def parse_task_text_generation_input(request) -> TextGenerationInput:
input_tensor = deserialize_bytes_tensor(b_input_tensor)
text_generation_input.prompt = str(input_tensor[0].decode("utf-8"))
print(
f"[DEBUG] input `prompt` type({type(text_generation_input.prompt)}): {text_generation_input.prompt}"
f"[DEBUG] input `prompt` type\
({type(text_generation_input.prompt)}): {text_generation_input.prompt}"
)

if input_name == "max_new_tokens":
text_generation_inputmax_new_tokens = int.from_bytes(
b_input_tensor, "little"
)
print(
f"[DEBUG] input `max_new_tokens` type({type(text_generation_inputmax_new_tokens)}): {text_generation_inputmax_new_tokens}"
f"[DEBUG] input `max_new_tokens` type\
({type(text_generation_inputmax_new_tokens)}): {text_generation_inputmax_new_tokens}"
)

if input_name == "top_k":
text_generation_input.top_k = int.from_bytes(b_input_tensor, "little")
print(
f"[DEBUG] input `top_k` type({type(text_generation_input.top_k)}): {text_generation_input.top_k}"
f"[DEBUG] input `top_k` type\
({type(text_generation_input.top_k)}): {text_generation_input.top_k}"
)

if input_name == "temperature":
text_generation_input.temperature = struct.unpack("f", b_input_tensor)[
0
]
print(
f"[DEBUG] input `temperature` type({type(text_generation_input.temperature)}): {text_generation_input.temperature}"
f"[DEBUG] input `temperature` type\
({type(text_generation_input.temperature)}): {text_generation_input.temperature}"
)
temperature: float = round(temperature, 2)

Expand All @@ -126,14 +131,16 @@ def parse_task_text_generation_input(request) -> TextGenerationInput:
b_input_tensor, "little"
)
print(
f"[DEBUG] input `random_seed` type({type(text_generation_input.random_seed)}): {text_generation_input.random_seed}"
f"[DEBUG] input `random_seed` type\
({type(text_generation_input.random_seed)}): {text_generation_input.random_seed}"
)

if input_name == "stop_words":
input_tensor = deserialize_bytes_tensor(b_input_tensor)
text_generation_input.stop_words = input_tensor[0]
print(
f"[DEBUG] input `stop_words` type({type(text_generation_input.stop_words)}): {text_generation_input.stop_words}"
f"[DEBUG] input `stop_words` type\
({type(text_generation_input.stop_words)}): {text_generation_input.stop_words}"
)
if len(text_generation_input.stop_words) == 0:
text_generation_input.stop_words = None
Expand All @@ -147,14 +154,16 @@ def parse_task_text_generation_input(request) -> TextGenerationInput:
str(text_generation_input.stop_words[0])
]
print(
f"[DEBUG] parsed input `stop_words` type({type(text_generation_input.stop_words)}): {text_generation_input.stop_words}"
f"[DEBUG] parsed input `stop_words` type\
({type(text_generation_input.stop_words)}): {text_generation_input.stop_words}"
)

if input_name == "extra_params":
input_tensor = deserialize_bytes_tensor(b_input_tensor)
extra_params_str = str(input_tensor[0].decode("utf-8"))
print(
f"[DEBUG] input `extra_params` type({type(extra_params_str)}): {extra_params_str}"
f"[DEBUG] input `extra_params` type\
({type(extra_params_str)}): {extra_params_str}"
)

try:
Expand Down

0 comments on commit 0faa5ab

Please sign in to comment.