Skip to content
This repository has been archived by the owner on Jul 15, 2024. It is now read-only.

Commit

Permalink
Merge pull request caikit#636 from gabe-l-hart/LivenessProbe-635
Browse files Browse the repository at this point in the history
Liveness probe 635
  • Loading branch information
gabe-l-hart authored Jan 8, 2024
2 parents dff142c + a387263 commit e33e280
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 26 deletions.
2 changes: 2 additions & 0 deletions caikit/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ runtime:
options: {}
# Legacy config for setting thread pool size. See runtime.server_thread_pool_size instead
server_thread_pool_size: null
# Timeout for health probe to receive a response
probe_timeout: null

# Configuration for the http server
http:
Expand Down
100 changes: 78 additions & 22 deletions caikit_health_probe/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module implements a common health probe for all running runtime servers.
This module implements common health probes (liveness and readiness) for all
running runtime servers.
"""
# Standard
from contextlib import contextmanager
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import importlib.util
import os
import sys
import tempfile
import warnings

# Third Party
import psutil

# First Party
import alog

Expand All @@ -42,51 +46,81 @@


@alog.timed_function(log.debug)
def health_probe() -> bool:
"""Run a health probe against all running runtime servers.
def readiness_probe() -> bool:
"""Run a readiness probe against all running runtime servers.
This function is intended to be run from an environment where the config is
identical to the config that the server is running such as from inside a
kubernetes pod where the server is also running.
Returns:
healthy (bool): True if all servers are healthy, False otherwise
ready (bool): True if all servers are ready to take requests, False
otherwise
"""

# Get TLS key/cert files if possible
config = get_config()
tls_key = config.runtime.tls.server.key
tls_cert = config.runtime.tls.server.cert
client_ca = config.runtime.tls.client.cert
http_healthy, grpc_healthy = None, None
http_ready, grpc_ready = None, None

