From 5dc07015a860653cdaa99594708b2d972da1d0da Mon Sep 17 00:00:00 2001 From: Bazire Date: Mon, 9 Oct 2023 18:20:27 +0200 Subject: [PATCH] Switch from threading to multiprocessing in worker Contributes to #GSK-1863 --- .gitignore | 1 + giskard/client/giskard_client.py | 2 + giskard/commands/cli_worker.py | 25 +- giskard/datasets/base/__init__.py | 28 +- giskard/ml_worker/ml_worker.py | 15 +- giskard/ml_worker/websocket/__init__.py | 3 +- giskard/ml_worker/websocket/listener.py | 347 ++++++++++++++---------- giskard/ml_worker/websocket/utils.py | 42 ++- giskard/models/base/model.py | 15 +- giskard/utils/__init__.py | 99 +++++++ pyproject.toml | 6 +- 11 files changed, 386 insertions(+), 197 deletions(-) diff --git a/.gitignore b/.gitignore index 55677bac25..d133ea1b12 100755 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +enron_with_categories .history .mypy_cache docker-stack.yml diff --git a/giskard/client/giskard_client.py b/giskard/client/giskard_client.py index a7259cb902..bc3248cefe 100644 --- a/giskard/client/giskard_client.py +++ b/giskard/client/giskard_client.py @@ -75,6 +75,8 @@ def __call__(self, r): class GiskardClient: def __init__(self, url: str, key: str, hf_token: str = None): self.host_url = url + self.key = key + self.hf_token = hf_token base_url = urljoin(url, "/api/v2/") self._session = sessions.BaseUrlSession(base_url=base_url) self._session.mount(base_url, ErrorHandlingAdapter()) diff --git a/giskard/commands/cli_worker.py b/giskard/commands/cli_worker.py index 254b6f746c..235b1ec64b 100644 --- a/giskard/commands/cli_worker.py +++ b/giskard/commands/cli_worker.py @@ -1,10 +1,11 @@ +from typing import Optional + import asyncio import functools import logging import os import platform import sys -from typing import Optional import click import lockfile @@ -13,19 +14,19 @@ from lockfile.pidlockfile import PIDLockFile, read_pid_from_pidfile, remove_existing_pidfile from pydantic import AnyHttpUrl -from giskard.cli_utils import common_options from giskard.cli_utils import ( + common_options, create_pid_file_path, + follow_file, + get_log_path, remove_stale_pid_file, run_daemon, - get_log_path, tail, - follow_file, validate_url, ) from giskard.path_utils import run_dir from giskard.settings import settings -from giskard.utils.analytics_collector import anonymize, analytics +from giskard.utils.analytics_collector import analytics, anonymize logger = logging.getLogger(__name__) @@ -85,7 +86,13 @@ def wrapper(*args, **kwargs): envvar="GSK_HF_TOKEN", help="Access token for Giskard hosted in a private Hugging Face Spaces", ) -def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token): +@click.option( + "--parallelism", + "nb_workers", + default=None, + help="Number of processes to use for parallelism (None for number of cpu)", +) +def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token, nb_workers): """\b Start ML Worker. @@ -102,7 +109,7 @@ def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token): ) api_key = initialize_api_key(api_key, is_server) hf_token = initialize_hf_token(hf_token, is_server) - _start_command(is_server, url, api_key, is_daemon, hf_token) + _start_command(is_server, url, api_key, is_daemon, hf_token, nb_workers) def initialize_api_key(api_key, is_server): @@ -126,7 +133,7 @@ def initialize_hf_token(hf_token, is_server): return hf_token -def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None): +def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None, nb_workers=None): from giskard.ml_worker.ml_worker import MLWorker start_msg = "Starting ML Worker" @@ -154,7 +161,7 @@ def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None run_daemon(is_server, url, api_key, hf_token) else: ml_worker = MLWorker(is_server, url, api_key, hf_token) - asyncio.get_event_loop().run_until_complete(ml_worker.start()) + asyncio.get_event_loop().run_until_complete(ml_worker.start(nb_workers)) except KeyboardInterrupt: logger.info("Exiting") if ml_worker: diff --git a/giskard/datasets/base/__init__.py b/giskard/datasets/base/__init__.py index 8a85f0c16a..15eae5933e 100644 --- a/giskard/datasets/base/__init__.py +++ b/giskard/datasets/base/__init__.py @@ -1,3 +1,5 @@ +from typing import Dict, Hashable, List, Optional, Union + import inspect import logging import posixpath @@ -5,34 +7,30 @@ import uuid from functools import cached_property from pathlib import Path -from typing import Dict, Optional, List, Union, Hashable import numpy as np import pandas import pandas as pd import yaml -from pandas.api.types import is_list_like -from pandas.api.types import is_numeric_dtype +from mlflow import MlflowClient +from pandas.api.types import is_list_like, is_numeric_dtype from xxhash import xxh3_128_hexdigest from zstandard import ZstdDecompressor -from mlflow import MlflowClient from giskard.client.giskard_client import GiskardClient -from giskard.client.io_utils import save_df, compress +from giskard.client.io_utils import compress, save_df from giskard.client.python_utils import warning from giskard.core.core import DatasetMeta, SupportedColumnTypes from giskard.core.validation import configured_validate_arguments -from giskard.ml_worker.testing.registry.slicing_function import ( - SlicingFunction, - SlicingFunctionType, -) +from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction, SlicingFunctionType from giskard.ml_worker.testing.registry.transformation_function import ( TransformationFunction, TransformationFunctionType, ) from giskard.settings import settings -from ..metadata.indexing import ColumnMetadataMixin + from ...ml_worker.utils.file_utils import get_file_name +from ..metadata.indexing import ColumnMetadataMixin SAMPLE_SIZE = 1000 @@ -521,7 +519,7 @@ def load(cls, local_path: str): ) @classmethod - def download(cls, client: GiskardClient, project_key, dataset_id, sample: bool = False): + def download(cls, client: Optional[GiskardClient], project_key, dataset_id, sample: bool = False): """ Downloads a dataset from a Giskard project and returns a Dataset object. If the client is None, then the function assumes that it is running in an internal worker and looks for the dataset locally. @@ -661,9 +659,7 @@ def to_mlflow(self, mlflow_client: MlflowClient = None, mlflow_run_id: str = Non # To avoid file being open in write mode and read at the same time, # First, we'll write it, then make sure to remove it - with tempfile.NamedTemporaryFile( - prefix="dataset-", suffix=".csv", delete=False - ) as f: + with tempfile.NamedTemporaryFile(prefix="dataset-", suffix=".csv", delete=False) as f: # Get file path local_path = f.name # Get name from file @@ -696,8 +692,10 @@ def to_wandb(self, **kwargs) -> None: Additional keyword arguments (see https://docs.wandb.ai/ref/python/init) to be added to the active WandB run. """ - from giskard.integrations.wandb.wandb_utils import wandb_run import wandb # noqa library import already checked in wandb_run + + from giskard.integrations.wandb.wandb_utils import wandb_run + from ...utils.analytics_collector import analytics with wandb_run(**kwargs) as run: diff --git a/giskard/ml_worker/ml_worker.py b/giskard/ml_worker/ml_worker.py index 25182b0ba3..24c0904f4e 100644 --- a/giskard/ml_worker/ml_worker.py +++ b/giskard/ml_worker/ml_worker.py @@ -1,16 +1,20 @@ +from typing import Optional + import logging import random import secrets -import stomp import time + +import stomp from pydantic import AnyHttpUrl -from websocket._exceptions import WebSocketException, WebSocketBadStatusException +from websocket._exceptions import WebSocketBadStatusException, WebSocketException import giskard +from giskard.cli_utils import validate_url from giskard.client.giskard_client import GiskardClient from giskard.ml_worker.testing.registry.registry import load_plugins from giskard.settings import settings -from giskard.cli_utils import validate_url +from giskard.utils import shutdown_pool, start_pool logger = logging.getLogger(__name__) @@ -140,9 +144,9 @@ def _connect_websocket_client(self, is_server=False): def is_remote_worker(self): return self.ml_worker_id is not INTERNAL_WORKER_ID - async def start(self): + async def start(self, nb_workers: Optional[int] = None): load_plugins() - + start_pool(nb_workers) if self.ws_conn: self.ws_stopping = False self.connect_websocket_client() @@ -162,3 +166,4 @@ def stop(self): if self.ws_conn: self.ws_stopping = True self.ws_conn.disconnect() + shutdown_pool() diff --git a/giskard/ml_worker/websocket/__init__.py b/giskard/ml_worker/websocket/__init__.py index 8a6d64a928..0d4e764d4a 100644 --- a/giskard/ml_worker/websocket/__init__.py +++ b/giskard/ml_worker/websocket/__init__.py @@ -1,6 +1,7 @@ -from enum import Enum from typing import Dict, List, Optional +from enum import Enum + from pydantic import BaseModel, Field diff --git a/giskard/ml_worker/websocket/listener.py b/giskard/ml_worker/websocket/listener.py index 3e65b573d1..fae0b9c4cd 100644 --- a/giskard/ml_worker/websocket/listener.py +++ b/giskard/ml_worker/websocket/listener.py @@ -1,3 +1,5 @@ +from typing import Callable, Dict, Optional, Union + import json import logging import math @@ -7,6 +9,8 @@ import tempfile import time import traceback +from concurrent.futures import Future +from dataclasses import dataclass from pathlib import Path import numpy as np @@ -16,6 +20,7 @@ import stomp import giskard +from giskard.client.giskard_client import GiskardClient from giskard.core.suite import Suite from giskard.datasets.base import Dataset from giskard.ml_worker import websocket @@ -25,11 +30,9 @@ from giskard.ml_worker.ml_worker import MLWorker from giskard.ml_worker.testing.registry.giskard_test import GiskardTest from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction -from giskard.ml_worker.testing.registry.transformation_function import ( - TransformationFunction, -) +from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction from giskard.ml_worker.utils.file_utils import get_file_name -from giskard.ml_worker.websocket import GetInfoParam, PushKind, CallToActionKind +from giskard.ml_worker.websocket import CallToActionKind, GetInfoParam, PushKind from giskard.ml_worker.websocket.action import MLWorkerAction from giskard.ml_worker.websocket.utils import ( do_run_adhoc_test, @@ -39,37 +42,133 @@ log_artifact_local, map_dataset_process_function_meta_ws, map_function_meta_ws, - map_suite_input_ws, map_result_to_single_test_result_ws, + map_suite_input_ws, parse_action_param, parse_function_arguments, ) from giskard.models.base import BaseModel -from giskard.models.model_explanation import ( - explain, - explain_text, -) +from giskard.models.model_explanation import explain, explain_text from giskard.push import Push -from giskard.utils import threaded +from giskard.utils import call_in_pool, log_pool_stats, shutdown_pool from giskard.utils.analytics_collector import analytics logger = logging.getLogger(__name__) -def websocket_log_actor(ml_worker: MLWorker, req: dict, *args, **kwargs): +MAX_STOMP_ML_WORKER_REPLY_SIZE = 1500 + + +@dataclass +class MLWorkerInfo: + id: str + is_remote: bool + + def __init__(self, worker: MLWorker): + self.id = worker.ml_worker_id + self.is_remote = worker.is_remote_worker() + + +def websocket_log_actor(ml_worker: MLWorkerInfo, req: dict, *args, **kwargs): param = req["param"] if "param" in req.keys() else {} action = req["action"] if "action" in req.keys() else "" - logger.info(f"ML Worker {ml_worker.ml_worker_id} performing {action} params: {param}") + logger.info(f"ML Worker {ml_worker.id} performing {action} params: {param}") WEBSOCKET_ACTORS = dict((action.name, websocket_log_actor) for action in MLWorkerAction) -MAX_STOMP_ML_WORKER_REPLY_SIZE = 1500 +def wrapped_handle_result(action: MLWorkerAction, ml_worker: MLWorker, start: float, rep_id: Optional[str]): + def handle_result(future: Union[Future, Callable[..., websocket.WorkerReply]]): + log_pool_stats() + + info = None # Needs to be defined in case of cancellation + + try: + info: websocket.WorkerReply = future.result() if isinstance(future, Future) else future() + except Exception as e: + info: websocket.WorkerReply = websocket.ErrorReply( + error_str=str(e), error_type=type(e).__name__, detail=traceback.format_exc() + ) + logger.warning(e) + finally: + analytics.track( + "mlworker:websocket:action", + { + "name": action.name, + "worker": ml_worker.ml_worker_id, + "language": "PYTHON", + "type": "ERROR" if isinstance(info, websocket.ErrorReply) else "SUCCESS", + "action_time": time.process_time() - start, + "error": info.error_str if isinstance(info, websocket.ErrorReply) else "", + "error_type": info.error_type if isinstance(info, websocket.ErrorReply) else "", + }, + ) + + if rep_id: + # Reply if there is an ID + logger.debug( + f"[WRAPPED_CALLBACK] replying {len(info.json(by_alias=True))} {info.json(by_alias=True)} for {action.name}" + ) + # Message fragmentation + FRAG_LEN = max(ml_worker.ws_max_reply_payload_size, MAX_STOMP_ML_WORKER_REPLY_SIZE) + payload = info.json(by_alias=True) if info else "{}" + frag_count = math.ceil(len(payload) / FRAG_LEN) + for frag_i in range(frag_count): + ml_worker.ws_conn.send( + f"/app/ml-worker/{ml_worker.ml_worker_id}/rep", + json.dumps( + { + "id": rep_id, + "action": action.name, + "payload": fragment_message(payload, frag_i, FRAG_LEN), + "f_index": frag_i, + "f_count": frag_count, + } + ), + ) -# Open a new thread to process and reply, avoid slowing down the WebSocket message loop -@threaded -def dispatch_action(callback, ml_worker, action, req): + analytics.track( + "mlworker:websocket:action:reply", + { + "name": action.name, + "worker": ml_worker.ml_worker_id, + "language": "PYTHON", + "action_time": time.process_time() - start, + "is_error": isinstance(info, websocket.ErrorReply), + "frag_len": FRAG_LEN, + "frag_count": frag_count, + "reply_len": len(payload), + }, + ) + + # Post-processing of stopWorker + if action == MLWorkerAction.stopWorker: + ml_worker.ws_stopping = True + ml_worker.ws_conn.disconnect() + shutdown_pool() + + return handle_result + + +def parse_and_execute( + *, + callback: Callable, + action: MLWorkerAction, + params, + ml_worker: MLWorkerInfo, + client_params: Optional[Dict[str, str]], +) -> websocket.WorkerReply: + action_params = parse_action_param(action, params) + return callback( + ml_worker=ml_worker, + client=GiskardClient(**client_params) if client_params is not None else None, + action=action.name, + params=action_params, + ) + + +def dispatch_action(callback, ml_worker, action, req, execute_in_pool): # Parse the response ID rep_id = req["id"] if "id" in req.keys() else None # Parse the param @@ -84,85 +183,61 @@ def dispatch_action(callback, ml_worker, action, req): "language": "PYTHON", }, ) + # Ws connection is lock pickable, so not usable as args + # GiskardClient is losing Session when pickling + client_params = ( + { + "url": ml_worker.client.host_url, + "key": ml_worker.client.key, + "hf_token": ml_worker.client.hf_token, + } + if ml_worker.client is not None + else None + ) start = time.process_time() - try: - params = parse_action_param(action, params) - # Call the function and get the response - info: websocket.WorkerReply = callback(ml_worker=ml_worker, action=action.name, params=params) - except Exception as e: - info: websocket.WorkerReply = websocket.ErrorReply( - error_str=str(e), error_type=type(e).__name__, detail=traceback.format_exc() - ) - logger.warning(e) - finally: - analytics.track( - "mlworker:websocket:action", - { - "name": action.name, - "worker": ml_worker.ml_worker_id, - "language": "PYTHON", - "type": "ERROR" if isinstance(info, websocket.ErrorReply) else "SUCCESS", - "action_time": time.process_time() - start, - "error": info.error_str if isinstance(info, websocket.ErrorReply) else "", - "error_type": info.error_type if isinstance(info, websocket.ErrorReply) else "", - }, - ) - if rep_id: - # Reply if there is an ID - logger.debug( - f"[WRAPPED_CALLBACK] replying {len(info.json(by_alias=True))} {info.json(by_alias=True)} for {action.name}" + result_handler = wrapped_handle_result(action, ml_worker, start, rep_id) + # If execution should be done in a pool + if execute_in_pool: + logger.debug("Submitting for action %s '%s' into the pool with %s", action.name, callback.__name__, params) + future = call_in_pool( + parse_and_execute, + callback=callback, + action=action, + params=params, + ml_worker=MLWorkerInfo(ml_worker), + client_params=client_params, ) - # Message fragmentation - FRAG_LEN = max(ml_worker.ws_max_reply_payload_size, MAX_STOMP_ML_WORKER_REPLY_SIZE) - payload = info.json(by_alias=True) if info else "{}" - frag_count = math.ceil(len(payload) / FRAG_LEN) - for frag_i in range(frag_count): - ml_worker.ws_conn.send( - f"/app/ml-worker/{ml_worker.ml_worker_id}/rep", - json.dumps( - { - "id": rep_id, - "action": action.name, - "payload": fragment_message(payload, frag_i, FRAG_LEN), - "f_index": frag_i, - "f_count": frag_count, - } - ), - ) + future.add_done_callback(result_handler) + log_pool_stats() + return - analytics.track( - "mlworker:websocket:action:reply", - { - "name": action.name, - "worker": ml_worker.ml_worker_id, - "language": "PYTHON", - "action_time": time.process_time() - start, - "is_error": isinstance(info, websocket.ErrorReply), - "frag_len": FRAG_LEN, - "frag_count": frag_count, - "reply_len": len(payload), - }, + result_handler( + lambda: parse_and_execute( + callback=callback, + action=action, + params=params, + ml_worker=MLWorkerInfo(ml_worker), + client_params=client_params, ) - - # Post-processing of stopWorker - if action == MLWorkerAction.stopWorker and ml_worker.ws_stopping is True: - ml_worker.ws_conn.disconnect() + ) -def websocket_actor(action: MLWorkerAction): +def websocket_actor(action: MLWorkerAction, execute_in_pool: bool = True): """ Register a function as an actor to an action from WebSocket connection """ def websocket_actor_callback(callback: callable): - if action in MLWorkerAction: - logger.debug(f'Registered "{callback.__name__}" for ML Worker "{action.name}"') + if action not in MLWorkerAction: + raise NotImplementedError(f"Missing implementation for {action}, not in MLWorkerAction") + logger.debug(f'Registered "{callback.__name__}" for ML Worker "{action.name}"') + + def wrapped_callback(ml_worker: MLWorker, req: dict, *args, **kwargs): + dispatch_action(callback, ml_worker, action, req, execute_in_pool) - def wrapped_callback(ml_worker: MLWorker, req: dict, *args, **kwargs): - dispatch_action(callback, ml_worker, action, req) + WEBSOCKET_ACTORS[action.name] = wrapped_callback - WEBSOCKET_ACTORS[action.name] = wrapped_callback return callback return websocket_actor_callback @@ -214,8 +289,8 @@ def on_message(self, frame): logger.info(f"MAX_STOMP_ML_WORKER_REPLY_SIZE set to {mtu}") -@websocket_actor(MLWorkerAction.getInfo) -def on_ml_worker_get_info(ml_worker: MLWorker, params: GetInfoParam, *args, **kwargs) -> websocket.GetInfo: +@websocket_actor(MLWorkerAction.getInfo, execute_in_pool=False) +def on_ml_worker_get_info(ml_worker: MLWorkerInfo, params: GetInfoParam, *args, **kwargs) -> websocket.GetInfo: logger.info("Collecting ML Worker info from WebSocket") # TODO(Bazire): seems to be deprecated https://setuptools.pypa.io/en/latest/pkg_resources.html#workingset-objects @@ -238,16 +313,15 @@ def on_ml_worker_get_info(ml_worker: MLWorker, params: GetInfoParam, *args, **kw interpreter=sys.executable, interpreterVersion=platform.python_version(), installedPackages=installed_packages, - mlWorkerId=ml_worker.ml_worker_id, - isRemote=ml_worker.is_remote_worker(), + mlWorkerId=ml_worker.id, + isRemote=ml_worker.is_remote, ) -@websocket_actor(MLWorkerAction.stopWorker) -def on_ml_worker_stop_worker(ml_worker: MLWorker, *args, **kwargs) -> websocket.Empty: +@websocket_actor(MLWorkerAction.stopWorker, execute_in_pool=False) +def on_ml_worker_stop_worker(*args, **kwargs) -> websocket.Empty: # Stop the server properly after sending disconnect logger.info("Stopping ML Worker") - ml_worker.ws_stopping = True return websocket.Empty() @@ -307,11 +381,11 @@ def run_other_model(dataset, prediction_results): @websocket_actor(MLWorkerAction.runModel) -def run_model(ml_worker: MLWorker, params: websocket.RunModelParam, *args, **kwargs) -> websocket.Empty: +def run_model(client: GiskardClient, params: websocket.RunModelParam, *args, **kwargs) -> websocket.Empty: try: - model = BaseModel.download(ml_worker.client, params.model.project_key, params.model.id) + model = BaseModel.download(client, params.model.project_key, params.model.id) dataset = Dataset.download( - ml_worker.client, + client, params.dataset.project_key, params.dataset.id, sample=params.dataset.sample, @@ -342,8 +416,8 @@ def run_model(ml_worker: MLWorker, params: websocket.RunModelParam, *args, **kwa tmp_dir = Path(f) predictions_csv = get_file_name("predictions", "csv", params.dataset.sample) results.to_csv(index=False, path_or_buf=tmp_dir / predictions_csv) - if ml_worker.client: - ml_worker.client.log_artifact( + if client: + client.log_artifact( tmp_dir / predictions_csv, f"{params.project_key}/models/inspections/{params.inspectionId}", ) @@ -355,8 +429,8 @@ def run_model(ml_worker: MLWorker, params: websocket.RunModelParam, *args, **kwa calculated_csv = get_file_name("calculated", "csv", params.dataset.sample) calculated.to_csv(index=False, path_or_buf=tmp_dir / calculated_csv) - if ml_worker.client: - ml_worker.client.log_artifact( + if client: + client.log_artifact( tmp_dir / calculated_csv, f"{params.project_key}/models/inspections/{params.inspectionId}", ) @@ -370,9 +444,9 @@ def run_model(ml_worker: MLWorker, params: websocket.RunModelParam, *args, **kwa @websocket_actor(MLWorkerAction.runModelForDataFrame) def run_model_for_data_frame( - ml_worker: MLWorker, params: websocket.RunModelForDataFrameParam, *args, **kwargs + client: Optional[GiskardClient], params: websocket.RunModelForDataFrameParam, *args, **kwargs ) -> websocket.RunModelForDataFrame: - model = BaseModel.download(ml_worker.client, params.model.project_key, params.model.id) + model = BaseModel.download(client, params.model.project_key, params.model.id) df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows]) ds = Dataset( model.prepare_dataframe(df, column_dtypes=params.column_dtypes), @@ -398,9 +472,9 @@ def run_model_for_data_frame( @websocket_actor(MLWorkerAction.explain) -def explain_ws(ml_worker: MLWorker, params: websocket.ExplainParam, *args, **kwargs) -> websocket.Explain: - model = BaseModel.download(ml_worker.client, params.model.project_key, params.model.id) - dataset = Dataset.download(ml_worker.client, params.dataset.project_key, params.dataset.id, params.dataset.sample) +def explain_ws(client: Optional[GiskardClient], params: websocket.ExplainParam, *args, **kwargs) -> websocket.Explain: + model = BaseModel.download(client, params.model.project_key, params.model.id) + dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id, params.dataset.sample) explanations = explain(model, dataset, params.columns) return websocket.Explain( @@ -409,8 +483,10 @@ def explain_ws(ml_worker: MLWorker, params: websocket.ExplainParam, *args, **kwa @websocket_actor(MLWorkerAction.explainText) -def explain_text_ws(ml_worker: MLWorker, params: websocket.ExplainTextParam, *args, **kwargs) -> websocket.ExplainText: - model = BaseModel.download(ml_worker.client, params.model.project_key, params.model.id) +def explain_text_ws( + client: Optional[GiskardClient], params: websocket.ExplainTextParam, *args, **kwargs +) -> websocket.ExplainText: + model = BaseModel.download(client, params.model.project_key, params.model.id) text_column = params.feature_name if params.column_types[text_column] != "text": @@ -441,22 +517,22 @@ def get_catalog(*args, **kwargs) -> websocket.Catalog: @websocket_actor(MLWorkerAction.datasetProcessing) def dataset_processing( - ml_worker: MLWorker, params: websocket.DatasetProcessingParam, *args, **kwargs + client: Optional[GiskardClient], params: websocket.DatasetProcessingParam, *args, **kwargs ) -> websocket.DatasetProcessing: - dataset = Dataset.download(ml_worker.client, params.dataset.project_key, params.dataset.id, params.dataset.sample) + dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id, params.dataset.sample) for function in params.functions: - arguments = parse_function_arguments(ml_worker, function.arguments) + arguments = parse_function_arguments(client, function.arguments) if function.slicingFunction: dataset.add_slicing_function( - SlicingFunction.download( - function.slicingFunction.id, ml_worker.client, function.slicingFunction.project_key - )(**arguments) + SlicingFunction.download(function.slicingFunction.id, client, function.slicingFunction.project_key)( + **arguments + ) ) else: dataset.add_transformation_function( TransformationFunction.download( - function.transformationFunction.id, ml_worker.client, function.transformationFunction.project_key + function.transformationFunction.id, client, function.transformationFunction.project_key )(**arguments) ) @@ -472,11 +548,7 @@ def dataset_processing( modifications=[ websocket.DatasetRowModificationResult( rowId=row[0], - modifications={ - key: str(value) - for key, value in row[1].items() - if not pd.isna(value) - }, + modifications={key: str(value) for key, value in row[1].items() if not pd.isna(value)}, ) for row in modified_rows.iterrows() ], @@ -485,16 +557,16 @@ def dataset_processing( @websocket_actor(MLWorkerAction.runAdHocTest) def run_ad_hoc_test( - ml_worker: MLWorker, params: websocket.RunAdHocTestParam, *args, **kwargs + client: Optional[GiskardClient], params: websocket.RunAdHocTestParam, *args, **kwargs ) -> websocket.RunAdHocTest: - test: GiskardTest = GiskardTest.download(params.testUuid, ml_worker.client, None) + test: GiskardTest = GiskardTest.download(params.testUuid, client, None) - arguments = parse_function_arguments(ml_worker, params.arguments) + arguments = parse_function_arguments(client, params.arguments) arguments["debug"] = params.debug if params.debug else None debug_info = extract_debug_info(params.arguments) if params.debug else None - test_result = do_run_adhoc_test(ml_worker.client, arguments, test, debug_info) + test_result = do_run_adhoc_test(client, arguments, test, debug_info) return websocket.RunAdHocTest( results=[ @@ -506,19 +578,21 @@ def run_ad_hoc_test( @websocket_actor(MLWorkerAction.runTestSuite) -def run_test_suite(ml_worker: MLWorker, params: websocket.TestSuiteParam, *args, **kwargs) -> websocket.TestSuite: +def run_test_suite( + client: Optional[GiskardClient], params: websocket.TestSuiteParam, *args, **kwargs +) -> websocket.TestSuite: log_listener = LogListener() try: tests = [ { - "test": GiskardTest.download(t.testUuid, ml_worker.client, None), - "arguments": parse_function_arguments(ml_worker, t.arguments), + "test": GiskardTest.download(t.testUuid, client, None), + "arguments": parse_function_arguments(client, t.arguments), "id": t.id, } for t in params.tests ] - global_arguments = parse_function_arguments(ml_worker, params.globalArguments) + global_arguments = parse_function_arguments(client, params.globalArguments) test_names = list( map( @@ -558,11 +632,11 @@ def run_test_suite(ml_worker: MLWorker, params: websocket.TestSuiteParam, *args, @websocket_actor(MLWorkerAction.generateTestSuite) def generate_test_suite( - ml_worker: MLWorker, params: websocket.GenerateTestSuiteParam, *args, **kwargs + client: Optional[GiskardClient], params: websocket.GenerateTestSuiteParam, *args, **kwargs ) -> websocket.GenerateTestSuite: inputs = [map_suite_input_ws(i) for i in params.inputs] - suite = Suite().generate_tests(inputs).to_dto(ml_worker.client, params.project_key) + suite = Suite().generate_tests(inputs).to_dto(client, params.project_key) return websocket.GenerateTestSuite( tests=[ @@ -578,19 +652,21 @@ def generate_test_suite( ) -@websocket_actor(MLWorkerAction.echo) +@websocket_actor(MLWorkerAction.echo, execute_in_pool=False) def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoMsg: return params @websocket_actor(MLWorkerAction.getPush) -def get_push(ml_worker: MLWorker, params: websocket.GetPushParam, *args, **kwargs) -> websocket.GetPushResponse: +def get_push( + client: Optional[GiskardClient], params: websocket.GetPushParam, *args, **kwargs +) -> websocket.GetPushResponse: object_uuid = "" object_params = {} project_key = params.model.project_key try: - model = BaseModel.download(ml_worker.client, params.model.project_key, params.model.id) - dataset = Dataset.download(ml_worker.client, params.dataset.project_key, params.dataset.id) + model = BaseModel.download(client, params.model.project_key, params.model.id) + dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id) df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows]) if params.column_dtypes: @@ -622,8 +698,7 @@ def get_push(ml_worker: MLWorker, params: websocket.GetPushParam, *args, **kwarg from giskard.push.contribution import create_contribution_push from giskard.push.perturbation import create_perturbation_push - from giskard.push.prediction import create_overconfidence_push - from giskard.push.prediction import create_borderline_push + from giskard.push.prediction import create_borderline_push, create_overconfidence_push contribs = create_contribution_push(model, dataset, df) perturbs = create_perturbation_push(model, dataset, df) @@ -655,24 +730,24 @@ def get_push(ml_worker: MLWorker, params: websocket.GetPushParam, *args, **kwarg or params.cta_kind == CallToActionKind.CREATE_SLICE_OPEN_DEBUGGER ): push.slicing_function.meta.tags.append("generated") - object_uuid = push.slicing_function.upload(ml_worker.client, project_key) + object_uuid = push.slicing_function.upload(client, project_key) if params.cta_kind == CallToActionKind.SAVE_PERTURBATION: for perturbation in push.transformation_function: - object_uuid = perturbation.upload(ml_worker.client, project_key) + object_uuid = perturbation.upload(client, project_key) if params.cta_kind == CallToActionKind.SAVE_EXAMPLE: - object_uuid = push.saved_example.upload(ml_worker.client, project_key) + object_uuid = push.saved_example.upload(client, project_key) if params.cta_kind == CallToActionKind.CREATE_TEST or params.cta_kind == CallToActionKind.ADD_TEST_TO_CATALOG: for test in push.tests: - object_uuid = test.upload(ml_worker.client, project_key) + object_uuid = test.upload(client, project_key) # create empty dict object_params = {} # for every object in push.test_params, check if they're a subclass of Savable and if yes upload them for test_param_name in push.test_params: test_param = push.test_params[test_param_name] if isinstance(test_param, RegistryArtifact): - object_params[test_param_name] = test_param.upload(ml_worker.client, project_key) + object_params[test_param_name] = test_param.upload(client, project_key) elif isinstance(test_param, Dataset): - object_params[test_param_name] = test_param.upload(ml_worker.client, project_key) + object_params[test_param_name] = test_param.upload(client, project_key) else: object_params[test_param_name] = test_param diff --git a/giskard/ml_worker/websocket/utils.py b/giskard/ml_worker/websocket/utils.py index 3353166cdf..108f68353a 100644 --- a/giskard/ml_worker/websocket/utils.py +++ b/giskard/ml_worker/websocket/utils.py @@ -1,35 +1,34 @@ +from typing import Any, Dict, List, Optional + import logging import os import posixpath import shutil from pathlib import Path -from typing import List, Dict, Any from mlflow.store.artifact.artifact_repo import verify_artifact_path -from giskard.core.suite import ModelInput, DatasetInput, SuiteInput +from giskard.client.giskard_client import GiskardClient +from giskard.core.suite import DatasetInput, ModelInput, SuiteInput from giskard.datasets.base import Dataset from giskard.ml_worker import websocket from giskard.ml_worker.exceptions.IllegalArgumentError import IllegalArgumentError -from giskard.ml_worker.ml_worker import MLWorker from giskard.ml_worker.testing.registry.registry import tests_registry from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction -from giskard.ml_worker.testing.registry.transformation_function import ( - TransformationFunction, -) -from giskard.ml_worker.testing.test_result import TestResult, TestMessageLevel +from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction +from giskard.ml_worker.testing.test_result import TestMessageLevel, TestResult from giskard.ml_worker.websocket import ( + DatasetProcessingParam, EchoMsg, ExplainParam, - GetInfoParam, - RunModelParam, - TestSuiteParam, ExplainTextParam, - RunAdHocTestParam, - DatasetProcessingParam, GenerateTestSuiteParam, - RunModelForDataFrameParam, + GetInfoParam, GetPushParam, + RunAdHocTestParam, + RunModelForDataFrameParam, + RunModelParam, + TestSuiteParam, ) from giskard.ml_worker.websocket.action import MLWorkerAction from giskard.models.base import BaseModel @@ -38,7 +37,7 @@ logger = logging.getLogger(__name__) -def parse_action_param(action, params): +def parse_action_param(action: MLWorkerAction, params): # TODO: Sort by usage frequency from future MixPanel metrics #NOSONAR if action == MLWorkerAction.getInfo: return GetInfoParam.parse_obj(params) @@ -157,7 +156,7 @@ def map_dataset_process_function_meta_ws(callable_type): } -def parse_function_arguments(ml_worker: MLWorker, request_arguments: List[websocket.FuncArgument]): +def parse_function_arguments(client: Optional[GiskardClient], request_arguments: List[websocket.FuncArgument]): arguments = dict() # Processing empty list @@ -169,21 +168,21 @@ def parse_function_arguments(ml_worker: MLWorker, request_arguments: List[websoc continue if arg.dataset is not None: arguments[arg.name] = Dataset.download( - ml_worker.client, + client, arg.dataset.project_key, arg.dataset.id, arg.dataset.sample, ) elif arg.model is not None: - arguments[arg.name] = BaseModel.download(ml_worker.client, arg.model.project_key, arg.model.id) + arguments[arg.name] = BaseModel.download(client, arg.model.project_key, arg.model.id) elif arg.slicingFunction is not None: arguments[arg.name] = SlicingFunction.download( - arg.slicingFunction.id, ml_worker.client, arg.slicingFunction.project_key - )(**parse_function_arguments(ml_worker, arg.args)) + arg.slicingFunction.id, client, arg.slicingFunction.project_key + )(**parse_function_arguments(client, arg.args)) elif arg.transformationFunction is not None: arguments[arg.name] = TransformationFunction.download( - arg.transformationFunction.id, ml_worker.client, arg.transformationFunction.project_key - )(**parse_function_arguments(ml_worker, arg.args)) + arg.transformationFunction.id, client, arg.transformationFunction.project_key + )(**parse_function_arguments(client, arg.args)) elif arg.float_arg is not None: arguments[arg.name] = float(arg.float_arg) elif arg.int_arg is not None: @@ -246,7 +245,6 @@ def do_run_adhoc_test(client, arguments, test, debug_info=None): logger.info(f"Executing {test.meta.display_name or f'{test.meta.module}.{test.meta.name}'}") test_result = test.get_builder()(**arguments).execute() if test_result.output_df is not None: # i.e. if debug is True and test has failed - if debug_info is None: raise ValueError( "You have requested to debug the test, " diff --git a/giskard/models/base/model.py b/giskard/models/base/model.py index 87c6cd4b38..824ded2592 100644 --- a/giskard/models/base/model.py +++ b/giskard/models/base/model.py @@ -1,3 +1,5 @@ +from typing import Iterable, Optional, Union + import builtins import importlib import logging @@ -8,25 +10,24 @@ import uuid from abc import ABC, abstractmethod from pathlib import Path -from typing import Iterable, Optional, Union import cloudpickle import numpy as np import pandas as pd import yaml -from .model_prediction import ModelPredictionResults -from ..cache import get_cache_enabled -from ..utils import np_types_to_native from ...client.giskard_client import GiskardClient from ...core.core import ModelMeta, ModelType, SupportedModelTypes from ...core.validation import configured_validate_arguments from ...datasets.base import Dataset +from ...ml_worker.exceptions.giskard_exception import GiskardException from ...ml_worker.utils.logging import Timer from ...models.cache import ModelCache from ...path_utils import get_size from ...settings import settings -from ...ml_worker.exceptions.giskard_exception import GiskardException +from ..cache import get_cache_enabled +from ..utils import np_types_to_native +from .model_prediction import ModelPredictionResults META_FILENAME = "giskard-model-meta.yaml" @@ -376,7 +377,7 @@ def upload(self, client: GiskardClient, project_key, validate_ds=None) -> str: return str(self.id) @classmethod - def download(cls, client: GiskardClient, project_key, model_id): + def download(cls, client: Optional[GiskardClient], project_key, model_id): """ Downloads the specified model from the Giskard server and loads it into memory. @@ -477,8 +478,8 @@ def to_mlflow(self): raise NotImplementedError def _llm_agent(self, dataset=None, allow_dataset_queries: bool = False, scan_result=None): - from ...llm.talk.talk import create_ml_llm from ...llm.config import llm_config + from ...llm.talk.talk import create_ml_llm data_source_tools = [] if allow_dataset_queries: diff --git a/giskard/utils/__init__.py b/giskard/utils/__init__.py index d8f7f03789..36f5c8d9c2 100644 --- a/giskard/utils/__init__.py +++ b/giskard/utils/__init__.py @@ -1,5 +1,11 @@ +import logging +from asyncio import Future +from concurrent.futures import ProcessPoolExecutor +from functools import wraps from threading import Thread +LOGGER = logging.getLogger(__name__) + def threaded(fn): def wrapper(*args, **kwargs): @@ -10,6 +16,99 @@ def wrapper(*args, **kwargs): return wrapper +class WorkerPool: + "Utility class to wrap a Process pool" + + def __init__(self): + self.pool = None + + def start(self, *args, **kwargs): + if self.pool is not None: + return + LOGGER.info("Starting worker pool...") + self.pool = ProcessPoolExecutor(*args, **kwargs) + LOGGER.info("Pool is started") + + def shutdown(self, wait=True, cancel_futures=False): + if self.pool is None: + return + self.pool.shutdown(wait=wait, cancel_futures=cancel_futures) + self.pool = None + + def submit(self, *args, **kwargs) -> Future: + if self.pool is None: + raise ValueError("Pool is not started") + return self.pool.submit(*args, **kwargs) + + def map(self, *args, **kwargs): + if self.pool is None: + raise ValueError("Pool is not started") + return self.pool.map(*args, **kwargs) + + def log_stats(self): + if self.pool is None: + LOGGER.debug("Pool is not yet started") + return + LOGGER.debug( + "Pool is currently having :\n - %s pending items\n - %s workers", + len(self.pool._pending_work_items), + len(self.pool._processes), + ) + + +POOL = WorkerPool() + + +def log_pool_stats(): + """Log pools stats to have some debug info""" + POOL.log_stats() + + +def start_pool(max_workers: int = None): + """Start the pool and warm it up, to get all workers up. + + Args: + max_workers (int, optional): _description_. Defaults to None. + """ + POOL.start(max_workers=max_workers) + # Doing warmup to spin up all workers + for _ in POOL.map(int, range(100)): + # Consume the results + pass + log_pool_stats() + + +def shutdown_pool(): + """Stop the pool""" + POOL.shutdown(wait=True, cancel_futures=True) + + +def call_in_pool(fn, *args, **kwargs): + """Submit the function call with args and kwargs inside the process pool + + Args: + fn (function): the function to call + + Returns: + Future: the promise of the results + """ + return POOL.submit(fn, *args, **kwargs) + + +def pooled(fn): + """Decorator to make a function be called inside the pool. + + Args: + fn (function): the function to wrap + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + return call_in_pool(fn, *args, **kwargs) + + return wrapper + + def fullname(o): klass = o.__class__ module = klass.__module__ diff --git a/pyproject.toml b/pyproject.toml index ca6aef3c9a..9368501360 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,10 @@ watch-doc = "python -m sphinx_autobuild --watch giskard docs docs/_build/html" clean = "rm -rf coverage.xml coverage* .coverage*" notebook = "jupyter notebook --ip 0.0.0.0 --port 8888 --no-browser --notebook-dir ./docs/reference/notebooks --NotebookApp.token=''" check-deps = "deptry ." -debug-worker = "python3 -m debugpy --listen localhost:5678 --wait-for-client giskard/cli.py worker start" - +debug-worker = "python3 -Xfrozen_modules=off -m debugpy --listen localhost:5678 --wait-for-client giskard/cli.py worker start" +worker = "python3 giskard/cli.py worker start" +debug-internal-worker = "python3 -Xfrozen_modules=off -m debugpy --listen localhost:5678 --wait-for-client giskard/cli.py worker -s start" +internal-worker = "python3 giskard/cli.py worker -s start" [tool.pdm.dev-dependencies] dev = [