Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyaoi committed Feb 15, 2023
1 parent 400e4aa commit 3b21ddd
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions examples/serving_h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@

from src.models.ssm_seq import SSMLMHeadModel

import logging

logger = logging.getLogger(__name__)
logger.setLevel(int(os.environ.get('LOG_LEVEL', logging.DEBUG)))


class H3Inference(FastInferenceInterface):
def __init__(self, model_name: str, args=None) -> None:
super().__init__(model_name, args if args is not None else {})
print("\n=============== Arguments ===============")
print(args.keys())
print(args)
logging.debug("\n=============== Arguments ===============")
logging.debug(args.keys())
logging.debug(args)
#for key in args.keys():
# print("{}: {}".format(arg, getattr(args, arg)))
print("=========================================\n")
# logging.debug("{}: {}".format(arg, getattr(args, arg)))
logging.debug("=========================================\n")

self.task_info={
"prompt_seqs": None,
Expand Down Expand Up @@ -100,10 +105,10 @@ def __init__(self, model_name: str, args=None) -> None:

self.model = model
self.tokenizer = tokenizer
print(f"<H3Inference.__init__> initialization done")
logging.debug(f"<H3Inference.__init__> initialization done")

def dispatch_request(self, args, env) -> Dict:
print(f"dispatch_request get {args}")
logging.debug(f"dispatch_request get {args}")
args = args[0]
args = {k: v for k, v in args.items() if v is not None}
# Inputs
Expand All @@ -113,11 +118,11 @@ def dispatch_request(self, args, env) -> Dict:
self.task_info["top_p"] = float(args.get("top_p", 0.9))

result = self._run_inference()
print(f"<H3Inference.dispatch_request> return: {result}")
logging.debug(f"<H3Inference.dispatch_request> return: {result}")
return result

def _run_inference(self):
print(f"<H3Inference._run_inference> enter rank-<{0}>")
logging.debug(f"<H3Inference._run_inference> enter rank-<{0}>")

with torch.no_grad():
prompt = self.task_info["prompt_seqs"][0]
Expand All @@ -131,7 +136,7 @@ def _run_inference(self):
eos_token_id=self.tokenizer.eos_token_id)[:, input_ids.shape[1]:] # do not include input in the result
time_elapsed = timeit.default_timer() - time

print("[INFO] H3 time costs: {:.2f} ms. <rank-{}>".format(time_elapsed * 1000, 0))
logging.debug("[INFO] H3 time costs: {:.2f} ms. <rank-{}>".format(time_elapsed * 1000, 0))

assert output_ids is not None

Expand Down Expand Up @@ -165,7 +170,7 @@ def _run_inference(self):
help='group name for together coordinator.')

args = parser.parse_args()

coord_url = os.environ.get("COORD_URL", "127.0.0.1")
coord_http_port = os.environ.get("COORD_HTTP_PORT", "8092")
coord_ws_port = os.environ.get("COORD_WS_PORT", "8093")
Expand Down

0 comments on commit 3b21ddd

Please sign in to comment.