From 3b21dddfc4f505f08112eca190efd82f30ddbe98 Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Wed, 15 Feb 2023 16:23:17 +0100 Subject: [PATCH] minor --- examples/serving_h3.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/serving_h3.py b/examples/serving_h3.py index e96173f..1730ad1 100644 --- a/examples/serving_h3.py +++ b/examples/serving_h3.py @@ -16,16 +16,21 @@ from src.models.ssm_seq import SSMLMHeadModel +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(int(os.environ.get('LOG_LEVEL', logging.DEBUG))) + class H3Inference(FastInferenceInterface): def __init__(self, model_name: str, args=None) -> None: super().__init__(model_name, args if args is not None else {}) - print("\n=============== Arguments ===============") - print(args.keys()) - print(args) + logging.debug("\n=============== Arguments ===============") + logging.debug(args.keys()) + logging.debug(args) #for key in args.keys(): - # print("{}: {}".format(arg, getattr(args, arg))) - print("=========================================\n") + # logging.debug("{}: {}".format(arg, getattr(args, arg))) + logging.debug("=========================================\n") self.task_info={ "prompt_seqs": None, @@ -100,10 +105,10 @@ def __init__(self, model_name: str, args=None) -> None: self.model = model self.tokenizer = tokenizer - print(f" initialization done") + logging.debug(f" initialization done") def dispatch_request(self, args, env) -> Dict: - print(f"dispatch_request get {args}") + logging.debug(f"dispatch_request get {args}") args = args[0] args = {k: v for k, v in args.items() if v is not None} # Inputs @@ -113,11 +118,11 @@ def dispatch_request(self, args, env) -> Dict: self.task_info["top_p"] = float(args.get("top_p", 0.9)) result = self._run_inference() - print(f" return: {result}") + logging.debug(f" return: {result}") return result def _run_inference(self): - print(f" enter rank-<{0}>") + logging.debug(f" enter rank-<{0}>") with torch.no_grad(): prompt = self.task_info["prompt_seqs"][0] @@ -131,7 +136,7 @@ def _run_inference(self): eos_token_id=self.tokenizer.eos_token_id)[:, input_ids.shape[1]:] # do not include input in the result time_elapsed = timeit.default_timer() - time - print("[INFO] H3 time costs: {:.2f} ms. ".format(time_elapsed * 1000, 0)) + logging.debug("[INFO] H3 time costs: {:.2f} ms. ".format(time_elapsed * 1000, 0)) assert output_ids is not None @@ -165,7 +170,7 @@ def _run_inference(self): help='group name for together coordinator.') args = parser.parse_args() - + coord_url = os.environ.get("COORD_URL", "127.0.0.1") coord_http_port = os.environ.get("COORD_HTTP_PORT", "8092") coord_ws_port = os.environ.get("COORD_WS_PORT", "8093")