if config.runtime.http.enabled:
log.debug("Checking HTTP server health")
http_healthy = _http_health_probe(
http_ready = _http_readiness_probe(
config.runtime.http.port, tls_key, tls_cert, client_ca
)

if config.runtime.grpc.enabled:
log.debug("Checking gRPC server health")
grpc_healthy = _grpc_health_probe(
grpc_ready = _grpc_readiness_probe(
config.runtime.grpc.port, tls_key, tls_cert, client_ca
)

if False in [http_healthy, grpc_healthy]:
if False in [http_ready, grpc_ready]:
log.info(
"<RUN64273066I>",
"Server not healthy. HTTP: %s, gRPC: %s",
http_healthy,
grpc_healthy,
"Runtime server(s) not ready. HTTP: %s, gRPC: %s",
http_ready,
grpc_ready,
)
return False
return True


@alog.timed_function(log.debug)
def liveness_probe(runtime_proc_identifier: str = "caikit.runtime") -> bool:
# Get all running processes that we have access to
this_proc = psutil.Process()
this_exe = this_proc.exe()
procs = [_get_proc_info(pid) for pid in psutil.pids() if pid != this_proc.pid]

# Filter down to caikit runtime processes
caikit_procs = [
proc_info
for proc_info in procs
if proc_info is not None
and proc_info[0] == this_exe
and any(runtime_proc_identifier in arg for arg in proc_info[1])
]

# If we have running caikit processes, we consider the server to be alive
return bool(caikit_procs)


## Implementation ##############################################################


def _http_health_probe(
def _get_proc_info(pid: int) -> Optional[Tuple[str, List[str]]]:
"""Attempt to get the given pid's information (exe and cmdline)"""
try:
proc = psutil.Process(pid)
return (proc.exe(), proc.cmdline())
except psutil.Error:
return None


def _http_readiness_probe(
port: int,
tls_key: Optional[str],
tls_cert: Optional[str],
Expand All @@ -95,15 +129,15 @@ def _http_health_probe(
"""Probe the http server
The implementation of this utility is a bit tricky because mTLS makes this
quite challenging. For insecure or TLS servers, we expect a valid healthy
quite challenging. For insecure or TLS servers, we expect a valid ready
response, but for mTLS servers, we may not have a valid key/cert pair that
the client can present to the server that is signed by the expected CA if
the trusted client CA does not match the one that signed the server's
key/cert pair.
The workaround for this is to detect SSLError and consider that to be a
passing health check. If the server is healthy enough to _reject_ bad SSL
requests, it's healthy enough to server good ones!
passing readiness check. If the server is ready enough to _reject_ bad SSL
requests, it's ready enough to server good ones!
Args:
port (int): The port that the HTTP server is serving on
Expand All @@ -115,7 +149,8 @@ def _http_health_probe(
for mutual client auth
Returns:
healthy (bool): True if all servers are healthy, False otherwise
ready (bool): True if the http server is ready to take requests, False
otherwise
"""
# NOTE: Local imports for optional dependency
with alog.ContextTimer(log.debug2, "Done with local grpc imports: "):
Expand Down Expand Up @@ -167,7 +202,7 @@ def _http_health_probe(
return False


def _grpc_health_probe(
def _grpc_readiness_probe(
port: int,
tls_key: Optional[str],
tls_cert: Optional[str],
Expand All @@ -176,7 +211,7 @@ def _grpc_health_probe(
"""Probe the grpc server
Since the gRPC server trusts its own cert for client verification, we can
make a valid health probe against the running server regardless of (m)TLS
make a valid readiness probe against the running server regardless of (m)TLS
config.
Args:
Expand All @@ -189,7 +224,8 @@ def _grpc_health_probe(
for mutual client auth
Returns:
healthy (bool): True if all servers are healthy, False otherwise
ready (bool): True if the grpc server is ready to take requests, False
otherwise
"""
# NOTE: Local imports for optional dependency
with alog.ContextTimer(log.debug2, "Done with local grpc imports: "):
Expand Down Expand Up @@ -241,7 +277,10 @@ def _grpc_health_probe(

client = health_pb2_grpc.HealthStub(channel)
try:
client.Check(health_pb2.HealthCheckRequest())
client.Check(
health_pb2.HealthCheckRequest(),
timeout=get_config().runtime.grpc.probe_timeout,
)
return True
except Exception as err: # pylint: disable=broad-exception-caught
log.debug2("Caught unexpected error: %s", err, exc_info=True)
Expand Down Expand Up @@ -294,7 +333,24 @@ def main():
thread_id=caikit_config.log.thread_id,
formatter=caikit_config.log.formatter,
)
if not health_probe():

# Pull the probe type from the command line, defaulting to readiness
probe_type_map = {
"readiness": readiness_probe,
"liveness": liveness_probe,
}
probe_type = "readiness"
probe_args = []
if len(sys.argv) > 1:
probe_type = sys.argv[1]
if len(sys.argv) > 2:
probe_args = sys.argv[2:]
log.debug("Probe type: %s", probe_type)
log.debug("Probe args: %s", probe_args)
probe_fn = probe_type_map.get(probe_type.lower())
assert probe_fn is not None, f"Invalid probe type: {probe_type}"

if not probe_fn(*probe_args):
sys.exit(1)


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"munch>=2.5.0,<5.0",
"numpy>=1.22.2,<2",
"protobuf>=3.19.0,<5",
"psutil>=5,<6",
"py-to-proto>=0.5.0,<0.6.0,!=0.2.1",
"PyYAML>=6.0,<7.0",
"semver>=2.13.0,<4.0",
Expand Down
37 changes: 33 additions & 4 deletions tests/runtime/test_caikit_health_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from enum import Enum
from unittest import mock
import os
import shlex
import subprocess
import sys

# Third Party
import pytest
Expand Down Expand Up @@ -146,7 +149,7 @@ class ProbeTestConfig:
),
],
)
def test_health_probe(test_config: ProbeTestConfig):
def test_readiness_probe(test_config: ProbeTestConfig):
"""Test all of the different ways that the servers could be running"""
with alog.ContextLog(log.info, "---LOG CONFIG: %s---", test_config):
# Get ports for both servers
Expand Down Expand Up @@ -192,11 +195,11 @@ def test_health_probe(test_config: ProbeTestConfig):
"merge",
):
# Health probe fails with no servers booted
assert not caikit_health_probe.health_probe()
assert not caikit_health_probe.readiness_probe()
# If booting the gRPC server, do so
with maybe_runtime_grpc_test_server(grpc_port):
# If only running gRPC, health probe should pass
assert caikit_health_probe.health_probe() == (
assert caikit_health_probe.readiness_probe() == (
test_config.should_become_healthy
and test_config.server_mode == ServerMode.GRPC
)
Expand All @@ -208,6 +211,32 @@ def test_health_probe(test_config: ProbeTestConfig):
):
# Probe should always pass with both possible servers up
assert (
caikit_health_probe.health_probe()
caikit_health_probe.readiness_probe()
== test_config.should_become_healthy
)


@pytest.mark.parametrize(
["proc_identifier", "expected"],
[(None, True), ("caikit.runt", True), ("foobar", False)],
)
def test_liveness_probe(proc_identifier, expected):
"""Test the logic for determining if the server process is alive"""
cmd = f"{sys.executable} -m caikit.runtime"
args = [] if proc_identifier is None else [proc_identifier]

# Liveness should fail if process is not booted
assert not caikit_health_probe.liveness_probe(*args)

proc = None
try:
# Start the process
proc = subprocess.Popen(shlex.split(cmd))

# Liveness should pass/fail as expected
assert caikit_health_probe.liveness_probe(*args) == expected

finally:
# Kill the process if it started
if proc is not None and proc.poll() is None:
proc.kill()

0 comments on commit e33e280

Please sign in to comment.