diff --git a/.github/workflows/pytype.yml b/.github/workflows/pytype.yml new file mode 100644 index 0000000..5d84e6e --- /dev/null +++ b/.github/workflows/pytype.yml @@ -0,0 +1,27 @@ +name: "pytype" + +on: + pull_request: + types: + - 'synchronize' + - 'opened' + +jobs: + type-check: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: 3.11 + + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install pytype websocket-client + + - name: Type-Check + run: pytype -j auto . \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..2b2e6e7 --- /dev/null +++ b/README.md @@ -0,0 +1,96 @@ +

+ Join Discord + Join Forums + File Issue +

+ +# TrueNAS Websocket Client + +*Found an issue? Please report it on our [Jira bugtracker](https://jira.ixsystems.com).* + +## About + +The TrueNAS websocket client provides the command line tool `midclt` and the means to easily communicate with [middleware](https://github.com/truenas/middleware) using Python by making calls through the [websocket API](https://www.truenas.com/docs/api/scale_websocket_api.html). The client can connect to a local TrueNAS instance by default or to a specified remote socket. This offers an alternative to going through the [web UI](https://github.com/truenas/webui) or connecting via ssh. + +By default, communication facilitated by the API between the client and middleware now uses the [JSON-RPC 2.0](https://www.jsonrpc.org/specification) protocol. However, it is still possible to use the legacy client by passing a legacy uri, e.g. `'ws://some.truenas.address/websocket'` as opposed to `'ws://some.truenas.address/api/current'`. + +## Getting Started + +TrueNAS comes with this client preinstalled, but it is also possible to use the TrueNAS websocket client from a non-TrueNAS host. + +On a non-TrueNAS host, ensure that Git is installed and run `pip install git+https://github.com/truenas/api_client.git` to automatically install dependencies. You may alternatively clone this repository and run `python setup.py install`. Using a Python venv is recommended. + +## Usage + +The `midclt` command (not to be confused with the [TrueNAS CLI](https://github.com/truenas/midcli)) provides a way to make direct API calls through the client. To view its syntax, enter `midclt -h`. The `-h` option can also be used with any of `midclt`'s subcommands. + +The client's default behavior is to connect to the localhost's middlewared socket. For a remote connection, e.g. from a Windows host, you must specify the `--uri` option and authenticate with either user credentials or an API key. For example: `midclt --uri ws:///api/current -K key ...` + +### Make local API calls + +``` +root@my_truenas[~]# midclt call user.create '{"full_name": "John Doe", "username": "user", "password": "pass", "group_create": true}' +``` + +### Login to a remote TrueNAS + +``` +root@my_truenas[~]# midclt --uri ws://some.other.truenas/api/current -U user -P password call system.info +``` + +### Start a job + +``` +root@my_truenas[~]# midclt call -j pool.dataset.lock mypool/mydataset +``` + +## Development + +The TrueNAS API client can also be used in Python scripts. + +### Make local API calls + +```python +from truenas_api_client import Client + +with Client() as c: # Local IPC + print(c.ping()) # pong + user = {"full_name": "John Doe", "username": "user", "password": "pass", "group_create": True} + entry_id = c.call("user.create", user) + user = c.call("user.get_instance", entry_id) + print(user["full_name"]) # John Doe +``` + +### Login with a user account or an API key + +```python +# User account +with Client(uri="ws://some.other.truenas/api/current") as c: + c.call("auth.login", username, password) + +# API key +with Client(uri="ws://some.other.truenas/api/current") as c: + c.call("auth.login_with_api_key", key) +``` + +### Start a job + +```python +with Client() as c: + is_locked = c.call("pool.dataset.lock", "mypool/mydataset", job=True) + if is_locked: + args = {"datasets": [{"name": "mypool/mydataset", "passphrase": "passphrase"}]} + c.call("pool.dataset.unlock", "mypool/mydataset", args, job=True) +``` + +## Helpful Links + + + + + +- [Websocket API docs](https://www.truenas.com/docs/api/scale_websocket_api.html) +- [Middleware repo](https://github.com/truenas/middleware) +- [Official TrueNAS Documentation Hub](https://www.truenas.com/docs/) +- [Get started building TrueNAS Scale](https://github.com/truenas/scale-build) +- [Forums](https://www.truenas.com/community/) diff --git a/truenas_api_client/__init__.py b/truenas_api_client/__init__.py index 462366b..520bef7 100644 --- a/truenas_api_client/__init__.py +++ b/truenas_api_client/__init__.py @@ -1,6 +1,37 @@ +"""Provides a simple way to call middleware API endpoints using a websocket connection. + +The full websocket API documentation can be found at https://www.truenas.com/docs/api/core_websocket_api.html. + +Example:: + + $ midclt ping && echo 'Connected' || echo 'Unable to ping' + Connected + $ midclt call user.create '{"full_name": "John Doe", "username": "user", "password": "pass", "group_create": true}' + 70 + $ midclt call user.get_instance 70 + {"id": 70, "uid": 3000, "username": "user", "unixhash": ... } + $ midclt call user.query '[["full_name", "=", "John Doe"]]' + {"id": 70, "uid": 3000, "username": "user", "unixhash": ... } + +Example:: + + with Client() as c: # Local IPC + print(c.ping()) # pong + user = {"full_name": "John Doe", "username": "user", "password": "pass", "group_create": True} + id = c.call("user.create", user) + user = c.call("user.get_instance", id) + print(user["full_name"]) # John Doe + +Example:: + + c = Client("ws://example.com/api/current") # Remote websocket connection + c.close() + +""" import argparse from base64 import b64decode from collections import defaultdict +from collections.abc import Callable, Iterable import errno import logging import pickle @@ -10,6 +41,7 @@ import sys from threading import Event, Lock, Thread import time +from typing import Any, Literal, NotRequired, Protocol, TypeAlias, TypedDict import urllib.parse import uuid @@ -24,15 +56,35 @@ from .config import CALL_TIMEOUT from .exc import ReserveFDException, ClientException, ErrnoMixin, ValidationErrors, CallTimeout from .legacy import LegacyClient -from .jsonrpc import JSONRPCError -from .utils import MIDDLEWARE_RUN_DIR, ProgressBar, undefined +from .jsonrpc import CollectionUpdateParams, ErrorObj, JobFields, JSONRPCError, JSONRPCMessage, TruenasError +from .utils import MIDDLEWARE_RUN_DIR, ProgressBar, undefined, UndefinedType logger = logging.getLogger(__name__) class Client: - def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False, - call_timeout=undefined, verify_ssl=True): + """Implicit wrapper of either a `JSONRPCClient` or a `LegacyClient`.""" + + def __init__(self, uri: str | None=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False, + call_timeout: float | UndefinedType=undefined, verify_ssl=True): + """Initialize either a `JSONRPCClient` or a `LegacyClient`. + + Use `JSONRPCClient` unless `uri` ends with '/websocket'. + + Args: + uri: The address to connect to. Defaults to the local middlewared socket. + reserved_ports: `True` if the local socket should use a reserved port. + py_exceptions: `True` if the server should include exception objects in + `message['error']['data']['py_exception']`. + log_py_exceptions: `True` if exception tracebacks from API calls should be logged. + call_timeout: Number of seconds to allow an API call before timing out. Can be overridden on a per-call + basis. Defaults to `CALL_TIMEOUT`. + verify_ssl: `True` if SSL certificate should be verified before connecting. + + Raises: + ClientException: `WSClient` timed out or some other connection error occurred. + + """ if uri is not None and uri.endswith('/websocket'): client_class = LegacyClient else: @@ -51,16 +103,36 @@ def __exit__(self, exc_type, exc_val, exc_tb): class WSClient: - def __init__(self, url, *, client, reserved_ports=False, verify_ssl=True): + """A supporter class for `JSONRPCClient` that manages the `WebSocket` connection to the server. + + The object used by `JSONRPCClient` to send and receive data. + + """ + def __init__(self, url: str, *, client: 'JSONRPCClient', reserved_ports: bool=False, verify_ssl: bool=True): + """Initialize a `WSClient`. + + Args: + url: The websocket to connect to. `ws://` or `wss://` for secure connection. + client: Reference to the `JSONRPCClient` instance that uses this object. + reserved_ports: `True` if the `socket` should bind to a reserved port, i.e. 600-1024. + verify_ssl: `True` if SSL certificate should be verified before connecting. + + """ self.url = url self.client = client self.reserved_ports = reserved_ports self.verify_ssl = verify_ssl - self.socket = None - self.app = None + self.socket: socket.socket + self.app: WebSocketApp def connect(self): + """Connect a `socket` and start a `WebSocketApp` in a daemon `Thread`. + + Raises: + Exception: The `socket` failed to connect. + + """ unix_socket_prefix = "ws+unix://" if self.url.startswith(unix_socket_prefix): self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) @@ -93,14 +165,27 @@ def connect(self): ) Thread(daemon=True, target=self.app.run_forever).start() - def send(self, data): + def send(self, data: bytes | str): + """Send data to the server by calling `WebSocketApp.send()`. + + Args: + data: The serialized JSON-RPC v2.0-formatted request to send. + + """ return self.app.send(data) def close(self): + """Cleanly close the `WebSocket` connection to the server.""" self.app.close() self.client.on_close(STATUS_NORMAL) def _bind_to_reserved_port(self): + """Bind to a random port in the 600-1024 range. + + Raises: + ReserveFDException: Five failed attempts with different ports. + + """ # linux doesn't have a mechanism to allow the kernel to dynamically # assign ports in the "privileged" range (i.e. 600 - 1024) so we # loop through and call bind() on a privileged port explicitly since @@ -125,6 +210,11 @@ def _bind_to_reserved_port(self): raise ReserveFDException() def _on_open(self, app): + """Callback passed to the `WebSocketApp` to execute when `run_forever` is called. + + Configure the `socket` and call `client.on_open()`. + + """ # TCP keepalive settings don't apply to local unix sockets if 'ws+unix' not in self.url: # enable keepalives on the socket @@ -150,29 +240,77 @@ def _on_open(self, app): self.client.on_open() def _on_message(self, app, data): + """Callback passed to the `WebSocketApp` to execute when data is received. + + Pass the received data to the `JSONRPCClient`. + + """ self.client._recv(json.loads(data)) def _on_error(self, app, e): + """Callback passed to the `WebSocketApp` to execute when an error occurs. + + Log the error. + + """ logger.warning("Websocket client error: %r", e) self.client._ws_connection_error = e def _on_close(self, app, code, reason): + """Callback passed to the `WebSocketApp` to execute when it closes. + + Close the `JSONRPCClient`. + + """ self.client.on_close(code, reason) class Call: - def __init__(self, method, params): + """An encapsulation of the data from a single request-response pair.""" + + def __init__(self, method: str, params: tuple): + """Initialize a `Call` object with an automatically-assigned id. + + Args: + method: The API method being called. + params: Arguments passed to the method. + + """ self.id = str(uuid.uuid4()) self.method = method self.params = params self.returned = Event() - self.result = None - self.error = None - self.py_exception = None + self.result: Any = None + self.error: ClientException | None = None + self.py_exception: BaseException | None = None + + +class _JobDict(JobFields): + """Contains data received from the server for a particular running job.""" + __ready: Event + """Is set when the job returns or ends in error.""" + __callback: '_JobCallback | None' + """Procedure to execute each time a job update is received.""" + + +_JobCallback: TypeAlias = Callable[[_JobDict], None] class Job: - def __init__(self, client, job_id, callback=None): + """A long-running background process on the server initiated by an API call. + + Every `Job` is responsible for a corresponding `_JobDict` in the client's list of jobs. + + """ + def __init__(self, client: 'JSONRPCClient', job_id: str, callback: _JobCallback | None=None): + """Initialize `Job`. + + Args: + client: Reference to the client that created this `Job` and receives updates on its progress. + job_id: The job id returned by the server. Index of this `Job` in the `client._jobs` dictionary. + callback: A procedure to be called every time a job event is received. + + """ self.client = client self.job_id = job_id # If a job event has been received already then we must set an Event @@ -190,6 +328,16 @@ def __repr__(self): return f'' def result(self): + """Wait for the job to finish and return its result. + + Returns: + Any: The job's result. + + Raises: + ValidationErrors: The job failed due to one or more validation errors. + ClientException: No job event was received or it did not succeed. + + """ # Wait indefinitely for the job event with state SUCCESS/FAILED/ABORTED self.event.wait() job = self.client._jobs.pop(self.job_id, None) @@ -197,7 +345,7 @@ def result(self): raise ClientException('No job event was received.') if job['state'] != 'SUCCESS': if job['exc_info'] and job['exc_info']['type'] == 'VALIDATION': - raise ValidationErrors(job['exc_info']['extra']) + raise ValidationErrors(job['exc_info']['extra'] or []) raise ClientException( job['error'], trace={ @@ -210,12 +358,55 @@ def result(self): return job['result'] +class _EventCallbackProtocol(Protocol): + """Specifies how event callbacks should be defined.""" + def __call__(self, mtype: str, **message: Any) -> None: ... + + +class _Payload(TypedDict): + """Contains data for managing a subscription. + + Attributes: + callback: Procedure to call when the event is triggered. + sync: If `True`, main client thread blocks until `callback` finishes each time it is invoked. Otherwise, run + `callback` in the background as a daemon `Thread`. + event: `Event` that is set when the subscription should end. + error: Information included in the Notification if the subscription ended in error. + id: Random UUID assigned by `core.subscribe`. + ready: For backwards compatibility with `LegacyClient`. + + """ + callback: _EventCallbackProtocol | None + sync: bool + event: Event + error: NotRequired[str | TruenasError | None] + id: NotRequired[str] + ready: NotRequired[Event] + + class JSONRPCClient: - def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False, - call_timeout=undefined, verify_ssl=True): - """ - Arguments: - :reserved_ports(bool): should the local socket used a reserved port + """The object used to interface with the TrueNAS API. + + Keeps track of the calls made, jobs submitted, and callbacks. Maintains a websocket connection using a `WSClient`. + + """ + def __init__(self, uri: str | None=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False, + call_timeout: float | UndefinedType=undefined, verify_ssl=True): + """Initialize a `JSONRPCClient`. + + Args: + uri: The address to connect to. Defaults to the local middlewared socket. + reserved_ports: `True` if the local socket should use a reserved port. + py_exceptions: `True` if the server should include exception objects in + `message['error']['data']['py_exception']`. + log_py_exceptions: `True` if exception tracebacks from API calls should be logged. + call_timeout: Number of seconds to allow an API call before timing out. Can be overridden on a per-call + basis. Defaults to `CALL_TIMEOUT`. + verify_ssl: `True` if SSL certificate should be verified before connecting. + + Raises: + ClientException: `WSClient` timed out or some other connection error occurred. + """ if uri is None: uri = f'ws+unix://{MIDDLEWARE_RUN_DIR}/middlewared.sock' @@ -223,19 +414,19 @@ def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_e if call_timeout is undefined: call_timeout = CALL_TIMEOUT - self._calls = {} - self._jobs = defaultdict(dict) + self._calls: dict[str, Call] = {} + self._jobs: defaultdict[str, _JobDict] = defaultdict(dict) # type: ignore self._jobs_lock = Lock() self._jobs_watching = False self._py_exceptions = py_exceptions self._log_py_exceptions = log_py_exceptions self._call_timeout = call_timeout - self._event_callbacks = defaultdict(list) + self._event_callbacks: defaultdict[str, list[_Payload]] = defaultdict(list) self._set_options_call: Call | None = None self._closed = Event() self._connected = Event() - self._connection_error = None - self._ws_connection_error = None + self._connection_error: str | None = None + self._ws_connection_error: WebSocketException self._ws = WSClient( uri, client=self, @@ -246,7 +437,7 @@ def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_e self._connected.wait(10) if not self._connected.is_set(): raise ClientException('Failed connection handshake') - if self._ws_connection_error is not None: + if hasattr(self, '_ws_connection_error'): if isinstance(self._ws_connection_error, WebSocketException): raise self._ws_connection_error if self._connection_error is not None: @@ -257,10 +448,17 @@ def __enter__(self): def __exit__(self, typ, value, traceback): self.close() - if typ is not None: - raise def _send(self, data): + """Send data to the server using `WSClient`. + + Args: + data: Object serializable with `ejson`. + + Raises: + ClientException: Connection to the server closed prematurely. + + """ try: self._ws.send(json.dumps(data)) except (AttributeError, WebSocketConnectionClosedException): @@ -268,13 +466,29 @@ def _send(self, data): # running tasks in the event loop (i.e. failover.call_remote failover.get_disks_local) raise ClientException('Unexpected closure of remote connection', errno.ECONNABORTED) - def _recv(self, message): + def _recv(self, message: JSONRPCMessage): + """Process a deserialized JSON-RPC v2.0 message from the server. + + The TrueNAS websocket `JSONRPCClient` receives data from the server in two standard forms: Notifications and + Responses. These are defined in the JSON-RPC v2.0 protocol at https://www.jsonrpc.org/specification. + + In the TrueNAS websocket client, Notifications are used to communicate subscription updates including when a + subscription is terminated. These subscription updates also include updates about jobs submitted by the client + via `core.get_jobs`. + + A Response is the server's answer to a Request sent by the client which may or may not come back immediately + depending on the Request sent. + + Args: + message: Deserialized JSON-RPC v2.0 data from the server. + + """ try: if 'method' in message: - params = message['params'] match message['method']: case 'collection_update': if self._event_callbacks: + params = message['params'] if '*' in self._event_callbacks: for event in self._event_callbacks['*']: self._run_callback(event, [params['msg'].upper()], params) @@ -282,6 +496,7 @@ def _recv(self, message): for event in self._event_callbacks[params['collection']]: self._run_callback(event, [params['msg'].upper()], params) case 'notify_unsubscribed': + params = message['params'] if params['collection'] in self._event_callbacks: for event in self._event_callbacks[params['collection']]: if 'error' in params: @@ -295,7 +510,7 @@ def _recv(self, message): try: self._parse_error(message['error'], self._set_options_call) except Exception: - logger.error('Unhandled exception in Client._parse_error', exc_info=True) + logger.error('Unhandled exception in JSONRPCClient._parse_error', exc_info=True) else: logger.error('Error setting client options: %r', self._set_options_call.error) self._connected.set() @@ -306,7 +521,7 @@ def _recv(self, message): try: self._parse_error(message['error'], call) except Exception: - logger.error('Unhandled exception in Client._parse_error', exc_info=True) + logger.error('Unhandled exception in JSONRPCClient._parse_error', exc_info=True) call.returned.set() self._unregister_call(call) else: @@ -314,14 +529,21 @@ def _recv(self, message): else: logger.error('Received unknown message %r', message) except Exception: - logger.error('Unhandled exception in Client._recv', exc_info=True) + logger.error('Unhandled exception in JSONRPCClient._recv', exc_info=True) - def _parse_error(self, error: dict, call: Call): + def _parse_error(self, error: ErrorObj, call: Call): + """Convert an error received from the server into a `ClientException` and store it. + + Args: + error: The JSON object received in an error Response. + call: The associated `Call` object with which to store the `ClientException`. + + """ code = JSONRPCError(error['code']) if self._py_exceptions and code in [JSONRPCError.INVALID_PARAMS, JSONRPCError.TRUENAS_CALL_ERROR]: data = error['data'] call.error = ClientException(data['reason'], data['error'], data['trace'], data['extra']) - if data.get('py_exception'): + if 'py_exception' in data: try: call.py_exception = pickle.loads(b64decode(data['py_exception'])) except Exception as e: @@ -334,16 +556,38 @@ def _parse_error(self, error: dict, call: Call): else: call.error = ClientException(code.name) - def _run_callback(self, event, args, kwargs): + def _run_callback(self, event: _Payload, args: Iterable[str], kwargs: CollectionUpdateParams): + """Call the passed `_Payload`'s callback function. + + Block until the callback returns if `event['sync']` is set. Otherwise, run in a separate daemon `Thread`. + + Args: + event: The `_Payload` whose callback to run. + args: Positional arguments to the callback. + kwargs: Keyword arguments to the callback. + + """ + if event['callback'] is None: + return if event['sync']: event['callback'](*args, **kwargs) else: Thread(target=event['callback'], args=args, kwargs=kwargs, daemon=True).start() def on_open(self): + """Make an API call to `core.set_options` to configure how middlewared sends its responses.""" self._set_options_call = self.call("core.set_options", {"py_exceptions": self._py_exceptions}, background=True) - def on_close(self, code, reason=None): + def on_close(self, code: int, reason: str | None=None): + """Close this `JSONRPCClient` in response to the `WebSocketApp` closing. + + End all unanswered calls and unreturned jobs with an error. + + Args: + code: One of several closing frame status codes defined in `websocket._abnf`. + reason: A message to accompany the closing code and provide more information. + + """ error = f'WebSocket connection closed with code={code!r}, reason={reason!r}' self._connection_error = error @@ -371,22 +615,33 @@ def on_close(self, code, reason=None): self._closed.set() - def _register_call(self, call): + def _register_call(self, call: Call): + """Save a `Call` and index it by its id.""" self._calls[call.id] = call - def _unregister_call(self, call): + def _unregister_call(self, call: Call): + """Remove a `Call` after it has returned.""" self._calls.pop(call.id, None) - def _jobs_callback(self, mtype, **message): - """ - Method to process the received job events. + def _jobs_callback(self, mtype: str, *, fields: JobFields, **message): + """Process a received job event. + + Update the saved job info, execute its saved callback in the background, and set its "__ready" flag if its + "state" is received. + + Args: + mtype: Indicates if the job state has changed. + **message: The members contained in `CollectionUpdateParams`. + + Keyword Args: + fields (JobFields): Contains job id and other information about the job from the server. + """ - fields = message.get('fields') job_id = fields['id'] with self._jobs_lock: if fields: job = self._jobs[job_id] - job.update(fields) + job.update(**fields) if callable(job.get('__callback')): Thread(target=job['__callback'], args=(job,), daemon=True).start() if mtype == 'CHANGED' and job['state'] in ('SUCCESS', 'FAILED', 'ABORTED'): @@ -400,13 +655,37 @@ def _jobs_callback(self, mtype, **message): event.set() def _jobs_subscribe(self): - """ - Subscribe to job updates, calling `_jobs_callback` on every new event. - """ + """Subscribe to job updates, calling `_jobs_callback` on every new event.""" self._jobs_watching = True self.subscribe('core.get_jobs', self._jobs_callback, sync=True) - def call(self, method, *params, background=False, callback=None, job=False, register_call=None, timeout=undefined): + def call(self, method: str, *params, background=False, callback: _JobCallback | None=None, + job: Literal['RETURN'] | bool=False, register_call: bool | None=None, + timeout: float | UndefinedType=undefined) -> Any: + """The primary way to send call requests to the API. + + Send a JSON-RPC v2.0 Request to the server. + + Args: + method: An API endpoint to call. + *params: Arguments to pass to the endpoint. + background: If `background=True`, send the request and return a `Call` object before receiving a response. + By default, wait for the call to return instead. + callback: The callback to pass to the job if `job` is set. + job: If set, subscribe to job updates and if `background=False`, create a `Job`. If `job='RETURN'`, return + the `Job` object rather than just its result. + timeout: Number of seconds to allow the call before timing out if `background=False`. + + Returns: + Call: If `background` is set, return an object representing the request-response pair. + Job: If `job='RETURN'`, return the `Job` object. + Any: Otherwise, return the result of the call. + + Raises: + ClientException: Connection to the server closed prematurely or the call ended in error. + CallTimeout: The call took longer than `timeout` seconds to return. + + """ if register_call is None: register_call = not background @@ -436,17 +715,36 @@ def call(self, method, *params, background=False, callback=None, job=False, regi if not background: self._unregister_call(c) - def wait(self, c, *, callback=None, job=False, timeout=undefined): + def wait(self, c: Call, *, callback: _JobCallback | None=None, job: Literal['RETURN'] | bool=False, + timeout: float | UndefinedType=undefined) -> Any: + """Wait for an API call to return and return its result. + + Args: + c: The `Call` object containing the data that was sent. + callback: The callback to pass to the job if `job` is set. + job: If set, create a `Job`. If `job='RETURN'`, return the `Job` object rather than just its result. + timeout: Override the default number of seconds until a timeout exception occurs. + + Returns: + Job: If `job='RETURN'`, return the `Job` object. + Any: If `job=True`, return the job's result. Otherwise, return the call's result. + + Raises: + CallTimeout: The call took longer than `timeout` seconds to return. + ClientException: The call ended in error and `py_exception` was not enabled for `c`. + BaseException: The call ended in error and `py_exception` was enabled for `c`. + + """ if timeout is undefined: timeout = self._call_timeout try: - if not c.returned.wait(timeout): + if not c.returned.wait(timeout): # type: ignore raise CallTimeout() if c.error: if c.py_exception: - if self._log_py_exceptions: + if self._log_py_exceptions and c.error.trace: logger.error(c.error.trace["formatted"]) raise c.py_exception else: @@ -463,14 +761,34 @@ def wait(self, c, *, callback=None, job=False, timeout=undefined): self._unregister_call(c) @staticmethod - def event_payload(): + def event_payload() -> _Payload: + """Create an empty payload. + + Returns: + _Payload: Empty `_Payload`. + + """ return { 'callback': None, 'sync': False, 'event': Event(), } - def subscribe(self, name, callback, payload=None, sync=False): + def subscribe(self, name: str, callback: _EventCallbackProtocol, payload: _Payload | None=None, + sync: bool=False) -> str: + """Subscribe to an event by calling `core.subscribe`. + + Args: + name: The name of the event to subscribe to. + callback: A procedure to call when an event is triggered. + payload: Dictionary containing subscription information. + sync: If `True`, main client thread blocks until `callback` finishes each time it is invoked. Otherwise, + run `callback` in the background as a daemon `Thread`. + + Returns: + str: The `_Payload` id assigned by `core.subscribe`. + + """ payload = payload or self.event_payload() payload.update({ 'callback': callback, @@ -480,27 +798,56 @@ def subscribe(self, name, callback, payload=None, sync=False): payload['id'] = self.call('core.subscribe', name, timeout=10) return payload['id'] - def unsubscribe(self, id_): + def unsubscribe(self, id_: str): + """Call `core.unsubscribe` and remove all associated `_Payload`s + + Args: + id_: `id` of the `_Payload` to remove. + + """ self.call('core.unsubscribe', id_) for k, events in list(self._event_callbacks.items()): - events = [v for v in events if v['id'] != id_] + events = [v for v in events if v.get('id') != id_] if events: self._event_callbacks[k] = events else: self._event_callbacks.pop(k) - def ping(self, timeout=10): + def ping(self, timeout: float=10) -> Literal['pong']: + """Call `core.ping` to verify connection to the server. + + Args: + timeout: Number of seconds to allow before raising `CallTimeout`. + + Raises: + ClientException: Connection to the server closed prematurely or the call ended in error. + CallTimeout: The call took longer than `timeout` seconds to return. + + """ c = self.call('core.ping', background=True, register_call=True) return self.wait(c, timeout=timeout) def close(self): + """Allow one second for the `WSClient` to close.""" self._ws.close() # Wait for websocketclient thread to close self._closed.wait(1) - self._ws = None + del self._ws def main(): + """The entry point for midclt. Run `midclt -h` to see usage. + + Sub-commands: + call, ping, subscribe + + Options: + -h, -q, -u URI, -U USERNAME, -P PASSWORD, -K API_KEY, -t TIMEOUT + + Raises: + ValueError: Login failed (`midclt call`) or a subscription terminated with an error (`midclt subscribe`). + + """ parser = argparse.ArgumentParser() parser.add_argument('-q', '--quiet', action='store_true') parser.add_argument('-u', '--uri') @@ -557,7 +904,8 @@ def from_json(args): if args.job_print == 'progressbar': # display the job progress and status message while we wait - def callback(progress_bar, job): + def pb_callback(progress_bar: ProgressBar, job: _JobDict): + """Update `progress_bar` with information in `job['progress']`.""" try: progress_bar.update( job['progress']['percent'], job['progress']['description'] @@ -566,16 +914,17 @@ def callback(progress_bar, job): print(f'Failed to update progress bar: {e!s}', file=sys.stderr) with ProgressBar() as progress_bar: - kwargs.update({ - 'job': True, - 'callback': lambda job: callback(progress_bar, job) - }) + kwargs.update( + job=True, + callback=lambda job: pb_callback(progress_bar, job) + ) rv = c.call(args.method[0], *list(from_json(args.method[1:])), **kwargs) progress_bar.finish() else: lastdesc = '' - def callback(job): + def callback(job: _JobDict): + """Print `job`'s description to `stderr` if it has changed.""" nonlocal lastdesc desc = job['progress']['description'] if desc is not None and desc != lastdesc: @@ -616,19 +965,20 @@ def callback(job): event = subscribe_payload['event'] number = 0 - def cb(mtype, **message): + def cb(mtype: str, **message): + """Print the event message and unsubscribe if the maximum number of events is reached.""" nonlocal number print(json.dumps(message)) number += 1 if args.number and number >= args.number: event.set() - c.subscribe(args.event, cb, subscribe_payload) + c.subscribe(args.event, cb, subscribe_payload) # type: ignore (`LegacyClient` does not return `_Payload`) if not event.wait(timeout=args.timeout): sys.exit(1) - if subscribe_payload['error']: + if 'error' in subscribe_payload and subscribe_payload['error']: raise ValueError(subscribe_payload['error']) sys.exit(0) diff --git a/truenas_api_client/config.py b/truenas_api_client/config.py index bcc066a..997f24b 100644 --- a/truenas_api_client/config.py +++ b/truenas_api_client/config.py @@ -1,3 +1,4 @@ import os CALL_TIMEOUT = int(os.environ.get("CALL_TIMEOUT", 60)) +"""Default number of seconds to allow an API call until timing out.""" diff --git a/truenas_api_client/ejson.py b/truenas_api_client/ejson.py index 2044c3b..721afa4 100644 --- a/truenas_api_client/ejson.py +++ b/truenas_api_client/ejson.py @@ -1,9 +1,42 @@ +"""Provides wrappers of the `json` module for handling Python sets and common objects of the `datetime` module. + +Specifically, this module allows `datetime.date`, `datetime.time`, +`datetime.datetime`, and `set` objects to be serialized and deserialized in +addition to the types handled by the `json` module (those types are listed +[here](https://docs.python.org/3.11/library/json.html#json.JSONDecoder)). + +Example:: + + >>> from ejson import dumps, loads + >>> obj = {'string', 4, date.today(), time(16, 22, 6)} + >>> serialized = dumps(obj) + >>> serialized + {"$set": [4, {"$type": "date", "$value": "2024-07-03"}, "string", {"$time": "16:22:06"}]} + >>> deserialized = loads(serialized) + +""" import calendar from datetime import date, datetime, time, timedelta, timezone import json class JSONEncoder(json.JSONEncoder): + """Custom JSON encoder that extends the default encoder to handle more types. + + In addition to the types already supported by `json.JSONEncoder`, this + encoder adds support for the following types: + + | Python | JSON | + | ----------------- | ------------------------------------------------- | + | datetime.date | {"$type": "date", "$value": string[YYYY-MM-DD]} | + | datetime.datetime | {"$date": number[Total milliseconds since EPOCH]} | + | datetime.time | {"$time": string[HH:MM:SS]} | + | set | {"$set": array[items...]} | + + Note: When serializing Python sets, the order that the elements appear in + the JSON array is undefined. + + """ def default(self, obj): if type(obj) is date: return {'$type': 'date', '$value': obj.isoformat()} @@ -19,13 +52,18 @@ def default(self, obj): return super(JSONEncoder, self).default(obj) -def object_hook(obj): +def object_hook(obj: dict): + """Used when deserializing `date`, `time`, `datetime`, and `set` objects. + + Passed as a kwarg to a JSON deserialization function like `json.dump()`. + + """ obj_len = len(obj) if obj_len == 1: if '$date' in obj: return datetime.fromtimestamp(obj['$date'] / 1000, tz=timezone.utc) + timedelta(milliseconds=obj['$date'] % 1000) if '$time' in obj: - return time(*[int(i) for i in obj['$time'].split(':')]) + return time(*[int(i) for i in obj['$time'].split(':')[:4]]) # type: ignore if '$set' in obj: return set(obj['$set']) if obj_len == 2 and '$type' in obj and '$value' in obj: @@ -35,12 +73,28 @@ def object_hook(obj): def dump(obj, fp, **kwargs): + """Wraps `json.dump()` and uses the custom `JSONEncoder`. + + Can serialize `date`, `time`, `datetime`, and `set` objects + to a file-like object. + + """ return json.dump(obj, fp, cls=JSONEncoder, **kwargs) -def dumps(obj, **kwargs): +def dumps(obj, **kwargs) -> str: + """Wraps `json.dumps()` and uses the custom `JSONEncoder`. + + Can serialize `date`, `time`, `datetime`, and `set` objects. + + """ return json.dumps(obj, cls=JSONEncoder, **kwargs) -def loads(obj, **kwargs): +def loads(obj: str | bytes | bytearray, **kwargs): + """Wraps `json.loads()` and uses a custom `object_hook` argument. + + Can deserialize `date`, `time`, `datetime`, and `set` objects. + + """ return json.loads(obj, object_hook=object_hook, **kwargs) diff --git a/truenas_api_client/exc.py b/truenas_api_client/exc.py index d6e3cf0..1bd27b0 100644 --- a/truenas_api_client/exc.py +++ b/truenas_api_client/exc.py @@ -1,8 +1,11 @@ -from collections import namedtuple +"""Defines general classes for handling exceptions which may be raised through the client.""" + +from collections.abc import Iterable import errno +from .jsonrpc import ErrorExtra, TruenasTraceback try: - from libzfs import Error as ZFSError + from libzfs import Error as ZFSError # pytype: disable=import-error except ImportError: # this happens on our CI/CD runners as they do not install the py-libzfs module to run our api integration tests LIBZFS = False @@ -11,21 +14,42 @@ class ReserveFDException(Exception): + """A `WSClient` instance failed to bind to a reserved port.""" pass class ErrnoMixin: + """Provides custom error codes and a function to get the name of an error code.""" + ENOMETHOD = 201 + """Service not found or method not found in service.""" ESERVICESTARTFAILURE = 202 + """Service failed to start.""" EALERTCHECKERUNAVAILABLE = 203 + """Alert checker unavailable.""" EREMOTENODEERROR = 204 + """Remote node responded with an error.""" EDATASETISLOCKED = 205 + """Locked datasets.""" EINVALIDRRDTIMESTAMP = 206 + """Invalid RRD timestamp.""" ENOTAUTHENTICATED = 207 + """Client not authenticated.""" ESSLCERTVERIFICATIONERROR = 208 + """SSL certificate/host key could not be verified.""" @classmethod - def _get_errname(cls, code): + def _get_errname(cls, code: int) -> str | None: + """Get the name of an error given its error code. + + Args: + code: An error code for either a ZFSError or a custom error defined in this class. + + Returns: + str: The name of the associated error. + None: `code` does not match any known errors. + + """ if LIBZFS and 2000 <= code <= 2100: return 'EZFS_' + ZFSError(code).name for k, v in cls.__dict__.items(): @@ -34,7 +58,19 @@ def _get_errname(cls, code): class ClientException(ErrnoMixin, Exception): - def __init__(self, error, errno=None, trace=None, extra=None): + """Represents any exception that might arise from a `Client`.""" + + def __init__(self, error: str, errno: int | None=None, trace: TruenasTraceback | None=None, + extra: list[ErrorExtra] | None=None): + """Initialize `ClientException`. + + Args: + error: An error message offering a reason for the exception. + errno: An error code to classify the error. + trace: Traceback information from the server. + extra: Any other errors pertaining to the exception. + + """ self.errno = errno self.error = error self.trace = trace @@ -44,14 +80,19 @@ def __str__(self): return self.error -Error = namedtuple('Error', ['attribute', 'errmsg', 'errcode']) +class ValidationErrors(ClientException): + """A raisable collection of `ErrorExtra`s that indicates a validation error occurred on the server.""" + def __init__(self, errors: Iterable[ErrorExtra]): + """Initialize `ValidationErrors`. -class ValidationErrors(ClientException): - def __init__(self, errors): + Args: + errors: List of error codes and messages from the server. + + """ self.errors = [] for e in errors: - self.errors.append(Error(e[0], e[1], e[2])) + self.errors.append(ErrorExtra(e[0], e[1], e[2])) super().__init__(str(self)) @@ -64,5 +105,7 @@ def __str__(self): class CallTimeout(ClientException): + """A special `ClientException` raised when a `Call` times out before it can return a result.""" def __init__(self): + """Initiate a `ClientException` with message `"Call timeout"`.""" super().__init__("Call timeout", errno.ETIMEDOUT) diff --git a/truenas_api_client/jsonrpc.py b/truenas_api_client/jsonrpc.py index c48f61f..cf11f87 100644 --- a/truenas_api_client/jsonrpc.py +++ b/truenas_api_client/jsonrpc.py @@ -1,4 +1,10 @@ +"""Collection of types used to reference the structure of JSONRPC-2.0 messages received from the server. + +https://www.jsonrpc.org/specification + +""" import enum +from typing import Any, Literal, NamedTuple, NotRequired, TypeAlias, TypedDict class JSONRPCError(enum.Enum): @@ -11,3 +17,96 @@ class JSONRPCError(enum.Enum): # Custom error codes from -32000 to -32099 as allowed by the specification above TRUENAS_TOO_MANY_CONCURRENT_CALLS = -32000 TRUENAS_CALL_ERROR = -32001 + + +class JobProgress(TypedDict): + percent: float + description: str + + +class ErrorExtra(NamedTuple): + attribute: str + errmsg: str + errcode: int + + +class ExcInfo(TypedDict): + type: str + extra: list[ErrorExtra] | None + repr: str + + +class JobFields(TypedDict): + id: str + state: str + progress: JobProgress + result: Any + exc_info: ExcInfo + error: str + exception: str + + +class CollectionUpdateParams(TypedDict): + msg: str + collection: str + id: NotRequired[Any] + fields: NotRequired[JobFields] + extra: NotRequired[dict] + + +class CollectionUpdate(TypedDict): + jsonrpc: Literal['2.0'] + method: Literal['collection_update'] + params: CollectionUpdateParams + + +TruenasTraceback = TypedDict('TruenasTraceback', { + 'class': str, + 'frames': NotRequired[list[dict[str, Any]]], + 'formatted': str, + 'repr': str, +}) +# Has to be defined this way because `class` is a keyword. + + +class TruenasError(TypedDict): + error: int + errname: str + reason: str + trace: TruenasTraceback | None + extra: list[ErrorExtra] + py_exception: NotRequired[str] + + +class NotifyUnsubscribedParams(TypedDict): + collection: str + error: TruenasError + + +class NotifyUnsubscribed(TypedDict): + jsonrpc: Literal['2.0'] + method: Literal['notify_unsubscribed'] + params: NotifyUnsubscribedParams + + +class SuccessResponse(TypedDict): + jsonrpc: Literal['2.0'] + result: Any + id: str + + +class ErrorObj(TypedDict): + code: int + message: str | None + data: TruenasError + + +class ErrorResponse(TypedDict): + jsonrpc: Literal['2.0'] + error: ErrorObj + id: str + + +Notification: TypeAlias = CollectionUpdate | NotifyUnsubscribed +Response: TypeAlias = SuccessResponse | ErrorResponse +JSONRPCMessage: TypeAlias = Notification | Response diff --git a/truenas_api_client/legacy.py b/truenas_api_client/legacy.py index 1377590..60ab72b 100644 --- a/truenas_api_client/legacy.py +++ b/truenas_api_client/legacy.py @@ -1,3 +1,5 @@ +"""The websocket client prior to implementing JSONRPC-2.0 protocol. Used for backwards compatibility.""" + from base64 import b64decode from collections import defaultdict import errno @@ -20,7 +22,7 @@ from . import ejson as json from .config import CALL_TIMEOUT from .exc import ReserveFDException, ClientException, ValidationErrors, CallTimeout -from .utils import MIDDLEWARE_RUN_DIR, undefined +from .utils import MIDDLEWARE_RUN_DIR, undefined, UndefinedType logger = logging.getLogger(__name__) @@ -190,7 +192,7 @@ def result(self): class LegacyClient: def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False, - call_timeout=undefined, verify_ssl=True): + call_timeout: float | UndefinedType=undefined, verify_ssl=True): """ Arguments: :reserved_ports(bool): should the local socket used a reserved port diff --git a/truenas_api_client/utils.py b/truenas_api_client/utils.py index 55e18c6..3793bcf 100644 --- a/truenas_api_client/utils.py +++ b/truenas_api_client/utils.py @@ -1,18 +1,52 @@ +"""Utility classes for use in the TrueNAS API client. + +Includes `Struct` for creating regular objects out of `Mapping`s with string +keys and `ProgressBar` for displaying the progress of a task in the CLI. + +Attributes: + MIDDLEWARE_RUN_DIR: Directory containing the middlewared Unix domain socket. + undefined: A dummy object similar in purpose to `None` that indicates an unset variable. + +""" import sys +from typing import Any, final, Mapping + MIDDLEWARE_RUN_DIR = '/var/run/middleware' -undefined = object() + +@final +class UndefinedType: + def __new__(cls): + if not hasattr(cls, '_instance'): + cls._instance = super().__new__(cls) + return cls._instance +undefined = UndefinedType() class Struct: - """ - Simpler wrapper to access using object attributes instead of keys. + """Simpler wrapper to access using object attributes instead of keys. + This is meant for compatibility when switch scripts to use middleware client instead of django directly. + + Example:: + + >>> d = {'a':1, 'b':'2', 'c': [3, '4', {'d':5}], 'e':{'f':{'g':6}}} + >>> s = Struct(d) + >>> s.c + [3, '4', {'d':5}] + >>> s.e.f.g + 6 + """ + def __init__(self, mapping: Mapping[str, Any]): + """Initialize a `Struct` with a `Mapping`. - def __init__(self, mapping): + Args: + mapping: Contains string keys that will become the `Struct`'s attribute names. + + """ for k, v in mapping.items(): if isinstance(v, dict): setattr(self, k, Struct(v)) @@ -21,6 +55,26 @@ def __init__(self, mapping): class ProgressBar(object): + """A simple text-based progress bar that writes to `sys.stderr`. + + Status: (message) + Total Progress: [#####################___________________] 53.00% + + Example: + ``` + with ProgressBar() as pb: + for step in range(1, 101): + pb.update(step) + ``` + + Attributes: + message: String to display next to "Status". + percentage: A float from `0.0` to `100.0` representing the total progress. + write_stream: This is `sys.stderr` by default but can be any `TextIO`. + used_flag: Indicates whether `update()` has been called. + extra: A string or printable object to display after the status message. + + """ def __init__(self): self.message = None self.percentage = 0 @@ -32,6 +86,11 @@ def __enter__(self): return self def draw(self): + """Erase the previous progress bar and draw an updated one. + + If `self.extra` is set, will display "Status: (message) Extra: (extra)". + + """ progress_width = 40 filled_width = int(self.percentage * progress_width) self.write_stream.write('\033[2K\033[A\033[2K\r') @@ -47,7 +106,14 @@ def draw(self): ) self.write_stream.flush() - def update(self, percentage=None, message=None): + def update(self, percentage: float | None=None, message: str | None=None): + """Update the progress bar with a new percentage and/or message, redrawing it. + + Args: + percentage: The new percentage to display. A value of `100.0` represents full. + message: The "Status" message to display above the progress bar. + + """ if not self.used_flag: self.write_stream.write('\n') self.used_flag = True @@ -58,6 +124,7 @@ def update(self, percentage=None, message=None): self.draw() def finish(self): + """Fill the progress bar to 100%.""" self.percentage = 1 def __exit__(self, typ, value, traceback):