Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/kserve torch auth #3302

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions kubernetes/kserve/kserve_wrapper/TorchserveModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" The torchserve side inference end-points request are handled to
return a KServe side response """
import json
import logging
import os
import pathlib
Expand Down Expand Up @@ -54,15 +55,19 @@ def __init__(
grpc_inference_address,
protocol,
model_dir,
ts_auth_enabled,
):
"""The Model Name, Inference Address, Management Address and the model directory
are specified.

Args:
name (str): Model Name
inference_address (str): The Inference Address in which we hit the inference end point
management_address (str): The Management Address in which we register the model.
model_dir (str): The location of the model artifacts.
management_address (str): The Management Address in which we register the model
grpc_inference_address (str): The GRPC Inference Address
protocol (str): The API REST protocol version
model_dir (str): The location of the model artifacts
ts_auth_enabled (bool): Whether torchserve has auth enabled
"""
super().__init__(name)

Expand All @@ -74,6 +79,7 @@ def __init__(
self.inference_address = inference_address
self.management_address = management_address
self.model_dir = model_dir
self.ts_auth_enabled = ts_auth_enabled

# Validate the protocol value passed otherwise, the default value will be used
if protocol is not None:
Expand All @@ -85,6 +91,9 @@ def __init__(
logging.info("Predict URL set to %s", self.predictor_host)
logging.info("Explain URL set to %s", self.explainer_host)
logging.info("Protocol version is %s", self.protocol)
logging.info(
"Torchserve auth is %s", "enabled" if self.ts_auth_enabled else "disabled"
)

def grpc_client(self):
if self._grpc_client_stub is None:
Expand Down Expand Up @@ -168,13 +177,24 @@ def load(self) -> bool:
logging.info(
f"Loading {self.name} .. {num_try} of {model_load_max_try} tries.."
)
# Sleep first so that it won't sleep after ready
logging.info(f"Sleep {model_load_delay} seconds for load {self.name}..")
time.sleep(model_load_delay)

try:
headers = None
if self.ts_auth_enabled:
with open("key_file.json") as f:
keys = json.load(f)
management_endpoint_key = keys["management"]["key"]
headers = {"Authorization": f"Bearer {management_endpoint_key}"}

response = requests.get(
READINESS_URL_FORMAT.format(
self.management_address, self.name, model_load_customized
),
timeout=model_load_timeout,
headers=headers,
).json()

default_verison = response[0]
Expand Down Expand Up @@ -207,9 +227,6 @@ def load(self) -> bool:
logging.info(f"Failed loading model {self.name}")
break

logging.info(f"Sleep {model_load_delay} seconds for load {self.name}..")
time.sleep(model_load_delay)

if self.ready:
logging.info(f"The model {self.name} is ready")

Expand Down
10 changes: 8 additions & 2 deletions kubernetes/kserve/kserve_wrapper/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def parse_config():
model_name: The name of the model specified in the config.properties
inference_address: The inference address in which the inference endpoint is hit
management_address: The management address in which the model gets registered
model_store: the path in which the .mar file resides
model_store: The path in which the .mar file resides
"""
separator = "="
ts_configuration = {}
Expand Down Expand Up @@ -66,7 +66,7 @@ def parse_config():
grpc_inference_address = grpc_inference_address.replace("/", "")

logging.info(
"Wrapper : Model names %s, inference address %s, management address %s, grpc_inference_address, %s, model store %s",
"Wrapper: Model names %s, inference address %s, management address %s, grpc_inference_address %s, model store %s",
model_names,
inference_address,
management_address,
Expand All @@ -92,6 +92,11 @@ def parse_config():
model_dir,
) = parse_config()

# Torchserve enables auth by default. This can be defined in env variables, cli arguments or config file
# https://github.com/pytorch/serve/blob/master/docs/token_authorization_api.md
# The simplest method to check whether it's enabled is to check whether the file `key_file.json` exists
ts_auth_enabled = os.path.exists("key_file.json")

protocol = os.environ.get("PROTOCOL_VERSION", PredictorProtocol.REST_V1.value)

models = []
Expand All @@ -104,6 +109,7 @@ def parse_config():
grpc_inference_address,
protocol,
model_dir,
ts_auth_enabled,
)
# By default model.load() is called on first request. Enabling load all
# model in TS config.properties, all models are loaded at start and the
Expand Down