Skip to content

Commit

Permalink
Switch from threading to multiprocessing in worker
Browse files Browse the repository at this point in the history
Contributes to #GSK-1863
  • Loading branch information
Hartorn committed Oct 10, 2023
1 parent bf1502f commit 8a8edf2
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 193 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
enron_with_categories
.history
.mypy_cache
docker-stack.yml
Expand Down
2 changes: 2 additions & 0 deletions giskard/client/giskard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __call__(self, r):
class GiskardClient:
def __init__(self, url: str, key: str, hf_token: str = None):
self.host_url = url
self.key = key
self.hf_token = hf_token
base_url = urljoin(url, "/api/v2/")
self._session = sessions.BaseUrlSession(base_url=base_url)
self._session.mount(base_url, ErrorHandlingAdapter())
Expand Down
25 changes: 16 additions & 9 deletions giskard/commands/cli_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional

import asyncio
import functools
import logging
import os
import platform
import sys
from typing import Optional

import click
import lockfile
Expand All @@ -13,19 +14,19 @@
from lockfile.pidlockfile import PIDLockFile, read_pid_from_pidfile, remove_existing_pidfile
from pydantic import AnyHttpUrl

from giskard.cli_utils import common_options
from giskard.cli_utils import (
common_options,
create_pid_file_path,
follow_file,
get_log_path,
remove_stale_pid_file,
run_daemon,
get_log_path,
tail,
follow_file,
validate_url,
)
from giskard.path_utils import run_dir
from giskard.settings import settings
from giskard.utils.analytics_collector import anonymize, analytics
from giskard.utils.analytics_collector import analytics, anonymize

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,7 +86,13 @@ def wrapper(*args, **kwargs):
envvar="GSK_HF_TOKEN",
help="Access token for Giskard hosted in a private Hugging Face Spaces",
)
def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token):
@click.option(
"--parallelism",
"nb_workers",
default=None,
help="Number of processes to use for parallelism (None for number of cpu)",
)
def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token, nb_workers):
"""\b
Start ML Worker.
Expand All @@ -102,7 +109,7 @@ def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token):
)
api_key = initialize_api_key(api_key, is_server)
hf_token = initialize_hf_token(hf_token, is_server)
_start_command(is_server, url, api_key, is_daemon, hf_token)
_start_command(is_server, url, api_key, is_daemon, hf_token, nb_workers)


def initialize_api_key(api_key, is_server):
Expand All @@ -126,7 +133,7 @@ def initialize_hf_token(hf_token, is_server):
return hf_token


def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None):
def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None, nb_workers=None):
from giskard.ml_worker.ml_worker import MLWorker

start_msg = "Starting ML Worker"
Expand Down Expand Up @@ -154,7 +161,7 @@ def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None
run_daemon(is_server, url, api_key, hf_token)
else:
ml_worker = MLWorker(is_server, url, api_key, hf_token)
asyncio.get_event_loop().run_until_complete(ml_worker.start())
asyncio.get_event_loop().run_until_complete(ml_worker.start(nb_workers))
except KeyboardInterrupt:
logger.info("Exiting")
if ml_worker:
Expand Down
28 changes: 13 additions & 15 deletions giskard/datasets/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
from typing import Dict, Hashable, List, Optional, Union

import inspect
import logging
import posixpath
import tempfile
import uuid
from functools import cached_property
from pathlib import Path
from typing import Dict, Optional, List, Union, Hashable

import numpy as np
import pandas
import pandas as pd
import yaml
from pandas.api.types import is_list_like
from pandas.api.types import is_numeric_dtype
from mlflow import MlflowClient
from pandas.api.types import is_list_like, is_numeric_dtype
from xxhash import xxh3_128_hexdigest
from zstandard import ZstdDecompressor
from mlflow import MlflowClient

from giskard.client.giskard_client import GiskardClient
from giskard.client.io_utils import save_df, compress
from giskard.client.io_utils import compress, save_df
from giskard.client.python_utils import warning
from giskard.core.core import DatasetMeta, SupportedColumnTypes
from giskard.core.validation import configured_validate_arguments
from giskard.ml_worker.testing.registry.slicing_function import (
SlicingFunction,
SlicingFunctionType,
)
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction, SlicingFunctionType
from giskard.ml_worker.testing.registry.transformation_function import (
TransformationFunction,
TransformationFunctionType,
)
from giskard.settings import settings
from ..metadata.indexing import ColumnMetadataMixin

from ...ml_worker.utils.file_utils import get_file_name
from ..metadata.indexing import ColumnMetadataMixin

SAMPLE_SIZE = 1000

Expand Down Expand Up @@ -521,7 +519,7 @@ def load(cls, local_path: str):
)

@classmethod
def download(cls, client: GiskardClient, project_key, dataset_id, sample: bool = False):
def download(cls, client: Optional[GiskardClient], project_key, dataset_id, sample: bool = False):
"""
Downloads a dataset from a Giskard project and returns a Dataset object.
If the client is None, then the function assumes that it is running in an internal worker and looks for the dataset locally.
Expand Down Expand Up @@ -661,9 +659,7 @@ def to_mlflow(self, mlflow_client: MlflowClient = None, mlflow_run_id: str = Non

# To avoid file being open in write mode and read at the same time,
# First, we'll write it, then make sure to remove it
with tempfile.NamedTemporaryFile(
prefix="dataset-", suffix=".csv", delete=False
) as f:
with tempfile.NamedTemporaryFile(prefix="dataset-", suffix=".csv", delete=False) as f:
# Get file path
local_path = f.name
# Get name from file
Expand Down Expand Up @@ -696,8 +692,10 @@ def to_wandb(self, **kwargs) -> None:
Additional keyword arguments
(see https://docs.wandb.ai/ref/python/init) to be added to the active WandB run.
"""
from giskard.integrations.wandb.wandb_utils import wandb_run
import wandb # noqa library import already checked in wandb_run

from giskard.integrations.wandb.wandb_utils import wandb_run

from ...utils.analytics_collector import analytics

with wandb_run(**kwargs) as run:
Expand Down
15 changes: 10 additions & 5 deletions giskard/ml_worker/ml_worker.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from typing import Optional

import logging
import random
import secrets
import stomp
import time

import stomp
from pydantic import AnyHttpUrl
from websocket._exceptions import WebSocketException, WebSocketBadStatusException
from websocket._exceptions import WebSocketBadStatusException, WebSocketException

import giskard
from giskard.cli_utils import validate_url
from giskard.client.giskard_client import GiskardClient
from giskard.ml_worker.testing.registry.registry import load_plugins
from giskard.settings import settings
from giskard.cli_utils import validate_url
from giskard.utils import shutdown_pool, start_pool

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -140,9 +144,9 @@ def _connect_websocket_client(self, is_server=False):
def is_remote_worker(self):
return self.ml_worker_id is not INTERNAL_WORKER_ID

async def start(self):
async def start(self, nb_workers: Optional[int] = None):
load_plugins()

start_pool(nb_workers)
if self.ws_conn:
self.ws_stopping = False
self.connect_websocket_client()
Expand All @@ -162,3 +166,4 @@ def stop(self):
if self.ws_conn:
self.ws_stopping = True
self.ws_conn.disconnect()
shutdown_pool()
3 changes: 2 additions & 1 deletion giskard/ml_worker/websocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class Catalog(WorkerReply):


class DataRow(BaseModel):
columns: Dict[str, str]
columns: Dict[str, str] = Field(..., repr=False)



class DataFrame(BaseModel):
Expand Down
Loading

0 comments on commit 8a8edf2

Please sign in to comment.