Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GSK-1863 : Using multiprocess instead of threading #1469

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
enron_with_categories
Inokinoki marked this conversation as resolved.
Show resolved Hide resolved
.history
.mypy_cache
docker-stack.yml
Expand Down
2 changes: 2 additions & 0 deletions giskard/client/giskard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __call__(self, r):
class GiskardClient:
def __init__(self, url: str, key: str, hf_token: str = None):
self.host_url = url
self.key = key
self.hf_token = hf_token
base_url = urljoin(url, "/api/v2/")
self._session = sessions.BaseUrlSession(base_url=base_url)
self._session.mount(base_url, ErrorHandlingAdapter())
Expand Down
25 changes: 16 additions & 9 deletions giskard/commands/cli_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional

import asyncio
import functools
import logging
import os
import platform
import sys
from typing import Optional

import click
import lockfile
Expand All @@ -13,19 +14,19 @@
from lockfile.pidlockfile import PIDLockFile, read_pid_from_pidfile, remove_existing_pidfile
from pydantic import AnyHttpUrl

from giskard.cli_utils import common_options
from giskard.cli_utils import (
common_options,
create_pid_file_path,
follow_file,
get_log_path,
remove_stale_pid_file,
run_daemon,
get_log_path,
tail,
follow_file,
validate_url,
)
from giskard.path_utils import run_dir
from giskard.settings import settings
from giskard.utils.analytics_collector import anonymize, analytics
from giskard.utils.analytics_collector import analytics, anonymize

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,7 +86,13 @@ def wrapper(*args, **kwargs):
envvar="GSK_HF_TOKEN",
help="Access token for Giskard hosted in a private Hugging Face Spaces",
)
def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token):
@click.option(
"--parallelism",
"nb_workers",
default=None,
help="Number of processes to use for parallelism (None for number of cpu)",
)
def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token, nb_workers):
"""\b
Start ML Worker.

Expand All @@ -102,7 +109,7 @@ def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token):
)
api_key = initialize_api_key(api_key, is_server)
hf_token = initialize_hf_token(hf_token, is_server)
_start_command(is_server, url, api_key, is_daemon, hf_token)
_start_command(is_server, url, api_key, is_daemon, hf_token, nb_workers)


def initialize_api_key(api_key, is_server):
Expand All @@ -126,7 +133,7 @@ def initialize_hf_token(hf_token, is_server):
return hf_token


def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None):
def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None, nb_workers=None):
from giskard.ml_worker.ml_worker import MLWorker

start_msg = "Starting ML Worker"
Expand Down Expand Up @@ -154,7 +161,7 @@ def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None
run_daemon(is_server, url, api_key, hf_token)
else:
ml_worker = MLWorker(is_server, url, api_key, hf_token)
asyncio.get_event_loop().run_until_complete(ml_worker.start())
asyncio.get_event_loop().run_until_complete(ml_worker.start(nb_workers))
except KeyboardInterrupt:
logger.info("Exiting")
if ml_worker:
Expand Down
28 changes: 13 additions & 15 deletions giskard/datasets/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
from typing import Dict, Hashable, List, Optional, Union

import inspect
import logging
import posixpath
import tempfile
import uuid
from functools import cached_property
from pathlib import Path
from typing import Dict, Optional, List, Union, Hashable

import numpy as np
import pandas
import pandas as pd
import yaml
from pandas.api.types import is_list_like
from pandas.api.types import is_numeric_dtype
from mlflow import MlflowClient
from pandas.api.types import is_list_like, is_numeric_dtype
from xxhash import xxh3_128_hexdigest
from zstandard import ZstdDecompressor
from mlflow import MlflowClient

from giskard.client.giskard_client import GiskardClient
from giskard.client.io_utils import save_df, compress
from giskard.client.io_utils import compress, save_df
from giskard.client.python_utils import warning
from giskard.core.core import DatasetMeta, SupportedColumnTypes
from giskard.core.validation import configured_validate_arguments
from giskard.ml_worker.testing.registry.slicing_function import (
SlicingFunction,
SlicingFunctionType,
)
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction, SlicingFunctionType
from giskard.ml_worker.testing.registry.transformation_function import (
TransformationFunction,
TransformationFunctionType,
)
from giskard.settings import settings
from ..metadata.indexing import ColumnMetadataMixin

from ...ml_worker.utils.file_utils import get_file_name
from ..metadata.indexing import ColumnMetadataMixin

SAMPLE_SIZE = 1000

Expand Down Expand Up @@ -521,7 +519,7 @@ def load(cls, local_path: str):
)

@classmethod
def download(cls, client: GiskardClient, project_key, dataset_id, sample: bool = False):
def download(cls, client: Optional[GiskardClient], project_key, dataset_id, sample: bool = False):
"""
Downloads a dataset from a Giskard project and returns a Dataset object.
If the client is None, then the function assumes that it is running in an internal worker and looks for the dataset locally.
Expand Down Expand Up @@ -661,9 +659,7 @@ def to_mlflow(self, mlflow_client: MlflowClient = None, mlflow_run_id: str = Non

# To avoid file being open in write mode and read at the same time,
# First, we'll write it, then make sure to remove it
with tempfile.NamedTemporaryFile(
prefix="dataset-", suffix=".csv", delete=False
) as f:
with tempfile.NamedTemporaryFile(prefix="dataset-", suffix=".csv", delete=False) as f:
# Get file path
local_path = f.name
# Get name from file
Expand Down Expand Up @@ -696,8 +692,10 @@ def to_wandb(self, **kwargs) -> None:
Additional keyword arguments
(see https://docs.wandb.ai/ref/python/init) to be added to the active WandB run.
"""
from giskard.integrations.wandb.wandb_utils import wandb_run
import wandb # noqa library import already checked in wandb_run

from giskard.integrations.wandb.wandb_utils import wandb_run

from ...utils.analytics_collector import analytics

with wandb_run(**kwargs) as run:
Expand Down
15 changes: 10 additions & 5 deletions giskard/ml_worker/ml_worker.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from typing import Optional

import logging
import random
import secrets
import stomp
import time

import stomp
from pydantic import AnyHttpUrl
from websocket._exceptions import WebSocketException, WebSocketBadStatusException
from websocket._exceptions import WebSocketBadStatusException, WebSocketException

import giskard
from giskard.cli_utils import validate_url
from giskard.client.giskard_client import GiskardClient
from giskard.ml_worker.testing.registry.registry import load_plugins
from giskard.settings import settings
from giskard.cli_utils import validate_url
from giskard.utils import shutdown_pool, start_pool

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -140,9 +144,9 @@ def _connect_websocket_client(self, is_server=False):
def is_remote_worker(self):
return self.ml_worker_id is not INTERNAL_WORKER_ID

async def start(self):
async def start(self, nb_workers: Optional[int] = None):
load_plugins()

start_pool(nb_workers)
if self.ws_conn:
self.ws_stopping = False
self.connect_websocket_client()
Expand All @@ -162,3 +166,4 @@ def stop(self):
if self.ws_conn:
self.ws_stopping = True
self.ws_conn.disconnect()
shutdown_pool()
andreybavt marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion giskard/ml_worker/websocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from typing import Dict, List, Optional

from enum import Enum

from pydantic import BaseModel, Field


Expand Down
Loading