diff --git a/caikit/config/config.yml b/caikit/config/config.yml index 008a81e2d..434e1f038 100644 --- a/caikit/config/config.yml +++ b/caikit/config/config.yml @@ -142,6 +142,8 @@ runtime: options: {} # Legacy config for setting thread pool size. See runtime.server_thread_pool_size instead server_thread_pool_size: null + # Timeout for health probe to receive a response + probe_timeout: null # Configuration for the http server http: diff --git a/caikit_health_probe/__main__.py b/caikit_health_probe/__main__.py index 48a18278f..3a12337c4 100644 --- a/caikit_health_probe/__main__.py +++ b/caikit_health_probe/__main__.py @@ -12,17 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This module implements a common health probe for all running runtime servers. +This module implements common health probes (liveness and readiness) for all +running runtime servers. """ # Standard from contextlib import contextmanager -from typing import Optional, Tuple +from typing import List, Optional, Tuple import importlib.util import os import sys import tempfile import warnings +# Third Party +import psutil + # First Party import alog @@ -42,15 +46,16 @@ @alog.timed_function(log.debug) -def health_probe() -> bool: - """Run a health probe against all running runtime servers. +def readiness_probe() -> bool: + """Run a readiness probe against all running runtime servers. This function is intended to be run from an environment where the config is identical to the config that the server is running such as from inside a kubernetes pod where the server is also running. Returns: - healthy (bool): True if all servers are healthy, False otherwise + ready (bool): True if all servers are ready to take requests, False + otherwise """ # Get TLS key/cert files if possible @@ -58,35 +63,64 @@ def health_probe() -> bool: tls_key = config.runtime.tls.server.key tls_cert = config.runtime.tls.server.cert client_ca = config.runtime.tls.client.cert - http_healthy, grpc_healthy = None, None + http_ready, grpc_ready = None, None if config.runtime.http.enabled: log.debug("Checking HTTP server health") - http_healthy = _http_health_probe( + http_ready = _http_readiness_probe( config.runtime.http.port, tls_key, tls_cert, client_ca ) if config.runtime.grpc.enabled: log.debug("Checking gRPC server health") - grpc_healthy = _grpc_health_probe( + grpc_ready = _grpc_readiness_probe( config.runtime.grpc.port, tls_key, tls_cert, client_ca ) - if False in [http_healthy, grpc_healthy]: + if False in [http_ready, grpc_ready]: log.info( "", - "Server not healthy. HTTP: %s, gRPC: %s", - http_healthy, - grpc_healthy, + "Runtime server(s) not ready. HTTP: %s, gRPC: %s", + http_ready, + grpc_ready, ) return False return True +@alog.timed_function(log.debug) +def liveness_probe(runtime_proc_identifier: str = "caikit.runtime") -> bool: + # Get all running processes that we have access to + this_proc = psutil.Process() + this_exe = this_proc.exe() + procs = [_get_proc_info(pid) for pid in psutil.pids() if pid != this_proc.pid] + + # Filter down to caikit runtime processes + caikit_procs = [ + proc_info + for proc_info in procs + if proc_info is not None + and proc_info[0] == this_exe + and any(runtime_proc_identifier in arg for arg in proc_info[1]) + ] + + # If we have running caikit processes, we consider the server to be alive + return bool(caikit_procs) + + ## Implementation ############################################################## -def _http_health_probe( +def _get_proc_info(pid: int) -> Optional[Tuple[str, List[str]]]: + """Attempt to get the given pid's information (exe and cmdline)""" + try: + proc = psutil.Process(pid) + return (proc.exe(), proc.cmdline()) + except psutil.Error: + return None + + +def _http_readiness_probe( port: int, tls_key: Optional[str], tls_cert: Optional[str], @@ -95,15 +129,15 @@ def _http_health_probe( """Probe the http server The implementation of this utility is a bit tricky because mTLS makes this - quite challenging. For insecure or TLS servers, we expect a valid healthy + quite challenging. For insecure or TLS servers, we expect a valid ready response, but for mTLS servers, we may not have a valid key/cert pair that the client can present to the server that is signed by the expected CA if the trusted client CA does not match the one that signed the server's key/cert pair. The workaround for this is to detect SSLError and consider that to be a - passing health check. If the server is healthy enough to _reject_ bad SSL - requests, it's healthy enough to server good ones! + passing readiness check. If the server is ready enough to _reject_ bad SSL + requests, it's ready enough to server good ones! Args: port (int): The port that the HTTP server is serving on @@ -115,7 +149,8 @@ def _http_health_probe( for mutual client auth Returns: - healthy (bool): True if all servers are healthy, False otherwise + ready (bool): True if the http server is ready to take requests, False + otherwise """ # NOTE: Local imports for optional dependency with alog.ContextTimer(log.debug2, "Done with local grpc imports: "): @@ -167,7 +202,7 @@ def _http_health_probe( return False -def _grpc_health_probe( +def _grpc_readiness_probe( port: int, tls_key: Optional[str], tls_cert: Optional[str], @@ -176,7 +211,7 @@ def _grpc_health_probe( """Probe the grpc server Since the gRPC server trusts its own cert for client verification, we can - make a valid health probe against the running server regardless of (m)TLS + make a valid readiness probe against the running server regardless of (m)TLS config. Args: @@ -189,7 +224,8 @@ def _grpc_health_probe( for mutual client auth Returns: - healthy (bool): True if all servers are healthy, False otherwise + ready (bool): True if the grpc server is ready to take requests, False + otherwise """ # NOTE: Local imports for optional dependency with alog.ContextTimer(log.debug2, "Done with local grpc imports: "): @@ -241,7 +277,10 @@ def _grpc_health_probe( client = health_pb2_grpc.HealthStub(channel) try: - client.Check(health_pb2.HealthCheckRequest()) + client.Check( + health_pb2.HealthCheckRequest(), + timeout=get_config().runtime.grpc.probe_timeout, + ) return True except Exception as err: # pylint: disable=broad-exception-caught log.debug2("Caught unexpected error: %s", err, exc_info=True) @@ -294,7 +333,24 @@ def main(): thread_id=caikit_config.log.thread_id, formatter=caikit_config.log.formatter, ) - if not health_probe(): + + # Pull the probe type from the command line, defaulting to readiness + probe_type_map = { + "readiness": readiness_probe, + "liveness": liveness_probe, + } + probe_type = "readiness" + probe_args = [] + if len(sys.argv) > 1: + probe_type = sys.argv[1] + if len(sys.argv) > 2: + probe_args = sys.argv[2:] + log.debug("Probe type: %s", probe_type) + log.debug("Probe args: %s", probe_args) + probe_fn = probe_type_map.get(probe_type.lower()) + assert probe_fn is not None, f"Invalid probe type: {probe_type}" + + if not probe_fn(*probe_args): sys.exit(1) diff --git a/pyproject.toml b/pyproject.toml index 72e4fc706..e7f7a3315 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "munch>=2.5.0,<5.0", "numpy>=1.22.2,<2", "protobuf>=3.19.0,<5", + "psutil>=5,<6", "py-to-proto>=0.5.0,<0.6.0,!=0.2.1", "PyYAML>=6.0,<7.0", "semver>=2.13.0,<4.0", diff --git a/tests/runtime/test_caikit_health_probe.py b/tests/runtime/test_caikit_health_probe.py index a36769e38..c42cca7d1 100644 --- a/tests/runtime/test_caikit_health_probe.py +++ b/tests/runtime/test_caikit_health_probe.py @@ -25,6 +25,9 @@ from enum import Enum from unittest import mock import os +import shlex +import subprocess +import sys # Third Party import pytest @@ -146,7 +149,7 @@ class ProbeTestConfig: ), ], ) -def test_health_probe(test_config: ProbeTestConfig): +def test_readiness_probe(test_config: ProbeTestConfig): """Test all of the different ways that the servers could be running""" with alog.ContextLog(log.info, "---LOG CONFIG: %s---", test_config): # Get ports for both servers @@ -192,11 +195,11 @@ def test_health_probe(test_config: ProbeTestConfig): "merge", ): # Health probe fails with no servers booted - assert not caikit_health_probe.health_probe() + assert not caikit_health_probe.readiness_probe() # If booting the gRPC server, do so with maybe_runtime_grpc_test_server(grpc_port): # If only running gRPC, health probe should pass - assert caikit_health_probe.health_probe() == ( + assert caikit_health_probe.readiness_probe() == ( test_config.should_become_healthy and test_config.server_mode == ServerMode.GRPC ) @@ -208,6 +211,32 @@ def test_health_probe(test_config: ProbeTestConfig): ): # Probe should always pass with both possible servers up assert ( - caikit_health_probe.health_probe() + caikit_health_probe.readiness_probe() == test_config.should_become_healthy ) + + +@pytest.mark.parametrize( + ["proc_identifier", "expected"], + [(None, True), ("caikit.runt", True), ("foobar", False)], +) +def test_liveness_probe(proc_identifier, expected): + """Test the logic for determining if the server process is alive""" + cmd = f"{sys.executable} -m caikit.runtime" + args = [] if proc_identifier is None else [proc_identifier] + + # Liveness should fail if process is not booted + assert not caikit_health_probe.liveness_probe(*args) + + proc = None + try: + # Start the process + proc = subprocess.Popen(shlex.split(cmd)) + + # Liveness should pass/fail as expected + assert caikit_health_probe.liveness_probe(*args) == expected + + finally: + # Kill the process if it started + if proc is not None and proc.poll() is None: + proc.kill()