From 7063efe5e838b863c7dbd00622b2351006e1be0b Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Tue, 20 Aug 2024 11:48:09 -0400 Subject: [PATCH 01/10] initial commit for pydoclint added --- .pre-commit-config.yaml | 11 +++++++++++ packages/syft/setup.cfg | 1 + 2 files changed, 12 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9dcd417cccc..0224819e88b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -180,6 +180,17 @@ repos: - id: prettier exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock|.vscode) + - repo: https://github.com/jsh9/pydoclint + rev: 0.5.6 + hooks: + - id: pydoclint + args: [ + # --config=packages/syft/pyproject.toml, + --quiet, + --check-return-types=False, + ] + # args: [--style=google, --check-return-types=False] + # - repo: meta # hooks: # - id: identity diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 64221e223b1..b71eb9ed013 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -101,6 +101,7 @@ dev = pre-commit==3.7.1 ruff==0.4.7 safety>=2.4.0b2 + pydoclint==0.5.6 telemetry = opentelemetry-api==1.14.0 From dc3f8423cc4eff6f2a40c94e29a7b467224395a3 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Tue, 20 Aug 2024 16:18:15 -0400 Subject: [PATCH 02/10] Enforcing Google pydoclint style in Syft --- .pre-commit-config.yaml | 4 +- .../syft/src/syft/custom_worker/builder.py | 75 ++- packages/syft/src/syft/dev/prof.py | 14 +- packages/syft/src/syft/serde/arrow.py | 6 +- .../src/syft/serde/lib_service_registry.py | 16 +- packages/syft/src/syft/serde/serializable.py | 34 +- .../src/syft/service/action/action_object.py | 365 +++++-------- .../src/syft/service/action/action_store.py | 46 +- .../src/syft/service/action/action_types.py | 23 +- packages/syft/src/syft/service/api/api.py | 35 +- .../syft/service/network/network_service.py | 30 +- .../service/network/rathole_config_builder.py | 63 ++- .../src/syft/service/network/server_peer.py | 138 +++-- .../syft/src/syft/service/network/utils.py | 2 +- .../src/syft/service/notifier/notifier.py | 2 +- .../syft/service/notifier/notifier_service.py | 95 +++- .../src/syft/service/notifier/smtp_client.py | 22 +- .../syft/src/syft/service/project/project.py | 7 +- .../syft/service/project/project_service.py | 2 +- .../syft/src/syft/service/request/request.py | 7 +- .../syft/service/settings/settings_service.py | 21 +- .../service/worker/worker_pool_service.py | 111 +++- .../src/syft/store/dict_document_store.py | 58 +- .../syft/src/syft/store/document_store.py | 71 ++- packages/syft/src/syft/store/locks.py | 100 ++-- .../src/syft/store/mongo_document_store.py | 67 +-- .../src/syft/store/sqlite_document_store.py | 75 ++- .../src/syft/types/syft_object_registry.py | 23 +- .../components/tabulator_template.py | 8 +- packages/syft/src/syft/util/table.py | 4 +- packages/syft/src/syft/util/util.py | 516 ++++++++++++++++-- 31 files changed, 1309 insertions(+), 731 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0224819e88b..46b4a0e4129 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -187,9 +187,9 @@ repos: args: [ # --config=packages/syft/pyproject.toml, --quiet, - --check-return-types=False, + --style=google, + --allow-init-docstring=true, ] - # args: [--style=google, --check-return-types=False] # - repo: meta # hooks: diff --git a/packages/syft/src/syft/custom_worker/builder.py b/packages/syft/src/syft/custom_worker/builder.py index e47f341a27f..7cda1cbd80a 100644 --- a/packages/syft/src/syft/custom_worker/builder.py +++ b/packages/syft/src/syft/custom_worker/builder.py @@ -31,6 +31,11 @@ class CustomWorkerBuilder: @cached_property def builder(self) -> BuilderBase: + """Returns the appropriate builder instance based on the environment. + + Returns: + BuilderBase: An instance of either KubernetesBuilder or DockerBuilder. + """ if IN_KUBERNETES: return KubernetesBuilder() else: @@ -39,16 +44,22 @@ def builder(self) -> BuilderBase: def build_image( self, config: WorkerConfig, - tag: str | None = None, + tag: str | None, **kwargs: Any, ) -> ImageBuildResult: - """ - Builds a Docker image from the given configuration. + """Builds a Docker image from the given configuration. + Args: config (WorkerConfig): The configuration for building the Docker image. - tag (str): The tag to use for the image. - """ + tag (str | None): The tag to use for the image. Defaults to None. + **kwargs (Any): Additional keyword arguments for the build process. + + Returns: + ImageBuildResult: The result of the image build process. + Raises: + TypeError: If the config type is not recognized. + """ if isinstance(config, DockerWorkerConfig): return self._build_dockerfile(config, tag, **kwargs) elif isinstance(config, CustomWorkerConfig): @@ -64,13 +75,18 @@ def push_image( password: str, **kwargs: Any, ) -> ImagePushResult: - """ - Pushes a Docker image to the given repo. + """Pushes a Docker image to the given registry. + Args: - repo (str): The repo to push the image to. - tag (str): The tag to use for the image. - """ + tag (str): The tag of the image to push. + registry_url (str): The URL of the registry. + username (str): The username for registry authentication. + password (str): The password for registry authentication. + **kwargs (Any): Additional keyword arguments for the push process. + Returns: + ImagePushResult: The result of the image push process. + """ return self.builder.push_image( tag=tag, username=username, @@ -84,6 +100,16 @@ def _build_dockerfile( tag: str, **kwargs: Any, ) -> ImageBuildResult: + """Builds a Docker image using a Dockerfile. + + Args: + config (DockerWorkerConfig): The configuration containing the Dockerfile. + tag (str): The tag to use for the image. + **kwargs (Any): Additional keyword arguments for the build process. + + Returns: + ImageBuildResult: The result of the image build process. + """ return self.builder.build_image( dockerfile=config.dockerfile, tag=tag, @@ -95,8 +121,18 @@ def _build_template( config: CustomWorkerConfig, **kwargs: Any, ) -> ImageBuildResult: - # Builds a Docker pre-made CPU/GPU image template using a CustomWorkerConfig - # remove once GPU is supported + """Builds a Docker image using a pre-made template. + + Args: + config (CustomWorkerConfig): The configuration containing template settings. + **kwargs (Any): Additional keyword arguments for the build process. + + Returns: + ImageBuildResult: The result of the image build process. + + Raises: + Exception: If GPU support is requested but not supported. + """ if config.build.gpu: raise Exception("GPU custom worker is not supported yet") @@ -120,16 +156,19 @@ def _build_template( ) def find_worker_image(self, type: str) -> Path: - """ - Find the Worker Dockerfile and it's context path - - PROD will be in `$APPDIR/grid/` - - DEV will be in `packages/grid/backend/grid/images` - - In both the cases context dir does not matter (unless we're calling COPY) + """Finds the Worker Dockerfile and its context path. + + The production Dockerfile will be located at `$APPDIR/grid/`. + The development Dockerfile will be located in `packages/grid/backend/grid/images`. Args: - type (str): The type of worker. + type (str): The type of worker (e.g., 'cpu' or 'gpu'). + Returns: Path: The path to the Dockerfile. + + Raises: + FileNotFoundError: If the Dockerfile is not found in any of the expected locations. """ filename = f"worker_{type}.dockerfile" lookup_paths = [ diff --git a/packages/syft/src/syft/dev/prof.py b/packages/syft/src/syft/dev/prof.py index a6aeffc5780..902f47ed6a4 100644 --- a/packages/syft/src/syft/dev/prof.py +++ b/packages/syft/src/syft/dev/prof.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Generator import contextlib import os import signal @@ -8,15 +9,16 @@ @contextlib.contextmanager -def pyspy() -> None: # type: ignore +def pyspy() -> Generator[subprocess.Popen, None, None]: """Profile a block of code using py-spy. Intended for development purposes only. Example: - ``` - with pyspy(): - # do some work - a = [i for i in range(1000000)] - ``` + with pyspy(): + # do some work + a = [i for i in range(1000000)] + + Yields: + subprocess.Popen: The process object running py-spy. """ fd, fname = tempfile.mkstemp(".svg") os.close(fd) diff --git a/packages/syft/src/syft/serde/arrow.py b/packages/syft/src/syft/serde/arrow.py index 31a5ad1d27c..ec9554b9ba2 100644 --- a/packages/syft/src/syft/serde/arrow.py +++ b/packages/syft/src/syft/serde/arrow.py @@ -80,13 +80,13 @@ def numpyutf8toarray(input_index: np.ndarray) -> np.ndarray: def arraytonumpyutf8(string_list: str | np.ndarray) -> bytes: - """Encodes string Numpyarray to utf-8 encoded numpy array. + """Encodes string Numpyarray to utf-8 encoded numpy array. Args: - string_list (np.ndarray): NumpyArray to be encoded + string_list (str | np.ndarray): NumpyArray or string to be encoded. Returns: - bytes: serialized utf-8 encoded int Numpy array + bytes: serialized utf-8 encoded int Numpy array. """ array_shape = np.array(string_list).shape string_list = np.array(string_list).flatten() diff --git a/packages/syft/src/syft/serde/lib_service_registry.py b/packages/syft/src/syft/serde/lib_service_registry.py index 517df6c643c..a3f0901b32a 100644 --- a/packages/syft/src/syft/serde/lib_service_registry.py +++ b/packages/syft/src/syft/serde/lib_service_registry.py @@ -120,15 +120,16 @@ def init_child( child_obj: type | object, absolute_path: str, ) -> Self | None: - """Get the child of parent as a CMPBase object + """Get the child of parent as a CMPBase object. Args: - parent_obj (_type_): parent object - child_path (_type_): _description_ - child_obj (_type_): _description_ + parent_obj (type | object): The parent object. + child_path (str): The path of the child object. + child_obj (type | object): The child object. + absolute_path (str): The absolute path of the child object. Returns: - _type_: _description_ + Self | None: The initialized CMPBase object or None if not applicable. """ parent_is_parent_module = CMPBase.parent_is_parent_module(parent_obj, child_obj) if CMPBase.isfunction(child_obj) and parent_is_parent_module: @@ -141,11 +142,6 @@ def init_child( elif inspect.ismodule(child_obj) and CMPBase.is_submodule( parent_obj, child_obj ): - ## TODO, we could register modules and functions in 2 ways: - # A) as numpy.float32 (what we are doing now) - # B) as numpy.core.float32 (currently not supported) - # only allow submodules - return CMPModule( child_path, permissions=self.permissions, diff --git a/packages/syft/src/syft/serde/serializable.py b/packages/syft/src/syft/serde/serializable.py index 9a683dbcf57..6296f276710 100644 --- a/packages/syft/src/syft/serde/serializable.py +++ b/packages/syft/src/syft/serde/serializable.py @@ -10,7 +10,6 @@ module_type = type(syft) - T = TypeVar("T", bound=type) @@ -26,25 +25,28 @@ def serializable( Recursively serialize attributes of the class. Args: - `attrs` : List of attributes to serialize - `without` : List of attributes to exclude from serialization - `inherit` : Whether to inherit serializable attribute list from base class - `inheritable` : Whether the serializable attribute list can be inherited by derived class - - For non-pydantic classes, - - `inheritable=True` => Derived classes will include base class `attrs` - - `inheritable=False` => Derived classes will not include base class `attrs` - - `inherit=True` => Base class `attrs` + `attrs` - `without` - - `inherit=False` => `attrs` - `without` - - For pydantic classes, + attrs (list[str] | None): List of attributes to serialize. Defaults to None. + without (list[str] | None): List of attributes to exclude from serialization. Defaults to None. + inherit (bool | None): Whether to inherit the serializable attribute list from the base class. Defaults to True. + inheritable (bool | None): Whether the serializable attribute list can be inherited by derived + classes. Defaults to True. + canonical_name (str | None): The canonical name for the serialization. Defaults to None. + version (int | None): The version number for the serialization. Defaults to None. + + For non-pydantic classes: + - `inheritable=True` => Derived classes will include base class `attrs`. + - `inheritable=False` => Derived classes will not include base class `attrs`. + - `inherit=True` => Base class `attrs` + `attrs` - `without`. + - `inherit=False` => `attrs` - `without`. + + For pydantic classes: - No need to provide `attrs`. They will be automatically inferred. - Providing `attrs` will override the inferred attributes. - - `without` will work only on attributes of `Optional` type - - `inherit`, `inheritable` will not work as pydantic inherits by default + - `without` will work only on attributes of `Optional` type. + - `inherit`, `inheritable` will not work as pydantic inherits by default. Returns: - Decorated class + Callable[[T], T]: The decorated class. """ def rs_decorator(cls: T) -> T: diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 1196e3b5dc1..6628a476ba7 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -10,7 +10,6 @@ import logging from pathlib import Path import sys -import threading import time import traceback import types @@ -30,7 +29,6 @@ # relative from ...client.api import APIRegistry -from ...client.api import SyftAPI from ...client.api import SyftAPICall from ...client.client import SyftClient from ...serde.serializable import serializable @@ -41,7 +39,6 @@ from ...service.response import SyftSuccess from ...service.response import SyftWarning from ...store.linked_obj import LinkedObject -from ...types.base import SyftBaseModel from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftBaseObject @@ -97,19 +94,16 @@ def repr_cls(c: Any) -> str: class Action(SyftObject): """Serializable Action object. - Parameters: - path: str - The path of the Type of the remote object. - op: str - The method to be executed from the remote object. - remote_self: Optional[LineageID] - The extended UID of the SyftObject - args: List[LineageID] - `op` args - kwargs: Dict[str, LineageID] - `op` kwargs - result_id: Optional[LineageID] - Extended UID of the resulted SyftObject + Attributes: + path (str | None): The path of the Type of the remote object. + op (str | None): The method to be executed from the remote object. + remote_self (LineageID | None): The extended UID of the SyftObject. + args (list[LineageID]): Arguments passed to the operation. + kwargs (dict[str, LineageID]): Keyword arguments passed to the operation. + result_id (LineageID): Extended UID of the resulted SyftObject. + action_type (ActionType | None): The type of action being performed. + create_object (SyftObject | None): The Syft object to be created. + user_code_id (UID | None): The UID associated with the user code. """ __canonical_name__ = "Action" @@ -155,11 +149,6 @@ def syft_history_hash(self) -> int: hashes = 0 if self.remote_self: hashes += hash(self.remote_self.syft_history_hash) - # 🔵 TODO: resolve this - # if the object is ActionDataEmpty then the type might not be equal to the - # real thing. This is the same issue with determining the result type from - # a pointer operation in the past, so we should think about what we want here - # hashes += hash(self.path) hashes += hash(self.op) for arg in self.args: hashes += hash(arg.syft_history_hash) @@ -406,17 +395,14 @@ class PreHookContext(SyftBaseObject): """Hook context - Parameters: - obj: Any - The ActionObject to use for the action - op_name: str - The method name to use for the action - server_uid: Optional[UID] - Optional Syft server UID - result_id: Optional[Union[UID, LineageID]] - Optional result Syft UID - action: Optional[Action] - The action generated by the current hook + Attributes: + obj (Any): The ActionObject to use for the action. + op_name (str): The method name to use for the action. + server_uid (Optional[UID]): Optional Syft server UID. + result_id (Optional[Union[UID, LineageID]]): Optional result Syft UID. + action (Optional[Action]): The action generated by the current hook. + result_twin_type (TwinMode | None): The twin mode for the result. + action_type (ActionType | None): The type of action being performed. """ obj: Any = None @@ -431,18 +417,17 @@ class PreHookContext(SyftBaseObject): def make_action_side_effect( context: PreHookContext, *args: Any, **kwargs: Any ) -> Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: - """Create a new action from context_op_name, and add it to the PreHookContext - - Parameters: - context: PreHookContext - PreHookContext object - *args: - Operation *args - **kwargs - Operation *kwargs + """Create a new action from context_op_name, and add it to the PreHookContext. + + Args: + context (PreHookContext): The PreHookContext object. + *args (Any): Operation arguments. + **kwargs (Any): Operation keyword arguments. + Returns: - - Ok[[Tuple[PreHookContext, Tuple[Any, ...], Dict[str, Any]]] on success - - Err[str] on failure + Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: + - Ok[Tuple[PreHookContext, Tuple[Any, ...], Dict[str, Any]]] on success. + - Err[str] on failure. """ try: action = context.obj.syft_make_action_with_self( @@ -460,90 +445,24 @@ def make_action_side_effect( return Ok((context, args, kwargs)) -class TraceResultRegistry: - __result_registry__: dict[int, TraceResult] = {} - - @classmethod - def set_trace_result_for_current_thread( - cls, - client: SyftClient, - ) -> None: - cls.__result_registry__[threading.get_ident()] = TraceResult( - client=client, is_tracing=True - ) - - @classmethod - def get_trace_result_for_thread(cls) -> TraceResult | None: - return cls.__result_registry__.get(threading.get_ident(), None) - - @classmethod - def reset_result_for_thread(cls) -> None: - if threading.get_ident() in cls.__result_registry__: - del cls.__result_registry__[threading.get_ident()] - - @classmethod - def current_thread_is_tracing(cls) -> bool: - trace_result = cls.get_trace_result_for_thread() - if trace_result is None: - return False - else: - return trace_result.is_tracing - - -class TraceResult(SyftBaseModel): - result: list = [] - client: SyftClient - is_tracing: bool = False - - -def trace_action_side_effect( +def send_action_side_effect( context: PreHookContext, *args: Any, **kwargs: Any ) -> Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: - action = context.action - if action is not None and TraceResultRegistry.current_thread_is_tracing(): - trace_result = TraceResultRegistry.get_trace_result_for_thread() - trace_result.result += [action] # type: ignore - return Ok((context, args, kwargs)) - - -def convert_to_pointers( - api: SyftAPI, - server_uid: UID | None = None, - args: list | None = None, - kwargs: dict | None = None, -) -> tuple[list, dict]: - # relative - from ..dataset.dataset import Asset - - def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: - if ( - not isinstance(arg, ActionObject | Asset | UID) - and api.signing_key is not None # type: ignore[unreachable] - ): - arg = ActionObject.from_obj( # type: ignore[unreachable] - syft_action_data=arg, - syft_client_verify_key=api.signing_key.verify_key, - syft_server_location=api.server_uid, - ) - arg.syft_server_uid = server_uid - r = arg._save_to_blob_storage() - if isinstance(r, SyftError): - print(r.message) - if isinstance(r, SyftWarning): - logger.debug(r.message) - arg = api.services.action.set(arg) - return arg + """Create a new action from the context.op_name, and execute it on the remote server. - arg_list = [process_arg(arg) for arg in args] if args else [] - kwarg_dict = {k: process_arg(v) for k, v in kwargs.items()} if kwargs else {} - - return arg_list, kwarg_dict + Args: + context (PreHookContext): The PreHookContext object. + *args (Any): Operation arguments. + **kwargs (Any): Operation keyword arguments. + Returns: + Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: + - Ok[Tuple[PreHookContext, Tuple[Any, ...], Dict[str, Any]]] on success. + - Err[str] on failure. -def send_action_side_effect( - context: PreHookContext, *args: Any, **kwargs: Any -) -> Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: - """Create a new action from the context.op_name, and execute it on the remote server.""" + Raises: + RuntimeError: If the action cannot be created or if an unexpected response is received. + """ try: if context.action is None: result = make_action_side_effect(context, *args, **kwargs) @@ -555,7 +474,7 @@ def send_action_side_effect( action_result = context.obj.syft_execute_action(context.action, sync=True) if not isinstance(action_result, ActionObject): - raise RuntimeError(f"Got back unexpected response : {action_result}") + raise RuntimeError(f"Got back unexpected response: {action_result}") else: context.server_uid = action_result.syft_server_uid context.result_id = action_result.id @@ -570,18 +489,21 @@ def send_action_side_effect( def propagate_server_uid( context: PreHookContext, op: str, result: Any ) -> Result[Ok[Any], Err[str]]: - """Patch the result to include the syft_server_uid - - Parameters: - context: PreHookContext - PreHookContext object - op: str - Which operation was executed - result: Any - The result to patch + """Patch the result to include the syft_server_uid. + + Args: + context (PreHookContext): The PreHookContext object. + op (str): Which operation was executed. + result (Any): The result to patch. + Returns: - - Ok[[result] on success - - Err[str] on failure + Result[Ok[Any], Err[str]]: + - Ok[result] on success. + - Err[str] on failure. + + Raises: + RuntimeError: If the parent object does not have a syft_server_uid or + if the output is not wrapped and should not propagate the server_uid. """ if context.op_name in dont_make_side_effects or not hasattr( context.obj, "syft_server_uid" @@ -592,14 +514,16 @@ def propagate_server_uid( syft_server_uid = getattr(context.obj, "syft_server_uid", None) if syft_server_uid is None: raise RuntimeError( - "Can't proagate server_uid because parent doesnt have one" + "Can't propagate server_uid because parent doesn't have one" ) if op not in context.obj._syft_dont_wrap_attrs(): if hasattr(result, "syft_server_uid"): result.syft_server_uid = syft_server_uid else: - raise RuntimeError("dont propogate server_uid because output isnt wrapped") + raise RuntimeError( + "Don't propagate server_uid because output isn't wrapped" + ) except Exception: return Err(f"propagate_server_uid failed with {traceback.format_exc()}") @@ -970,16 +894,18 @@ def syft_eq(self, ext_obj: Self | None) -> bool: def syft_execute_action( self, action: Action, sync: bool = True ) -> ActionObjectPointer: - """Execute a remote action + """Execute a remote action. - Parameters: - action: Action - Which action to execute - sync: bool - Run sync/async + Args: + action (Action): Which action to execute. + sync (bool): Run sync/async. Returns: - ActionObjectPointer + ActionObjectPointer: The pointer to the action object. + + Raises: + SyftException: If the pointer doesn't have a server_uid. + ValueError: If the API is not found. """ if self.syft_server_uid is None: raise SyftException("Pointers can't execute without a server_uid.") @@ -1060,12 +986,12 @@ def _syft_try_to_save_to_store(self, obj: SyftObject) -> None: trace_result.result += [action] # type: ignore api = APIRegistry.api_for( - server_uid=self.syft_server_location, + server_uid=obj.syft_server_location, user_verify_key=self.syft_client_verify_key, ) if api is None: print( - f"failed saving {obj} to blob storage, api is None. You must login to {self.syft_server_location}" + f"failed saving {obj} to blob storage, api is None. You must login to {obj.syft_server_location}" ) return else: @@ -1105,33 +1031,24 @@ def syft_make_action( path: str, op: str, remote_self: UID | LineageID | None = None, - args: ( - list[UID | LineageID | ActionObjectPointer | ActionObject | Any] | None - ) = None, - kwargs: ( - dict[str, UID | LineageID | ActionObjectPointer | ActionObject | Any] | None - ) = None, + args: list[UID | LineageID | ActionObjectPointer | ActionObject | Any] + | None = None, + kwargs: dict[str, UID | LineageID | ActionObjectPointer | ActionObject | Any] + | None = None, action_type: ActionType | None = None, ) -> Action: - """Generate new action from the information - - Parameters: - path: str - The path of the Type of the remote object. - op: str - The method to be executed from the remote object. - remote_self: Optional[Union[UID, LineageID]] - The extended UID of the SyftObject - args: Optional[List[Union[UID, LineageID, ActionObjectPointer, ActionObject]]] - `op` args - kwargs: Optional[Dict[str, Union[UID, LineageID, ActionObjectPointer, ActionObject]]] - `op` kwargs - Returns: - Action object + """Generate new action from the information. - Raises: - ValueError: For invalid args or kwargs - PydanticValidationError: For args and kwargs + Args: + path (str): The path of the Type of the remote object. + op (str): The method to be executed from the remote object. + remote_self (UID | LineageID | None): The extended UID of the SyftObject. + args (list[UID | LineageID | ActionObjectPointer | ActionObject | Any] | None): Operation arguments. + kwargs (dict[str, UID | LineageID | ActionObjectPointer | ActionObject | Any] | None): Operation keyword arguments. + action_type (ActionType | None): The type of action being performed. + + Returns: + Action: The generated action object. """ if args is None: args = [] @@ -1139,7 +1056,6 @@ def syft_make_action( kwargs = {} arg_ids = [self._syft_prepare_obj_uid(obj) for obj in args] - kwarg_ids = {k: self._syft_prepare_obj_uid(obj) for k, obj in kwargs.items()} action = Action( @@ -1155,25 +1071,20 @@ def syft_make_action( def syft_make_action_with_self( self, op: str, - args: list[UID | ActionObjectPointer] | None = None, + args: dict[str, UID | ActionObjectPointer] | None = None, kwargs: dict[str, UID | ActionObjectPointer] | None = None, action_type: ActionType | None = None, ) -> Action: """Generate new method action from the current object. - Parameters: - op: str - The method to be executed from the remote object. - args: List[LineageID] - `op` args - kwargs: Dict[str, LineageID] - `op` kwargs - Returns: - Action object + Args: + op (str): The method to be executed from the remote object. + args (dict[str, UID | ActionObjectPointer] | None): Operation arguments. + kwargs (dict[str, UID | ActionObjectPointer] | None): Operation keyword arguments. + action_type (ActionType | None): The type of action being performed. - Raises: - ValueError: For invalid args or kwargs - PydanticValidationError: For args and kwargs + Returns: + Action: The generated action object. """ path = self.syft_get_path() return self.syft_make_action( @@ -1201,7 +1112,7 @@ def get_sync_dependencies( return [] def syft_get_path(self) -> str: - """Get the type path of the underlying object""" + """Get the type path of the underlying object.""" if ( isinstance(self, AnyActionObject) and self.syft_internal_type @@ -1217,12 +1128,11 @@ def syft_remote_method( ) -> Callable: """Generate a Callable object for remote calls. - Parameters: - op: str - he method to be executed from the remote object. + Args: + op (str): The method to be executed from the remote object. Returns: - A function + Callable: A function to perform the operation. """ def wrapper( @@ -1267,7 +1177,7 @@ def _send( return res def get_from(self, client: SyftClient) -> Any: - """Get the object from a Syft Client""" + """Get the object from a Syft Client.""" res = client.api.services.action.get(self.id) if not isinstance(res, ActionObject): return SyftError(message=f"{res}") @@ -1302,7 +1212,14 @@ def has_storage_permission(self) -> bool: return api.services.action.has_storage_permission(self.id) def get(self, block: bool = False) -> Any: - """Get the object from a Syft Client""" + """Get the object from a Syft Client. + + Args: + block (bool): Whether to block until the object is available. + + Returns: + Any: The object retrieved from the Syft Client. + """ # relative if block: @@ -1402,13 +1319,22 @@ def from_obj( ) -> ActionObject: """Create an ActionObject from an existing object. - Parameters: - syft_action_data: Any - The object to be converted to a Syft ActionObject - id: Optional[UID] - Which ID to use for the ActionObject. Optional - syft_lineage_id: Optional[LineageID] - Which LineageID to use for the ActionObject. Optional + Args: + syft_action_data (Any): The object to be converted to a Syft ActionObject. + id (UID | None): Which ID to use for the ActionObject. Optional. + syft_lineage_id (LineageID | None): Which LineageID to use for the ActionObject. Optional. + syft_client_verify_key (SyftVerifyKey | None): The client verification key. + syft_server_location (UID | None): The server location UID. + syft_resolved (bool | None): Whether the object is resolved. + data_server_id (UID | None): The data server ID. + syft_blob_storage_entry_id (UID | None): The blob storage entry ID. + + Returns: + ActionObject: The created ActionObject. + + Raises: + SyftException: If the object's type is unsupported. + ValueError: If the UID and LineageID don't match. """ if id is not None and syft_lineage_id is not None and id != syft_lineage_id.id: raise ValueError("UID and LineageID should match") @@ -1517,7 +1443,6 @@ def obj_not_ready( @classmethod def empty( - # TODO: fix the mypy issue cls, syft_internal_type: type[Any] | None = None, id: UID | None = None, @@ -1526,17 +1451,19 @@ def empty( data_server_id: UID | None = None, syft_blob_storage_entry_id: UID | None = None, ) -> Self: - """Create an ActionObject from a type, using a ActionDataEmpty object - - Parameters: - syft_internal_type: Type - The Type for which to create a ActionDataEmpty object - id: Optional[UID] - Which ID to use for the ActionObject. Optional - syft_lineage_id: Optional[LineageID] - Which LineageID to use for the ActionObject. Optional - """ + """Create an ActionObject from a type, using an ActionDataEmpty object. + + Args: + syft_internal_type (type[Any] | None): The Type for which to create an ActionDataEmpty object. + id (UID | None): Which ID to use for the ActionObject. + syft_lineage_id (LineageID | None): Which LineageID to use for the ActionObject. + syft_resolved (bool | None): Whether the object is resolved. + data_server_id (UID | None): The data server ID. + syft_blob_storage_entry_id (UID | None): The blob storage entry ID. + Returns: + Self: The created ActionObject. + """ syft_internal_type = ( type(None) if syft_internal_type is None else syft_internal_type ) @@ -1585,10 +1512,10 @@ def __post_init__(self) -> None: def _syft_add_pre_hooks__(self, eager_execution: bool) -> None: """ - Add pre-hooks + Add pre-hooks. Args: - eager_execution: bool: If eager execution is enabled, hooks for + eager_execution (bool): If eager execution is enabled, hooks for tracing and executing the action on remote are added. """ @@ -1607,10 +1534,10 @@ def _syft_add_pre_hooks__(self, eager_execution: bool) -> None: def _syft_add_post_hooks__(self, eager_execution: bool) -> None: """ - Add post-hooks + Add post-hooks. Args: - eager_execution: bool: If eager execution is enabled, hooks for + eager_execution (bool): If eager execution is enabled, hooks for tracing and executing the action on remote are added. """ if eager_execution: @@ -1622,7 +1549,7 @@ def _syft_add_post_hooks__(self, eager_execution: bool) -> None: def _syft_run_pre_hooks__( self, context: PreHookContext, name: str, args: Any, kwargs: Any ) -> tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]: - """Hooks executed before the actual call""" + """Hooks executed before the actual call.""" result_args, result_kwargs = args, kwargs if name in self.syft_pre_hooks__: for hook in self.syft_pre_hooks__[name]: @@ -1657,7 +1584,7 @@ def _syft_run_pre_hooks__( def _syft_run_post_hooks__( self, context: PreHookContext, name: str, result: Any ) -> Any: - """Hooks executed after the actual call""" + """Hooks executed after the actual call.""" new_result = result if name in self.syft_post_hooks__: for hook in self.syft_post_hooks__[name]: @@ -1691,7 +1618,7 @@ def _syft_run_post_hooks__( def _syft_output_action_object( self, result: Any, context: PreHookContext | None = None ) -> Any: - """Wrap the result in an ActionObject""" + """Wrap the result in an ActionObject.""" if issubclass(type(result), ActionObject): return result @@ -1943,9 +1870,11 @@ def __getattribute__(self, name: str) -> Any: * use the syft/_syft prefix for internal methods. * add the method name to the passthrough_attrs. - Parameters: - name: str - The name of the attribute to access. + Args: + name (str): The name of the attribute to access. + + Returns: + Any: The value of the requested attribute. """ # bypass ipython canary verification if name == "_ipython_canary_method_should_not_exist_": diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 250b3c5e9b5..f69ff80894b 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -40,14 +40,7 @@ class ActionStore: @serializable(canonical_name="KeyValueActionStore", version=1) class KeyValueActionStore(ActionStore): - """Generic Key-Value Action store. - - Parameters: - store_config: StoreConfig - Backend specific configuration, including connection configuration, database name, or client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ + """Generic Key-Value Action store.""" def __init__( self, @@ -56,6 +49,15 @@ def __init__( root_verify_key: SyftVerifyKey | None = None, document_store: DocumentStore | None = None, ) -> None: + """ + Generic Key-Value Action store. + + Args: + server_uid (UID): Unique identifier for the server instance. + store_config (StoreConfig): Backend specific configuration, including connection configuration, database name, or client class type. + root_verify_key (SyftVerifyKey | None): Signature verification key, used for checking access permissions. + document_store (DocumentStore | None): Document store used for storing user information. + """ self.server_uid = server_uid self.store_config = store_config self.settings = BasePartitionSettings(name="Action") @@ -373,14 +375,7 @@ def migrate_data( @serializable(canonical_name="DictActionStore", version=1) class DictActionStore(KeyValueActionStore): - """Dictionary-Based Key-Value Action store. - - Parameters: - store_config: StoreConfig - Backend specific configuration, including client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ + """Dictionary-Based Key-Value Action store.""" def __init__( self, @@ -389,6 +384,15 @@ def __init__( root_verify_key: SyftVerifyKey | None = None, document_store: DocumentStore | None = None, ) -> None: + """ + Dictionary-Based Key-Value Action store. + + Args: + server_uid (UID): Unique identifier for the server instance. + store_config (StoreConfig | None): Backend specific configuration, including connection configuration, database name, or client class type. + root_verify_key (SyftVerifyKey | None): Signature verification key, used for checking access permissions. + document_store (DocumentStore | None): Document store used for storing user information. + """ store_config = store_config if store_config is not None else DictStoreConfig() super().__init__( server_uid=server_uid, @@ -403,10 +407,14 @@ class SQLiteActionStore(KeyValueActionStore): """SQLite-Based Key-Value Action store. Parameters: + server_uid: UID + Unique identifier for the server instance. store_config: StoreConfig SQLite specific configuration, including connection settings or client class type. root_verify_key: Optional[SyftVerifyKey] Signature verification key, used for checking access permissions. + document_store: Optional[DocumentStore] + Document store used for storing user information. """ pass @@ -414,13 +422,17 @@ class SQLiteActionStore(KeyValueActionStore): @serializable(canonical_name="MongoActionStore", version=1) class MongoActionStore(KeyValueActionStore): - """Mongo-Based Action store. + """Mongo-Based Action store. Parameters: + server_uid: UID + Unique identifier for the server instance. store_config: StoreConfig Mongo specific configuration. root_verify_key: Optional[SyftVerifyKey] Signature verification key, used for checking access permissions. + document_store: Optional[DocumentStore] + Document store used for storing user information. """ pass diff --git a/packages/syft/src/syft/service/action/action_types.py b/packages/syft/src/syft/service/action/action_types.py index c7bd730d557..b946e8643cb 100644 --- a/packages/syft/src/syft/service/action/action_types.py +++ b/packages/syft/src/syft/service/action/action_types.py @@ -11,11 +11,15 @@ def action_type_for_type(obj_or_type: Any) -> type: - """Convert standard type to Syft types + """Convert standard type to Syft types. - Parameters: - obj_or_type: Union[object, type] - Can be an object or a class + Args: + obj_or_type (Any): Can be an object or a class. If it's an instance of + `ActionDataEmpty`, the internal type is used. + + Returns: + type: Corresponding Syft type for the given object or type. If no corresponding + type is found, the default Syft type for `Any` is returned. """ if isinstance(obj_or_type, ActionDataEmpty): obj_or_type = obj_or_type.syft_internal_type @@ -31,11 +35,14 @@ def action_type_for_type(obj_or_type: Any) -> type: def action_type_for_object(obj: Any) -> type: - """Convert standard type to Syft types + """Convert an object's type to the corresponding Syft type. + + Args: + obj (Any): The object to convert. - Parameters: - obj_or_type: Union[object, type] - Can be an object or a class + Returns: + type: Corresponding Syft type for the given object. If no corresponding + type is found, the default Syft type for `Any` is returned. """ _type = type(obj) diff --git a/packages/syft/src/syft/service/api/api.py b/packages/syft/src/syft/service/api/api.py index 89e19146e99..1e39fb54738 100644 --- a/packages/syft/src/syft/service/api/api.py +++ b/packages/syft/src/syft/service/api/api.py @@ -418,7 +418,8 @@ def has_permission(self, context: AuthedServiceContext) -> bool: """Check if the user has permission to access the endpoint. Args: - context: The context of the user requesting the code. + context (AuthedServiceContext): The context of the user requesting the code. + Returns: bool: True if the user has permission to access the endpoint, False otherwise. """ @@ -430,7 +431,8 @@ def select_code(self, context: AuthedServiceContext) -> Result[Ok, Err]: """Select the code to execute based on the user's permissions and public code availability. Args: - context: The context of the user requesting the code. + context (AuthedServiceContext): The context of the user requesting the code. + Returns: Result[Ok, Err]: The selected code to execute. """ @@ -442,9 +444,10 @@ def exec(self, context: AuthedServiceContext, *args: Any, **kwargs: Any) -> Any: """Execute the code based on the user's permissions and public code availability. Args: - context: The context of the user requesting the code. - *args: Any - **kwargs: Any + context (AuthedServiceContext): The context of the user requesting the code. + *args (Any): Additional arguments to pass to the code. + **kwargs (Any): Additional keyword arguments to pass to the code. + Returns: Any: The result of the executed code. """ @@ -458,7 +461,16 @@ def exec(self, context: AuthedServiceContext, *args: Any, **kwargs: Any) -> Any: def exec_mock_function( self, context: AuthedServiceContext, *args: Any, **kwargs: Any ) -> Any: - """Execute the public code if it exists.""" + """Execute the public code if it exists. + + Args: + context (AuthedServiceContext): The context of the user requesting the code. + *args (Any): Additional arguments to pass to the code. + **kwargs (Any): Additional keyword arguments to pass to the code. + + Returns: + Any: The result of the executed public code or an error if no public code is available. + """ if self.mock_function: return self.exec_code(self.mock_function, context, *args, **kwargs) @@ -467,14 +479,15 @@ def exec_mock_function( def exec_private_function( self, context: AuthedServiceContext, *args: Any, **kwargs: Any ) -> Any: - """Execute the private code if user is has the proper permissions. + """Execute the private code if the user has the proper permissions. Args: - context: The context of the user requesting the code. - *args: Any - **kwargs: Any + context (AuthedServiceContext): The context of the user requesting the code. + *args (Any): Additional arguments to pass to the code. + **kwargs (Any): Additional keyword arguments to pass to the code. + Returns: - Any: The result of the executed code. + Any: The result of the executed code or an error message if the user does not have permission. """ if self.private_function is None: return SyftError(message="No private code available") diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 24f117b7323..04943656243 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -621,7 +621,7 @@ def add_route( called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: - SyftSuccess | SyftError + SyftSuccess | SyftError: Success message or error message. """ # verify if the peer is truly the one sending the request to add the route to itself if called_by_peer and peer_verify_key != context.credentials: @@ -677,10 +677,7 @@ def delete_route_on_peer( route (ServerRoute): The route to be deleted. Returns: - SyftSuccess: If the route is successfully deleted. - SyftError: If there is an error deleting the route. - SyftInfo: If there is only one route left for the peer and - the admin chose not to remove it + SyftSuccess | SyftError | SyftInfo: Success, error, or informational response. """ # creates a client on the remote server based on the credentials # of the current server's client @@ -711,20 +708,17 @@ def delete_route( ) -> SyftSuccess | SyftError | SyftInfo: """ Delete a route for a given peer. - If a peer has no routes left, there will be a prompt asking if the user want to remove it. + If a peer has no routes left, there will be a prompt asking if the user wants to remove it. If the answer is yes, it will be removed from the stash and will no longer be a peer. Args: context (AuthedServiceContext): The authentication context for the service. peer_verify_key (SyftVerifyKey): The verify key of the remote server peer. - route (ServerRoute): The route to be deleted. + route (ServerRoute | None): The route to be deleted. called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: - SyftSuccess: If the route is successfully deleted. - SyftError: If there is an error deleting the route. - SyftInfo: If there is only one route left for the peer and - the admin chose not to remove it + SyftSuccess | SyftError | SyftInfo: Success, error, or informational response. """ if called_by_peer and peer_verify_key != context.credentials: # verify if the peer is truly the one sending the request to delete the route to itself @@ -811,13 +805,12 @@ def update_route_priority_on_peer( Args: context (AuthedServiceContext): The authentication context. peer (ServerPeer): The peer representing the remote server. - route (ServerRoute): The route to be added. + route (ServerRoute): The route to be updated. priority (int | None): The new priority value for the route. If not - provided, it will be assigned the highest priority among all peers + provided, it will be assigned the highest priority among all peers. Returns: - SyftSuccess | SyftError: A success message if the route is verified, - otherwise an error message. + SyftSuccess | SyftError: Success or error message. """ # creates a client on the remote server based on the credentials # of the current server's client @@ -850,17 +843,18 @@ def update_route_priority( called_by_peer: bool = False, ) -> SyftSuccess | SyftError: """ - Updates a route's priority for the given peer + Updates a route's priority for the given peer. Args: context (AuthedServiceContext): The authentication context for the service. peer_verify_key (SyftVerifyKey): The verify key of the peer whose route priority needs to be updated. route (ServerRoute): The route for which the priority needs to be updated. priority (int | None): The new priority value for the route. If not - provided, it will be assigned the highest priority among all peers + provided, it will be assigned the highest priority among all peers. + called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: - SyftSuccess | SyftError: Successful / Error response + SyftSuccess | SyftError: Successful or error response. """ if called_by_peer and peer_verify_key != context.credentials: return SyftError( diff --git a/packages/syft/src/syft/service/network/rathole_config_builder.py b/packages/syft/src/syft/service/network/rathole_config_builder.py index fb0ef01d798..33eb1641a68 100644 --- a/packages/syft/src/syft/service/network/rathole_config_builder.py +++ b/packages/syft/src/syft/service/network/rathole_config_builder.py @@ -32,10 +32,9 @@ def add_host_to_server(self, peer: ServerPeer) -> None: Args: peer (ServerPeer): The peer to be added to the rathole server. - Returns: - None + Raises: + Exception: If the peer has no rathole route or if the rathole config map is not found. """ - rathole_route = peer.get_rtunnel_route() if not rathole_route: raise Exception(f"Peer: {peer} has no rathole route: {rathole_route}") @@ -88,10 +87,9 @@ def remove_host_from_server(self, peer_id: str, server_name: str) -> None: peer_id (str): The id of the peer to be removed. server_name (str): The name of the peer to be removed. - Returns: - None + Raises: + Exception: If the rathole config map is not found. """ - rathole_config_map = KubeUtils.get_configmap( client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP ) @@ -116,14 +114,27 @@ def remove_host_from_server(self, peer_id: str, server_name: str) -> None: self._remove_dynamic_addr_from_rathole(server_name) def _get_random_port(self) -> int: - """Get a random port number.""" + """Get a random port number. + + Returns: + int: A randomly generated port number. + """ return secrets.randbits(15) def add_host_to_client( self, peer_name: str, peer_id: str, rtunnel_token: str, remote_addr: str ) -> None: - """Add a host to the rathole client toml file.""" + """Add a host to the rathole client toml file. + + Args: + peer_name (str): The name of the peer. + peer_id (str): The id of the peer. + rtunnel_token (str): The rtunnel token for the peer. + remote_addr (str): The remote address for the rathole client. + Raises: + Exception: If the rathole config map is not found. + """ config = RatholeConfig( uuid=peer_id, secret_token=rtunnel_token, @@ -156,8 +167,14 @@ def add_host_to_client( KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) def remove_host_from_client(self, peer_id: str) -> None: - """Remove a host from the rathole client toml file.""" + """Remove a host from the rathole client toml file. + + Args: + peer_id (str): The id of the peer to be removed. + Raises: + Exception: If the rathole config map is not found. + """ rathole_config_map = KubeUtils.get_configmap( client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP ) @@ -183,8 +200,15 @@ def remove_host_from_client(self, peer_id: str) -> None: def _add_dynamic_addr_to_rathole( self, config: RatholeConfig, entrypoint: str = "web" ) -> None: - """Add a port to the rathole proxy config map.""" + """Add a port to the rathole proxy config map. + + Args: + config (RatholeConfig): The RatholeConfig object containing server details. + entrypoint (str): The entrypoint for the rathole proxy (default: "web"). + Raises: + Exception: If the rathole proxy config map is not found. + """ rathole_proxy_config_map = KubeUtils.get_configmap( self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP ) @@ -229,8 +253,14 @@ def _add_dynamic_addr_to_rathole( self._expose_port_on_rathole_service(config.server_name, config.local_addr_port) def _remove_dynamic_addr_from_rathole(self, server_name: str) -> None: - """Remove a port from the rathole proxy config map.""" + """Remove a port from the rathole proxy config map. + + Args: + server_name (str): The name of the server to remove from the proxy config map. + Raises: + Exception: If the rathole proxy config map is not found. + """ rathole_proxy_config_map = KubeUtils.get_configmap( self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP ) @@ -259,8 +289,12 @@ def _remove_dynamic_addr_from_rathole(self, server_name: str) -> None: self._remove_port_on_rathole_service(server_name) def _expose_port_on_rathole_service(self, port_name: str, port: int) -> None: - """Expose a port on the rathole service.""" + """Expose a port on the rathole service. + Args: + port_name (str): The name of the port. + port (int): The port number to expose. + """ rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") rathole_service = cast(Service, rathole_service) @@ -290,8 +324,11 @@ def _expose_port_on_rathole_service(self, port_name: str, port: int) -> None: rathole_service.patch(config) def _remove_port_on_rathole_service(self, port_name: str) -> None: - """Remove a port from the rathole service.""" + """Remove a port from the rathole service. + Args: + port_name (str): The name of the port to remove. + """ rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") rathole_service = cast(Service, rathole_service) diff --git a/packages/syft/src/syft/service/network/server_peer.py b/packages/syft/src/syft/service/network/server_peer.py index 941396820d5..10005721f4f 100644 --- a/packages/syft/src/syft/service/network/server_peer.py +++ b/packages/syft/src/syft/service/network/server_peer.py @@ -70,20 +70,22 @@ class ServerPeer(SyftObject): pinged_timestamp: DateTime | None = None def existed_route(self, route: ServerRouteType) -> tuple[bool, int | None]: - """Check if a route exists in self.server_routes + """Check if a route exists in self.server_routes. Args: - route: the route to be checked. For now it can be either - HTTPServerRoute or PythonServerRoute + route (ServerRouteType): The route to be checked. It can be either + HTTPServerRoute, PythonServerRoute, or VeilidServerRoute. Returns: - if the route exists, returns (True, index of the existed route in self.server_routes) - if the route does not exist returns (False, None) - """ + tuple[bool, int | None]: A tuple containing a boolean indicating whether the route exists, + and the index of the route if it exists, otherwise None. + Raises: + ValueError: If the route type is not supported. + """ if route: if not isinstance( - route, HTTPServerRoute | PythonServerRoute | VeilidServerRoute + route, (HTTPServerRoute, PythonServerRoute, VeilidServerRoute) ): raise ValueError(f"Unsupported route type: {type(route)}") for i, r in enumerate(self.server_routes): @@ -93,33 +95,28 @@ def existed_route(self, route: ServerRouteType) -> tuple[bool, int | None]: return (False, None) def update_route_priority(self, route: ServerRoute) -> ServerRoute: - """ - Assign the new_route's priority to be current max + 1 + """Assign the new route's priority to be current max + 1. Args: route (ServerRoute): The new route whose priority is to be updated. Returns: - ServerRoute: The new route with the updated priority + ServerRoute: The new route with the updated priority. """ current_max_priority: int = max(route.priority for route in self.server_routes) route.priority = current_max_priority + 1 return route def pick_highest_priority_route(self, oldest: bool = True) -> ServerRoute: - """ - Picks the route with the highest priority from the list of server routes. + """Pick the route with the highest priority from the list of server routes. Args: - oldest (bool): - If True, picks the oldest route to have the highest priority, - meaning the route with min priority value. - If False, picks the most recent route with the highest priority, - meaning the route with max priority value. + oldest (bool): If True, picks the oldest route with the highest priority + (lowest priority value). If False, picks the most recent route + with the highest priority (highest priority value). Returns: ServerRoute: The route with the highest priority. - """ highest_priority_route: ServerRoute = self.server_routes[-1] for route in self.server_routes[:-1]: @@ -132,36 +129,30 @@ def pick_highest_priority_route(self, oldest: bool = True) -> ServerRoute: return highest_priority_route def update_route(self, route: ServerRoute) -> None: - """ - Update the route for the server. - If the route already exists, return it. - If the route is new, assign it to have the priority of (current_max + 1) + """Update the route for the server. + + If the route already exists, it updates the existing route. + If the route is new, it assigns it a priority of (current_max + 1). Args: route (ServerRoute): The new route to be added to the peer. """ existed, idx = self.existed_route(route) - if existed: + if existed and idx is not None: self.server_routes[idx] = route # type: ignore else: new_route = self.update_route_priority(route) self.server_routes.append(new_route) def update_routes(self, new_routes: list[ServerRoute]) -> None: - """ - Update multiple routes of the server peer. + """Update multiple routes of the server peer. - This method takes a list of new routes as input. - It first updates the priorities of the new routes. - Then, for each new route, it checks if the route already exists for the server peer. - If it does, it updates the priority of the existing route. - If it doesn't, it adds the new route to the server. + This method updates the priorities of new routes and checks if each route + already exists for the server peer. If a route exists, it updates the priority; + otherwise, it adds the new route to the server. Args: new_routes (list[ServerRoute]): The new routes to be added to the server. - - Returns: - None """ for new_route in new_routes: self.update_route(new_route) @@ -169,21 +160,20 @@ def update_routes(self, new_routes: list[ServerRoute]) -> None: def update_existed_route_priority( self, route: ServerRoute, priority: int | None = None ) -> ServerRouteType | SyftError: - """ - Update the priority of an existed route. + """Update the priority of an existing route. Args: route (ServerRoute): The route whose priority is to be updated. priority (int | None): The new priority of the route. If not given, - the route will be assigned with the highest priority. + the route will be assigned the highest priority. Returns: - ServerRoute: The route with updated priority if the route exists - SyftError: If the route does not exist or the priority is invalid + ServerRouteType | SyftError: The route with updated priority if the route exists, + otherwise a SyftError. """ if priority is not None and priority <= 0: return SyftError( - message="Priority must be greater than 0. Now it is {priority}." + message=f"Priority must be greater than 0. Now it is {priority}." ) existed, index = self.existed_route(route=route) @@ -202,6 +192,17 @@ def update_existed_route_priority( @staticmethod def from_client(client: SyftClient) -> "ServerPeer": + """Create a ServerPeer object from a SyftClient. + + Args: + client (SyftClient): The SyftClient from which to create the ServerPeer. + + Returns: + ServerPeer: The created ServerPeer object. + + Raises: + ValueError: If the client does not have metadata. + """ if not client.metadata: raise ValueError("Client has to have metadata first") @@ -212,8 +213,7 @@ def from_client(client: SyftClient) -> "ServerPeer": @property def latest_added_route(self) -> ServerRoute | None: - """ - Returns the latest added route from the list of server routes. + """Get the latest added route. Returns: ServerRoute | None: The latest added route, or None if there are no routes. @@ -223,11 +223,20 @@ def latest_added_route(self) -> ServerRoute | None: def client_with_context( self, context: ServerServiceContext ) -> Result[type[SyftClient], str]: - # third party + """Create a SyftClient using the context of a ServerService. + + Args: + context (ServerServiceContext): The context to use for creating the client. + + Returns: + Result[type[SyftClient], str]: A Result object containing the SyftClient + type if successful, or an error message if unsuccessful. + Raises: + ValueError: If there are no routes to the peer. + """ if len(self.server_routes) < 1: raise ValueError(f"No routes to peer: {self}") - # select the route with highest priority to connect to the peer final_route: ServerRoute = self.pick_highest_priority_route() connection: ServerConnection = route_to_connection(route=final_route) try: @@ -243,6 +252,17 @@ def client_with_context( ) def client_with_key(self, credentials: SyftSigningKey) -> SyftClient | SyftError: + """Create a SyftClient using a signing key. + + Args: + credentials (SyftSigningKey): The signing key to use for creating the client. + + Returns: + SyftClient | SyftError: The created SyftClient, or a SyftError if unsuccessful. + + Raises: + ValueError: If there are no routes to the peer. + """ if len(self.server_routes) < 1: raise ValueError(f"No routes to peer: {self}") @@ -257,28 +277,44 @@ def client_with_key(self, credentials: SyftSigningKey) -> SyftClient | SyftError @property def guest_client(self) -> SyftClient: + """Create a guest SyftClient with a randomly generated signing key. + + Returns: + SyftClient: The created guest SyftClient. + """ guest_key = SyftSigningKey.generate() return self.client_with_key(credentials=guest_key) def proxy_from(self, client: SyftClient) -> SyftClient: + """Create a proxy SyftClient from an existing client. + + Args: + client (SyftClient): The existing SyftClient to proxy from. + + Returns: + SyftClient: The created proxy SyftClient. + """ return client.proxy_to(self) def get_rtunnel_route(self) -> HTTPServerRoute | None: + """Get the HTTPServerRoute with an rtunnel token. + + Returns: + HTTPServerRoute | None: The route with the rtunnel token, or None if not found. + """ for route in self.server_routes: if hasattr(route, "rtunnel_token") and route.rtunnel_token: return route return None def delete_route(self, route: ServerRouteType) -> SyftError | None: - """ - Deletes a route from the peer's route list. - Takes O(n) where is n is the number of routes in self.server_routes. + """Delete a route from the peer's route list. Args: - route (ServerRouteType): The route to be deleted; + route (ServerRouteType): The route to be deleted. Returns: - SyftError: If failing to delete server route + SyftError | None: A SyftError if the deletion fails, or None if successful. """ if route: try: @@ -306,6 +342,12 @@ class ServerPeerUpdate(PartialSyftObject): def drop_veilid_route() -> Callable: + """Drop VeilidServerRoute from the server routes in the context output. + + Returns: + Callable: The function that drops VeilidServerRoute from the context output. + """ + def _drop_veilid_route(context: TransformContext) -> TransformContext: if context.output: server_routes = context.output["server_routes"] diff --git a/packages/syft/src/syft/service/network/utils.py b/packages/syft/src/syft/service/network/utils.py index c5b9e0c084e..37c114ec657 100644 --- a/packages/syft/src/syft/service/network/utils.py +++ b/packages/syft/src/syft/service/network/utils.py @@ -37,7 +37,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> SyftError | No context (AuthedServiceContext): The authenticated service context. Returns: - None + SyftError | None: """ network_service = cast( diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py index 5eec531b91e..c35c45fcf3a 100644 --- a/packages/syft/src/syft/service/notifier/notifier.py +++ b/packages/syft/src/syft/service/notifier/notifier.py @@ -255,7 +255,7 @@ def select_notifiers(self, notification: Notification) -> list[BaseNotifier]: Args: notification (Notification): The notification object Returns: - List[BaseNotifier]: A list of enabled notifier objects + list[BaseNotifier]: A list of enabled notifier objects """ notifier_objs = [] for notifier_type in notification.notifier_types: diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index c8c09ba3d50..0523bea85f5 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -38,16 +38,17 @@ def __init__(self, store: DocumentStore) -> None: self.store = store self.stash = NotifierStash(store=store) - def settings( # Maybe just notifier.settings + def settings( self, context: AuthedServiceContext, ) -> NotifierSettings | SyftError: - """Get Notifier Settings + """Get Notifier Settings. Args: - context: The request context + context (AuthedServiceContext): The request context. + Returns: - Union[NotifierSettings, SyftError]: Notifier Settings or SyftError + NotifierSettings | SyftError: Notifier Settings or SyftError if an error occurs. """ result = self.stash.get(credentials=context.credentials) if result.is_err(): @@ -87,7 +88,7 @@ def set_notifier_active_to_true( def set_notifier_active_to_false( self, context: AuthedServiceContext - ) -> SyftSuccess: + ) -> SyftSuccess | SyftError: """ Essentially a duplicate of turn_off method. """ @@ -117,17 +118,16 @@ def turn_on( """Turn on email notifications. Args: - email_username (Optional[str]): Email server username. Defaults to None. - email_password (Optional[str]): Email email server password. Defaults to None. - sender_email (Optional[str]): Email sender email. Defaults to None. - Returns: - Union[SyftSuccess, SyftError]: A union type representing the success or error response. - - Raises: - None + context (AuthedServiceContext): The request context. + email_username (str | None): Email server username. Defaults to None. + email_password (str | None): Email server password. Defaults to None. + email_sender (str | None): Email sender address. Defaults to None. + email_server (str | None): Email server address. Defaults to None. + email_port (int | None): Email server port. Defaults to 587. + Returns: + SyftSuccess | SyftError: SyftSuccess if successful, SyftError otherwise. """ - result = self.stash.get(credentials=context.credentials) # 1 - If something went wrong at db level, return the error @@ -228,8 +228,13 @@ def turn_off( """ Turn off email notifications service. PySyft notifications will still work. - """ + Args: + context (AuthedServiceContext): The request context. + + Returns: + SyftSuccess | SyftError: SyftSuccess if successful, SyftError otherwise. + """ result = self.stash.get(credentials=context.credentials) if result.is_err(): @@ -254,18 +259,30 @@ def activate( """ Activate email notifications for the authenticated user. This will only work if the datasite owner has enabled notifications. - """ + Args: + context (AuthedServiceContext): The request context. + notifier_type (NOTIFIERS): The notifier type to activate. Defaults to NOTIFIERS.EMAIL. + + Returns: + SyftSuccess | SyftError: SyftSuccess if successful, SyftError otherwise. + """ user_service = context.server.get_service("userservice") return user_service.enable_notifications(context, notifier_type=notifier_type) def deactivate( self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL ) -> SyftSuccess | SyftError: - """Deactivate email notifications for the authenticated user + """Deactivate email notifications for the authenticated user. This will only work if the datasite owner has enabled notifications. - """ + Args: + context (AuthedServiceContext): The request context. + notifier_type (NOTIFIERS): The notifier type to deactivate. Defaults to NOTIFIERS.EMAIL. + + Returns: + SyftSuccess | SyftError: SyftSuccess if successful, SyftError otherwise. + """ user_service = context.server.get_service("userservice") return user_service.disable_notifications(context, notifier_type=notifier_type) @@ -279,18 +296,23 @@ def init_notifier( smtp_host: str | None = None, ) -> Result[Ok, Err]: """Initialize Notifier settings for a Server. + If settings already exist, it will use the existing one. If not, it will create a new one. Args: - server: Server to initialize the notifier - active: If notifier should be active - email_username: Email username to send notifications - email_password: Email password to send notifications - Raises: - Exception: If something went wrong + server (AbstractServer): Server to initialize the notifier. + email_username (str | None): Email username to send notifications. Defaults to None. + email_password (str | None): Email password to send notifications. Defaults to None. + email_sender (str | None): Email sender address. Defaults to None. + smtp_port (int | None): SMTP server port. Defaults to None. + smtp_host (str | None): SMTP server host. Defaults to None. + Returns: - Union: SyftSuccess or SyftError + Result[Ok, Err]: Ok if successful, Err otherwise. + + Raises: + Exception: Error in creating or initializing notifier """ try: # Create a new NotifierStash since its a static method. @@ -341,6 +363,16 @@ def init_notifier( def set_email_rate_limit( self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int ) -> SyftSuccess | SyftError: + """Set the email rate limit for a specific email type. + + Args: + context (AuthedServiceContext): The request context. + email_type (EMAIL_TYPES): The type of email for which to set the rate limit. + daily_limit (int): The daily limit for the specified email type. + + Returns: + SyftSuccess | SyftError: SyftSuccess if successful, SyftError otherwise. + """ notifier = self.stash.get(context.credentials) if notifier.is_err(): return SyftError(message="Couldn't set the email rate limit.") @@ -358,7 +390,18 @@ def set_email_rate_limit( # This method is used by other services to dispatch notifications internally def dispatch_notification( self, context: AuthedServiceContext, notification: Notification - ) -> SyftError: + ) -> SyftError | SyftSuccess: + """Dispatch a notification to the user. + + This method is used internally by other services to send notifications. + + Args: + context (AuthedServiceContext): The request context. + notification (Notification): The notification to dispatch. + + Returns: + SyftError | SyftSuccess: SyftSuccess if the notification was successfully dispatched, SyftError otherwise. + """ admin_key = context.server.get_service("userservice").admin_verify_key() notifier = self.stash.get(admin_key) if notifier.is_err(): diff --git a/packages/syft/src/syft/service/notifier/smtp_client.py b/packages/syft/src/syft/service/notifier/smtp_client.py index 8210257badd..378a7cae181 100644 --- a/packages/syft/src/syft/service/notifier/smtp_client.py +++ b/packages/syft/src/syft/service/notifier/smtp_client.py @@ -22,11 +22,23 @@ class SMTPClient(BaseModel): @model_validator(mode="before") @classmethod def check_user_and_password(cls, values: dict) -> dict: + """Validate that both username and password are provided.""" if not (values.get("username", None) and values.get("password")): raise ValueError("Both username and password must be provided") return values def send(self, sender: str, receiver: list[str], subject: str, body: str) -> None: + """Send an email using the SMTP server. + + Args: + sender (str): The sender's email address. + receiver (list[str]): A list of recipient email addresses. + subject (str): The subject of the email. + body (str): The HTML body of the email. + + Raises: + ValueError: If subject, body, or receiver is not provided. + """ if not (subject and body and receiver): raise ValueError("Subject, body, and recipient email(s) are required") @@ -50,10 +62,16 @@ def send(self, sender: str, receiver: list[str], subject: str, body: str) -> Non def check_credentials( cls, server: str, port: int, username: str, password: str ) -> Result[Ok, Err]: - """Check if the credentials are valid. + """Check if the provided SMTP credentials are valid. + + Args: + server (str): The SMTP server address. + port (int): The port number to connect to. + username (str): The username for the SMTP server. + password (str): The password for the SMTP server. Returns: - bool: True if the credentials are valid, False otherwise. + Result[Ok, Err]: Ok if the credentials are valid, Err with an exception otherwise. """ try: with smtplib.SMTP(server, port, timeout=SOCKET_TIMEOUT) as smtp_server: diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index e1c1dfe1b47..6f6255a749c 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -317,7 +317,7 @@ def status(self, project: Project) -> SyftInfo | SyftError | None: project (Project): Project object to check the status Returns: - str: Status of the request. + SyftInfo | SyftError | None: Status of the request. During Request status calculation, we do not allow multiple responses """ @@ -562,9 +562,10 @@ def status( Args: project (Project): Project object to check the status + pretty_print (bool): Flag for pretty printing Returns: - str: Status of the poll + dict | SyftError | SyftInfo | None: Status of the poll During Poll calculation, a user would have answered the poll many times The status of the poll would be calculated based on the latest answer of the user @@ -1384,7 +1385,7 @@ def hash_object(obj: Any) -> tuple[bytes, str]: obj (Any): Object to be hashed Returns: - str: Hashed value of the object + tuple[bytes, str]: Hashed value of the object """ hash_bytes = _serialize(obj, to_bytes=True, for_hashing=True) hash = hashlib.sha256(hash_bytes) diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 0da9d043e18..027bc02116e 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -382,7 +382,7 @@ def check_for_project_request( context (AuthedServiceContext): Context of the server Returns: - Union[SyftSuccess, SyftError]: SyftSuccess if message is created else SyftError + SyftSuccess | SyftError: SyftSuccess if message is created else SyftError """ if ( diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 9b5bb00ca22..0170dfa519a 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -638,6 +638,9 @@ def deny(self, reason: str) -> SyftSuccess | SyftError: Args: reason (str): Reason for which the request has been denied. + + Returns: + SyftSuccess | SyftError: Result of the operation. """ api = self._get_api() if isinstance(api, SyftError): @@ -824,9 +827,9 @@ def deposit_result( result (Any): ActionObject or any object to be saved as an ActionObject. log_stdout (str): stdout logs. log_stderr (str): stderr logs. - approve (bool, optional): Only supported for L2 requests. If True, the request will be approved. + approve (bool | None): Only supported for L2 requests. If True, the request will be approved. Defaults to None. - + **kwargs (dict[str, Any]): Additional arguments. Returns: Job | SyftError: Job object if successful, else SyftError. diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index b54abacd078..657bc7792a1 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -96,26 +96,15 @@ def _update( Update the Server Settings using the provided values. Args: - name: Optional[str] - Server name - organization: Optional[str] - Organization name - description: Optional[str] - Server description - on_board: Optional[bool] - Show onboarding panel when a user logs in for the first time - signup_enabled: Optional[bool] - Enable/Disable registration - admin_email: Optional[str] - Administrator email - association_request_auto_approval: Optional[bool] + context (AuthedServiceContext): The authenticated service context. + settings (ServerSettingsUpdate): The settings to update. Returns: - Result[SyftSuccess, SyftError]: A result indicating the success or failure of the update operation. + Result[Ok, Err]: A result indicating the success or failure of the update operation. Example: - >>> server_client.update(name='foo', organization='bar', description='baz', signup_enabled=True) - SyftSuccess: Settings updated successfully. + >>> server_client.update(settings=ServerSettingsUpdate(signup_enabled=True)) + SyftSuccess: Settings updated successfully. """ result = self.stash.get_all(context.credentials) if result.is_ok(): diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 0a1418f9c2d..433a86f5752 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -77,16 +77,18 @@ def launch( ) -> list[ContainerSpawnStatus] | SyftError: """Creates a pool of workers from the given SyftWorkerImage. - - Retrieves the image for the given UID - - Use docker to launch containers for given image - - For each successful container instantiation create a SyftWorker object - - Creates a SyftWorkerPool object - Args: - context (AuthedServiceContext): context passed to the service - name (str): name of the pool - image_id (UID): UID of the SyftWorkerImage against which the pool should be created - num_workers (int): the number of SyftWorker that needs to be created in the pool + context (AuthedServiceContext): Context passed to the service. + pool_name (str): Name of the pool. + image_uid (UID | None): UID of the SyftWorkerImage against which the pool should be created. + num_workers (int): The number of SyftWorkers that need to be created in the pool. + registry_username (str | None, optional): Username for the registry. Defaults to None. + registry_password (str | None, optional): Password for the registry. Defaults to None. + pod_annotations (dict[str, str] | None, optional): Annotations for the pod. Defaults to None. + pod_labels (dict[str, str] | None, optional): Labels for the pod. Defaults to None. + + Returns: + list[ContainerSpawnStatus] | SyftError: List of container spawn statuses or an error. """ result = self.stash.get_by_name(context.credentials, pool_name=pool_name) @@ -173,16 +175,19 @@ def create_pool_request( pod_annotations: dict[str, str] | None = None, pod_labels: dict[str, str] | None = None, ) -> SyftError | SyftSuccess: - """ - Create a request to launch the worker pool based on a built image. + """Create a request to launch the worker pool based on a built image. Args: context (AuthedServiceContext): The authenticated service context. pool_name (str): The name of the worker pool. num_workers (int): The number of workers in the pool. - image_uid (Optional[UID]): The UID of the built image. - reason (Optional[str], optional): The reason for creating the - worker pool. Defaults to "". + image_uid (UID): The UID of the built image. + reason (str | None, optional): The reason for creating the worker pool. Defaults to "". + pod_annotations (dict[str, str] | None, optional): Annotations for the pod. Defaults to None. + pod_labels (dict[str, str] | None, optional): Labels for the pod. Defaults to None. + + Returns: + SyftError | SyftSuccess: Success or error message. """ # Check if image exists for the given image id @@ -254,18 +259,22 @@ def create_image_and_pool_request( pod_annotations: dict[str, str] | None = None, pod_labels: dict[str, str] | None = None, ) -> SyftError | SyftSuccess: - """ - Create a request to launch the worker pool based on a built image. + """Create a request to launch the worker pool based on a built image. Args: context (AuthedServiceContext): The authenticated service context. pool_name (str): The name of the worker pool. num_workers (int): The number of workers in the pool. - config: (WorkerConfig): Config of the image to be built. - tag (str | None, optional): - a human-readable manifest identifier that is typically a specific version or variant of an image, - only needed for `DockerWorkerConfig` to tag the image after it is built. + config (WorkerConfig): Config of the image to be built. + tag (str | None, optional): A human-readable manifest identifier. Required for `DockerWorkerConfig`. + registry_uid (UID | None, optional): UID of the registry in Kubernetes mode. Required for `DockerWorkerConfig`. reason (str | None, optional): The reason for creating the worker image and pool. Defaults to "". + pull_image (bool, optional): Whether to pull the image. Defaults to True. + pod_annotations (dict[str, str] | None, optional): Annotations for the pod. Defaults to None. + pod_labels (dict[str, str] | None, optional): Labels for the pod. Defaults to None. + + Returns: + SyftError | SyftSuccess: Success or error message. """ if not isinstance(config, DockerWorkerConfig | PrebuiltWorkerConfig): @@ -358,8 +367,14 @@ def create_image_and_pool_request( def get_all( self, context: AuthedServiceContext ) -> DictTuple[str, WorkerPool] | SyftError: - # TODO: During get_all, we should dynamically make a call to docker to get the status of the containers - # and update the status of the workers in the pool. + """Get all worker pools. + + Args: + context (AuthedServiceContext): The authenticated service context. + + Returns: + DictTuple[str, WorkerPool] | SyftError: All worker pools or an error. + """ result = self.stash.get_all(credentials=context.credentials) if result.is_err(): return SyftError(message=f"{result.err()}") @@ -387,13 +402,15 @@ def add_workers( Worker pool is fetched either using the unique pool id or pool name. Args: - context (AuthedServiceContext): _description_ - number (int): number of workers to add - pool_id (Optional[UID], optional): Unique UID of the pool. Defaults to None. - pool_name (Optional[str], optional): Unique name of the pool. Defaults to None. + context (AuthedServiceContext): The authenticated service context. + number (int): Number of workers to add. + pool_id (UID | None, optional): Unique UID of the pool. Defaults to None. + pool_name (str | None, optional): Unique name of the pool. Defaults to None. + registry_username (str | None, optional): Username for the registry. Defaults to None. + registry_password (str | None, optional): Password for the registry. Defaults to None. Returns: - Union[List[ContainerSpawnStatus], SyftError]: List of spawned workers with their status and error if any. + list[ContainerSpawnStatus] | SyftError: List of spawned workers with their status or an error. """ if number <= 0: @@ -472,9 +489,18 @@ def scale( pool_id: UID | None = None, pool_name: str | None = None, ) -> SyftError | SyftSuccess: - """ - Scale the worker pool to the given number of workers in Kubernetes. + """Scale the worker pool to the given number of workers in Kubernetes. + Allows both scaling up and down the worker pool. + + Args: + context (AuthedServiceContext): The authenticated service context. + number (int): Number of workers to scale to. + pool_id (UID | None, optional): Unique UID of the pool. Defaults to None. + pool_name (str | None, optional): Unique name of the pool. Defaults to None. + + Returns: + SyftError | SyftSuccess: Success or error message. """ if not IN_KUBERNETES: @@ -558,6 +584,15 @@ def scale( def filter_by_image_id( self, context: AuthedServiceContext, image_uid: UID ) -> list[WorkerPool] | SyftError: + """Filter worker pools by image ID. + + Args: + context (AuthedServiceContext): The authenticated service context. + image_uid (UID): The UID of the image. + + Returns: + list[WorkerPool] | SyftError: List of worker pools or an error. + """ result = self.stash.get_by_image_uid(context.credentials, image_uid) if result.is_err(): @@ -573,6 +608,15 @@ def filter_by_image_id( def get_by_name( self, context: AuthedServiceContext, pool_name: str ) -> list[WorkerPool] | SyftError: + """Get worker pool by name. + + Args: + context (AuthedServiceContext): The authenticated service context. + pool_name (str): The name of the worker pool. + + Returns: + list[WorkerPool] | SyftError: Worker pool or an error. + """ result = self.stash.get_by_name(context.credentials, pool_name) if result.is_err(): @@ -592,8 +636,15 @@ def sync_pool_from_request( context: AuthedServiceContext, request: Request, ) -> SyftSuccess | SyftError: - """Re-submit request from a different server""" + """Re-submit request from a different server. + Args: + context (AuthedServiceContext): The authenticated service context. + request (Request): The request object. + + Returns: + SyftSuccess | SyftError: Success or error message. + """ num_of_changes = len(request.changes) pool_name, num_workers, config, image_uid, tag = None, None, None, None, None diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py index e97d9105645..1a73692eee3 100644 --- a/packages/syft/src/syft/store/dict_document_store.py +++ b/packages/syft/src/syft/store/dict_document_store.py @@ -21,14 +21,33 @@ @serializable(canonical_name="DictBackingStore", version=1) class DictBackingStore(dict, KeyValueBackingStore): # type: ignore[misc] - # TODO: fix the mypy issue - """Dictionary-based Store core logic""" + """Dictionary-based Store core logic + + This class provides the core logic for a dictionary-based key-value store. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the dictionary-based backing store. + + Args: + *args (Any): Positional arguments. + **kwargs (Any): Keyword arguments. + """ super().__init__() self._ddtype = kwargs.get("ddtype", None) def __getitem__(self, key: Any) -> Any: + """Retrieve an item from the store by key. + + Args: + key (Any): The key of the item to retrieve. + + Returns: + Any: The value associated with the key. + + Raises: + KeyError: If the key is not found in the store. + """ try: value = super().__getitem__(key) return value @@ -42,25 +61,23 @@ def __getitem__(self, key: Any) -> Any: class DictStorePartition(KeyValueStorePartition): """Dictionary-based StorePartition + This class represents a partition within a dictionary-based key-value store. + Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for indexing and partitioning - `store_config`: DictStoreConfig - DictStore specific configuration + settings (PartitionSettings): PySyft specific settings, used for indexing and partitioning. + store_config (DictStoreConfig): Dictionary Store specific configuration. """ def prune(self) -> None: + """Reset the partition by reinitializing the store.""" self.init_store() -# the base document store is already a dict but we can change it later @serializable(canonical_name="DictDocumentStore", version=1) class DictDocumentStore(DocumentStore): """Dictionary-based Document Store - Parameters: - `store_config`: DictStoreConfig - Dictionary Store specific configuration, containing the store type and the backing store type + This class represents a document store implemented using a dictionary. """ partition_type = DictStorePartition @@ -80,25 +97,26 @@ def __init__( ) def reset(self) -> None: + """Reset the document store by pruning all partitions.""" for partition in self.partitions.values(): partition.prune() @serializable() class DictStoreConfig(StoreConfig): - __canonical_name__ = "DictStoreConfig" """Dictionary-based configuration + This class provides the configuration for a dictionary-based document store. + + Attributes: + store_type (type[DocumentStore]): The Document type used. Default: DictDocumentStore. + backing_store (type[KeyValueBackingStore]): The backend type used. Default: DictBackingStore. + locking_config (LockingConfig): The config used for store locking. Defaults to ThreadingLockingConfig. + Parameters: - `store_type`: Type[DocumentStore] - The Document type used. Default: DictDocumentStore - `backing_store`: Type[KeyValueBackingStore] - The backend type used. Default: DictBackingStore - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. - Defaults to ThreadingLockingConfig. + store_type (Type[DocumentStore]): The Document type used. Default: DictDocumentStore. + backing_store (Type[KeyValueBackingStore]): The backend type used. Default: DictBackingStore. + locking_config (LockingConfig): The config used for store locking. """ store_type: type[DocumentStore] = DictDocumentStore diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 4012bc25ca5..49114c8df67 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -39,9 +39,8 @@ class BasePartitionSettings(SyftBaseModel): """Basic Partition Settings - Parameters: - name: str - Identifier to be used as prefix by stores and for partitioning + Attributes: + name (str): Identifier to be used as a prefix by stores and for partitioning. """ name: str @@ -148,7 +147,6 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: pk_key = partition_key.key pk_type = partition_key.type_ - # 🟡 TODO: support more advanced types than List[type] if partition_key.type_list: pk_value = partition_key.extract_list(obj) else: @@ -156,12 +154,8 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: pk_value = obj else: pk_value = getattr(obj, pk_key) - # object has a method for getting these types - # we can't use properties because we don't seem to be able to get the - # return types - # TODO: fix the mypy issue - if isinstance(pk_value, types.FunctionType | types.MethodType): # type: ignore[unreachable] - pk_value = pk_value() # type: ignore[unreachable] + if isinstance(pk_value, types.FunctionType | types.MethodType): + pk_value = pk_value() if pk_value and not isinstance(pk_value, pk_type): raise Exception( @@ -179,7 +173,6 @@ def as_dict_mongo(self) -> dict[str, Any]: if key == "id": key = "_id" if self.type_list: - # We want to search inside the list of values return {key: {"$in": self.value}} return {key: self.value} @@ -202,19 +195,15 @@ class QueryKeys(SyftBaseModel): @property def all(self) -> tuple[QueryKey, ...] | list[QueryKey]: - # make sure we always return a list even if there's a single value return self.qks if isinstance(self.qks, tuple | list) else [self.qks] @staticmethod def from_obj(partition_keys: PartitionKeys, obj: SyftObject) -> QueryKeys: qks = [] for partition_key in partition_keys.all: - pk_key = partition_key.key # name of the attribute + pk_key = partition_key.key pk_type = partition_key.type_ pk_value = getattr(obj, pk_key) - # object has a method for getting these types - # we can't use properties because we don't seem to be able to get the - # return types if isinstance(pk_value, types.FunctionType | types.MethodType): pk_value = pk_value() if partition_key.type_list: @@ -267,7 +256,6 @@ def as_dict_mongo(self) -> dict: if qk_key == "id": qk_key = "_id" if qk.type_list: - # We want to search inside the list of values qk_dict[qk_key] = {"$in": qk_value} else: qk_dict[qk_key] = qk_value @@ -299,14 +287,7 @@ def searchable_keys(self) -> PartitionKeys: version=1, ) class StorePartition: - """Base StorePartition - - Parameters: - settings: PartitionSettings - PySyft specific settings - store_config: StoreConfig - Backend specific configuration - """ + """Base StorePartition""" def __init__( self, @@ -316,6 +297,18 @@ def __init__( store_config: StoreConfig, has_admin_permissions: Callable[[SyftVerifyKey], bool] | None = None, ) -> None: + """Base store partition initialization + + Args: + server_uid (UID): Unique identifier for the server instance. + root_verify_key (SyftVerifyKey | None): Root signature verification key. + settings (PartitionSettings): PySyft specific settings. + store_config (StoreConfig): Backend specific configuration. + has_admin_permissions (Callable[[SyftVerifyKey], bool] | None): Callback to check admin permissions. + + Raises: + RuntimeError: If the store initialization fails. + """ if root_verify_key is None: root_verify_key = SyftSigningKey.generate().verify_key self.server_uid = server_uid @@ -353,7 +346,6 @@ def store_query_key(self, obj: Any) -> QueryKey: def store_query_keys(self, objs: Any) -> QueryKeys: return QueryKeys(qks=[self.store_query_key(obj) for obj in objs]) - # Thread-safe methods def _thread_safe_cbk(self, cbk: Callable, *args: Any, **kwargs: Any) -> Any | Err: locked = self.lock.acquire(blocking=True) if not locked: @@ -475,11 +467,6 @@ def migrate_data( self._migrate_data, to_klass, context, has_permission ) - # Potentially thread-unsafe methods. - # CAUTION: - # * Don't use self.lock here. - # * Do not call the public thread-safe methods here(with locking). - # These methods are called from the public thread-safe API, and will hang the process. def _set( self, credentials: SyftVerifyKey, @@ -570,12 +557,7 @@ def _migrate_data( @instrument @serializable(canonical_name="DocumentStore", version=1) class DocumentStore: - """Base Document Store - - Parameters: - store_config: StoreConfig - Store specific configuration. - """ + """Base Document Store""" partitions: dict[str, StorePartition] partition_type: type[StorePartition] @@ -586,6 +568,16 @@ def __init__( root_verify_key: SyftVerifyKey | None, store_config: StoreConfig, ) -> None: + """Base document store initialization + + Args: + server_uid (UID): Unique identifier for the server instance. + root_verify_key (SyftVerifyKey | None): Root signature verification key. + store_config (StoreConfig): Store specific configuration. + + Raises: + Exception: If store config is not found + """ if store_config is None: raise Exception("must have store config") self.partitions = {} @@ -840,6 +832,11 @@ class BaseUIDStoreStash(BaseStash): class StoreConfig(SyftBaseObject): """Base Store configuration + Attributes: + store_type (type[DocumentStore]): Document Store type. + client_config (StoreClientConfig | None): Backend-specific config. + locking_config (LockingConfig): The config used for store locking. + Parameters: store_type: Type Document Store type diff --git a/packages/syft/src/syft/store/locks.py b/packages/syft/src/syft/store/locks.py index 2494dffa895..12174771b89 100644 --- a/packages/syft/src/syft/store/locks.py +++ b/packages/syft/src/syft/store/locks.py @@ -19,19 +19,21 @@ @serializable(canonical_name="LockingConfig", version=1) class LockingConfig(BaseModel): """ - Locking config + Locking configuration. + + Attributes: + lock_name (str): Lock name. + namespace (str | None): Namespace to use for setting lock keys in the backend store. + expire (int | None): Lock expiration time in seconds. If explicitly set to `None`, the lock will not expire. + timeout (int | None): Timeout to acquire lock (in seconds). + retry_interval (float): Retry interval to retry acquiring a lock if previous attempts failed. Args: - lock_name: str - Lock name - namespace: Optional[str] - Namespace to use for setting lock keys in the backend store. - expire: Optional[int] - Lock expiration time in seconds. If explicitly set to `None`, lock will not expire. - timeout: Optional[int] - Timeout to acquire lock(seconds) - retry_interval: float - Retry interval to retry acquiring a lock if previous attempts failed. + lock_name (str): Lock name. + namespace (str | None): Namespace to use for setting lock keys in the backend store. + expire (int | None): Lock expiration time in seconds. If explicitly set to `None`, the lock will not expire. + timeout (int | None): Timeout to acquire lock (in seconds). + retry_interval (float): Retry interval to retry acquiring a lock if previous attempts failed. """ lock_name: str = "syft_lock" @@ -44,7 +46,7 @@ class LockingConfig(BaseModel): @serializable(canonical_name="NoLockingConfig", version=1) class NoLockingConfig(LockingConfig): """ - No-locking policy + No-locking policy. """ pass @@ -53,7 +55,7 @@ class NoLockingConfig(LockingConfig): @serializable(canonical_name="ThreadingLockingConfig", version=1) class ThreadingLockingConfig(LockingConfig): """ - Threading-based locking policy + Threading-based locking policy. """ pass @@ -72,9 +74,10 @@ def __init__(self, expire: int, **kwargs: Any) -> None: @property def _locked(self) -> bool: """ - Implementation of method to check if lock has been acquired. Must be - :returns: if the lock is acquired or not - :rtype: bool + Check if the lock has been acquired. + + Returns: + bool: True if the lock is acquired, False otherwise. """ locked = self.lock.locked() if ( @@ -88,9 +91,10 @@ def _locked(self) -> bool: def _acquire(self) -> bool: """ - Implementation of acquiring a lock in a non-blocking fashion. - :returns: if the lock was successfully acquired or not - :rtype: bool + Acquire the lock in a non-blocking fashion. + + Returns: + bool: True if the lock was successfully acquired, False otherwise. """ locked = self.lock.locked() if ( @@ -100,18 +104,15 @@ def _acquire(self) -> bool: ): self._release() - status = self.lock.acquire( - blocking=False - ) # timeout/retries handle in the `acquire` method + status = self.lock.acquire(blocking=False) if status: self.locked_timestamp = time.time() return status def _release(self) -> None: """ - Implementation of releasing an acquired lock. + Release the acquired lock. """ - try: return self.lock.release() except RuntimeError: # already unlocked @@ -119,17 +120,23 @@ def _release(self) -> None: def _renew(self) -> bool: """ - Implementation of renewing an acquired lock. + Renew the acquired lock. + + Returns: + bool: True if the lock was successfully renewed, False otherwise. """ return True class SyftLock(BaseLock): """ - Syft Lock implementations. + Syft Lock implementation. + + Args: + config (LockingConfig): Configuration specific to a locking strategy. - Params: - config: Config specific to a locking strategy. + Raises: + ValueError: If an unsupported config type is provided. """ def __init__(self, config: LockingConfig): @@ -162,24 +169,25 @@ def __init__(self, config: LockingConfig): @property def _locked(self) -> bool: """ - Implementation of method to check if lock has been acquired. + Check if the lock has been acquired. - :returns: if the lock is acquired or not - :rtype: bool + Returns: + bool: True if the lock is acquired, False otherwise. """ if self.passthrough: return False - return self._lock.locked() if self._lock else False + return self._lock._locked if self._lock else False def acquire(self, blocking: bool = True) -> bool: """ Acquire a lock, blocking or non-blocking. - :param bool blocking: acquire a lock in a blocking or non-blocking - fashion. Defaults to True. - :returns: if the lock was successfully acquired or not - :rtype: bool - """ + Args: + blocking (bool): Acquire a lock in a blocking or non-blocking fashion. Defaults to True. + + Returns: + bool: True if the lock was successfully acquired, False otherwise. + """ if not blocking: return self._acquire() @@ -195,16 +203,14 @@ def acquire(self, blocking: bool = True) -> bool: logger.debug( f"Timeout elapsed after {self.timeout} seconds while trying to acquiring lock." ) - # third party return False def _acquire(self) -> bool: """ - Implementation of acquiring a lock in a non-blocking fashion. - `acquire` makes use of this implementation to provide blocking and non-blocking implementations. + Acquire the lock in a non-blocking fashion. - :returns: if the lock was successfully acquired or not - :rtype: bool + Returns: + bool: True if the lock was successfully acquired, False otherwise. """ if self.passthrough: return True @@ -216,7 +222,10 @@ def _acquire(self) -> bool: def _release(self) -> bool | None: """ - Implementation of releasing an acquired lock. + Release the acquired lock. + + Returns: + bool | None: True if the lock was successfully released, None otherwise. """ if self.passthrough: return None @@ -229,7 +238,10 @@ def _release(self) -> bool | None: def _renew(self) -> bool: """ - Implementation of renewing an acquired lock. + Renew the acquired lock. + + Returns: + bool: True if the lock was successfully renewed, False otherwise. """ if self.passthrough: return True diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index 070a738faae..3d51b46908b 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -120,13 +120,14 @@ def from_mongo( @serializable(attrs=["storage_type"], canonical_name="MongoStorePartition", version=1) class MongoStorePartition(StorePartition): - """Mongo StorePartition + """Mongo StorePartition. - Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for partitioning and indexing. - `store_config`: MongoStoreConfig - Mongo specific configuration + Attributes: + storage_type (type[StorableObjectType]): The storage type for objects in the partition. + + Args: + settings (PartitionSettings): PySyft specific settings, used for partitioning and indexing. + store_config (MongoStoreConfig): Mongo specific configuration. """ storage_type: type[StorableObjectType] = MongoBsonObject @@ -173,7 +174,7 @@ def init_store(self) -> Result[Ok, Err]: # These methods are called from the public thread-safe API, and will hang the process. def _create_update_index(self) -> Result[Ok, Err]: - """Create or update mongo database indexes""" + """Create or update mongo database indexes.""" collection_status = self.collection if collection_status.is_err(): return collection_status @@ -482,7 +483,7 @@ def _delete( ) def has_permission(self, permission: ActionObjectPermission) -> bool: - """Check if the permission is inside the permission collection""" + """Check if the permission is inside the permission collection.""" collection_permissions_status = self.permissions if collection_permissions_status.is_err(): return False @@ -640,7 +641,7 @@ def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: self.add_storage_permission(permission) def has_storage_permission(self, permission: StoragePermission) -> bool: # type: ignore - """Check if the storage_permission is inside the storage_permission collection""" + """Check if the storage_permission is inside the storage_permission collection.""" storage_permissions_or_err = self.storage_permissions if storage_permissions_or_err.is_err(): return storage_permissions_or_err @@ -810,12 +811,7 @@ def _migrate_data( @serializable(canonical_name="MongoDocumentStore", version=1) class MongoDocumentStore(DocumentStore): - """Mongo Document Store - - Parameters: - `store_config`: MongoStoreConfig - Mongo specific configuration, including connection configuration, database name, or client class type. - """ + """Mongo Document Store.""" partition_type = MongoStorePartition @@ -827,18 +823,13 @@ class MongoDocumentStore(DocumentStore): ) class MongoBackingStore(KeyValueBackingStore): """ - Core logic for the MongoDB key-value store - - Parameters: - `index_name`: str - Index name (can be either 'data' or 'permissions') - `settings`: PartitionSettings - Syft specific settings - `store_config`: StoreConfig - Connection Configuration - `ddtype`: Type - Optional and should be None - Used to make a consistent interface with SQLiteBackingStore + Core logic for the MongoDB key-value store. + + Args: + index_name (str): Index name (can be either 'data' or 'permissions'). + settings (PartitionSettings): Syft specific settings. + store_config (StoreConfig): Connection configuration. + ddtype (type | None): Optional and should be None. Used to make a consistent interface with SQLiteBackingStore. """ def __init__( @@ -1042,19 +1033,15 @@ def __del__(self) -> None: @serializable() class MongoStoreConfig(StoreConfig): __canonical_name__ = "MongoStoreConfig" - """Mongo Store configuration - - Parameters: - `client_config`: MongoStoreClientConfig - Mongo connection details: hostname, port, user, password etc. - `store_type`: Type[DocumentStore] - The type of the DocumentStore. Default: MongoDocumentStore - `db_name`: str - Database name - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. + """Mongo Store configuration. + + Args: + client_config (MongoStoreClientConfig): Mongo connection details: hostname, port, user, password etc. + store_type (Type[DocumentStore]): The type of the DocumentStore. Default: MongoDocumentStore. + db_name (str): Database name. + locking_config (LockingConfig): The config used for store locking. Available options: + * NoLockingConfig: no locking, ideal for single-thread stores. + * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. Defaults to NoLockingConfig. """ diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index e27ce2c7d16..488f27e30fd 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -81,15 +81,11 @@ def raise_exception(table_name: str, e: Exception) -> None: class SQLiteBackingStore(KeyValueBackingStore): """Core Store logic for the SQLite stores. - Parameters: - `index_name`: str - Index name - `settings`: PartitionSettings - Syft specific settings - `store_config`: SQLiteStoreConfig - Connection Configuration - `ddtype`: Type - Class used as fallback on `get` errors + Args: + index_name (str): Index name. + settings (PartitionSettings): Syft specific settings. + store_config (StoreConfig): Connection configuration. + ddtype (type | None): Class used as fallback on `get` errors. """ def __init__( @@ -358,7 +354,7 @@ def __contains__(self, key: Any) -> bool: def __iter__(self) -> Any: return iter(self.keys()) - def __del__(self) -> None: + def __del(self) -> None: try: self._close() except Exception as e: @@ -369,11 +365,9 @@ def __del__(self) -> None: class SQLiteStorePartition(KeyValueStorePartition): """SQLite StorePartition - Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for indexing and partitioning - `store_config`: SQLiteStoreConfig - SQLite specific configuration + Args: + settings (PartitionSettings): PySyft specific settings, used for indexing and partitioning. + store_config (SQLiteStoreConfig): SQLite specific configuration. """ def close(self) -> None: @@ -402,34 +396,31 @@ def commit(self) -> None: # the base document store is already a dict but we can change it later @serializable(canonical_name="SQLiteDocumentStore", version=1) class SQLiteDocumentStore(DocumentStore): - """SQLite Document Store - - Parameters: - `store_config`: StoreConfig - SQLite specific configuration, including connection details and client class type. - """ + """SQLite Document Store.""" partition_type = SQLiteStorePartition @serializable(canonical_name="SQLiteStoreClientConfig", version=1) class SQLiteStoreClientConfig(StoreClientConfig): - """SQLite connection config + """SQLite connection config. - Parameters: - `filename` : str - Database name - `path` : Path or str - Database folder - `check_same_thread`: bool - If True (default), ProgrammingError will be raised if the database connection is used + Attributes: + filename (str): Database name. + path (str | Path): Database folder. + check_same_thread (bool): If True (default), ProgrammingError will be raised if the database connection is used by a thread other than the one that created it. If False, the connection may be accessed in multiple threads; write operations may need to be serialized by the user to avoid data corruption. - `timeout`: int - How many seconds the connection should wait before raising an exception, if the database + timeout (int): How many seconds the connection should wait before raising an exception, if the database is locked by another connection. If another connection opens a transaction to modify the database, it will be locked until that transaction is committed. Default five seconds. + + Parameters: + filename (str): Database name. + path (str | Path): Database folder. + check_same_thread (bool): See Attributes section. + timeout (int): See Attributes section. """ filename: str = "syftdb.sqlite" @@ -454,19 +445,15 @@ def file_path(self) -> Path | None: @serializable() class SQLiteStoreConfig(StoreConfig): __canonical_name__ = "SQLiteStoreConfig" - """SQLite Store config, used by SQLiteStorePartition - - Parameters: - `client_config`: SQLiteStoreClientConfig - SQLite connection configuration - `store_type`: DocumentStore - Class interacting with QueueStash. Default: SQLiteDocumentStore - `backing_store`: KeyValueBackingStore - The Store core logic. Default: SQLiteBackingStore - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. + """SQLite Store config, used by SQLiteStorePartition. + + Args: + client_config (SQLiteStoreClientConfig): SQLite connection configuration. + store_type (type[DocumentStore]): Class interacting with QueueStash. Default: SQLiteDocumentStore. + backing_store (type[KeyValueBackingStore]): The Store core logic. Default: SQLiteBackingStore. + locking_config (LockingConfig): The config used for store locking. Available options: + * NoLockingConfig: no locking, ideal for single-thread stores. + * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. Defaults to NoLockingConfig. """ diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 4206e92d9bb..f4dbfc51806 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -48,32 +48,35 @@ def get_latest_version(cls, canonical_name: str) -> int: @classmethod def get_identifier_for_type(cls, obj: Any) -> tuple[str, int]: """ - This is to create the string in nonrecursiveBlob + This is to create the string in nonrecursiveBlob. """ return cls.__type_to_canonical_name__[obj] @classmethod def get_canonical_name_version(cls, obj: Any) -> tuple[str, int]: """ - Retrieves the canonical name for both objects and types. + Retrieves the canonical name and version for both objects and types. This function works for both objects and types, returning the canonical name - as a string. It handles various cases, including built-in types, instances of - classes, and enum members. + as a string and its version as an integer. It handles various cases, including + built-in types, instances of classes, and enum members. If the object is not registered in the registry, a ValueError is raised. Examples: - get_canonical_name_version([1,2,3]) -> "list" - get_canonical_name_version(list) -> "type" - get_canonical_name_version(MyEnum.A) -> "MyEnum" - get_canonical_name_version(MyEnum) -> "type" + get_canonical_name_version([1, 2, 3]) -> ("list", version) + get_canonical_name_version(list) -> ("type", version) + get_canonical_name_version(MyEnum.A) -> ("MyEnum", version) + get_canonical_name_version(MyEnum) -> ("type", version) Args: - obj: The object or type for which to get the canonical name. + obj (Any): The object or type for which to get the canonical name. Returns: - The canonical name and version of the object or type. + tuple[str, int]: The canonical name and version of the object or type. + + Raises: + ValueError: If the canonical name for the object or type is not found. """ # for types we return "type" diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index f00fed8694f..47cc0f553b7 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -169,8 +169,8 @@ def build_tabulator_table_with_data( Args: table_data (list[dict]): The data to populate the table. table_metadata (dict): The metadata for the table. - uid (str, optional): The unique identifier for the table. Defaults to None. - max_height (int, optional): The maximum height of the table. Defaults to None. + uid (str | None, optional): The unique identifier for the table. Defaults to None. + max_height (int | None, optional): The maximum height of the table. Defaults to None. pagination (bool, optional): Whether to enable pagination. Defaults to True. header_sort (bool, optional): Whether to enable header sorting. Defaults to True. @@ -198,8 +198,8 @@ def build_tabulator_table( Args: obj (Any): The object to build the table from. - uid (str, optional): The unique identifier for the table. Defaults to None. - max_height (int, optional): The maximum height of the table. Defaults to None. + uid (str | None, optional): The unique identifier for the table. Defaults to None. + max_height (int | None, optional): The maximum height of the table. Defaults to None. pagination (bool, optional): Whether to enable pagination. Defaults to True. header_sort (bool, optional): Whether to enable header sorting. Defaults to True. diff --git a/packages/syft/src/syft/util/table.py b/packages/syft/src/syft/util/table.py index 25db3fb4552..4547d3a7733 100644 --- a/packages/syft/src/syft/util/table.py +++ b/packages/syft/src/syft/util/table.py @@ -228,12 +228,10 @@ def prepare_table_data( add_index (bool, optional): Whether to add an index column to the table. Defaults to True. Returns: - tuple: A tuple (table_data, table_metadata) where table_data is a list of dictionaries + tuple[list[dict], dict]: A tuple (table_data, table_metadata) where table_data is a list of dictionaries where each dictionary represents a row in the table and table_metadata is a dictionary containing metadata about the table such as name, icon, etc. - """ - values = _get_values_for_table_repr(obj) if len(values) == 0: return [], {} diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index c9011c82ffc..d511729f1d5 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -63,7 +63,14 @@ def get_env(key: str, default: Any | None = None) -> str | None: def full_name_with_qualname(klass: type) -> str: - """Returns the klass module name + klass qualname.""" + """Returns the klass module name + klass qualname. + + Args: + klass (type): The class whose fully qualified name is needed. + + Returns: + str: The fully qualified name of the class. + """ try: if not hasattr(klass, "__module__"): return f"builtins.{get_qualname_for(klass)}" @@ -75,7 +82,17 @@ def full_name_with_qualname(klass: type) -> str: def full_name_with_name(klass: type) -> str: - """Returns the klass module name + klass name.""" + """Returns the klass module name + klass name. + + Args: + klass (type): The class whose fully qualified name is needed. + + Returns: + str: The fully qualified name of the class. + + Raises: + Exception: If there is an error while getting the fully qualified name. + """ try: if not hasattr(klass, "__module__"): return f"builtins.{get_name_for(klass)}" @@ -86,6 +103,14 @@ def full_name_with_name(klass: type) -> str: def get_qualname_for(klass: type) -> str: + """Get the qualified name of a class. + + Args: + klass (type): The class to get the qualified name for. + + Returns: + str: The qualified name of the class. + """ qualname = getattr(klass, "__qualname__", None) or getattr(klass, "__name__", None) if qualname is None: qualname = extract_name(klass) @@ -93,6 +118,14 @@ def get_qualname_for(klass: type) -> str: def get_name_for(klass: type) -> str: + """Get the name of a class. + + Args: + klass (type): The class to get the name for. + + Returns: + str: The name of the class. + """ klass_name = getattr(klass, "__name__", None) if klass_name is None: klass_name = extract_name(klass) @@ -115,6 +148,12 @@ def get_mb_size(data: Any, handlers: dict | None = None) -> float: which is referenced in official sys.getsizeof documentation https://docs.python.org/3/library/sys.html#sys.getsizeof. + Args: + data (Any): The object to calculate the memory size for. + handlers (dict | None): Custom handlers for additional types. + + Returns: + float: The memory size of the object in MB. """ def dict_handler(d: dict[Any, Any]) -> Iterator[Any]: @@ -162,6 +201,14 @@ def sizeof(o: Any) -> int: def get_mb_serialized_size(data: Any) -> Ok[float] | Err[str]: + """Get the size of a serialized object in MB. + + Args: + data (Any): The object to be serialized and measured. + + Returns: + Ok[float] | Err[str]: The size of the serialized object in MB if successful, or an error message if serialization fails. + """ try: serialized_data = serialize(data, to_bytes=True) return Ok(sys.getsizeof(serialized_data) / (1024 * 1024)) @@ -174,6 +221,17 @@ def get_mb_serialized_size(data: Any) -> Ok[float] | Err[str]: def extract_name(klass: type) -> str: + """Extract the name of a class from its string representation. + + Args: + klass (type): The class to extract the name from. + + Returns: + str: The extracted name of the class. + + Raises: + ValueError: If the class name could not be extracted. + """ name_regex = r".+class.+?([\w\._]+).+" regex2 = r"([\w\.]+)" matches = re.match(name_regex, str(klass)) @@ -195,6 +253,19 @@ def extract_name(klass: type) -> str: def validate_type(_object: object, _type: type, optional: bool = False) -> Any: + """Validate that an object is of a certain type. + + Args: + _object (object): The object to validate. + _type (type): The type to validate against. + optional (bool): Whether the object can be None. + + Returns: + Any: The validated object. + + Raises: + Exception: If the object is not of the expected type. + """ if isinstance(_object, _type) or (optional and (_object is None)): return _object @@ -202,6 +273,18 @@ def validate_type(_object: object, _type: type, optional: bool = False) -> Any: def validate_field(_object: object, _field: str) -> Any: + """Validate that an object has a certain field. + + Args: + _object (object): The object to validate. + _field (str): The field to validate. + + Returns: + Any: The value of the field. + + Raises: + Exception: If the field is not set on the object. + """ object = getattr(_object, _field, None) if object is not None: @@ -211,19 +294,17 @@ def validate_field(_object: object, _field: str) -> Any: def get_fully_qualified_name(obj: object) -> str: - """Return the full path and name of a class + """Return the full path and name of a class. Sometimes we want to return the entire path and name encoded using periods. Args: - obj: the object we want to get the name of + obj (object): The object we want to get the name of. Returns: - the full path and name of the object - + str: The full path and name of the object. """ - fqn = obj.__class__.__module__ try: @@ -234,13 +315,12 @@ def get_fully_qualified_name(obj: object) -> str: def aggressive_set_attr(obj: object, name: str, attr: object) -> None: - """Different objects prefer different types of monkeypatching - try them all + """Different objects prefer different types of monkeypatching - try them all. Args: - obj: object whose attribute has to be set - name: attribute name - attr: value given to the attribute - + obj (object): The object whose attribute has to be set. + name (str): The attribute name. + attr (object): The value given to the attribute. """ try: setattr(obj, name, attr) @@ -249,6 +329,14 @@ def aggressive_set_attr(obj: object, name: str, attr: object) -> None: def key_emoji(key: object) -> str: + """Generate an emoji representation of a key. + + Args: + key (object): The key object. + + Returns: + str: An emoji string representing the key. + """ try: if isinstance(key, bytes | SigningKey | VerifyKey): hex_chars = bytes(key).hex()[-8:] @@ -260,6 +348,14 @@ def key_emoji(key: object) -> str: def char_emoji(hex_chars: str) -> str: + """Generate an emoji based on a hexadecimal string. + + Args: + hex_chars (str): The hexadecimal string. + + Returns: + str: An emoji string generated from the hexadecimal string. + """ base = ord("\U0001f642") hex_base = ord("0") code = 0 @@ -270,6 +366,11 @@ def char_emoji(hex_chars: str) -> str: def get_root_data_path() -> Path: + """Get the root data path for storing datasets. + + Returns: + Path: The root data path. + """ # get the PySyft / data directory to share datasets between notebooks # on Linux and MacOS the directory is: ~/.syft/data" # on Windows the directory is: C:/Users/$USER/.syft/data @@ -281,6 +382,15 @@ def get_root_data_path() -> Path: def download_file(url: str, full_path: str | Path) -> Path | None: + """Download a file from a URL. + + Args: + url (str): The URL of the file to download. + full_path (str | Path): The full path where the file should be saved. + + Returns: + Path | None: The path to the downloaded file, or None if the download failed. + """ full_path = Path(full_path) if not full_path.exists(): r = requests.get(url, allow_redirects=True, verify=verify_tls()) # nosec @@ -293,25 +403,46 @@ def download_file(url: str, full_path: str | Path) -> Path | None: def verify_tls() -> bool: + """Verify whether TLS should be used. + + Returns: + bool: True if TLS should be used, False otherwise. + """ return not str_to_bool(str(os.environ.get("IGNORE_TLS_ERRORS", "0"))) def ssl_test() -> bool: + """Check if SSL is properly configured. + + Returns: + bool: True if SSL is configured, False otherwise. + """ return len(os.environ.get("REQUESTS_CA_BUNDLE", "")) > 0 def initializer(event_loop: BaseSelectorEventLoop | None = None) -> None: """Set the same event loop to other threads/processes. + This is needed because there are new threads/processes started with - the Executor and they do not have have an event loop set + the Executor and they do not have an event loop set. + Args: - event_loop: The event loop. + event_loop (BaseSelectorEventLoop | None): The event loop to set. """ if event_loop: asyncio.set_event_loop(event_loop) def split_rows(rows: Sequence, cpu_count: int) -> list: + """Split a sequence of rows into chunks for parallel processing. + + Args: + rows (Sequence): The sequence of rows to split. + cpu_count (int): The number of chunks to split into. + + Returns: + list: A list of row chunks. + """ n = len(rows) a, b = divmod(n, cpu_count) start = 0 @@ -324,6 +455,14 @@ def split_rows(rows: Sequence, cpu_count: int) -> list: def list_sum(*inp_lst: list[Any]) -> Any: + """Sum a list of lists element-wise. + + Args: + *inp_lst (list[Any]): The list of lists to sum. + + Returns: + Any: The sum of the lists. + """ s = inp_lst[0] for i in inp_lst[1:]: s = s + i @@ -332,6 +471,14 @@ def list_sum(*inp_lst: list[Any]) -> Any: @contextmanager def concurrency_override(count: int = 1) -> Iterator: + """Context manager to override concurrency count. + + Args: + count (int): The concurrency count to set. Defaults to 1. + + Yields: + Iterator: A context manager. + """ # this only effects local code so its best to use in unit tests try: os.environ["FORCE_CONCURRENCY_COUNT"] = f"{count}" @@ -345,8 +492,17 @@ def print_process( # type: ignore finish: EventClass, success: EventClass, lock: LockBase, - refresh_rate=0.1, + refresh_rate: float = 0.1, ) -> None: + """Print a dynamic process message that updates periodically. + + Args: + message (str): The message to print. + finish (EventClass): Event to signal the finish of the process. + success (EventClass): Event to signal the success of the process. + lock (LockBase): A lock to synchronize the print output. + refresh_rate (float, optional): The refresh rate for updating the message. Defaults to 0.1. + """ with lock: while not finish.is_set(): print(f"{bcolors.bold(message)} .", end="\r") @@ -368,12 +524,13 @@ def print_process( # type: ignore def print_dynamic_log( message: str, ) -> tuple[EventClass, EventClass]: - """ - Prints a dynamic log message that will change its color (to green or red) when some process is done. + """Print a dynamic log message that will change its color when some process is done. - message: str = Message to be printed. + Args: + message (str): The message to be printed. - return: tuple of events that can control the log print from the outside of this method. + Returns: + tuple[EventClass, EventClass]: Tuple of events that can control the log print from outside this method. """ finish = multiprocessing.Event() success = multiprocessing.Event() @@ -392,6 +549,19 @@ def print_dynamic_log( def find_available_port( host: str, port: int | None = None, search: bool = False ) -> int: + """Find an available port on the specified host. + + Args: + host (str): The host to check for available ports. + port (int | None): The port to check. Defaults to a random port. + search (bool): Whether to search for the next available port if the given port is in use. + + Returns: + int: The available port number. + + Raises: + Exception: If the port is not available and search is False. + """ if port is None: port = random.randint(1500, 65000) # nosec port_available = False @@ -426,9 +596,8 @@ def find_available_port( def get_random_available_port() -> int: """Retrieve a random available port number from the host OS. - Returns - ------- - int: Available port number. + Returns: + int: Available port number. """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as soc: soc.bind(("localhost", 0)) @@ -436,15 +605,20 @@ def get_random_available_port() -> int: def get_loaded_syft() -> ModuleType: + """Get the loaded Syft module. + + Returns: + ModuleType: The loaded Syft module. + """ return sys.modules[__name__.split(".")[0]] def get_subclasses(obj_type: type) -> list[type]: - """Recursively generate the list of all classes within the sub-tree of an object + """Recursively generate the list of all classes within the sub-tree of an object. As a paradigm in Syft, we often allow for something to be known about by another part of the codebase merely because it has subclassed a particular object. While - this can be a big "magicish" it also can simplify future extensions and reduce + this can be a bit "magicish" it also can simplify future extensions and reduce the likelihood of small mistakes (if done right). This is a utility function which allows us to look for sub-classes and the sub-classes @@ -452,11 +626,10 @@ def get_subclasses(obj_type: type) -> list[type]: hierarchy. Args: - obj_type: the type we want to look for sub-classes of + obj_type (type): The type we want to look for sub-classes of. Returns: - the list of subclasses of obj_type - + list[type]: The list of subclasses of obj_type. """ classes = [] @@ -467,27 +640,25 @@ def get_subclasses(obj_type: type) -> list[type]: def index_modules(a_dict: object, keys: list[str]) -> object: - """Recursively find a syft module from its path + """Recursively find a Syft module from its path. This is the recursive inner function of index_syft_by_module_name. See that method for a full description. Args: - a_dict: a module we're traversing - keys: the list of string attributes we're using to traverse the module + a_dict (object): A module we're traversing. + keys (list[str]): The list of string attributes we're using to traverse the module. Returns: - a reference to the final object - + object: A reference to the final object. """ - if len(keys) == 0: return a_dict return index_modules(a_dict=a_dict.__dict__[keys[0]], keys=keys[1:]) def index_syft_by_module_name(fully_qualified_name: str) -> object: - """Look up a Syft class/module/function from full path and name + """Look up a Syft class/module/function from full path and name. Sometimes we want to use the fully qualified name (such as one generated from the 'get_fully_qualified_name' method below) to @@ -496,13 +667,14 @@ def index_syft_by_module_name(fully_qualified_name: str) -> object: representation of the specific object it is meant to deserialize to. Args: - fully_qualified_name: the name in str of a module, class, or function + fully_qualified_name (str): The name in str of a module, class, or function. Returns: - a reference to the actual object at that string path + object: A reference to the actual object at that string path. + Raises: + ReferenceError: If the reference does not match expected patterns. """ - # @Tudor this needs fixing during the serde refactor # we should probably just support the native type names as lookups for serde if fully_qualified_name == "builtins.NoneType": @@ -512,13 +684,22 @@ def index_syft_by_module_name(fully_qualified_name: str) -> object: if attr_list[0] != "syft": raise ReferenceError(f"Reference don't match: {attr_list[0]}") - # if attr_list[1] != "core" and attr_list[1] != "user": - # raise ReferenceError(f"Reference don't match: {attr_list[1]}") - return index_modules(a_dict=get_loaded_syft(), keys=attr_list[1:]) def obj2pointer_type(obj: object | None = None, fqn: str | None = None) -> type: + """Get the pointer type for an object based on its fully qualified name. + + Args: + obj (object | None): The object to get the pointer type for. + fqn (str | None): The fully qualified name of the object. + + Returns: + type: The pointer type for the object. + + Raises: + Exception: If the pointer type cannot be found. + """ if fqn is None: try: fqn = get_fully_qualified_name(obj=obj) @@ -543,6 +724,15 @@ def obj2pointer_type(obj: object | None = None, fqn: str | None = None) -> type: def prompt_warning_message(message: str, confirm: bool = False) -> bool: + """Prompt a warning message and optionally request user confirmation. + + Args: + message (str): The warning message to display. + confirm (bool): Whether to request user confirmation. + + Returns: + bool: True if the user confirms, False otherwise. + """ # relative from ..service.response import SyftWarning @@ -741,6 +931,11 @@ def prompt_warning_message(message: str, confirm: bool = False) -> bool: def random_name() -> str: + """Generate a random name by combining a left and right name part. + + Returns: + str: The generated random name. + """ left_i = randbelow(len(left_name) - 1) right_i = randbelow(len(right_name) - 1) return f"{left_name[left_i].capitalize()} {right_name[right_i].capitalize()}" @@ -750,9 +945,18 @@ def inherit_tags( attr_path_and_name: str, result: object, self_obj: object | None, - args: tuple | list, - kwargs: dict, + args: tuple[Any, ...] | list[Any], + kwargs: dict[str, Any], ) -> None: + """Inherit tags from input objects to the result object. + + Args: + attr_path_and_name (str): The attribute path and name to add as a tag. + result (object): The result object to inherit tags. + self_obj (object | None): The object that might have tags. + args (tuple[Any, ...] | list[Any]): Arguments that might have tags. + kwargs (dict[str, Any]): Keyword arguments that might have tags. + """ tags = [] if self_obj is not None and hasattr(self_obj, "tags"): tags.extend(list(self_obj.tags)) @@ -774,6 +978,16 @@ def inherit_tags( def autocache( url: str, extension: str | None = None, cache: bool = True ) -> Path | None: + """Automatically cache a file from a URL. + + Args: + url (str): The URL of the file to cache. + extension (str | None): The file extension to use. + cache (bool): Whether to use the cache if the file already exists. + + Returns: + Path | None: The path to the cached file, or None if caching failed. + """ try: data_path = get_root_data_path() file_hash = hashlib.sha256(url.encode("utf8")).hexdigest() @@ -790,6 +1004,14 @@ def autocache( def str_to_bool(bool_str: str | None) -> bool: + """Convert a string to a boolean value. + + Args: + bool_str (str | None): The string to convert. + + Returns: + bool: The converted boolean value. + """ result = False bool_str = str(bool_str).lower() if bool_str == "true" or bool_str == "1": @@ -800,20 +1022,19 @@ def str_to_bool(bool_str: str | None) -> bool: # local scope functions cant be pickled so this needs to be global def parallel_execution( fn: Callable[..., Any], - parties: None | list[Any] = None, + parties: list[Any] | None = None, cpu_bound: bool = False, ) -> Callable[..., list[Any]]: """Wrap a function such that it can be run in parallel at multiple parties. + Args: - fn (Callable): The function to run. - parties (Union[None, List[Any]]): Clients from syft. If this is set, then the + fn (Callable[..., Any]): The function to run. + parties (list[Any] | None): Clients from syft. If this is set, then the function should be run remotely. Defaults to None. - cpu_bound (bool): Because of the GIL (global interpreter lock) sometimes - it makes more sense to use processes than threads if it is set then - processes should be used since they really run in parallel if not then - it makes sense to use threads since there is no bottleneck on the CPU side + cpu_bound (bool): Whether to use processes instead of threads. + Returns: - Callable[..., List[Any]]: A Callable that returns a list of results. + Callable[..., list[Any]]: A Callable that returns a list of results. """ @functools.wraps(fn) @@ -822,11 +1043,16 @@ def wrapper( kwargs: dict[Any, dict[Any, Any]] | None = None, ) -> list[Any]: """Wrap sanity checks and checks what executor should be used. + Args: - args (List[List[Any]]): Args. - kwargs (Optional[Dict[Any, Dict[Any, Any]]]): Kwargs. Default to None. + args (list[list[Any]]): The list of lists of arguments. + kwargs (dict[Any, dict[Any, Any]] | None): The dictionary of keyword arguments. + Returns: - List[Any]: Results from the parties + list[Any]: The list of results from the parties. + + Raises: + Exception: If the arguments list is empty. """ if args is None or len(args) == 0: raise Exception("Parallel execution requires more than 0 args") @@ -873,6 +1099,14 @@ def wrapper( def concurrency_count(factor: float = 0.8) -> int: + """Get the current concurrency count based on CPU count and a factor. + + Args: + factor (float): The factor to apply to the CPU count. Defaults to 0.8. + + Returns: + int: The calculated concurrency count. + """ force_count = int(os.environ.get("FORCE_CONCURRENCY_COUNT", 0)) mp_count = force_count if force_count >= 1 else int(mp.cpu_count() * factor) return mp_count @@ -891,18 +1125,51 @@ class bcolors: @staticmethod def green(message: str) -> str: + """Return a green-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The green-colored message. + """ return bcolors.GREEN + message + bcolors.ENDC @staticmethod def red(message: str) -> str: + """Return a red-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The red-colored message. + """ return bcolors.RED + message + bcolors.ENDC @staticmethod def yellow(message: str) -> str: + """Return a yellow-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The yellow-colored message. + """ return bcolors.YELLOW + message + bcolors.ENDC @staticmethod def bold(message: str, end_color: bool = False) -> str: + """Return a bold string. + + Args: + message (str): The message to bold. + end_color (bool): Whether to reset color after the message. + + Returns: + str: The bolded message. + """ msg = bcolors.BOLD + message if end_color: msg += bcolors.ENDC @@ -910,6 +1177,15 @@ def bold(message: str, end_color: bool = False) -> str: @staticmethod def underline(message: str, end_color: bool = False) -> str: + """Return an underlined string. + + Args: + message (str): The message to underline. + end_color (bool): Whether to reset color after the message. + + Returns: + str: The underlined message. + """ msg = bcolors.UNDERLINE + message if end_color: msg += bcolors.ENDC @@ -917,18 +1193,47 @@ def underline(message: str, end_color: bool = False) -> str: @staticmethod def warning(message: str) -> str: + """Return a warning-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The warning-colored message. + """ return bcolors.bold(bcolors.yellow(message)) @staticmethod def success(message: str) -> str: + """Return a success-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The success-colored message. + """ return bcolors.green(message) @staticmethod def failure(message: str) -> str: + """Return a failure-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The failure-colored message. + """ return bcolors.red(message) def os_name() -> str: + """Get the name of the operating system. + + Returns: + str: The name of the operating system. + """ os_name = platform.system() if os_name.lower() == "darwin": return "macOS" @@ -938,18 +1243,38 @@ def os_name() -> str: # Note: In the future there might be other interpreters that we want to use def is_interpreter_jupyter() -> bool: + """Check if the current interpreter is Jupyter. + + Returns: + bool: True if the current interpreter is Jupyter, False otherwise. + """ return get_interpreter_module() == "ipykernel.zmqshell" def is_interpreter_colab() -> bool: + """Check if the current interpreter is Google Colab. + + Returns: + bool: True if the current interpreter is Google Colab, False otherwise. + """ return get_interpreter_module() == "google.colab._shell" def is_interpreter_standard() -> bool: + """Check if the current interpreter is a standard Python interpreter. + + Returns: + bool: True if the current interpreter is standard, False otherwise. + """ return get_interpreter_module() == "StandardInterpreter" def get_interpreter_module() -> str: + """Get the module name of the current interpreter. + + Returns: + str: The module name of the current interpreter. + """ try: # third party from IPython import get_ipython @@ -967,14 +1292,30 @@ def get_interpreter_module() -> str: def thread_ident() -> int | None: + """Get the identifier of the current thread. + + Returns: + int | None: The thread identifier, or None if not available. + """ return threading.current_thread().ident def proc_id() -> int: + """Get the process ID of the current process. + + Returns: + int: The process ID. + """ return os.getpid() def set_klass_module_to_syft(klass: type, module_name: str) -> None: + """Set the module of a class to Syft. + + Args: + klass (type): The class to set the module for. + module_name (str): The name of the module. + """ if module_name not in sys.modules["syft"].__dict__: new_module = types.ModuleType(module_name) else: @@ -984,8 +1325,14 @@ def set_klass_module_to_syft(klass: type, module_name: str) -> None: def get_queue_address(port: int) -> str: - """Get queue address based on container host name.""" + """Get queue address based on container host name. + + Args: + port (int): The port number. + Returns: + str: The queue address. + """ container_host = os.getenv("CONTAINER_HOST", None) if container_host == "k8s": return f"tcp://backend:{port}" @@ -995,14 +1342,32 @@ def get_queue_address(port: int) -> str: def get_dev_mode() -> bool: + """Check if the application is running in development mode. + + Returns: + bool: True if in development mode, False otherwise. + """ return str_to_bool(os.getenv("DEV_MODE", "False")) def generate_token() -> str: + """Generate a secure random token. + + Returns: + str: The generated token. + """ return secrets.token_hex(64) def sanitize_html(html: str) -> str: + """Sanitize HTML content by allowing specific tags and attributes. + + Args: + html (str): The HTML content to sanitize. + + Returns: + str: The sanitized HTML content. + """ policy = { "tags": ["svg", "strong", "rect", "path", "circle"], "attributes": { @@ -1041,6 +1406,14 @@ def sanitize_html(html: str) -> str: def parse_iso8601_date(date_string: str) -> datetime: + """Parse an ISO8601 date string into a datetime object. + + Args: + date_string (str): The ISO8601 date string. + + Returns: + datetime: The parsed datetime object. + """ # Handle variable length of microseconds by trimming to 6 digits if "." in date_string: base_date, microseconds = date_string.split(".") @@ -1051,6 +1424,15 @@ def parse_iso8601_date(date_string: str) -> datetime: def get_latest_tag(registry: str, repo: str) -> str | None: + """Get the latest tag from a Docker registry for a given repository. + + Args: + registry (str): The Docker registry. + repo (str): The repository name. + + Returns: + str | None: The latest tag, or None if no tags are found. + """ repo_url = f"http://{registry}/v2/{repo}" res = requests.get(url=f"{repo_url}/tags/list", timeout=5) tags = res.json().get("tags", []) @@ -1071,6 +1453,14 @@ def get_latest_tag(registry: str, repo: str) -> str | None: def get_nb_secrets(defaults: dict | None = None) -> dict: + """Get secrets for notebooks from a JSON file. + + Args: + defaults (dict | None): Default values for the secrets. + + Returns: + dict: The secrets loaded from the JSON file. + """ if defaults is None: defaults = {} @@ -1087,21 +1477,29 @@ def get_nb_secrets(defaults: dict | None = None) -> dict: class CustomRepr(reprlib.Repr): def repr_str(self, obj: Any, level: int = 0) -> str: + """Return a truncated string representation if it is too long. + + Args: + obj (Any): The object to represent. + level (int): The level of detail in the representation. + + Returns: + str: The truncated string representation. + """ if len(obj) <= self.maxstring: return repr(obj) return repr(obj[: self.maxstring] + "...") def repr_truncation(obj: Any, max_elements: int = 10) -> str: - """ - Return a truncated string representation of the object if it is too long. + """Return a truncated string representation of the object if it is too long. Args: - - obj: The object to be represented (can be str, list, dict, set...). - - max_elements: Maximum number of elements to display before truncating. + obj (Any): The object to be represented (can be str, list, dict, set...). + max_elements (int): Maximum number of elements to display before truncating. Returns: - - A string representation of the object, truncated if necessary. + str: A string representation of the object, truncated if necessary. """ r = CustomRepr() r.maxlist = max_elements # For lists From ae5809cfc92a3f66f80c8a507cb379cace1b3284 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Tue, 20 Aug 2024 16:29:06 -0400 Subject: [PATCH 03/10] Fixing bugs introducted by accidental code delete --- .../src/syft/service/action/action_object.py | 86 ++++++++++++++++++- .../src/syft/service/action/action_store.py | 6 +- .../src/syft/service/network/server_peer.py | 2 +- .../service/worker/worker_pool_service.py | 3 +- .../syft/src/syft/store/document_store.py | 4 +- packages/syft/src/syft/util/util.py | 3 +- 6 files changed, 96 insertions(+), 8 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 6628a476ba7..31e665fe8ea 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -10,6 +10,7 @@ import logging from pathlib import Path import sys +import threading import time import traceback import types @@ -29,6 +30,7 @@ # relative from ...client.api import APIRegistry +from ...client.api import SyftAPI from ...client.api import SyftAPICall from ...client.client import SyftClient from ...serde.serializable import serializable @@ -39,6 +41,7 @@ from ...service.response import SyftSuccess from ...service.response import SyftWarning from ...store.linked_obj import LinkedObject +from ...types.base import SyftBaseModel from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftBaseObject @@ -445,6 +448,86 @@ def make_action_side_effect( return Ok((context, args, kwargs)) +class TraceResultRegistry: + __result_registry__: dict[int, TraceResult] = {} + + @classmethod + def set_trace_result_for_current_thread( + cls, + client: SyftClient, + ) -> None: + cls.__result_registry__[threading.get_ident()] = TraceResult( + client=client, is_tracing=True + ) + + @classmethod + def get_trace_result_for_thread(cls) -> TraceResult | None: + return cls.__result_registry__.get(threading.get_ident(), None) + + @classmethod + def reset_result_for_thread(cls) -> None: + if threading.get_ident() in cls.__result_registry__: + del cls.__result_registry__[threading.get_ident()] + + @classmethod + def current_thread_is_tracing(cls) -> bool: + trace_result = cls.get_trace_result_for_thread() + if trace_result is None: + return False + else: + return trace_result.is_tracing + + +class TraceResult(SyftBaseModel): + result: list = [] + client: SyftClient + is_tracing: bool = False + + +def trace_action_side_effect( + context: PreHookContext, *args: Any, **kwargs: Any +) -> Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: + action = context.action + if action is not None and TraceResultRegistry.current_thread_is_tracing(): + trace_result = TraceResultRegistry.get_trace_result_for_thread() + trace_result.result += [action] # type: ignore + return Ok((context, args, kwargs)) + + +def convert_to_pointers( + api: SyftAPI, + server_uid: UID | None = None, + args: list | None = None, + kwargs: dict | None = None, +) -> tuple[list, dict]: + # relative + from ..dataset.dataset import Asset + + def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: + if ( + not isinstance(arg, ActionObject | Asset | UID) + and api.signing_key is not None # type: ignore[unreachable] + ): + arg = ActionObject.from_obj( # type: ignore[unreachable] + syft_action_data=arg, + syft_client_verify_key=api.signing_key.verify_key, + syft_server_location=api.server_uid, + ) + arg.syft_server_uid = server_uid + r = arg._save_to_blob_storage() + if isinstance(r, SyftError): + print(r.message) + if isinstance(r, SyftWarning): + logger.debug(r.message) + arg = api.services.action.set(arg) + return arg + + arg_list = [process_arg(arg) for arg in args] if args else [] + kwarg_dict = {k: process_arg(v) for k, v in kwargs.items()} if kwargs else {} + + return arg_list, kwarg_dict + + def send_action_side_effect( context: PreHookContext, *args: Any, **kwargs: Any ) -> Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: @@ -1044,7 +1127,8 @@ def syft_make_action( op (str): The method to be executed from the remote object. remote_self (UID | LineageID | None): The extended UID of the SyftObject. args (list[UID | LineageID | ActionObjectPointer | ActionObject | Any] | None): Operation arguments. - kwargs (dict[str, UID | LineageID | ActionObjectPointer | ActionObject | Any] | None): Operation keyword arguments. + kwargs (dict[str, UID | LineageID | ActionObjectPointer | ActionObject | Any] | None): Operation + keyword arguments. action_type (ActionType | None): The type of action being performed. Returns: diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index f69ff80894b..f898ec5f899 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -54,7 +54,8 @@ def __init__( Args: server_uid (UID): Unique identifier for the server instance. - store_config (StoreConfig): Backend specific configuration, including connection configuration, database name, or client class type. + store_config (StoreConfig): Backend specific configuration, including connection configuration, + database name, or client class type. root_verify_key (SyftVerifyKey | None): Signature verification key, used for checking access permissions. document_store (DocumentStore | None): Document store used for storing user information. """ @@ -389,7 +390,8 @@ def __init__( Args: server_uid (UID): Unique identifier for the server instance. - store_config (StoreConfig | None): Backend specific configuration, including connection configuration, database name, or client class type. + store_config (StoreConfig | None): Backend specific configuration, including connection configuration, + database name, or client class type. root_verify_key (SyftVerifyKey | None): Signature verification key, used for checking access permissions. document_store (DocumentStore | None): Document store used for storing user information. """ diff --git a/packages/syft/src/syft/service/network/server_peer.py b/packages/syft/src/syft/service/network/server_peer.py index 10005721f4f..e8af2c85233 100644 --- a/packages/syft/src/syft/service/network/server_peer.py +++ b/packages/syft/src/syft/service/network/server_peer.py @@ -85,7 +85,7 @@ def existed_route(self, route: ServerRouteType) -> tuple[bool, int | None]: """ if route: if not isinstance( - route, (HTTPServerRoute, PythonServerRoute, VeilidServerRoute) + route, HTTPServerRoute | PythonServerRoute | VeilidServerRoute ): raise ValueError(f"Unsupported route type: {type(route)}") for i, r in enumerate(self.server_routes): diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 433a86f5752..127a108915b 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -267,7 +267,8 @@ def create_image_and_pool_request( num_workers (int): The number of workers in the pool. config (WorkerConfig): Config of the image to be built. tag (str | None, optional): A human-readable manifest identifier. Required for `DockerWorkerConfig`. - registry_uid (UID | None, optional): UID of the registry in Kubernetes mode. Required for `DockerWorkerConfig`. + registry_uid (UID | None, optional): UID of the registry in Kubernetes mode. Required + for `DockerWorkerConfig`. reason (str | None, optional): The reason for creating the worker image and pool. Defaults to "". pull_image (bool, optional): Whether to pull the image. Defaults to True. pod_annotations (dict[str, str] | None, optional): Annotations for the pod. Defaults to None. diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 49114c8df67..c3342d8a2dc 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -154,8 +154,8 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: pk_value = obj else: pk_value = getattr(obj, pk_key) - if isinstance(pk_value, types.FunctionType | types.MethodType): - pk_value = pk_value() + if isinstance(pk_value, types.FunctionType | types.MethodType): # type: ignore[unreachable] + pk_value = pk_value() # type: ignore[unreachable] if pk_value and not isinstance(pk_value, pk_type): raise Exception( diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index d511729f1d5..9178e8b3577 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -207,7 +207,8 @@ def get_mb_serialized_size(data: Any) -> Ok[float] | Err[str]: data (Any): The object to be serialized and measured. Returns: - Ok[float] | Err[str]: The size of the serialized object in MB if successful, or an error message if serialization fails. + Ok[float] | Err[str]: The size of the serialized object in MB if successful, or an error + message if serialization fails. """ try: serialized_data = serialize(data, to_bytes=True) From 222b8ae052855f9e16d51bdf9d1ee61128a62af9 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Tue, 20 Aug 2024 16:29:49 -0400 Subject: [PATCH 04/10] temporarily supressing pydoclint and checking --- .pre-commit-config.yaml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46b4a0e4129..8f2a37e29ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -180,16 +180,16 @@ repos: - id: prettier exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock|.vscode) - - repo: https://github.com/jsh9/pydoclint - rev: 0.5.6 - hooks: - - id: pydoclint - args: [ - # --config=packages/syft/pyproject.toml, - --quiet, - --style=google, - --allow-init-docstring=true, - ] + # - repo: https://github.com/jsh9/pydoclint + # rev: 0.5.6 + # hooks: + # - id: pydoclint + # args: [ + # # --config=packages/syft/pyproject.toml, + # --quiet, + # --style=google, + # --allow-init-docstring=true, + # ] # - repo: meta # hooks: From 3667db5e48b13c7dab5dcb1e5013f0cc202f289a Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Wed, 21 Aug 2024 16:58:45 -0400 Subject: [PATCH 05/10] Fixing bugs introducted during the syle upgrade --- .pre-commit-config.yaml | 20 +++++++++---------- .../src/syft/service/action/action_object.py | 8 ++++---- .../src/syft/service/network/server_peer.py | 2 +- packages/syft/src/syft/store/locks.py | 2 +- .../src/syft/store/sqlite_document_store.py | 2 +- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f2a37e29ad..46b4a0e4129 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -180,16 +180,16 @@ repos: - id: prettier exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock|.vscode) - # - repo: https://github.com/jsh9/pydoclint - # rev: 0.5.6 - # hooks: - # - id: pydoclint - # args: [ - # # --config=packages/syft/pyproject.toml, - # --quiet, - # --style=google, - # --allow-init-docstring=true, - # ] + - repo: https://github.com/jsh9/pydoclint + rev: 0.5.6 + hooks: + - id: pydoclint + args: [ + # --config=packages/syft/pyproject.toml, + --quiet, + --style=google, + --allow-init-docstring=true, + ] # - repo: meta # hooks: diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 31e665fe8ea..9f0a46a8007 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -1069,12 +1069,12 @@ def _syft_try_to_save_to_store(self, obj: SyftObject) -> None: trace_result.result += [action] # type: ignore api = APIRegistry.api_for( - server_uid=obj.syft_server_location, + server_uid=self.syft_server_location, user_verify_key=self.syft_client_verify_key, ) if api is None: print( - f"failed saving {obj} to blob storage, api is None. You must login to {obj.syft_server_location}" + f"failed saving {obj} to blob storage, api is None. You must login to {self.syft_server_location}" ) return else: @@ -1155,7 +1155,7 @@ def syft_make_action( def syft_make_action_with_self( self, op: str, - args: dict[str, UID | ActionObjectPointer] | None = None, + args: list[UID | ActionObjectPointer] | None = None, kwargs: dict[str, UID | ActionObjectPointer] | None = None, action_type: ActionType | None = None, ) -> Action: @@ -1163,7 +1163,7 @@ def syft_make_action_with_self( Args: op (str): The method to be executed from the remote object. - args (dict[str, UID | ActionObjectPointer] | None): Operation arguments. + args (list[UID | ActionObjectPointer] | None): Operation arguments. kwargs (dict[str, UID | ActionObjectPointer] | None): Operation keyword arguments. action_type (ActionType | None): The type of action being performed. diff --git a/packages/syft/src/syft/service/network/server_peer.py b/packages/syft/src/syft/service/network/server_peer.py index e8af2c85233..8dd68f0e918 100644 --- a/packages/syft/src/syft/service/network/server_peer.py +++ b/packages/syft/src/syft/service/network/server_peer.py @@ -138,7 +138,7 @@ def update_route(self, route: ServerRoute) -> None: route (ServerRoute): The new route to be added to the peer. """ existed, idx = self.existed_route(route) - if existed and idx is not None: + if existed: self.server_routes[idx] = route # type: ignore else: new_route = self.update_route_priority(route) diff --git a/packages/syft/src/syft/store/locks.py b/packages/syft/src/syft/store/locks.py index 12174771b89..bb71e84f691 100644 --- a/packages/syft/src/syft/store/locks.py +++ b/packages/syft/src/syft/store/locks.py @@ -176,7 +176,7 @@ def _locked(self) -> bool: """ if self.passthrough: return False - return self._lock._locked if self._lock else False + return self._lock.locked() if self._lock else False def acquire(self, blocking: bool = True) -> bool: """ diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index 488f27e30fd..26ec5b44050 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -354,7 +354,7 @@ def __contains__(self, key: Any) -> bool: def __iter__(self) -> Any: return iter(self.keys()) - def __del(self) -> None: + def __del__(self) -> None: try: self._close() except Exception as e: From cd52f5b31ba13e3910462a7a013838ee9b4e3b9a Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Wed, 21 Aug 2024 17:22:46 -0400 Subject: [PATCH 06/10] Fixing DictStoreConfig bug --- packages/syft/src/syft/serde/lib_service_registry.py | 5 +++++ packages/syft/src/syft/service/action/action_object.py | 5 +++++ packages/syft/src/syft/store/dict_document_store.py | 3 ++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/serde/lib_service_registry.py b/packages/syft/src/syft/serde/lib_service_registry.py index a3f0901b32a..314fe9bc1f4 100644 --- a/packages/syft/src/syft/serde/lib_service_registry.py +++ b/packages/syft/src/syft/serde/lib_service_registry.py @@ -142,6 +142,11 @@ def init_child( elif inspect.ismodule(child_obj) and CMPBase.is_submodule( parent_obj, child_obj ): + ## TODO, we could register modules and functions in 2 ways: + # A) as numpy.float32 (what we are doing now) + # B) as numpy.core.float32 (currently not supported) + # only allow submodules + return CMPModule( child_path, permissions=self.permissions, diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 9f0a46a8007..1eede8bdfce 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -152,6 +152,11 @@ def syft_history_hash(self) -> int: hashes = 0 if self.remote_self: hashes += hash(self.remote_self.syft_history_hash) + # 🔵 TODO: resolve this + # if the object is ActionDataEmpty then the type might not be equal to the + # real thing. This is the same issue with determining the result type from + # a pointer operation in the past, so we should think about what we want here + # hashes += hash(self.path) hashes += hash(self.op) for arg in self.args: hashes += hash(arg.syft_history_hash) diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py index 1a73692eee3..e56554eeb32 100644 --- a/packages/syft/src/syft/store/dict_document_store.py +++ b/packages/syft/src/syft/store/dict_document_store.py @@ -118,7 +118,8 @@ class DictStoreConfig(StoreConfig): backing_store (Type[KeyValueBackingStore]): The backend type used. Default: DictBackingStore. locking_config (LockingConfig): The config used for store locking. """ - + __canonical_name__ = "DictStoreConfig" + store_type: type[DocumentStore] = DictDocumentStore backing_store: type[KeyValueBackingStore] = DictBackingStore locking_config: LockingConfig = Field(default_factory=ThreadingLockingConfig) From 503e460d3e3f0506b43035cb7ade2ef07c3456a2 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Wed, 21 Aug 2024 21:11:04 -0400 Subject: [PATCH 07/10] Linting error fixed; shelving for now --- packages/syft/src/syft/dev/prof.py | 13 +------------ packages/syft/src/syft/store/dict_document_store.py | 3 ++- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/packages/syft/src/syft/dev/prof.py b/packages/syft/src/syft/dev/prof.py index 902f47ed6a4..469105a0b2b 100644 --- a/packages/syft/src/syft/dev/prof.py +++ b/packages/syft/src/syft/dev/prof.py @@ -1,5 +1,4 @@ # stdlib -from collections.abc import Generator import contextlib import os import signal @@ -9,17 +8,7 @@ @contextlib.contextmanager -def pyspy() -> Generator[subprocess.Popen, None, None]: - """Profile a block of code using py-spy. Intended for development purposes only. - - Example: - with pyspy(): - # do some work - a = [i for i in range(1000000)] - - Yields: - subprocess.Popen: The process object running py-spy. - """ +def pyspy() -> None: # type: ignore fd, fname = tempfile.mkstemp(".svg") os.close(fd) diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py index e56554eeb32..6197a5d9ae6 100644 --- a/packages/syft/src/syft/store/dict_document_store.py +++ b/packages/syft/src/syft/store/dict_document_store.py @@ -118,8 +118,9 @@ class DictStoreConfig(StoreConfig): backing_store (Type[KeyValueBackingStore]): The backend type used. Default: DictBackingStore. locking_config (LockingConfig): The config used for store locking. """ + __canonical_name__ = "DictStoreConfig" - + store_type: type[DocumentStore] = DictDocumentStore backing_store: type[KeyValueBackingStore] = DictBackingStore locking_config: LockingConfig = Field(default_factory=ThreadingLockingConfig) From 7faea1ddec9259bb8936ee1ffce3c46000f091f4 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Wed, 11 Sep 2024 17:31:26 -0400 Subject: [PATCH 08/10] Fixed indendation error --- .../syft/src/syft/service/action/action_object.py | 6 ++---- .../syft/src/syft/service/network/server_peer.py | 9 ++++----- .../src/syft/service/notifier/notifier_service.py | 14 +++++++------- .../src/syft/service/project/project_service.py | 4 ++-- packages/syft/src/syft/util/util.py | 1 - 5 files changed, 15 insertions(+), 19 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 1e988e7dee6..bbeafbe46f7 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -572,9 +572,7 @@ def send_action_side_effect( @as_result(SyftException) -def propagate_server_uid( - context: PreHookContext, op: str, result: Any -) -> Any: +def propagate_server_uid(context: PreHookContext, op: str, result: Any) -> Any: """Patch the result to include the syft_server_uid. Args: @@ -583,7 +581,7 @@ def propagate_server_uid( result (Any): The result to patch. Returns: - Any: + Any: Raises: SyftException: If the parent object does not have a syft_server_uid or diff --git a/packages/syft/src/syft/service/network/server_peer.py b/packages/syft/src/syft/service/network/server_peer.py index 68efe7d6739..f7e6368c76a 100644 --- a/packages/syft/src/syft/service/network/server_peer.py +++ b/packages/syft/src/syft/service/network/server_peer.py @@ -167,7 +167,7 @@ def update_existed_route_priority( Returns: ServerRouteType: The route with updated priority. - + Raises: SyftException: Route doesn't exist or priority is incorrect. """ @@ -231,9 +231,9 @@ def client_with_context(self, context: ServerServiceContext) -> SyftClient: Returns: SyftClient: The SyftClient object - - Raises: - SyftException: If there are no routes to the peer. + + Raises: + SyftException: If there are no routes to the peer. """ if len(self.server_routes) < 1: raise SyftException(f"No routes to peer: {self}") @@ -247,7 +247,6 @@ def client_with_context(self, context: ServerServiceContext) -> SyftClient: connection=connection, credentials=context.server.signing_key ) - @as_result(SyftException) def client_with_key(self, credentials: SyftSigningKey) -> SyftClient: """Create a SyftClient using a signing key. diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 8e977875971..1a72480f612 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -83,7 +83,7 @@ def _set_notifier(self, context: AuthedServiceContext, active: bool) -> SyftSucc def set_notifier_active_to_false( self, context: AuthedServiceContext - ) -> SyftSuccess | SyftError: + ) -> SyftSuccess: """ Essentially a duplicate of turn_off method. """ @@ -114,7 +114,7 @@ def turn_on( Returns: SyftSuccess : SyftSuccess if successful. - + Raises: SyftException: any error that occurs during the process """ @@ -233,7 +233,7 @@ def activate( def deactivate( self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL ) -> SyftSuccess: - """Deactivate email notifications for the authenticated user. + """Deactivate email notifications for the authenticated user. This will only work if the datasite owner has enabled notifications. Args: @@ -271,7 +271,7 @@ def init_notifier( smtp_host (str | None): SMTP server host. Defaults to None. Returns: - SyftSuccess: + SyftSuccess: Raises: SyftException: Error in creating or initializing notifier @@ -365,9 +365,9 @@ def dispatch_notification( Returns: SyftSuccess: SyftSuccess if the notification was successfully dispatched. - - Raises: - SyftException | RateLimitException: + + Raises: + SyftException | RateLimitException: - SyftException: Notifier settings not found or few other things. - RateLimitException: Surpassed email threshold limit """ diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 7501acc5a61..927e3fa432c 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -366,8 +366,8 @@ def check_for_project_request( Returns: None: No return - - Raises: + + Raises: SyftException: If notification failed to send. """ if ( diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 6732c2fe456..63b192e03a5 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -1451,7 +1451,6 @@ def get_latest_tag(registry: str, repo: str) -> str | None: return None - def get_caller_file_path() -> str | None: stack = inspect.stack() From 91097051a7ac8a36dc4ed22fa2831826e4bfb373 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Wed, 11 Sep 2024 18:55:57 -0400 Subject: [PATCH 09/10] fixed html_str error --- .../src/syft/service/notification/notification_service.py | 2 +- packages/syft/src/syft/util/util.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 15b2b9725d8..fe009063c68 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -93,7 +93,7 @@ def user_settings( def settings( self, context: AuthedServiceContext, - ) -> NotifierSettings: + ) -> NotifierSettings | None: return context.server.services.notifier.settings(context).unwrap() @service_method( diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 63b192e03a5..71f44269ac6 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -1358,11 +1358,11 @@ def generate_token() -> str: return secrets.token_hex(64) -def sanitize_html(html: str) -> str: +def sanitize_html(html_str: str) -> str: """Sanitize HTML content by allowing specific tags and attributes. Args: - html (str): The HTML content to sanitize. + html_str (str): The HTML content to sanitize. Returns: str: The sanitized HTML content. From db40822e7fcc8ff986f471e2f6b5bd3651bbe4a8 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Wed, 11 Sep 2024 19:57:57 -0400 Subject: [PATCH 10/10] fixed notifer bug --- .../syft/src/syft/service/notifier/notifier_service.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 1a72480f612..bbf85554b29 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -113,7 +113,7 @@ def turn_on( email_port (int | None): Email server port. Defaults to 587. Returns: - SyftSuccess : SyftSuccess if successful. + SyftSuccess: SyftSuccess if successful. Raises: SyftException: any error that occurs during the process @@ -243,9 +243,10 @@ def deactivate( Returns: SyftSuccess: SyftSuccess if successful. """ - return context.server.services.user.disable_notifications( + result = context.server.services.user.disable_notifications( context, notifier_type=notifier_type - ).unwrap() + ) + return result @staticmethod @as_result(SyftException) @@ -271,7 +272,7 @@ def init_notifier( smtp_host (str | None): SMTP server host. Defaults to None. Returns: - SyftSuccess: + SyftSuccess: SyftSuccess if successful. Raises: SyftException: Error in creating or initializing notifier