diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3487d8d0915..8cde6ae6f49 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -180,6 +180,17 @@ repos: - id: prettier exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock|.vscode) + - repo: https://github.com/jsh9/pydoclint + rev: 0.5.6 + hooks: + - id: pydoclint + args: [ + # --config=packages/syft/pyproject.toml, + --quiet, + --style=google, + --allow-init-docstring=true, + ] + # - repo: meta # hooks: # - id: identity diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 6c0d2d48459..cd0cb55546e 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -103,6 +103,7 @@ dev = ruff==0.4.7 safety>=2.4.0b2 aiosmtpd==1.4.6 + pydoclint==0.5.6 telemetry = opentelemetry-api==1.27.0 diff --git a/packages/syft/src/syft/custom_worker/builder.py b/packages/syft/src/syft/custom_worker/builder.py index e47f341a27f..7cda1cbd80a 100644 --- a/packages/syft/src/syft/custom_worker/builder.py +++ b/packages/syft/src/syft/custom_worker/builder.py @@ -31,6 +31,11 @@ class CustomWorkerBuilder: @cached_property def builder(self) -> BuilderBase: + """Returns the appropriate builder instance based on the environment. + + Returns: + BuilderBase: An instance of either KubernetesBuilder or DockerBuilder. + """ if IN_KUBERNETES: return KubernetesBuilder() else: @@ -39,16 +44,22 @@ def builder(self) -> BuilderBase: def build_image( self, config: WorkerConfig, - tag: str | None = None, + tag: str | None, **kwargs: Any, ) -> ImageBuildResult: - """ - Builds a Docker image from the given configuration. + """Builds a Docker image from the given configuration. + Args: config (WorkerConfig): The configuration for building the Docker image. - tag (str): The tag to use for the image. - """ + tag (str | None): The tag to use for the image. Defaults to None. + **kwargs (Any): Additional keyword arguments for the build process. + + Returns: + ImageBuildResult: The result of the image build process. + Raises: + TypeError: If the config type is not recognized. + """ if isinstance(config, DockerWorkerConfig): return self._build_dockerfile(config, tag, **kwargs) elif isinstance(config, CustomWorkerConfig): @@ -64,13 +75,18 @@ def push_image( password: str, **kwargs: Any, ) -> ImagePushResult: - """ - Pushes a Docker image to the given repo. + """Pushes a Docker image to the given registry. + Args: - repo (str): The repo to push the image to. - tag (str): The tag to use for the image. - """ + tag (str): The tag of the image to push. + registry_url (str): The URL of the registry. + username (str): The username for registry authentication. + password (str): The password for registry authentication. + **kwargs (Any): Additional keyword arguments for the push process. + Returns: + ImagePushResult: The result of the image push process. + """ return self.builder.push_image( tag=tag, username=username, @@ -84,6 +100,16 @@ def _build_dockerfile( tag: str, **kwargs: Any, ) -> ImageBuildResult: + """Builds a Docker image using a Dockerfile. + + Args: + config (DockerWorkerConfig): The configuration containing the Dockerfile. + tag (str): The tag to use for the image. + **kwargs (Any): Additional keyword arguments for the build process. + + Returns: + ImageBuildResult: The result of the image build process. + """ return self.builder.build_image( dockerfile=config.dockerfile, tag=tag, @@ -95,8 +121,18 @@ def _build_template( config: CustomWorkerConfig, **kwargs: Any, ) -> ImageBuildResult: - # Builds a Docker pre-made CPU/GPU image template using a CustomWorkerConfig - # remove once GPU is supported + """Builds a Docker image using a pre-made template. + + Args: + config (CustomWorkerConfig): The configuration containing template settings. + **kwargs (Any): Additional keyword arguments for the build process. + + Returns: + ImageBuildResult: The result of the image build process. + + Raises: + Exception: If GPU support is requested but not supported. + """ if config.build.gpu: raise Exception("GPU custom worker is not supported yet") @@ -120,16 +156,19 @@ def _build_template( ) def find_worker_image(self, type: str) -> Path: - """ - Find the Worker Dockerfile and it's context path - - PROD will be in `$APPDIR/grid/` - - DEV will be in `packages/grid/backend/grid/images` - - In both the cases context dir does not matter (unless we're calling COPY) + """Finds the Worker Dockerfile and its context path. + + The production Dockerfile will be located at `$APPDIR/grid/`. + The development Dockerfile will be located in `packages/grid/backend/grid/images`. Args: - type (str): The type of worker. + type (str): The type of worker (e.g., 'cpu' or 'gpu'). + Returns: Path: The path to the Dockerfile. + + Raises: + FileNotFoundError: If the Dockerfile is not found in any of the expected locations. """ filename = f"worker_{type}.dockerfile" lookup_paths = [ diff --git a/packages/syft/src/syft/dev/prof.py b/packages/syft/src/syft/dev/prof.py index a6aeffc5780..469105a0b2b 100644 --- a/packages/syft/src/syft/dev/prof.py +++ b/packages/syft/src/syft/dev/prof.py @@ -9,15 +9,6 @@ @contextlib.contextmanager def pyspy() -> None: # type: ignore - """Profile a block of code using py-spy. Intended for development purposes only. - - Example: - ``` - with pyspy(): - # do some work - a = [i for i in range(1000000)] - ``` - """ fd, fname = tempfile.mkstemp(".svg") os.close(fd) diff --git a/packages/syft/src/syft/serde/arrow.py b/packages/syft/src/syft/serde/arrow.py index 31a5ad1d27c..ec9554b9ba2 100644 --- a/packages/syft/src/syft/serde/arrow.py +++ b/packages/syft/src/syft/serde/arrow.py @@ -80,13 +80,13 @@ def numpyutf8toarray(input_index: np.ndarray) -> np.ndarray: def arraytonumpyutf8(string_list: str | np.ndarray) -> bytes: - """Encodes string Numpyarray to utf-8 encoded numpy array. + """Encodes string Numpyarray to utf-8 encoded numpy array. Args: - string_list (np.ndarray): NumpyArray to be encoded + string_list (str | np.ndarray): NumpyArray or string to be encoded. Returns: - bytes: serialized utf-8 encoded int Numpy array + bytes: serialized utf-8 encoded int Numpy array. """ array_shape = np.array(string_list).shape string_list = np.array(string_list).flatten() diff --git a/packages/syft/src/syft/serde/lib_service_registry.py b/packages/syft/src/syft/serde/lib_service_registry.py index 517df6c643c..314fe9bc1f4 100644 --- a/packages/syft/src/syft/serde/lib_service_registry.py +++ b/packages/syft/src/syft/serde/lib_service_registry.py @@ -120,15 +120,16 @@ def init_child( child_obj: type | object, absolute_path: str, ) -> Self | None: - """Get the child of parent as a CMPBase object + """Get the child of parent as a CMPBase object. Args: - parent_obj (_type_): parent object - child_path (_type_): _description_ - child_obj (_type_): _description_ + parent_obj (type | object): The parent object. + child_path (str): The path of the child object. + child_obj (type | object): The child object. + absolute_path (str): The absolute path of the child object. Returns: - _type_: _description_ + Self | None: The initialized CMPBase object or None if not applicable. """ parent_is_parent_module = CMPBase.parent_is_parent_module(parent_obj, child_obj) if CMPBase.isfunction(child_obj) and parent_is_parent_module: diff --git a/packages/syft/src/syft/serde/serializable.py b/packages/syft/src/syft/serde/serializable.py index 9a683dbcf57..6296f276710 100644 --- a/packages/syft/src/syft/serde/serializable.py +++ b/packages/syft/src/syft/serde/serializable.py @@ -10,7 +10,6 @@ module_type = type(syft) - T = TypeVar("T", bound=type) @@ -26,25 +25,28 @@ def serializable( Recursively serialize attributes of the class. Args: - `attrs` : List of attributes to serialize - `without` : List of attributes to exclude from serialization - `inherit` : Whether to inherit serializable attribute list from base class - `inheritable` : Whether the serializable attribute list can be inherited by derived class - - For non-pydantic classes, - - `inheritable=True` => Derived classes will include base class `attrs` - - `inheritable=False` => Derived classes will not include base class `attrs` - - `inherit=True` => Base class `attrs` + `attrs` - `without` - - `inherit=False` => `attrs` - `without` - - For pydantic classes, + attrs (list[str] | None): List of attributes to serialize. Defaults to None. + without (list[str] | None): List of attributes to exclude from serialization. Defaults to None. + inherit (bool | None): Whether to inherit the serializable attribute list from the base class. Defaults to True. + inheritable (bool | None): Whether the serializable attribute list can be inherited by derived + classes. Defaults to True. + canonical_name (str | None): The canonical name for the serialization. Defaults to None. + version (int | None): The version number for the serialization. Defaults to None. + + For non-pydantic classes: + - `inheritable=True` => Derived classes will include base class `attrs`. + - `inheritable=False` => Derived classes will not include base class `attrs`. + - `inherit=True` => Base class `attrs` + `attrs` - `without`. + - `inherit=False` => `attrs` - `without`. + + For pydantic classes: - No need to provide `attrs`. They will be automatically inferred. - Providing `attrs` will override the inferred attributes. - - `without` will work only on attributes of `Optional` type - - `inherit`, `inheritable` will not work as pydantic inherits by default + - `without` will work only on attributes of `Optional` type. + - `inherit`, `inheritable` will not work as pydantic inherits by default. Returns: - Decorated class + Callable[[T], T]: The decorated class. """ def rs_decorator(cls: T) -> T: diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index d7e566b733a..bbeafbe46f7 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -93,19 +93,16 @@ def repr_cls(c: Any) -> str: class Action(SyftObject): """Serializable Action object. - Parameters: - path: str - The path of the Type of the remote object. - op: str - The method to be executed from the remote object. - remote_self: Optional[LineageID] - The extended UID of the SyftObject - args: List[LineageID] - `op` args - kwargs: Dict[str, LineageID] - `op` kwargs - result_id: Optional[LineageID] - Extended UID of the resulted SyftObject + Attributes: + path (str | None): The path of the Type of the remote object. + op (str | None): The method to be executed from the remote object. + remote_self (LineageID | None): The extended UID of the SyftObject. + args (list[LineageID]): Arguments passed to the operation. + kwargs (dict[str, LineageID]): Keyword arguments passed to the operation. + result_id (LineageID): Extended UID of the resulted SyftObject. + action_type (ActionType | None): The type of action being performed. + create_object (SyftObject | None): The Syft object to be created. + user_code_id (UID | None): The UID associated with the user code. """ __canonical_name__ = "Action" @@ -411,17 +408,14 @@ class PreHookContext(SyftBaseObject): """Hook context - Parameters: - obj: Any - The ActionObject to use for the action - op_name: str - The method name to use for the action - server_uid: Optional[UID] - Optional Syft server UID - result_id: Optional[Union[UID, LineageID]] - Optional result Syft UID - action: Optional[Action] - The action generated by the current hook + Attributes: + obj (Any): The ActionObject to use for the action. + op_name (str): The method name to use for the action. + server_uid (Optional[UID]): Optional Syft server UID. + result_id (Optional[Union[UID, LineageID]]): Optional result Syft UID. + action (Optional[Action]): The action generated by the current hook. + result_twin_type (TwinMode | None): The twin mode for the result. + action_type (ActionType | None): The type of action being performed. """ obj: Any = None @@ -437,18 +431,15 @@ class PreHookContext(SyftBaseObject): def make_action_side_effect( context: PreHookContext, *args: Any, **kwargs: Any ) -> tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]: - """Create a new action from context_op_name, and add it to the PreHookContext - - Parameters: - context: PreHookContext - PreHookContext object - *args: - Operation *args - **kwargs - Operation *kwargs + """Create a new action from context_op_name, and add it to the PreHookContext. + + Args: + context (PreHookContext): The PreHookContext object. + *args (Any): Operation arguments. + **kwargs (Any): Operation keyword arguments. + Returns: - - Ok[[Tuple[PreHookContext, Tuple[Any, ...], Dict[str, Any]]] on success - - Err[str] on failure + tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]: """ try: action = context.obj.syft_make_action_with_self( @@ -547,7 +538,19 @@ def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: def send_action_side_effect( context: PreHookContext, *args: Any, **kwargs: Any ) -> tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]: - """Create a new action from the context.op_name, and execute it on the remote server.""" + """Create a new action from the context.op_name, and execute it on the remote server. + + Args: + context (PreHookContext): The PreHookContext object. + *args (Any): Operation arguments. + **kwargs (Any): Operation keyword arguments. + + Returns: + tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]: + + Raises: + RuntimeError: If the action cannot be created or if an unexpected response is received. + """ try: if context.action is None: context, _, _ = make_action_side_effect(context, *args, **kwargs).unwrap() @@ -570,18 +573,19 @@ def send_action_side_effect( @as_result(SyftException) def propagate_server_uid(context: PreHookContext, op: str, result: Any) -> Any: - """Patch the result to include the syft_server_uid - - Parameters: - context: PreHookContext - PreHookContext object - op: str - Which operation was executed - result: Any - The result to patch + """Patch the result to include the syft_server_uid. + + Args: + context (PreHookContext): The PreHookContext object. + op (str): Which operation was executed. + result (Any): The result to patch. + Returns: - - Ok[[result] on success - - Err[str] on failure + Any: + + Raises: + SyftException: If the parent object does not have a syft_server_uid or + if the output is not wrapped and should not propagate the server_uid. """ if context.op_name in dont_make_side_effects or not hasattr( context.obj, "syft_server_uid" @@ -956,16 +960,18 @@ def syft_eq(self, ext_obj: Self | None) -> bool: def syft_execute_action( self, action: Action, sync: bool = True ) -> ActionObjectPointer: - """Execute a remote action + """Execute a remote action. - Parameters: - action: Action - Which action to execute - sync: bool - Run sync/async + Args: + action (Action): Which action to execute. + sync (bool): Run sync/async. Returns: - ActionObjectPointer + ActionObjectPointer: The pointer to the action object. + + Raises: + SyftException: If the pointer doesn't have a server_uid. + ValueError: If the API is not found. """ if self.syft_server_uid is None: raise SyftException( @@ -1089,33 +1095,25 @@ def syft_make_action( path: str, op: str, remote_self: UID | LineageID | None = None, - args: ( - list[UID | LineageID | ActionObjectPointer | ActionObject | Any] | None - ) = None, - kwargs: ( - dict[str, UID | LineageID | ActionObjectPointer | ActionObject | Any] | None - ) = None, + args: list[UID | LineageID | ActionObjectPointer | ActionObject | Any] + | None = None, + kwargs: dict[str, UID | LineageID | ActionObjectPointer | ActionObject | Any] + | None = None, action_type: ActionType | None = None, ) -> Action: - """Generate new action from the information - - Parameters: - path: str - The path of the Type of the remote object. - op: str - The method to be executed from the remote object. - remote_self: Optional[Union[UID, LineageID]] - The extended UID of the SyftObject - args: Optional[List[Union[UID, LineageID, ActionObjectPointer, ActionObject]]] - `op` args - kwargs: Optional[Dict[str, Union[UID, LineageID, ActionObjectPointer, ActionObject]]] - `op` kwargs - Returns: - Action object + """Generate new action from the information. - Raises: - ValueError: For invalid args or kwargs - PydanticValidationError: For args and kwargs + Args: + path (str): The path of the Type of the remote object. + op (str): The method to be executed from the remote object. + remote_self (UID | LineageID | None): The extended UID of the SyftObject. + args (list[UID | LineageID | ActionObjectPointer | ActionObject | Any] | None): Operation arguments. + kwargs (dict[str, UID | LineageID | ActionObjectPointer | ActionObject | Any] | None): Operation + keyword arguments. + action_type (ActionType | None): The type of action being performed. + + Returns: + Action: The generated action object. """ if args is None: args = [] @@ -1123,7 +1121,6 @@ def syft_make_action( kwargs = {} arg_ids = [self._syft_prepare_obj_uid(obj) for obj in args] - kwarg_ids = {k: self._syft_prepare_obj_uid(obj) for k, obj in kwargs.items()} action = Action( @@ -1145,19 +1142,14 @@ def syft_make_action_with_self( ) -> Action: """Generate new method action from the current object. - Parameters: - op: str - The method to be executed from the remote object. - args: List[LineageID] - `op` args - kwargs: Dict[str, LineageID] - `op` kwargs - Returns: - Action object + Args: + op (str): The method to be executed from the remote object. + args (list[UID | ActionObjectPointer] | None): Operation arguments. + kwargs (dict[str, UID | ActionObjectPointer] | None): Operation keyword arguments. + action_type (ActionType | None): The type of action being performed. - Raises: - ValueError: For invalid args or kwargs - PydanticValidationError: For args and kwargs + Returns: + Action: The generated action object. """ path = self.syft_get_path() return self.syft_make_action( @@ -1184,7 +1176,7 @@ def get_sync_dependencies( return [] def syft_get_path(self) -> str: - """Get the type path of the underlying object""" + """Get the type path of the underlying object.""" if ( isinstance(self, AnyActionObject) and self.syft_internal_type @@ -1200,12 +1192,11 @@ def syft_remote_method( ) -> Callable: """Generate a Callable object for remote calls. - Parameters: - op: str - he method to be executed from the remote object. + Args: + op (str): The method to be executed from the remote object. Returns: - A function + Callable: A function to perform the operation. """ def wrapper( @@ -1245,7 +1236,7 @@ def _send( return res def get_from(self, client: SyftClient) -> Any: - """Get the object from a Syft Client""" + """Get the object from a Syft Client.""" res = client.api.services.action.get(self.id) return res.syft_action_data @@ -1263,7 +1254,14 @@ def has_storage_permission(self) -> bool: return False def get(self, block: bool = False) -> Any: - """Get the object from a Syft Client""" + """Get the object from a Syft Client. + + Args: + block (bool): Whether to block until the object is available. + + Returns: + Any: The object retrieved from the Syft Client. + """ # relative if block: @@ -1363,13 +1361,22 @@ def from_obj( ) -> ActionObject: """Create an ActionObject from an existing object. - Parameters: - syft_action_data: Any - The object to be converted to a Syft ActionObject - id: Optional[UID] - Which ID to use for the ActionObject. Optional - syft_lineage_id: Optional[LineageID] - Which LineageID to use for the ActionObject. Optional + Args: + syft_action_data (Any): The object to be converted to a Syft ActionObject. + id (UID | None): Which ID to use for the ActionObject. Optional. + syft_lineage_id (LineageID | None): Which LineageID to use for the ActionObject. Optional. + syft_client_verify_key (SyftVerifyKey | None): The client verification key. + syft_server_location (UID | None): The server location UID. + syft_resolved (bool | None): Whether the object is resolved. + data_server_id (UID | None): The data server ID. + syft_blob_storage_entry_id (UID | None): The blob storage entry ID. + + Returns: + ActionObject: The created ActionObject. + + Raises: + SyftException: If the object's type is unsupported. + ValueError: If the UID and LineageID don't match. """ if id is not None and syft_lineage_id is not None and id != syft_lineage_id.id: raise ValueError("UID and LineageID should match") @@ -1476,7 +1483,6 @@ def obj_not_ready( @classmethod def empty( - # TODO: fix the mypy issue cls, syft_internal_type: type[Any] | None = None, id: UID | None = None, @@ -1485,17 +1491,19 @@ def empty( data_server_id: UID | None = None, syft_blob_storage_entry_id: UID | None = None, ) -> Self: - """Create an ActionObject from a type, using a ActionDataEmpty object - - Parameters: - syft_internal_type: Type - The Type for which to create a ActionDataEmpty object - id: Optional[UID] - Which ID to use for the ActionObject. Optional - syft_lineage_id: Optional[LineageID] - Which LineageID to use for the ActionObject. Optional - """ + """Create an ActionObject from a type, using an ActionDataEmpty object. + + Args: + syft_internal_type (type[Any] | None): The Type for which to create an ActionDataEmpty object. + id (UID | None): Which ID to use for the ActionObject. + syft_lineage_id (LineageID | None): Which LineageID to use for the ActionObject. + syft_resolved (bool | None): Whether the object is resolved. + data_server_id (UID | None): The data server ID. + syft_blob_storage_entry_id (UID | None): The blob storage entry ID. + Returns: + Self: The created ActionObject. + """ syft_internal_type = ( type(None) if syft_internal_type is None else syft_internal_type ) @@ -1543,10 +1551,10 @@ def __post_init__(self) -> None: def _syft_add_pre_hooks__(self, eager_execution: bool) -> None: """ - Add pre-hooks + Add pre-hooks. Args: - eager_execution: bool: If eager execution is enabled, hooks for + eager_execution (bool): If eager execution is enabled, hooks for tracing and executing the action on remote are added. """ @@ -1565,10 +1573,10 @@ def _syft_add_pre_hooks__(self, eager_execution: bool) -> None: def _syft_add_post_hooks__(self, eager_execution: bool) -> None: """ - Add post-hooks + Add post-hooks. Args: - eager_execution: bool: If eager execution is enabled, hooks for + eager_execution (bool): If eager execution is enabled, hooks for tracing and executing the action on remote are added. """ if eager_execution: @@ -1580,7 +1588,7 @@ def _syft_add_post_hooks__(self, eager_execution: bool) -> None: def _syft_run_pre_hooks__( self, context: PreHookContext, name: str, args: Any, kwargs: Any ) -> tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]: - """Hooks executed before the actual call""" + """Hooks executed before the actual call.""" result_args, result_kwargs = args, kwargs if name in self.syft_pre_hooks__: for hook in self.syft_pre_hooks__[name]: @@ -1617,7 +1625,7 @@ def _syft_run_pre_hooks__( def _syft_run_post_hooks__( self, context: PreHookContext, name: str, result: Any ) -> Any: - """Hooks executed after the actual call""" + """Hooks executed after the actual call.""" new_result = result if name in self.syft_post_hooks__: for hook in self.syft_post_hooks__[name]: @@ -1651,7 +1659,7 @@ def _syft_run_post_hooks__( def _syft_output_action_object( self, result: Any, context: PreHookContext | None = None ) -> Any: - """Wrap the result in an ActionObject""" + """Wrap the result in an ActionObject.""" if issubclass(type(result), ActionObject): return result @@ -1903,9 +1911,11 @@ def __getattribute__(self, name: str) -> Any: * use the syft/_syft prefix for internal methods. * add the method name to the passthrough_attrs. - Parameters: - name: str - The name of the attribute to access. + Args: + name (str): The name of the attribute to access. + + Returns: + Any: The value of the requested attribute. """ # bypass ipython canary verification if name == "_ipython_canary_method_should_not_exist_": diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 0228165edb9..495509d8b1a 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -39,14 +39,7 @@ class ActionStore: @serializable(canonical_name="KeyValueActionStore", version=1) class KeyValueActionStore(ActionStore): - """Generic Key-Value Action store. - - Parameters: - store_config: StoreConfig - Backend specific configuration, including connection configuration, database name, or client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ + """Generic Key-Value Action store.""" def __init__( self, @@ -55,6 +48,16 @@ def __init__( root_verify_key: SyftVerifyKey | None = None, document_store: DocumentStore | None = None, ) -> None: + """ + Generic Key-Value Action store. + + Args: + server_uid (UID): Unique identifier for the server instance. + store_config (StoreConfig): Backend specific configuration, including connection configuration, + database name, or client class type. + root_verify_key (SyftVerifyKey | None): Signature verification key, used for checking access permissions. + document_store (DocumentStore | None): Document store used for storing user information. + """ self.server_uid = server_uid self.store_config = store_config self.settings = BasePartitionSettings(name="Action") @@ -400,14 +403,7 @@ def migrate_data(self, to_klass: SyftObject, credentials: SyftVerifyKey) -> bool @serializable(canonical_name="DictActionStore", version=1) class DictActionStore(KeyValueActionStore): - """Dictionary-Based Key-Value Action store. - - Parameters: - store_config: StoreConfig - Backend specific configuration, including client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ + """Dictionary-Based Key-Value Action store.""" def __init__( self, @@ -416,6 +412,16 @@ def __init__( root_verify_key: SyftVerifyKey | None = None, document_store: DocumentStore | None = None, ) -> None: + """ + Dictionary-Based Key-Value Action store. + + Args: + server_uid (UID): Unique identifier for the server instance. + store_config (StoreConfig | None): Backend specific configuration, including connection configuration, + database name, or client class type. + root_verify_key (SyftVerifyKey | None): Signature verification key, used for checking access permissions. + document_store (DocumentStore | None): Document store used for storing user information. + """ store_config = store_config if store_config is not None else DictStoreConfig() super().__init__( server_uid=server_uid, @@ -430,10 +436,14 @@ class SQLiteActionStore(KeyValueActionStore): """SQLite-Based Key-Value Action store. Parameters: + server_uid: UID + Unique identifier for the server instance. store_config: StoreConfig SQLite specific configuration, including connection settings or client class type. root_verify_key: Optional[SyftVerifyKey] Signature verification key, used for checking access permissions. + document_store: Optional[DocumentStore] + Document store used for storing user information. """ pass @@ -441,13 +451,17 @@ class SQLiteActionStore(KeyValueActionStore): @serializable(canonical_name="MongoActionStore", version=1) class MongoActionStore(KeyValueActionStore): - """Mongo-Based Action store. + """Mongo-Based Action store. Parameters: + server_uid: UID + Unique identifier for the server instance. store_config: StoreConfig Mongo specific configuration. root_verify_key: Optional[SyftVerifyKey] Signature verification key, used for checking access permissions. + document_store: Optional[DocumentStore] + Document store used for storing user information. """ pass diff --git a/packages/syft/src/syft/service/action/action_types.py b/packages/syft/src/syft/service/action/action_types.py index c7bd730d557..b946e8643cb 100644 --- a/packages/syft/src/syft/service/action/action_types.py +++ b/packages/syft/src/syft/service/action/action_types.py @@ -11,11 +11,15 @@ def action_type_for_type(obj_or_type: Any) -> type: - """Convert standard type to Syft types + """Convert standard type to Syft types. - Parameters: - obj_or_type: Union[object, type] - Can be an object or a class + Args: + obj_or_type (Any): Can be an object or a class. If it's an instance of + `ActionDataEmpty`, the internal type is used. + + Returns: + type: Corresponding Syft type for the given object or type. If no corresponding + type is found, the default Syft type for `Any` is returned. """ if isinstance(obj_or_type, ActionDataEmpty): obj_or_type = obj_or_type.syft_internal_type @@ -31,11 +35,14 @@ def action_type_for_type(obj_or_type: Any) -> type: def action_type_for_object(obj: Any) -> type: - """Convert standard type to Syft types + """Convert an object's type to the corresponding Syft type. + + Args: + obj (Any): The object to convert. - Parameters: - obj_or_type: Union[object, type] - Can be an object or a class + Returns: + type: Corresponding Syft type for the given object. If no corresponding + type is found, the default Syft type for `Any` is returned. """ _type = type(obj) diff --git a/packages/syft/src/syft/service/api/api.py b/packages/syft/src/syft/service/api/api.py index 44567e8d8ef..c50030efd4e 100644 --- a/packages/syft/src/syft/service/api/api.py +++ b/packages/syft/src/syft/service/api/api.py @@ -418,7 +418,8 @@ def has_permission(self, context: AuthedServiceContext) -> bool: """Check if the user has permission to access the endpoint. Args: - context: The context of the user requesting the code. + context (AuthedServiceContext): The context of the user requesting the code. + Returns: bool: True if the user has permission to access the endpoint, False otherwise. """ @@ -432,7 +433,8 @@ def select_code( """Select the code to execute based on the user's permissions and public code availability. Args: - context: The context of the user requesting the code. + context (AuthedServiceContext): The context of the user requesting the code. + Returns: Result[Ok, Err]: The selected code to execute. """ @@ -450,9 +452,10 @@ def exec( """Execute the code based on the user's permissions and public code availability. Args: - context: The context of the user requesting the code. - *args: Any - **kwargs: Any + context (AuthedServiceContext): The context of the user requesting the code. + *args (Any): Additional arguments to pass to the code. + **kwargs (Any): Additional keyword arguments to pass to the code. + Returns: Any: The result of the executed code. """ @@ -466,7 +469,16 @@ def exec_mock_function( log_id: UID | None = None, **kwargs: Any, ) -> Any: - """Execute the public code if it exists.""" + """Execute the public code if it exists. + + Args: + context (AuthedServiceContext): The context of the user requesting the code. + *args (Any): Additional arguments to pass to the code. + **kwargs (Any): Additional keyword arguments to pass to the code. + + Returns: + Any: The result of the executed public code or an error if no public code is available. + """ if self.mock_function: return self.exec_code( self.mock_function, context, *args, log_id=log_id, **kwargs @@ -481,14 +493,15 @@ def exec_private_function( log_id: UID | None = None, **kwargs: Any, ) -> Any: - """Execute the private code if user is has the proper permissions. + """Execute the private code if the user has the proper permissions. Args: - context: The context of the user requesting the code. - *args: Any - **kwargs: Any + context (AuthedServiceContext): The context of the user requesting the code. + *args (Any): Additional arguments to pass to the code. + **kwargs (Any): Additional keyword arguments to pass to the code. + Returns: - Any: The result of the executed code. + Any: The result of the executed code or an error message if the user does not have permission. """ if self.private_function is None: raise SyftException(public_message="No private code available") diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 428501fb92d..94279c0ea87 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -579,7 +579,7 @@ def add_route( called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: - SyftSuccess + SyftSuccess : Success message or error message. """ # verify if the peer is truly the one sending the request to add the route to itself if called_by_peer and peer_verify_key != context.credentials: @@ -637,9 +637,10 @@ def delete_route_on_peer( route (ServerRoute): The route to be deleted. Returns: - SyftSuccess: If the route is successfully deleted. - SyftInfo: If there is only one route left for the peer and - the admin chose not to remove it + SyftSuccess | SyftInfo: Success or informational response. + - SyftSuccess: If the route is successfully deleted. + - SyftInfo: If there is only one route left for the peer and + the admin chose not to remove it """ # creates a client on the remote server based on the credentials # of the current server's client @@ -666,18 +667,19 @@ def delete_route( ) -> SyftSuccess | SyftInfo: """ Delete a route for a given peer. - If a peer has no routes left, there will be a prompt asking if the user want to remove it. + If a peer has no routes left, there will be a prompt asking if the user wants to remove it. If the answer is yes, it will be removed from the stash and will no longer be a peer. Args: context (AuthedServiceContext): The authentication context for the service. peer_verify_key (SyftVerifyKey): The verify key of the remote server peer. - route (ServerRoute): The route to be deleted. + route (ServerRoute | None): The route to be deleted. called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: - SyftSuccess: If the route is successfully deleted. - SyftInfo: If there is only one route left for the peer and + SyftSuccess | SyftInfo: Success or informational response. + - SyftSuccess: If the route is successfully deleted. + - SyftInfo: If there is only one route left for the peer and the admin chose not to remove it """ if called_by_peer and peer_verify_key != context.credentials: @@ -759,9 +761,9 @@ def update_route_priority_on_peer( Args: context (AuthedServiceContext): The authentication context. peer (ServerPeer): The peer representing the remote server. - route (ServerRoute): The route to be added. + route (ServerRoute): The route to be updated. priority (int | None): The new priority value for the route. If not - provided, it will be assigned the highest priority among all peers + provided, it will be assigned the highest priority among all peers. Returns: SyftSuccess: A success message if the route is verified, @@ -793,14 +795,15 @@ def update_route_priority( called_by_peer: bool = False, ) -> SyftSuccess: """ - Updates a route's priority for the given peer + Updates a route's priority for the given peer. Args: context (AuthedServiceContext): The authentication context for the service. peer_verify_key (SyftVerifyKey): The verify key of the peer whose route priority needs to be updated. route (ServerRoute): The route for which the priority needs to be updated. priority (int | None): The new priority value for the route. If not - provided, it will be assigned the highest priority among all peers + provided, it will be assigned the highest priority among all peers. + called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: SyftSuccess : Successful response diff --git a/packages/syft/src/syft/service/network/rathole_config_builder.py b/packages/syft/src/syft/service/network/rathole_config_builder.py index fb0ef01d798..33eb1641a68 100644 --- a/packages/syft/src/syft/service/network/rathole_config_builder.py +++ b/packages/syft/src/syft/service/network/rathole_config_builder.py @@ -32,10 +32,9 @@ def add_host_to_server(self, peer: ServerPeer) -> None: Args: peer (ServerPeer): The peer to be added to the rathole server. - Returns: - None + Raises: + Exception: If the peer has no rathole route or if the rathole config map is not found. """ - rathole_route = peer.get_rtunnel_route() if not rathole_route: raise Exception(f"Peer: {peer} has no rathole route: {rathole_route}") @@ -88,10 +87,9 @@ def remove_host_from_server(self, peer_id: str, server_name: str) -> None: peer_id (str): The id of the peer to be removed. server_name (str): The name of the peer to be removed. - Returns: - None + Raises: + Exception: If the rathole config map is not found. """ - rathole_config_map = KubeUtils.get_configmap( client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP ) @@ -116,14 +114,27 @@ def remove_host_from_server(self, peer_id: str, server_name: str) -> None: self._remove_dynamic_addr_from_rathole(server_name) def _get_random_port(self) -> int: - """Get a random port number.""" + """Get a random port number. + + Returns: + int: A randomly generated port number. + """ return secrets.randbits(15) def add_host_to_client( self, peer_name: str, peer_id: str, rtunnel_token: str, remote_addr: str ) -> None: - """Add a host to the rathole client toml file.""" + """Add a host to the rathole client toml file. + + Args: + peer_name (str): The name of the peer. + peer_id (str): The id of the peer. + rtunnel_token (str): The rtunnel token for the peer. + remote_addr (str): The remote address for the rathole client. + Raises: + Exception: If the rathole config map is not found. + """ config = RatholeConfig( uuid=peer_id, secret_token=rtunnel_token, @@ -156,8 +167,14 @@ def add_host_to_client( KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) def remove_host_from_client(self, peer_id: str) -> None: - """Remove a host from the rathole client toml file.""" + """Remove a host from the rathole client toml file. + + Args: + peer_id (str): The id of the peer to be removed. + Raises: + Exception: If the rathole config map is not found. + """ rathole_config_map = KubeUtils.get_configmap( client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP ) @@ -183,8 +200,15 @@ def remove_host_from_client(self, peer_id: str) -> None: def _add_dynamic_addr_to_rathole( self, config: RatholeConfig, entrypoint: str = "web" ) -> None: - """Add a port to the rathole proxy config map.""" + """Add a port to the rathole proxy config map. + + Args: + config (RatholeConfig): The RatholeConfig object containing server details. + entrypoint (str): The entrypoint for the rathole proxy (default: "web"). + Raises: + Exception: If the rathole proxy config map is not found. + """ rathole_proxy_config_map = KubeUtils.get_configmap( self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP ) @@ -229,8 +253,14 @@ def _add_dynamic_addr_to_rathole( self._expose_port_on_rathole_service(config.server_name, config.local_addr_port) def _remove_dynamic_addr_from_rathole(self, server_name: str) -> None: - """Remove a port from the rathole proxy config map.""" + """Remove a port from the rathole proxy config map. + + Args: + server_name (str): The name of the server to remove from the proxy config map. + Raises: + Exception: If the rathole proxy config map is not found. + """ rathole_proxy_config_map = KubeUtils.get_configmap( self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP ) @@ -259,8 +289,12 @@ def _remove_dynamic_addr_from_rathole(self, server_name: str) -> None: self._remove_port_on_rathole_service(server_name) def _expose_port_on_rathole_service(self, port_name: str, port: int) -> None: - """Expose a port on the rathole service.""" + """Expose a port on the rathole service. + Args: + port_name (str): The name of the port. + port (int): The port number to expose. + """ rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") rathole_service = cast(Service, rathole_service) @@ -290,8 +324,11 @@ def _expose_port_on_rathole_service(self, port_name: str, port: int) -> None: rathole_service.patch(config) def _remove_port_on_rathole_service(self, port_name: str) -> None: - """Remove a port from the rathole service.""" + """Remove a port from the rathole service. + Args: + port_name (str): The name of the port to remove. + """ rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") rathole_service = cast(Service, rathole_service) diff --git a/packages/syft/src/syft/service/network/server_peer.py b/packages/syft/src/syft/service/network/server_peer.py index 594c5fe19dc..f7e6368c76a 100644 --- a/packages/syft/src/syft/service/network/server_peer.py +++ b/packages/syft/src/syft/service/network/server_peer.py @@ -66,17 +66,19 @@ class ServerPeer(SyftObject): pinged_timestamp: DateTime | None = None def existed_route(self, route: ServerRouteType) -> tuple[bool, int | None]: - """Check if a route exists in self.server_routes + """Check if a route exists in self.server_routes. Args: - route: the route to be checked. For now it can be either - HTTPServerRoute or PythonServerRoute + route (ServerRouteType): The route to be checked. It can be either + HTTPServerRoute, PythonServerRoute, or VeilidServerRoute. Returns: - if the route exists, returns (True, index of the existed route in self.server_routes) - if the route does not exist returns (False, None) - """ + tuple[bool, int | None]: A tuple containing a boolean indicating whether the route exists, + and the index of the route if it exists, otherwise None. + Raises: + ValueError: If the route type is not supported. + """ if route: if not isinstance( route, HTTPServerRoute | PythonServerRoute | VeilidServerRoute @@ -89,33 +91,28 @@ def existed_route(self, route: ServerRouteType) -> tuple[bool, int | None]: return (False, None) def update_route_priority(self, route: ServerRoute) -> ServerRoute: - """ - Assign the new_route's priority to be current max + 1 + """Assign the new route's priority to be current max + 1. Args: route (ServerRoute): The new route whose priority is to be updated. Returns: - ServerRoute: The new route with the updated priority + ServerRoute: The new route with the updated priority. """ current_max_priority: int = max(route.priority for route in self.server_routes) route.priority = current_max_priority + 1 return route def pick_highest_priority_route(self, oldest: bool = True) -> ServerRoute: - """ - Picks the route with the highest priority from the list of server routes. + """Pick the route with the highest priority from the list of server routes. Args: - oldest (bool): - If True, picks the oldest route to have the highest priority, - meaning the route with min priority value. - If False, picks the most recent route with the highest priority, - meaning the route with max priority value. + oldest (bool): If True, picks the oldest route with the highest priority + (lowest priority value). If False, picks the most recent route + with the highest priority (highest priority value). Returns: ServerRoute: The route with the highest priority. - """ highest_priority_route: ServerRoute = self.server_routes[-1] for route in self.server_routes[:-1]: @@ -128,10 +125,10 @@ def pick_highest_priority_route(self, oldest: bool = True) -> ServerRoute: return highest_priority_route def update_route(self, route: ServerRoute) -> None: - """ - Update the route for the server. - If the route already exists, return it. - If the route is new, assign it to have the priority of (current_max + 1) + """Update the route for the server. + + If the route already exists, it updates the existing route. + If the route is new, it assigns it a priority of (current_max + 1). Args: route (ServerRoute): The new route to be added to the peer. @@ -144,20 +141,14 @@ def update_route(self, route: ServerRoute) -> None: self.server_routes.append(new_route) def update_routes(self, new_routes: list[ServerRoute]) -> None: - """ - Update multiple routes of the server peer. + """Update multiple routes of the server peer. - This method takes a list of new routes as input. - It first updates the priorities of the new routes. - Then, for each new route, it checks if the route already exists for the server peer. - If it does, it updates the priority of the existing route. - If it doesn't, it adds the new route to the server. + This method updates the priorities of new routes and checks if each route + already exists for the server peer. If a route exists, it updates the priority; + otherwise, it adds the new route to the server. Args: new_routes (list[ServerRoute]): The new routes to be added to the server. - - Returns: - None """ for new_route in new_routes: self.update_route(new_route) @@ -172,10 +163,13 @@ def update_existed_route_priority( Args: route (ServerRoute): The route whose priority is to be updated. priority (int | None): The new priority of the route. If not given, - the route will be assigned with the highest priority. + the route will be assigned the highest priority. Returns: - ServerRoute: The route with updated priority if the route exists + ServerRouteType: The route with updated priority. + + Raises: + SyftException: Route doesn't exist or priority is incorrect. """ if priority is not None and priority <= 0: raise SyftException( @@ -200,6 +194,17 @@ def update_existed_route_priority( @staticmethod def from_client(client: SyftClient) -> "ServerPeer": + """Create a ServerPeer object from a SyftClient. + + Args: + client (SyftClient): The SyftClient from which to create the ServerPeer. + + Returns: + ServerPeer: The created ServerPeer object. + + Raises: + ValueError: If the client does not have metadata. + """ if not client.metadata: raise ValueError("Client has to have metadata first") @@ -210,8 +215,7 @@ def from_client(client: SyftClient) -> "ServerPeer": @property def latest_added_route(self) -> ServerRoute | None: - """ - Returns the latest added route from the list of server routes. + """Get the latest added route. Returns: ServerRoute | None: The latest added route, or None if there are no routes. @@ -220,11 +224,19 @@ def latest_added_route(self) -> ServerRoute | None: @as_result(SyftException) def client_with_context(self, context: ServerServiceContext) -> SyftClient: - # third party + """Create a SyftClient using the context of a ServerService. + + Args: + context (ServerServiceContext): The context to use for creating the client. + + Returns: + SyftClient: The SyftClient object + Raises: + SyftException: If there are no routes to the peer. + """ if len(self.server_routes) < 1: - raise ValueError(f"No routes to peer: {self}") - # select the route with highest priority to connect to the peer + raise SyftException(f"No routes to peer: {self}") final_route: ServerRoute = self.pick_highest_priority_route() connection: ServerConnection = route_to_connection(route=final_route) client_type = connection.get_client_type().unwrap( @@ -237,6 +249,17 @@ def client_with_context(self, context: ServerServiceContext) -> SyftClient: @as_result(SyftException) def client_with_key(self, credentials: SyftSigningKey) -> SyftClient: + """Create a SyftClient using a signing key. + + Args: + credentials (SyftSigningKey): The signing key to use for creating the client. + + Returns: + SyftClient: The created SyftClient, or a SyftError if unsuccessful. + + Raises: + SyftException: If there are no routes to the peer. + """ if len(self.server_routes) < 1: raise SyftException(public_message=f"No routes to peer: {self}") @@ -248,28 +271,45 @@ def client_with_key(self, credentials: SyftSigningKey) -> SyftClient: @property def guest_client(self) -> SyftClient: + """Create a guest SyftClient with a randomly generated signing key. + + Returns: + SyftClient: The created guest SyftClient. + """ guest_key = SyftSigningKey.generate() return self.client_with_key(credentials=guest_key).unwrap() def proxy_from(self, client: SyftClient) -> SyftClient: + """Create a proxy SyftClient from an existing client. + + Args: + client (SyftClient): The existing SyftClient to proxy from. + + Returns: + SyftClient: The created proxy SyftClient. + """ return client.proxy_to(self) def get_rtunnel_route(self) -> HTTPServerRoute | None: + """Get the HTTPServerRoute with an rtunnel token. + + Returns: + HTTPServerRoute | None: The route with the rtunnel token, or None if not found. + """ for route in self.server_routes: if hasattr(route, "rtunnel_token") and route.rtunnel_token: return route return None def delete_route(self, route: ServerRouteType) -> None: - """ - Deletes a route from the peer's route list. - Takes O(n) where is n is the number of routes in self.server_routes. + """Delete a route from the peer's route list. + Takes O(n) where is n is the number of routes in self.server_routes. Args: - route (ServerRouteType): The route to be deleted; + route (ServerRouteType): The route to be deleted. Returns: - None + None: If successful. """ if route: try: @@ -297,6 +337,12 @@ class ServerPeerUpdate(PartialSyftObject): def drop_veilid_route() -> Callable: + """Drop VeilidServerRoute from the server routes in the context output. + + Returns: + Callable: The function that drops VeilidServerRoute from the context output. + """ + def _drop_veilid_route(context: TransformContext) -> TransformContext: if context.output: server_routes = context.output["server_routes"] diff --git a/packages/syft/src/syft/service/network/utils.py b/packages/syft/src/syft/service/network/utils.py index 280e836f17b..5070e7b8d93 100644 --- a/packages/syft/src/syft/service/network/utils.py +++ b/packages/syft/src/syft/service/network/utils.py @@ -36,7 +36,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> None: context (AuthedServiceContext): The authenticated service context. Returns: - None + SyftError | None: """ network_stash = context.server.services.network.stash diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 15b2b9725d8..fe009063c68 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -93,7 +93,7 @@ def user_settings( def settings( self, context: AuthedServiceContext, - ) -> NotifierSettings: + ) -> NotifierSettings | None: return context.server.services.notifier.settings(context).unwrap() @service_method( diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py index 3cf784f5095..e4410d9c570 100644 --- a/packages/syft/src/syft/service/notifier/notifier.py +++ b/packages/syft/src/syft/service/notifier/notifier.py @@ -275,7 +275,7 @@ def select_notifiers(self, notification: Notification) -> list[BaseNotifier]: Args: notification (Notification): The notification object Returns: - List[BaseNotifier]: A list of enabled notifier objects + list[BaseNotifier]: A list of enabled notifier objects """ notifier_objs = [] for notifier_type in notification.notifier_types: diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index d8c1da2530f..bbf85554b29 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -45,11 +45,12 @@ def __init__(self, store: DocumentStore) -> None: def settings( self, context: AuthedServiceContext, - ) -> NotifierSettings: - """Get Notifier Settings + ) -> NotifierSettings | None: + """Get Notifier Settings. Args: - context: The request context + context (AuthedServiceContext): The request context. + Returns: NotifierSettings | None: The notifier settings, if it exists; None otherwise. """ @@ -104,16 +105,19 @@ def turn_on( """Turn on email notifications. Args: - email_username (Optional[str]): Email server username. Defaults to None. - email_password (Optional[str]): Email email server password. Defaults to None. - sender_email (Optional[str]): Email sender email. Defaults to None. + context (AuthedServiceContext): The request context. + email_username (str | None): Email server username. Defaults to None. + email_password (str | None): Email server password. Defaults to None. + email_sender (str | None): Email sender address. Defaults to None. + email_server (str | None): Email server address. Defaults to None. + email_port (int | None): Email server port. Defaults to 587. + Returns: - SyftSuccess: success response. + SyftSuccess: SyftSuccess if successful. Raises: SyftException: any error that occurs during the process """ - # 1 - If something went wrong at db level, return the error notifier = self.stash.get(credentials=context.credentials).unwrap() @@ -199,6 +203,12 @@ def turn_off( """ Turn off email notifications service. PySyft notifications will still work. + + Args: + context (AuthedServiceContext): The request context. + + Returns: + SyftSuccess: SyftSuccess if successful. """ notifier = self.stash.get(credentials=context.credentials).unwrap() @@ -223,8 +233,15 @@ def activate( def deactivate( self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL ) -> SyftSuccess: - """Deactivate email notifications for the authenticated user + """Deactivate email notifications for the authenticated user. This will only work if the datasite owner has enabled notifications. + + Args: + context (AuthedServiceContext): The request context. + notifier_type (NOTIFIERS): The notifier type to deactivate. Defaults to NOTIFIERS.EMAIL. + + Returns: + SyftSuccess: SyftSuccess if successful. """ result = context.server.services.user.disable_notifications( context, notifier_type=notifier_type @@ -242,18 +259,23 @@ def init_notifier( smtp_host: str | None = None, ) -> SyftSuccess: """Initialize Notifier settings for a Server. + If settings already exist, it will use the existing one. If not, it will create a new one. Args: - server: Server to initialize the notifier - active: If notifier should be active - email_username: Email username to send notifications - email_password: Email password to send notifications - Raises: - Exception: If something went wrong + server (AbstractServer): Server to initialize the notifier. + email_username (str | None): Email username to send notifications. Defaults to None. + email_password (str | None): Email password to send notifications. Defaults to None. + email_sender (str | None): Email sender address. Defaults to None. + smtp_port (int | None): SMTP server port. Defaults to None. + smtp_host (str | None): SMTP server host. Defaults to None. + Returns: - SyftSuccess + SyftSuccess: SyftSuccess if successful. + + Raises: + SyftException: Error in creating or initializing notifier """ try: # Create a new NotifierStash since its a static method. @@ -310,6 +332,16 @@ def init_notifier( def set_email_rate_limit( self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int ) -> SyftSuccess: + """Set the email rate limit for a specific email type. + + Args: + context (AuthedServiceContext): The request context. + email_type (EMAIL_TYPES): The type of email for which to set the rate limit. + daily_limit (int): The daily limit for the specified email type. + + Returns: + SyftSuccess: SyftSuccess if successful. + """ notifier = self.stash.get(context.credentials).unwrap( public_message="Couldn't set the email rate limit." ) @@ -324,6 +356,22 @@ def set_email_rate_limit( def dispatch_notification( self, context: AuthedServiceContext, notification: Notification ) -> SyftSuccess: + """Dispatch a notification to the user. + + This method is used internally by other services to send notifications. + + Args: + context (AuthedServiceContext): The request context. + notification (Notification): The notification to dispatch. + + Returns: + SyftSuccess: SyftSuccess if the notification was successfully dispatched. + + Raises: + SyftException | RateLimitException: + - SyftException: Notifier settings not found or few other things. + - RateLimitException: Surpassed email threshold limit + """ admin_key = context.server.services.user.admin_verify_key() # Silently fail on notification not delivered diff --git a/packages/syft/src/syft/service/notifier/smtp_client.py b/packages/syft/src/syft/service/notifier/smtp_client.py index eef25440af8..696eb55cb80 100644 --- a/packages/syft/src/syft/service/notifier/smtp_client.py +++ b/packages/syft/src/syft/service/notifier/smtp_client.py @@ -23,6 +23,17 @@ class SMTPClient(BaseModel): username: str | None = None def send(self, sender: str, receiver: list[str], subject: str, body: str) -> None: + """Send an email using the SMTP server. + + Args: + sender (str): The sender's email address. + receiver (list[str]): A list of recipient email addresses. + subject (str): The subject of the email. + body (str): The HTML body of the email. + + Raises: + ValueError: If subject, body, or receiver is not provided. + """ if not (subject and body and receiver): raise ValueError("Subject, body, and recipient email(s) are required") @@ -57,7 +68,13 @@ def send(self, sender: str, receiver: list[str], subject: str, body: str) -> Non def check_credentials( cls, server: str, port: int, username: str, password: str ) -> bool: - """Check if the credentials are valid. + """Check if the provided SMTP credentials are valid. + + Args: + server (str): The SMTP server address. + port (int): The port number to connect to. + username (str): The username for the SMTP server. + password (str): The password for the SMTP server. Returns: bool: True if the credentials are valid, False otherwise. diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 0b64a9d8870..60b3ab7a24c 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -318,7 +318,7 @@ def status(self, project: Project) -> SyftInfo | None: project (Project): Project object to check the status Returns: - str: Status of the request. + SyftInfo | SyftError | None: Status of the request. During Request status calculation, we do not allow multiple responses """ @@ -563,9 +563,10 @@ def status( Args: project (Project): Project object to check the status + pretty_print (bool): Flag for pretty printing Returns: - str: Status of the poll + dict | SyftError | SyftInfo | None: Status of the poll During Poll calculation, a user would have answered the poll many times The status of the poll would be calculated based on the latest answer of the user @@ -1348,7 +1349,7 @@ def hash_object(obj: Any) -> tuple[bytes, str]: obj (Any): Object to be hashed Returns: - str: Hashed value of the object + tuple[bytes, str]: Hashed value of the object """ hash_bytes = _serialize(obj, to_bytes=True, for_hashing=True) hash = hashlib.sha256(hash_bytes) diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 3a2543e38a1..927e3fa432c 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -365,7 +365,10 @@ def check_for_project_request( context (AuthedServiceContext): Context of the server Returns: - SyftSuccess: SyftSuccess if message is created else SyftError + None: No return + + Raises: + SyftException: If notification failed to send. """ if ( isinstance(project_event, ProjectRequest) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 1a492ea1d53..de87f6c3b30 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -603,6 +603,9 @@ def deny(self, reason: str) -> SyftSuccess: Args: reason (str): Reason for which the request has been denied. + + Returns: + SyftSuccess | SyftError: Result of the operation. """ api = self._get_api() @@ -776,9 +779,9 @@ def deposit_result( result (Any): ActionObject or any object to be saved as an ActionObject. log_stdout (str): stdout logs. log_stderr (str): stderr logs. - approve (bool, optional): Only supported for L2 requests. If True, the request will be approved. + approve (bool | None): Only supported for L2 requests. If True, the request will be approved. Defaults to None. - + **kwargs (dict[str, Any]): Additional arguments. Returns: Job: Job object if successful, else raise SyftException. diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index 43fe685b7e5..5ffd09c07b4 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -88,27 +88,16 @@ def update( Update the Server Settings using the provided values. Args: - name: Optional[str] - Server name - organization: Optional[str] - Organization name - description: Optional[str] - Server description - on_board: Optional[bool] - Show onboarding panel when a user logs in for the first time - signup_enabled: Optional[bool] - Enable/Disable registration - admin_email: Optional[str] - Administrator email - association_request_auto_approval: Optional[bool] + context (AuthedServiceContext): The authenticated service context. + settings (ServerSettingsUpdate): The settings to update. Returns: SyftSuccess: Message indicating the success of the operation, with the update server settings as the value property. Example: - >>> server_client.update(name='foo', organization='bar', description='baz', signup_enabled=True) - SyftSuccess: Settings updated successfully. + >>> server_client.update(settings=ServerSettingsUpdate(signup_enabled=True)) + SyftSuccess: Settings updated successfully. """ updated_settings = self._update(context, settings).unwrap() return SyftSuccess( diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 4ceced2bf26..2b54478db26 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -94,16 +94,18 @@ def launch( ) -> list[ContainerSpawnStatus]: """Creates a pool of workers from the given SyftWorkerImage. - - Retrieves the image for the given UID - - Use docker to launch containers for given image - - For each successful container instantiation create a SyftWorker object - - Creates a SyftWorkerPool object - Args: - context (AuthedServiceContext): context passed to the service - name (str): name of the pool - image_id (UID): UID of the SyftWorkerImage against which the pool should be created - num_workers (int): the number of SyftWorker that needs to be created in the pool + context (AuthedServiceContext): Context passed to the service. + pool_name (str): Name of the pool. + image_uid (UID | None): UID of the SyftWorkerImage against which the pool should be created. + num_workers (int): The number of SyftWorkers that need to be created in the pool. + registry_username (str | None, optional): Username for the registry. Defaults to None. + registry_password (str | None, optional): Password for the registry. Defaults to None. + pod_annotations (dict[str, str] | None, optional): Annotations for the pod. Defaults to None. + pod_labels (dict[str, str] | None, optional): Labels for the pod. Defaults to None. + + Returns: + list[ContainerSpawnStatus] | SyftError: List of container spawn statuses or an error. """ pool_exists = self.pool_exists(context, pool_name=pool_name).unwrap() @@ -176,9 +178,13 @@ def create_pool_request( context (AuthedServiceContext): The authenticated service context. pool_name (str): The name of the worker pool. num_workers (int): The number of workers in the pool. - image_uid (Optional[UID]): The UID of the built image. - reason (Optional[str], optional): The reason for creating the - worker pool. Defaults to "". + image_uid (UID): The UID of the built image. + reason (str | None, optional): The reason for creating the worker pool. Defaults to "". + pod_annotations (dict[str, str] | None, optional): Annotations for the pod. Defaults to None. + pod_labels (dict[str, str] | None, optional): Labels for the pod. Defaults to None. + + Returns: + Request: Request object """ # Check if image exists for the given image id worker_image_exists = self.image_exists(context, uid=image_uid).unwrap() @@ -241,11 +247,17 @@ def create_image_and_pool_request( context (AuthedServiceContext): The authenticated service context. pool_name (str): The name of the worker pool. num_workers (int): The number of workers in the pool. - config: (WorkerConfig): Config of the image to be built. - tag (str | None, optional): - a human-readable manifest identifier that is typically a specific version or variant of an image, - only needed for `DockerWorkerConfig` to tag the image after it is built. + config (WorkerConfig): Config of the image to be built. + tag (str | None, optional): A human-readable manifest identifier. Required for `DockerWorkerConfig`. + registry_uid (UID | None, optional): UID of the registry in Kubernetes mode. Required + for `DockerWorkerConfig`. reason (str | None, optional): The reason for creating the worker image and pool. Defaults to "". + pull_image (bool, optional): Whether to pull the image. Defaults to True. + pod_annotations (dict[str, str] | None, optional): Annotations for the pod. Defaults to None. + pod_labels (dict[str, str] | None, optional): Labels for the pod. Defaults to None. + + Returns: + Request: Request object. """ if not isinstance(config, DockerWorkerConfig | PrebuiltWorkerConfig): raise SyftException( @@ -327,6 +339,14 @@ def create_image_and_pool_request( roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_all(self, context: AuthedServiceContext) -> DictTuple[str, WorkerPool]: + """Get all worker pools. + + Args: + context (AuthedServiceContext): The authenticated service context. + + Returns: + DictTuple[str, WorkerPool]: All worker pools. + """ # TODO: During get_all, we should dynamically make a call to docker to get the status of the containers # and update the status of the workers in the pool. worker_pools = self.stash.get_all(credentials=context.credentials).unwrap() @@ -353,13 +373,15 @@ def add_workers( Worker pool is fetched either using the unique pool id or pool name. Args: - context (AuthedServiceContext): _description_ - number (int): number of workers to add - pool_id (Optional[UID], optional): Unique UID of the pool. Defaults to None. - pool_name (Optional[str], optional): Unique name of the pool. Defaults to None. + context (AuthedServiceContext): The authenticated service context. + number (int): Number of workers to add. + pool_id (UID | None, optional): Unique UID of the pool. Defaults to None. + pool_name (str | None, optional): Unique name of the pool. Defaults to None. + registry_username (str | None, optional): Username for the registry. Defaults to None. + registry_password (str | None, optional): Password for the registry. Defaults to None. Returns: - List[ContainerSpawnStatus]: List of spawned workers with their status and error if any. + List[ContainerSpawnStatus]: List of spawned workers with their status. """ if number <= 0: @@ -419,6 +441,15 @@ def scale( """ Scale the worker pool to the given number of workers in Kubernetes. Allows both scaling up and down the worker pool. + + Args: + context (AuthedServiceContext): The authenticated service context. + number (int): Number of workers to scale to. + pool_id (UID | None, optional): Unique UID of the pool. Defaults to None. + pool_name (str | None, optional): Unique name of the pool. Defaults to None. + + Returns: + SyftSuccess: Success message. """ client_warning = "" @@ -498,6 +529,15 @@ def scale( def filter_by_image_id( self, context: AuthedServiceContext, image_uid: UID ) -> list[WorkerPool]: + """Filter worker pools by image ID. + + Args: + context (AuthedServiceContext): The authenticated service context. + image_uid (UID): The UID of the image. + + Returns: + list[WorkerPool]: List of worker pools. + """ return self.stash.get_by_image_uid(context.credentials, image_uid).unwrap() @service_method( @@ -508,6 +548,15 @@ def filter_by_image_id( def get_by_name( self, context: AuthedServiceContext, pool_name: str ) -> list[WorkerPool]: + """Get worker pool by name. + + Args: + context (AuthedServiceContext): The authenticated service context. + pool_name (str): The name of the worker pool. + + Returns: + list[WorkerPool]: List of worker pools. + """ return self.stash.get_by_name(context.credentials, pool_name).unwrap() @service_method( @@ -521,8 +570,15 @@ def sync_pool_from_request( context: AuthedServiceContext, request: Request, ) -> Request: - """Re-submit request from a different server""" + """Re-submit request from a different server + Args: + context (AuthedServiceContext): The authenticated service context. + request (Request): The request object. + + Returns: + Request: Request object + """ num_of_changes = len(request.changes) pool_name, num_workers, config, image_uid, tag = None, None, None, None, None diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py index ca0f3e1f33a..e746205a659 100644 --- a/packages/syft/src/syft/store/dict_document_store.py +++ b/packages/syft/src/syft/store/dict_document_store.py @@ -21,14 +21,33 @@ @serializable(canonical_name="DictBackingStore", version=1) class DictBackingStore(dict, KeyValueBackingStore): # type: ignore[misc] - # TODO: fix the mypy issue - """Dictionary-based Store core logic""" + """Dictionary-based Store core logic + + This class provides the core logic for a dictionary-based key-value store. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the dictionary-based backing store. + + Args: + *args (Any): Positional arguments. + **kwargs (Any): Keyword arguments. + """ super().__init__() self._ddtype = kwargs.get("ddtype", None) def __getitem__(self, key: Any) -> Any: + """Retrieve an item from the store by key. + + Args: + key (Any): The key of the item to retrieve. + + Returns: + Any: The value associated with the key. + + Raises: + KeyError: If the key is not found in the store. + """ try: value = super().__getitem__(key) return value @@ -42,25 +61,23 @@ def __getitem__(self, key: Any) -> Any: class DictStorePartition(KeyValueStorePartition): """Dictionary-based StorePartition + This class represents a partition within a dictionary-based key-value store. + Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for indexing and partitioning - `store_config`: DictStoreConfig - DictStore specific configuration + settings (PartitionSettings): PySyft specific settings, used for indexing and partitioning. + store_config (DictStoreConfig): Dictionary Store specific configuration. """ def prune(self) -> None: + """Reset the partition by reinitializing the store.""" self.init_store().unwrap() -# the base document store is already a dict but we can change it later @serializable(canonical_name="DictDocumentStore", version=1) class DictDocumentStore(DocumentStore): """Dictionary-based Document Store - Parameters: - `store_config`: DictStoreConfig - Dictionary Store specific configuration, containing the store type and the backing store type + This class represents a document store implemented using a dictionary. """ partition_type = DictStorePartition @@ -80,27 +97,30 @@ def __init__( ) def reset(self) -> None: + """Reset the document store by pruning all partitions.""" for partition in self.partitions.values(): partition.prune() @serializable() class DictStoreConfig(StoreConfig): - __canonical_name__ = "DictStoreConfig" """Dictionary-based configuration + This class provides the configuration for a dictionary-based document store. + + Attributes: + store_type (type[DocumentStore]): The Document type used. Default: DictDocumentStore. + backing_store (type[KeyValueBackingStore]): The backend type used. Default: DictBackingStore. + locking_config (LockingConfig): The config used for store locking. Defaults to ThreadingLockingConfig. + Parameters: - `store_type`: Type[DocumentStore] - The Document type used. Default: DictDocumentStore - `backing_store`: Type[KeyValueBackingStore] - The backend type used. Default: DictBackingStore - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. - Defaults to ThreadingLockingConfig. + store_type (Type[DocumentStore]): The Document type used. Default: DictDocumentStore. + backing_store (Type[KeyValueBackingStore]): The backend type used. Default: DictBackingStore. + locking_config (LockingConfig): The config used for store locking. """ + __canonical_name__ = "DictStoreConfig" + store_type: type[DocumentStore] = DictDocumentStore backing_store: type[KeyValueBackingStore] = DictBackingStore locking_config: LockingConfig = Field(default_factory=ThreadingLockingConfig) diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index cc97802a08b..b6c56d16f64 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -44,9 +44,8 @@ class BasePartitionSettings(SyftBaseModel): """Basic Partition Settings - Parameters: - name: str - Identifier to be used as prefix by stores and for partitioning + Attributes: + name (str): Identifier to be used as a prefix by stores and for partitioning. """ name: str @@ -163,7 +162,6 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: pk_key = partition_key.key pk_type = partition_key.type_ - # 🟡 TODO: support more advanced types than List[type] if partition_key.type_list: pk_value = partition_key.extract_list(obj) else: @@ -171,10 +169,6 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: pk_value = obj else: pk_value = getattr(obj, pk_key) - # object has a method for getting these types - # we can't use properties because we don't seem to be able to get the - # return types - # TODO: fix the mypy issue if isinstance(pk_value, types.FunctionType | types.MethodType): # type: ignore[unreachable] pk_value = pk_value() # type: ignore[unreachable] @@ -194,7 +188,6 @@ def as_dict_mongo(self) -> dict[str, Any]: if key == "id": key = "_id" if self.type_list: - # We want to search inside the list of values return {key: {"$in": self.value}} return {key: self.value} @@ -217,19 +210,15 @@ class QueryKeys(SyftBaseModel): @property def all(self) -> tuple[QueryKey, ...] | list[QueryKey]: - # make sure we always return a list even if there's a single value return self.qks if isinstance(self.qks, tuple | list) else [self.qks] @staticmethod def from_obj(partition_keys: PartitionKeys, obj: SyftObject) -> QueryKeys: qks = [] for partition_key in partition_keys.all: - pk_key = partition_key.key # name of the attribute + pk_key = partition_key.key pk_type = partition_key.type_ pk_value = getattr(obj, pk_key) - # object has a method for getting these types - # we can't use properties because we don't seem to be able to get the - # return types if isinstance(pk_value, types.FunctionType | types.MethodType): pk_value = pk_value() if partition_key.type_list: @@ -282,7 +271,6 @@ def as_dict_mongo(self) -> dict: if qk_key == "id": qk_key = "_id" if qk.type_list: - # We want to search inside the list of values qk_dict[qk_key] = {"$in": qk_value} else: qk_dict[qk_key] = qk_value @@ -313,14 +301,7 @@ def searchable_keys(self) -> PartitionKeys: version=1, ) class StorePartition: - """Base StorePartition - - Parameters: - settings: PartitionSettings - PySyft specific settings - store_config: StoreConfig - Backend specific configuration - """ + """Base StorePartition""" def __init__( self, @@ -330,6 +311,18 @@ def __init__( store_config: StoreConfig, has_admin_permissions: Callable[[SyftVerifyKey], bool] | None = None, ) -> None: + """Base store partition initialization + + Args: + server_uid (UID): Unique identifier for the server instance. + root_verify_key (SyftVerifyKey | None): Root signature verification key. + settings (PartitionSettings): PySyft specific settings. + store_config (StoreConfig): Backend specific configuration. + has_admin_permissions (Callable[[SyftVerifyKey], bool] | None): Callback to check admin permissions. + + Raises: + RuntimeError: If the store initialization fails. + """ if root_verify_key is None: root_verify_key = SyftSigningKey.generate().verify_key self.server_uid = server_uid @@ -601,12 +594,7 @@ def _migrate_data( @serializable(canonical_name="DocumentStore", version=1) class DocumentStore: - """Base Document Store - - Parameters: - store_config: StoreConfig - Store specific configuration. - """ + """Base Document Store""" partitions: dict[str, StorePartition] partition_type: type[StorePartition] @@ -617,6 +605,16 @@ def __init__( root_verify_key: SyftVerifyKey | None, store_config: StoreConfig, ) -> None: + """Base document store initialization + + Args: + server_uid (UID): Unique identifier for the server instance. + root_verify_key (SyftVerifyKey | None): Root signature verification key. + store_config (StoreConfig): Store specific configuration. + + Raises: + Exception: If store config is not found + """ if store_config is None: raise Exception("must have store config") self.partitions = {} @@ -675,6 +673,11 @@ def get_partition_object_types(self) -> list[type]: class StoreConfig(SyftBaseObject): """Base Store configuration + Attributes: + store_type (type[DocumentStore]): Document Store type. + client_config (StoreClientConfig | None): Backend-specific config. + locking_config (LockingConfig): The config used for store locking. + Parameters: store_type: Type Document Store type diff --git a/packages/syft/src/syft/store/locks.py b/packages/syft/src/syft/store/locks.py index 2494dffa895..bb71e84f691 100644 --- a/packages/syft/src/syft/store/locks.py +++ b/packages/syft/src/syft/store/locks.py @@ -19,19 +19,21 @@ @serializable(canonical_name="LockingConfig", version=1) class LockingConfig(BaseModel): """ - Locking config + Locking configuration. + + Attributes: + lock_name (str): Lock name. + namespace (str | None): Namespace to use for setting lock keys in the backend store. + expire (int | None): Lock expiration time in seconds. If explicitly set to `None`, the lock will not expire. + timeout (int | None): Timeout to acquire lock (in seconds). + retry_interval (float): Retry interval to retry acquiring a lock if previous attempts failed. Args: - lock_name: str - Lock name - namespace: Optional[str] - Namespace to use for setting lock keys in the backend store. - expire: Optional[int] - Lock expiration time in seconds. If explicitly set to `None`, lock will not expire. - timeout: Optional[int] - Timeout to acquire lock(seconds) - retry_interval: float - Retry interval to retry acquiring a lock if previous attempts failed. + lock_name (str): Lock name. + namespace (str | None): Namespace to use for setting lock keys in the backend store. + expire (int | None): Lock expiration time in seconds. If explicitly set to `None`, the lock will not expire. + timeout (int | None): Timeout to acquire lock (in seconds). + retry_interval (float): Retry interval to retry acquiring a lock if previous attempts failed. """ lock_name: str = "syft_lock" @@ -44,7 +46,7 @@ class LockingConfig(BaseModel): @serializable(canonical_name="NoLockingConfig", version=1) class NoLockingConfig(LockingConfig): """ - No-locking policy + No-locking policy. """ pass @@ -53,7 +55,7 @@ class NoLockingConfig(LockingConfig): @serializable(canonical_name="ThreadingLockingConfig", version=1) class ThreadingLockingConfig(LockingConfig): """ - Threading-based locking policy + Threading-based locking policy. """ pass @@ -72,9 +74,10 @@ def __init__(self, expire: int, **kwargs: Any) -> None: @property def _locked(self) -> bool: """ - Implementation of method to check if lock has been acquired. Must be - :returns: if the lock is acquired or not - :rtype: bool + Check if the lock has been acquired. + + Returns: + bool: True if the lock is acquired, False otherwise. """ locked = self.lock.locked() if ( @@ -88,9 +91,10 @@ def _locked(self) -> bool: def _acquire(self) -> bool: """ - Implementation of acquiring a lock in a non-blocking fashion. - :returns: if the lock was successfully acquired or not - :rtype: bool + Acquire the lock in a non-blocking fashion. + + Returns: + bool: True if the lock was successfully acquired, False otherwise. """ locked = self.lock.locked() if ( @@ -100,18 +104,15 @@ def _acquire(self) -> bool: ): self._release() - status = self.lock.acquire( - blocking=False - ) # timeout/retries handle in the `acquire` method + status = self.lock.acquire(blocking=False) if status: self.locked_timestamp = time.time() return status def _release(self) -> None: """ - Implementation of releasing an acquired lock. + Release the acquired lock. """ - try: return self.lock.release() except RuntimeError: # already unlocked @@ -119,17 +120,23 @@ def _release(self) -> None: def _renew(self) -> bool: """ - Implementation of renewing an acquired lock. + Renew the acquired lock. + + Returns: + bool: True if the lock was successfully renewed, False otherwise. """ return True class SyftLock(BaseLock): """ - Syft Lock implementations. + Syft Lock implementation. + + Args: + config (LockingConfig): Configuration specific to a locking strategy. - Params: - config: Config specific to a locking strategy. + Raises: + ValueError: If an unsupported config type is provided. """ def __init__(self, config: LockingConfig): @@ -162,10 +169,10 @@ def __init__(self, config: LockingConfig): @property def _locked(self) -> bool: """ - Implementation of method to check if lock has been acquired. + Check if the lock has been acquired. - :returns: if the lock is acquired or not - :rtype: bool + Returns: + bool: True if the lock is acquired, False otherwise. """ if self.passthrough: return False @@ -174,12 +181,13 @@ def _locked(self) -> bool: def acquire(self, blocking: bool = True) -> bool: """ Acquire a lock, blocking or non-blocking. - :param bool blocking: acquire a lock in a blocking or non-blocking - fashion. Defaults to True. - :returns: if the lock was successfully acquired or not - :rtype: bool - """ + Args: + blocking (bool): Acquire a lock in a blocking or non-blocking fashion. Defaults to True. + + Returns: + bool: True if the lock was successfully acquired, False otherwise. + """ if not blocking: return self._acquire() @@ -195,16 +203,14 @@ def acquire(self, blocking: bool = True) -> bool: logger.debug( f"Timeout elapsed after {self.timeout} seconds while trying to acquiring lock." ) - # third party return False def _acquire(self) -> bool: """ - Implementation of acquiring a lock in a non-blocking fashion. - `acquire` makes use of this implementation to provide blocking and non-blocking implementations. + Acquire the lock in a non-blocking fashion. - :returns: if the lock was successfully acquired or not - :rtype: bool + Returns: + bool: True if the lock was successfully acquired, False otherwise. """ if self.passthrough: return True @@ -216,7 +222,10 @@ def _acquire(self) -> bool: def _release(self) -> bool | None: """ - Implementation of releasing an acquired lock. + Release the acquired lock. + + Returns: + bool | None: True if the lock was successfully released, None otherwise. """ if self.passthrough: return None @@ -229,7 +238,10 @@ def _release(self) -> bool | None: def _renew(self) -> bool: """ - Implementation of renewing an acquired lock. + Renew the acquired lock. + + Returns: + bool: True if the lock was successfully renewed, False otherwise. """ if self.passthrough: return True diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index 805cc042cdf..0cfd69bb676 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -120,13 +120,14 @@ def from_mongo( @serializable(attrs=["storage_type"], canonical_name="MongoStorePartition", version=1) class MongoStorePartition(StorePartition): - """Mongo StorePartition + """Mongo StorePartition. - Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for partitioning and indexing. - `store_config`: MongoStoreConfig - Mongo specific configuration + Attributes: + storage_type (type[StorableObjectType]): The storage type for objects in the partition. + + Args: + settings (PartitionSettings): PySyft specific settings, used for partitioning and indexing. + store_config (MongoStoreConfig): Mongo specific configuration. """ storage_type: type[StorableObjectType] = MongoBsonObject @@ -450,7 +451,7 @@ def _delete( ) def has_permission(self, permission: ActionObjectPermission) -> bool: - """Check if the permission is inside the permission collection""" + """Check if the permission is inside the permission collection.""" collection_permissions_status = self.permissions if collection_permissions_status.is_err(): return False @@ -737,12 +738,7 @@ def _migrate_data( @serializable(canonical_name="MongoDocumentStore", version=1) class MongoDocumentStore(DocumentStore): - """Mongo Document Store - - Parameters: - `store_config`: MongoStoreConfig - Mongo specific configuration, including connection configuration, database name, or client class type. - """ + """Mongo Document Store.""" partition_type = MongoStorePartition @@ -754,18 +750,13 @@ class MongoDocumentStore(DocumentStore): ) class MongoBackingStore(KeyValueBackingStore): """ - Core logic for the MongoDB key-value store - - Parameters: - `index_name`: str - Index name (can be either 'data' or 'permissions') - `settings`: PartitionSettings - Syft specific settings - `store_config`: StoreConfig - Connection Configuration - `ddtype`: Type - Optional and should be None - Used to make a consistent interface with SQLiteBackingStore + Core logic for the MongoDB key-value store. + + Args: + index_name (str): Index name (can be either 'data' or 'permissions'). + settings (PartitionSettings): Syft specific settings. + store_config (StoreConfig): Connection configuration. + ddtype (type | None): Optional and should be None. Used to make a consistent interface with SQLiteBackingStore. """ def __init__( @@ -939,19 +930,15 @@ def __del__(self) -> None: @serializable() class MongoStoreConfig(StoreConfig): __canonical_name__ = "MongoStoreConfig" - """Mongo Store configuration - - Parameters: - `client_config`: MongoStoreClientConfig - Mongo connection details: hostname, port, user, password etc. - `store_type`: Type[DocumentStore] - The type of the DocumentStore. Default: MongoDocumentStore - `db_name`: str - Database name - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. + """Mongo Store configuration. + + Args: + client_config (MongoStoreClientConfig): Mongo connection details: hostname, port, user, password etc. + store_type (Type[DocumentStore]): The type of the DocumentStore. Default: MongoDocumentStore. + db_name (str): Database name. + locking_config (LockingConfig): The config used for store locking. Available options: + * NoLockingConfig: no locking, ideal for single-thread stores. + * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. Defaults to NoLockingConfig. """ diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index 2cbef952862..a0c3048babd 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -84,15 +84,11 @@ def special_exception_public_message(table_name: str, e: Exception) -> str: class SQLiteBackingStore(KeyValueBackingStore): """Core Store logic for the SQLite stores. - Parameters: - `index_name`: str - Index name - `settings`: PartitionSettings - Syft specific settings - `store_config`: SQLiteStoreConfig - Connection Configuration - `ddtype`: Type - Class used as fallback on `get` errors + Args: + index_name (str): Index name. + settings (PartitionSettings): Syft specific settings. + store_config (StoreConfig): Connection configuration. + ddtype (type | None): Class used as fallback on `get` errors. """ def __init__( @@ -359,11 +355,9 @@ def __del__(self) -> None: class SQLiteStorePartition(KeyValueStorePartition): """SQLite StorePartition - Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for indexing and partitioning - `store_config`: SQLiteStoreConfig - SQLite specific configuration + Args: + settings (PartitionSettings): PySyft specific settings, used for indexing and partitioning. + store_config (SQLiteStoreConfig): SQLite specific configuration. """ def close(self) -> None: @@ -392,34 +386,31 @@ def commit(self) -> None: # the base document store is already a dict but we can change it later @serializable(canonical_name="SQLiteDocumentStore", version=1) class SQLiteDocumentStore(DocumentStore): - """SQLite Document Store - - Parameters: - `store_config`: StoreConfig - SQLite specific configuration, including connection details and client class type. - """ + """SQLite Document Store.""" partition_type = SQLiteStorePartition @serializable(canonical_name="SQLiteStoreClientConfig", version=1) class SQLiteStoreClientConfig(StoreClientConfig): - """SQLite connection config + """SQLite connection config. - Parameters: - `filename` : str - Database name - `path` : Path or str - Database folder - `check_same_thread`: bool - If True (default), ProgrammingError will be raised if the database connection is used + Attributes: + filename (str): Database name. + path (str | Path): Database folder. + check_same_thread (bool): If True (default), ProgrammingError will be raised if the database connection is used by a thread other than the one that created it. If False, the connection may be accessed in multiple threads; write operations may need to be serialized by the user to avoid data corruption. - `timeout`: int - How many seconds the connection should wait before raising an exception, if the database + timeout (int): How many seconds the connection should wait before raising an exception, if the database is locked by another connection. If another connection opens a transaction to modify the database, it will be locked until that transaction is committed. Default five seconds. + + Parameters: + filename (str): Database name. + path (str | Path): Database folder. + check_same_thread (bool): See Attributes section. + timeout (int): See Attributes section. """ filename: str = "syftdb.sqlite" @@ -444,19 +435,15 @@ def file_path(self) -> Path | None: @serializable() class SQLiteStoreConfig(StoreConfig): __canonical_name__ = "SQLiteStoreConfig" - """SQLite Store config, used by SQLiteStorePartition - - Parameters: - `client_config`: SQLiteStoreClientConfig - SQLite connection configuration - `store_type`: DocumentStore - Class interacting with QueueStash. Default: SQLiteDocumentStore - `backing_store`: KeyValueBackingStore - The Store core logic. Default: SQLiteBackingStore - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. + """SQLite Store config, used by SQLiteStorePartition. + + Args: + client_config (SQLiteStoreClientConfig): SQLite connection configuration. + store_type (type[DocumentStore]): Class interacting with QueueStash. Default: SQLiteDocumentStore. + backing_store (type[KeyValueBackingStore]): The Store core logic. Default: SQLiteBackingStore. + locking_config (LockingConfig): The config used for store locking. Available options: + * NoLockingConfig: no locking, ideal for single-thread stores. + * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. Defaults to NoLockingConfig. """ diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 9487ae6ece6..cea9e550ad1 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -48,32 +48,35 @@ def get_latest_version(cls, canonical_name: str) -> int: @classmethod def get_identifier_for_type(cls, obj: Any) -> tuple[str, int]: """ - This is to create the string in nonrecursiveBlob + This is to create the string in nonrecursiveBlob. """ return cls.__type_to_canonical_name__[obj] @classmethod def get_canonical_name_version(cls, obj: Any) -> tuple[str, int]: """ - Retrieves the canonical name for both objects and types. + Retrieves the canonical name and version for both objects and types. This function works for both objects and types, returning the canonical name - as a string. It handles various cases, including built-in types, instances of - classes, and enum members. + as a string and its version as an integer. It handles various cases, including + built-in types, instances of classes, and enum members. If the object is not registered in the registry, a ValueError is raised. Examples: - get_canonical_name_version([1,2,3]) -> "list" - get_canonical_name_version(list) -> "type" - get_canonical_name_version(MyEnum.A) -> "MyEnum" - get_canonical_name_version(MyEnum) -> "type" + get_canonical_name_version([1, 2, 3]) -> ("list", version) + get_canonical_name_version(list) -> ("type", version) + get_canonical_name_version(MyEnum.A) -> ("MyEnum", version) + get_canonical_name_version(MyEnum) -> ("type", version) Args: - obj: The object or type for which to get the canonical name. + obj (Any): The object or type for which to get the canonical name. Returns: - The canonical name and version of the object or type. + tuple[str, int]: The canonical name and version of the object or type. + + Raises: + ValueError: If the canonical name for the object or type is not found. """ # for types we return "type" diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index 4e93e82c45a..fa8af6175e5 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -177,8 +177,8 @@ def build_tabulator_table_with_data( Args: table_data (list[dict]): The data to populate the table. table_metadata (dict): The metadata for the table. - uid (str, optional): The unique identifier for the table. Defaults to None. - max_height (int, optional): The maximum height of the table. Defaults to None. + uid (str | None, optional): The unique identifier for the table. Defaults to None. + max_height (int | None, optional): The maximum height of the table. Defaults to None. pagination (bool, optional): Whether to enable pagination. Defaults to True. header_sort (bool, optional): Whether to enable header sorting. Defaults to True. @@ -206,8 +206,8 @@ def build_tabulator_table( Args: obj (Any): The object to build the table from. - uid (str, optional): The unique identifier for the table. Defaults to None. - max_height (int, optional): The maximum height of the table. Defaults to None. + uid (str | None, optional): The unique identifier for the table. Defaults to None. + max_height (int | None, optional): The maximum height of the table. Defaults to None. pagination (bool, optional): Whether to enable pagination. Defaults to True. header_sort (bool, optional): Whether to enable header sorting. Defaults to True. diff --git a/packages/syft/src/syft/util/table.py b/packages/syft/src/syft/util/table.py index 25db3fb4552..4547d3a7733 100644 --- a/packages/syft/src/syft/util/table.py +++ b/packages/syft/src/syft/util/table.py @@ -228,12 +228,10 @@ def prepare_table_data( add_index (bool, optional): Whether to add an index column to the table. Defaults to True. Returns: - tuple: A tuple (table_data, table_metadata) where table_data is a list of dictionaries + tuple[list[dict], dict]: A tuple (table_data, table_metadata) where table_data is a list of dictionaries where each dictionary represents a row in the table and table_metadata is a dictionary containing metadata about the table such as name, icon, etc. - """ - values = _get_values_for_table_repr(obj) if len(values) == 0: return [], {} diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index fa20c3fc2c2..71f44269ac6 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -62,7 +62,14 @@ def get_env(key: str, default: Any | None = None) -> str | None: def full_name_with_qualname(klass: type) -> str: - """Returns the klass module name + klass qualname.""" + """Returns the klass module name + klass qualname. + + Args: + klass (type): The class whose fully qualified name is needed. + + Returns: + str: The fully qualified name of the class. + """ try: if not hasattr(klass, "__module__"): return f"builtins.{get_qualname_for(klass)}" @@ -74,7 +81,17 @@ def full_name_with_qualname(klass: type) -> str: def full_name_with_name(klass: type) -> str: - """Returns the klass module name + klass name.""" + """Returns the klass module name + klass name. + + Args: + klass (type): The class whose fully qualified name is needed. + + Returns: + str: The fully qualified name of the class. + + Raises: + Exception: If there is an error while getting the fully qualified name. + """ try: if not hasattr(klass, "__module__"): return f"builtins.{get_name_for(klass)}" @@ -85,6 +102,14 @@ def full_name_with_name(klass: type) -> str: def get_qualname_for(klass: type) -> str: + """Get the qualified name of a class. + + Args: + klass (type): The class to get the qualified name for. + + Returns: + str: The qualified name of the class. + """ qualname = getattr(klass, "__qualname__", None) or getattr(klass, "__name__", None) if qualname is None: qualname = extract_name(klass) @@ -92,6 +117,14 @@ def get_qualname_for(klass: type) -> str: def get_name_for(klass: type) -> str: + """Get the name of a class. + + Args: + klass (type): The class to get the name for. + + Returns: + str: The name of the class. + """ klass_name = getattr(klass, "__name__", None) if klass_name is None: klass_name = extract_name(klass) @@ -114,6 +147,12 @@ def get_mb_size(data: Any, handlers: dict | None = None) -> float: which is referenced in official sys.getsizeof documentation https://docs.python.org/3/library/sys.html#sys.getsizeof. + Args: + data (Any): The object to calculate the memory size for. + handlers (dict | None): Custom handlers for additional types. + + Returns: + float: The memory size of the object in MB. """ def dict_handler(d: dict[Any, Any]) -> Iterator[Any]: @@ -161,6 +200,14 @@ def sizeof(o: Any) -> int: def get_mb_serialized_size(data: Any) -> float: + """Get the size of a serialized object in MB. + + Args: + data (Any): The object to be serialized and measured. + + Returns: + float: The size of the serialized object in MB if successful. + """ try: serialized_data = serialize(data, to_bytes=True) return sys.getsizeof(serialized_data) / (1024 * 1024) @@ -173,6 +220,17 @@ def get_mb_serialized_size(data: Any) -> float: def extract_name(klass: type) -> str: + """Extract the name of a class from its string representation. + + Args: + klass (type): The class to extract the name from. + + Returns: + str: The extracted name of the class. + + Raises: + ValueError: If the class name could not be extracted. + """ name_regex = r".+class.+?([\w\._]+).+" regex2 = r"([\w\.]+)" matches = re.match(name_regex, str(klass)) @@ -194,6 +252,19 @@ def extract_name(klass: type) -> str: def validate_type(_object: object, _type: type, optional: bool = False) -> Any: + """Validate that an object is of a certain type. + + Args: + _object (object): The object to validate. + _type (type): The type to validate against. + optional (bool): Whether the object can be None. + + Returns: + Any: The validated object. + + Raises: + Exception: If the object is not of the expected type. + """ if isinstance(_object, _type) or (optional and (_object is None)): return _object @@ -201,6 +272,18 @@ def validate_type(_object: object, _type: type, optional: bool = False) -> Any: def validate_field(_object: object, _field: str) -> Any: + """Validate that an object has a certain field. + + Args: + _object (object): The object to validate. + _field (str): The field to validate. + + Returns: + Any: The value of the field. + + Raises: + Exception: If the field is not set on the object. + """ object = getattr(_object, _field, None) if object is not None: @@ -210,19 +293,17 @@ def validate_field(_object: object, _field: str) -> Any: def get_fully_qualified_name(obj: object) -> str: - """Return the full path and name of a class + """Return the full path and name of a class. Sometimes we want to return the entire path and name encoded using periods. Args: - obj: the object we want to get the name of + obj (object): The object we want to get the name of. Returns: - the full path and name of the object - + str: The full path and name of the object. """ - fqn = obj.__class__.__module__ try: @@ -233,13 +314,12 @@ def get_fully_qualified_name(obj: object) -> str: def aggressive_set_attr(obj: object, name: str, attr: object) -> None: - """Different objects prefer different types of monkeypatching - try them all + """Different objects prefer different types of monkeypatching - try them all. Args: - obj: object whose attribute has to be set - name: attribute name - attr: value given to the attribute - + obj (object): The object whose attribute has to be set. + name (str): The attribute name. + attr (object): The value given to the attribute. """ try: setattr(obj, name, attr) @@ -248,6 +328,14 @@ def aggressive_set_attr(obj: object, name: str, attr: object) -> None: def key_emoji(key: object) -> str: + """Generate an emoji representation of a key. + + Args: + key (object): The key object. + + Returns: + str: An emoji string representing the key. + """ try: if isinstance(key, bytes | SigningKey | VerifyKey): hex_chars = bytes(key).hex()[-8:] @@ -259,6 +347,14 @@ def key_emoji(key: object) -> str: def char_emoji(hex_chars: str) -> str: + """Generate an emoji based on a hexadecimal string. + + Args: + hex_chars (str): The hexadecimal string. + + Returns: + str: An emoji string generated from the hexadecimal string. + """ base = ord("\U0001f642") hex_base = ord("0") code = 0 @@ -269,6 +365,11 @@ def char_emoji(hex_chars: str) -> str: def get_root_data_path() -> Path: + """Get the root data path for storing datasets. + + Returns: + Path: The root data path. + """ # get the PySyft / data directory to share datasets between notebooks # on Linux and MacOS the directory is: ~/.syft/data" # on Windows the directory is: C:/Users/$USER/.syft/data @@ -280,6 +381,15 @@ def get_root_data_path() -> Path: def download_file(url: str, full_path: str | Path) -> Path | None: + """Download a file from a URL. + + Args: + url (str): The URL of the file to download. + full_path (str | Path): The full path where the file should be saved. + + Returns: + Path | None: The path to the downloaded file, or None if the download failed. + """ full_path = Path(full_path) if not full_path.exists(): r = requests.get(url, allow_redirects=True, verify=verify_tls()) # nosec @@ -292,25 +402,46 @@ def download_file(url: str, full_path: str | Path) -> Path | None: def verify_tls() -> bool: + """Verify whether TLS should be used. + + Returns: + bool: True if TLS should be used, False otherwise. + """ return not str_to_bool(str(os.environ.get("IGNORE_TLS_ERRORS", "0"))) def ssl_test() -> bool: + """Check if SSL is properly configured. + + Returns: + bool: True if SSL is configured, False otherwise. + """ return len(os.environ.get("REQUESTS_CA_BUNDLE", "")) > 0 def initializer(event_loop: BaseSelectorEventLoop | None = None) -> None: """Set the same event loop to other threads/processes. + This is needed because there are new threads/processes started with - the Executor and they do not have have an event loop set + the Executor and they do not have an event loop set. + Args: - event_loop: The event loop. + event_loop (BaseSelectorEventLoop | None): The event loop to set. """ if event_loop: asyncio.set_event_loop(event_loop) def split_rows(rows: Sequence, cpu_count: int) -> list: + """Split a sequence of rows into chunks for parallel processing. + + Args: + rows (Sequence): The sequence of rows to split. + cpu_count (int): The number of chunks to split into. + + Returns: + list: A list of row chunks. + """ n = len(rows) a, b = divmod(n, cpu_count) start = 0 @@ -323,6 +454,14 @@ def split_rows(rows: Sequence, cpu_count: int) -> list: def list_sum(*inp_lst: list[Any]) -> Any: + """Sum a list of lists element-wise. + + Args: + *inp_lst (list[Any]): The list of lists to sum. + + Returns: + Any: The sum of the lists. + """ s = inp_lst[0] for i in inp_lst[1:]: s = s + i @@ -331,6 +470,14 @@ def list_sum(*inp_lst: list[Any]) -> Any: @contextmanager def concurrency_override(count: int = 1) -> Iterator: + """Context manager to override concurrency count. + + Args: + count (int): The concurrency count to set. Defaults to 1. + + Yields: + Iterator: A context manager. + """ # this only effects local code so its best to use in unit tests try: os.environ["FORCE_CONCURRENCY_COUNT"] = f"{count}" @@ -344,8 +491,17 @@ def print_process( # type: ignore finish: EventClass, success: EventClass, lock: LockBase, - refresh_rate=0.1, + refresh_rate: float = 0.1, ) -> None: + """Print a dynamic process message that updates periodically. + + Args: + message (str): The message to print. + finish (EventClass): Event to signal the finish of the process. + success (EventClass): Event to signal the success of the process. + lock (LockBase): A lock to synchronize the print output. + refresh_rate (float, optional): The refresh rate for updating the message. Defaults to 0.1. + """ with lock: while not finish.is_set(): print(f"{bcolors.bold(message)} .", end="\r") @@ -367,12 +523,13 @@ def print_process( # type: ignore def print_dynamic_log( message: str, ) -> tuple[EventClass, EventClass]: - """ - Prints a dynamic log message that will change its color (to green or red) when some process is done. + """Print a dynamic log message that will change its color when some process is done. - message: str = Message to be printed. + Args: + message (str): The message to be printed. - return: tuple of events that can control the log print from the outside of this method. + Returns: + tuple[EventClass, EventClass]: Tuple of events that can control the log print from outside this method. """ finish = multiprocessing.Event() success = multiprocessing.Event() @@ -391,6 +548,19 @@ def print_dynamic_log( def find_available_port( host: str, port: int | None = None, search: bool = False ) -> int: + """Find an available port on the specified host. + + Args: + host (str): The host to check for available ports. + port (int | None): The port to check. Defaults to a random port. + search (bool): Whether to search for the next available port if the given port is in use. + + Returns: + int: The available port number. + + Raises: + Exception: If the port is not available and search is False. + """ if port is None: port = random.randint(1500, 65000) # nosec port_available = False @@ -425,9 +595,8 @@ def find_available_port( def get_random_available_port() -> int: """Retrieve a random available port number from the host OS. - Returns - ------- - int: Available port number. + Returns: + int: Available port number. """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as soc: soc.bind(("localhost", 0)) @@ -435,15 +604,20 @@ def get_random_available_port() -> int: def get_loaded_syft() -> ModuleType: + """Get the loaded Syft module. + + Returns: + ModuleType: The loaded Syft module. + """ return sys.modules[__name__.split(".")[0]] def get_subclasses(obj_type: type) -> list[type]: - """Recursively generate the list of all classes within the sub-tree of an object + """Recursively generate the list of all classes within the sub-tree of an object. As a paradigm in Syft, we often allow for something to be known about by another part of the codebase merely because it has subclassed a particular object. While - this can be a big "magicish" it also can simplify future extensions and reduce + this can be a bit "magicish" it also can simplify future extensions and reduce the likelihood of small mistakes (if done right). This is a utility function which allows us to look for sub-classes and the sub-classes @@ -451,11 +625,10 @@ def get_subclasses(obj_type: type) -> list[type]: hierarchy. Args: - obj_type: the type we want to look for sub-classes of + obj_type (type): The type we want to look for sub-classes of. Returns: - the list of subclasses of obj_type - + list[type]: The list of subclasses of obj_type. """ classes = [] @@ -466,27 +639,25 @@ def get_subclasses(obj_type: type) -> list[type]: def index_modules(a_dict: object, keys: list[str]) -> object: - """Recursively find a syft module from its path + """Recursively find a Syft module from its path. This is the recursive inner function of index_syft_by_module_name. See that method for a full description. Args: - a_dict: a module we're traversing - keys: the list of string attributes we're using to traverse the module + a_dict (object): A module we're traversing. + keys (list[str]): The list of string attributes we're using to traverse the module. Returns: - a reference to the final object - + object: A reference to the final object. """ - if len(keys) == 0: return a_dict return index_modules(a_dict=a_dict.__dict__[keys[0]], keys=keys[1:]) def index_syft_by_module_name(fully_qualified_name: str) -> object: - """Look up a Syft class/module/function from full path and name + """Look up a Syft class/module/function from full path and name. Sometimes we want to use the fully qualified name (such as one generated from the 'get_fully_qualified_name' method below) to @@ -495,13 +666,14 @@ def index_syft_by_module_name(fully_qualified_name: str) -> object: representation of the specific object it is meant to deserialize to. Args: - fully_qualified_name: the name in str of a module, class, or function + fully_qualified_name (str): The name in str of a module, class, or function. Returns: - a reference to the actual object at that string path + object: A reference to the actual object at that string path. + Raises: + ReferenceError: If the reference does not match expected patterns. """ - # @Tudor this needs fixing during the serde refactor # we should probably just support the native type names as lookups for serde if fully_qualified_name == "builtins.NoneType": @@ -511,13 +683,22 @@ def index_syft_by_module_name(fully_qualified_name: str) -> object: if attr_list[0] != "syft": raise ReferenceError(f"Reference don't match: {attr_list[0]}") - # if attr_list[1] != "core" and attr_list[1] != "user": - # raise ReferenceError(f"Reference don't match: {attr_list[1]}") - return index_modules(a_dict=get_loaded_syft(), keys=attr_list[1:]) def obj2pointer_type(obj: object | None = None, fqn: str | None = None) -> type: + """Get the pointer type for an object based on its fully qualified name. + + Args: + obj (object | None): The object to get the pointer type for. + fqn (str | None): The fully qualified name of the object. + + Returns: + type: The pointer type for the object. + + Raises: + Exception: If the pointer type cannot be found. + """ if fqn is None: try: fqn = get_fully_qualified_name(obj=obj) @@ -542,6 +723,15 @@ def obj2pointer_type(obj: object | None = None, fqn: str | None = None) -> type: def prompt_warning_message(message: str, confirm: bool = False) -> bool: + """Prompt a warning message and optionally request user confirmation. + + Args: + message (str): The warning message to display. + confirm (bool): Whether to request user confirmation. + + Returns: + bool: True if the user confirms, False otherwise. + """ # relative from ..service.response import SyftWarning @@ -740,6 +930,11 @@ def prompt_warning_message(message: str, confirm: bool = False) -> bool: def random_name() -> str: + """Generate a random name by combining a left and right name part. + + Returns: + str: The generated random name. + """ left_i = randbelow(len(left_name) - 1) right_i = randbelow(len(right_name) - 1) return f"{left_name[left_i].capitalize()} {right_name[right_i].capitalize()}" @@ -749,9 +944,18 @@ def inherit_tags( attr_path_and_name: str, result: object, self_obj: object | None, - args: tuple | list, - kwargs: dict, + args: tuple[Any, ...] | list[Any], + kwargs: dict[str, Any], ) -> None: + """Inherit tags from input objects to the result object. + + Args: + attr_path_and_name (str): The attribute path and name to add as a tag. + result (object): The result object to inherit tags. + self_obj (object | None): The object that might have tags. + args (tuple[Any, ...] | list[Any]): Arguments that might have tags. + kwargs (dict[str, Any]): Keyword arguments that might have tags. + """ tags = [] if self_obj is not None and hasattr(self_obj, "tags"): tags.extend(list(self_obj.tags)) @@ -773,6 +977,16 @@ def inherit_tags( def autocache( url: str, extension: str | None = None, cache: bool = True ) -> Path | None: + """Automatically cache a file from a URL. + + Args: + url (str): The URL of the file to cache. + extension (str | None): The file extension to use. + cache (bool): Whether to use the cache if the file already exists. + + Returns: + Path | None: The path to the cached file, or None if caching failed. + """ try: data_path = get_root_data_path() file_hash = hashlib.sha256(url.encode("utf8")).hexdigest() @@ -789,6 +1003,14 @@ def autocache( def str_to_bool(bool_str: str | None) -> bool: + """Convert a string to a boolean value. + + Args: + bool_str (str | None): The string to convert. + + Returns: + bool: The converted boolean value. + """ result = False bool_str = str(bool_str).lower() if bool_str == "true" or bool_str == "1": @@ -799,20 +1021,19 @@ def str_to_bool(bool_str: str | None) -> bool: # local scope functions cant be pickled so this needs to be global def parallel_execution( fn: Callable[..., Any], - parties: None | list[Any] = None, + parties: list[Any] | None = None, cpu_bound: bool = False, ) -> Callable[..., list[Any]]: """Wrap a function such that it can be run in parallel at multiple parties. + Args: - fn (Callable): The function to run. - parties (Union[None, List[Any]]): Clients from syft. If this is set, then the + fn (Callable[..., Any]): The function to run. + parties (list[Any] | None): Clients from syft. If this is set, then the function should be run remotely. Defaults to None. - cpu_bound (bool): Because of the GIL (global interpreter lock) sometimes - it makes more sense to use processes than threads if it is set then - processes should be used since they really run in parallel if not then - it makes sense to use threads since there is no bottleneck on the CPU side + cpu_bound (bool): Whether to use processes instead of threads. + Returns: - Callable[..., List[Any]]: A Callable that returns a list of results. + Callable[..., list[Any]]: A Callable that returns a list of results. """ @functools.wraps(fn) @@ -821,11 +1042,16 @@ def wrapper( kwargs: dict[Any, dict[Any, Any]] | None = None, ) -> list[Any]: """Wrap sanity checks and checks what executor should be used. + Args: - args (List[List[Any]]): Args. - kwargs (Optional[Dict[Any, Dict[Any, Any]]]): Kwargs. Default to None. + args (list[list[Any]]): The list of lists of arguments. + kwargs (dict[Any, dict[Any, Any]] | None): The dictionary of keyword arguments. + Returns: - List[Any]: Results from the parties + list[Any]: The list of results from the parties. + + Raises: + Exception: If the arguments list is empty. """ if args is None or len(args) == 0: raise Exception("Parallel execution requires more than 0 args") @@ -872,6 +1098,14 @@ def wrapper( def concurrency_count(factor: float = 0.8) -> int: + """Get the current concurrency count based on CPU count and a factor. + + Args: + factor (float): The factor to apply to the CPU count. Defaults to 0.8. + + Returns: + int: The calculated concurrency count. + """ force_count = int(os.environ.get("FORCE_CONCURRENCY_COUNT", 0)) mp_count = force_count if force_count >= 1 else int(mp.cpu_count() * factor) return mp_count @@ -890,18 +1124,51 @@ class bcolors: @staticmethod def green(message: str) -> str: + """Return a green-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The green-colored message. + """ return bcolors.GREEN + message + bcolors.ENDC @staticmethod def red(message: str) -> str: + """Return a red-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The red-colored message. + """ return bcolors.RED + message + bcolors.ENDC @staticmethod def yellow(message: str) -> str: + """Return a yellow-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The yellow-colored message. + """ return bcolors.YELLOW + message + bcolors.ENDC @staticmethod def bold(message: str, end_color: bool = False) -> str: + """Return a bold string. + + Args: + message (str): The message to bold. + end_color (bool): Whether to reset color after the message. + + Returns: + str: The bolded message. + """ msg = bcolors.BOLD + message if end_color: msg += bcolors.ENDC @@ -909,6 +1176,15 @@ def bold(message: str, end_color: bool = False) -> str: @staticmethod def underline(message: str, end_color: bool = False) -> str: + """Return an underlined string. + + Args: + message (str): The message to underline. + end_color (bool): Whether to reset color after the message. + + Returns: + str: The underlined message. + """ msg = bcolors.UNDERLINE + message if end_color: msg += bcolors.ENDC @@ -916,18 +1192,47 @@ def underline(message: str, end_color: bool = False) -> str: @staticmethod def warning(message: str) -> str: + """Return a warning-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The warning-colored message. + """ return bcolors.bold(bcolors.yellow(message)) @staticmethod def success(message: str) -> str: + """Return a success-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The success-colored message. + """ return bcolors.green(message) @staticmethod def failure(message: str) -> str: + """Return a failure-colored string. + + Args: + message (str): The message to color. + + Returns: + str: The failure-colored message. + """ return bcolors.red(message) def os_name() -> str: + """Get the name of the operating system. + + Returns: + str: The name of the operating system. + """ os_name = platform.system() if os_name.lower() == "darwin": return "macOS" @@ -937,18 +1242,38 @@ def os_name() -> str: # Note: In the future there might be other interpreters that we want to use def is_interpreter_jupyter() -> bool: + """Check if the current interpreter is Jupyter. + + Returns: + bool: True if the current interpreter is Jupyter, False otherwise. + """ return get_interpreter_module() == "ipykernel.zmqshell" def is_interpreter_colab() -> bool: + """Check if the current interpreter is Google Colab. + + Returns: + bool: True if the current interpreter is Google Colab, False otherwise. + """ return get_interpreter_module() == "google.colab._shell" def is_interpreter_standard() -> bool: + """Check if the current interpreter is a standard Python interpreter. + + Returns: + bool: True if the current interpreter is standard, False otherwise. + """ return get_interpreter_module() == "StandardInterpreter" def get_interpreter_module() -> str: + """Get the module name of the current interpreter. + + Returns: + str: The module name of the current interpreter. + """ try: # third party from IPython import get_ipython @@ -966,14 +1291,30 @@ def get_interpreter_module() -> str: def thread_ident() -> int | None: + """Get the identifier of the current thread. + + Returns: + int | None: The thread identifier, or None if not available. + """ return threading.current_thread().ident def proc_id() -> int: + """Get the process ID of the current process. + + Returns: + int: The process ID. + """ return os.getpid() def set_klass_module_to_syft(klass: type, module_name: str) -> None: + """Set the module of a class to Syft. + + Args: + klass (type): The class to set the module for. + module_name (str): The name of the module. + """ if module_name not in sys.modules["syft"].__dict__: new_module = types.ModuleType(module_name) else: @@ -983,8 +1324,14 @@ def set_klass_module_to_syft(klass: type, module_name: str) -> None: def get_queue_address(port: int) -> str: - """Get queue address based on container host name.""" + """Get queue address based on container host name. + + Args: + port (int): The port number. + Returns: + str: The queue address. + """ container_host = os.getenv("CONTAINER_HOST", None) if container_host == "k8s": return f"tcp://backend:{port}" @@ -994,14 +1341,32 @@ def get_queue_address(port: int) -> str: def get_dev_mode() -> bool: + """Check if the application is running in development mode. + + Returns: + bool: True if in development mode, False otherwise. + """ return str_to_bool(os.getenv("DEV_MODE", "False")) def generate_token() -> str: + """Generate a secure random token. + + Returns: + str: The generated token. + """ return secrets.token_hex(64) def sanitize_html(html_str: str) -> str: + """Sanitize HTML content by allowing specific tags and attributes. + + Args: + html_str (str): The HTML content to sanitize. + + Returns: + str: The sanitized HTML content. + """ policy = { "tags": ["svg", "strong", "rect", "path", "circle", "code", "pre"], "attributes": { @@ -1040,6 +1405,14 @@ def sanitize_html(html_str: str) -> str: def parse_iso8601_date(date_string: str) -> datetime: + """Parse an ISO8601 date string into a datetime object. + + Args: + date_string (str): The ISO8601 date string. + + Returns: + datetime: The parsed datetime object. + """ # Handle variable length of microseconds by trimming to 6 digits if "." in date_string: base_date, microseconds = date_string.split(".") @@ -1050,6 +1423,15 @@ def parse_iso8601_date(date_string: str) -> datetime: def get_latest_tag(registry: str, repo: str) -> str | None: + """Get the latest tag from a Docker registry for a given repository. + + Args: + registry (str): The Docker registry. + repo (str): The repository name. + + Returns: + str | None: The latest tag, or None if no tags are found. + """ repo_url = f"http://{registry}/v2/{repo}" res = requests.get(url=f"{repo_url}/tags/list", timeout=5) tags = res.json().get("tags", []) @@ -1145,21 +1527,29 @@ def test_settings() -> Any: class CustomRepr(reprlib.Repr): def repr_str(self, obj: Any, level: int = 0) -> str: + """Return a truncated string representation if it is too long. + + Args: + obj (Any): The object to represent. + level (int): The level of detail in the representation. + + Returns: + str: The truncated string representation. + """ if len(obj) <= self.maxstring: return repr(obj) return repr(obj[: self.maxstring] + "...") def repr_truncation(obj: Any, max_elements: int = 10) -> str: - """ - Return a truncated string representation of the object if it is too long. + """Return a truncated string representation of the object if it is too long. Args: - - obj: The object to be represented (can be str, list, dict, set...). - - max_elements: Maximum number of elements to display before truncating. + obj (Any): The object to be represented (can be str, list, dict, set...). + max_elements (int): Maximum number of elements to display before truncating. Returns: - - A string representation of the object, truncated if necessary. + str: A string representation of the object, truncated if necessary. """ r = CustomRepr() r.maxlist = max_elements # For lists