diff --git a/.github/workflows/pytype.yml b/.github/workflows/pytype.yml
new file mode 100644
index 0000000..5d84e6e
--- /dev/null
+++ b/.github/workflows/pytype.yml
@@ -0,0 +1,27 @@
+name: "pytype"
+
+on:
+ pull_request:
+ types:
+ - 'synchronize'
+ - 'opened'
+
+jobs:
+ type-check:
+
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v3
+ - name: Setup Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.11
+
+ - name: Install Dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install pytype websocket-client
+
+ - name: Type-Check
+ run: pytype -j auto .
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..2b2e6e7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,96 @@
+
+
+
+
+
+
+# TrueNAS Websocket Client
+
+*Found an issue? Please report it on our [Jira bugtracker](https://jira.ixsystems.com).*
+
+## About
+
+The TrueNAS websocket client provides the command line tool `midclt` and the means to easily communicate with [middleware](https://github.com/truenas/middleware) using Python by making calls through the [websocket API](https://www.truenas.com/docs/api/scale_websocket_api.html). The client can connect to a local TrueNAS instance by default or to a specified remote socket. This offers an alternative to going through the [web UI](https://github.com/truenas/webui) or connecting via ssh.
+
+By default, communication facilitated by the API between the client and middleware now uses the [JSON-RPC 2.0](https://www.jsonrpc.org/specification) protocol. However, it is still possible to use the legacy client by passing a legacy uri, e.g. `'ws://some.truenas.address/websocket'` as opposed to `'ws://some.truenas.address/api/current'`.
+
+## Getting Started
+
+TrueNAS comes with this client preinstalled, but it is also possible to use the TrueNAS websocket client from a non-TrueNAS host.
+
+On a non-TrueNAS host, ensure that Git is installed and run `pip install git+https://github.com/truenas/api_client.git` to automatically install dependencies. You may alternatively clone this repository and run `python setup.py install`. Using a Python venv is recommended.
+
+## Usage
+
+The `midclt` command (not to be confused with the [TrueNAS CLI](https://github.com/truenas/midcli)) provides a way to make direct API calls through the client. To view its syntax, enter `midclt -h`. The `-h` option can also be used with any of `midclt`'s subcommands.
+
+The client's default behavior is to connect to the localhost's middlewared socket. For a remote connection, e.g. from a Windows host, you must specify the `--uri` option and authenticate with either user credentials or an API key. For example: `midclt --uri ws:///api/current -K key ...`
+
+### Make local API calls
+
+```
+root@my_truenas[~]# midclt call user.create '{"full_name": "John Doe", "username": "user", "password": "pass", "group_create": true}'
+```
+
+### Login to a remote TrueNAS
+
+```
+root@my_truenas[~]# midclt --uri ws://some.other.truenas/api/current -U user -P password call system.info
+```
+
+### Start a job
+
+```
+root@my_truenas[~]# midclt call -j pool.dataset.lock mypool/mydataset
+```
+
+## Development
+
+The TrueNAS API client can also be used in Python scripts.
+
+### Make local API calls
+
+```python
+from truenas_api_client import Client
+
+with Client() as c: # Local IPC
+ print(c.ping()) # pong
+ user = {"full_name": "John Doe", "username": "user", "password": "pass", "group_create": True}
+ entry_id = c.call("user.create", user)
+ user = c.call("user.get_instance", entry_id)
+ print(user["full_name"]) # John Doe
+```
+
+### Login with a user account or an API key
+
+```python
+# User account
+with Client(uri="ws://some.other.truenas/api/current") as c:
+ c.call("auth.login", username, password)
+
+# API key
+with Client(uri="ws://some.other.truenas/api/current") as c:
+ c.call("auth.login_with_api_key", key)
+```
+
+### Start a job
+
+```python
+with Client() as c:
+ is_locked = c.call("pool.dataset.lock", "mypool/mydataset", job=True)
+ if is_locked:
+ args = {"datasets": [{"name": "mypool/mydataset", "passphrase": "passphrase"}]}
+ c.call("pool.dataset.unlock", "mypool/mydataset", args, job=True)
+```
+
+## Helpful Links
+
+
+
+
+
+- [Websocket API docs](https://www.truenas.com/docs/api/scale_websocket_api.html)
+- [Middleware repo](https://github.com/truenas/middleware)
+- [Official TrueNAS Documentation Hub](https://www.truenas.com/docs/)
+- [Get started building TrueNAS Scale](https://github.com/truenas/scale-build)
+- [Forums](https://www.truenas.com/community/)
diff --git a/truenas_api_client/__init__.py b/truenas_api_client/__init__.py
index 462366b..520bef7 100644
--- a/truenas_api_client/__init__.py
+++ b/truenas_api_client/__init__.py
@@ -1,6 +1,37 @@
+"""Provides a simple way to call middleware API endpoints using a websocket connection.
+
+The full websocket API documentation can be found at https://www.truenas.com/docs/api/core_websocket_api.html.
+
+Example::
+
+ $ midclt ping && echo 'Connected' || echo 'Unable to ping'
+ Connected
+ $ midclt call user.create '{"full_name": "John Doe", "username": "user", "password": "pass", "group_create": true}'
+ 70
+ $ midclt call user.get_instance 70
+ {"id": 70, "uid": 3000, "username": "user", "unixhash": ... }
+ $ midclt call user.query '[["full_name", "=", "John Doe"]]'
+ {"id": 70, "uid": 3000, "username": "user", "unixhash": ... }
+
+Example::
+
+ with Client() as c: # Local IPC
+ print(c.ping()) # pong
+ user = {"full_name": "John Doe", "username": "user", "password": "pass", "group_create": True}
+ id = c.call("user.create", user)
+ user = c.call("user.get_instance", id)
+ print(user["full_name"]) # John Doe
+
+Example::
+
+ c = Client("ws://example.com/api/current") # Remote websocket connection
+ c.close()
+
+"""
import argparse
from base64 import b64decode
from collections import defaultdict
+from collections.abc import Callable, Iterable
import errno
import logging
import pickle
@@ -10,6 +41,7 @@
import sys
from threading import Event, Lock, Thread
import time
+from typing import Any, Literal, NotRequired, Protocol, TypeAlias, TypedDict
import urllib.parse
import uuid
@@ -24,15 +56,35 @@
from .config import CALL_TIMEOUT
from .exc import ReserveFDException, ClientException, ErrnoMixin, ValidationErrors, CallTimeout
from .legacy import LegacyClient
-from .jsonrpc import JSONRPCError
-from .utils import MIDDLEWARE_RUN_DIR, ProgressBar, undefined
+from .jsonrpc import CollectionUpdateParams, ErrorObj, JobFields, JSONRPCError, JSONRPCMessage, TruenasError
+from .utils import MIDDLEWARE_RUN_DIR, ProgressBar, undefined, UndefinedType
logger = logging.getLogger(__name__)
class Client:
- def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False,
- call_timeout=undefined, verify_ssl=True):
+ """Implicit wrapper of either a `JSONRPCClient` or a `LegacyClient`."""
+
+ def __init__(self, uri: str | None=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False,
+ call_timeout: float | UndefinedType=undefined, verify_ssl=True):
+ """Initialize either a `JSONRPCClient` or a `LegacyClient`.
+
+ Use `JSONRPCClient` unless `uri` ends with '/websocket'.
+
+ Args:
+ uri: The address to connect to. Defaults to the local middlewared socket.
+ reserved_ports: `True` if the local socket should use a reserved port.
+ py_exceptions: `True` if the server should include exception objects in
+ `message['error']['data']['py_exception']`.
+ log_py_exceptions: `True` if exception tracebacks from API calls should be logged.
+ call_timeout: Number of seconds to allow an API call before timing out. Can be overridden on a per-call
+ basis. Defaults to `CALL_TIMEOUT`.
+ verify_ssl: `True` if SSL certificate should be verified before connecting.
+
+ Raises:
+ ClientException: `WSClient` timed out or some other connection error occurred.
+
+ """
if uri is not None and uri.endswith('/websocket'):
client_class = LegacyClient
else:
@@ -51,16 +103,36 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class WSClient:
- def __init__(self, url, *, client, reserved_ports=False, verify_ssl=True):
+ """A supporter class for `JSONRPCClient` that manages the `WebSocket` connection to the server.
+
+ The object used by `JSONRPCClient` to send and receive data.
+
+ """
+ def __init__(self, url: str, *, client: 'JSONRPCClient', reserved_ports: bool=False, verify_ssl: bool=True):
+ """Initialize a `WSClient`.
+
+ Args:
+ url: The websocket to connect to. `ws://` or `wss://` for secure connection.
+ client: Reference to the `JSONRPCClient` instance that uses this object.
+ reserved_ports: `True` if the `socket` should bind to a reserved port, i.e. 600-1024.
+ verify_ssl: `True` if SSL certificate should be verified before connecting.
+
+ """
self.url = url
self.client = client
self.reserved_ports = reserved_ports
self.verify_ssl = verify_ssl
- self.socket = None
- self.app = None
+ self.socket: socket.socket
+ self.app: WebSocketApp
def connect(self):
+ """Connect a `socket` and start a `WebSocketApp` in a daemon `Thread`.
+
+ Raises:
+ Exception: The `socket` failed to connect.
+
+ """
unix_socket_prefix = "ws+unix://"
if self.url.startswith(unix_socket_prefix):
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
@@ -93,14 +165,27 @@ def connect(self):
)
Thread(daemon=True, target=self.app.run_forever).start()
- def send(self, data):
+ def send(self, data: bytes | str):
+ """Send data to the server by calling `WebSocketApp.send()`.
+
+ Args:
+ data: The serialized JSON-RPC v2.0-formatted request to send.
+
+ """
return self.app.send(data)
def close(self):
+ """Cleanly close the `WebSocket` connection to the server."""
self.app.close()
self.client.on_close(STATUS_NORMAL)
def _bind_to_reserved_port(self):
+ """Bind to a random port in the 600-1024 range.
+
+ Raises:
+ ReserveFDException: Five failed attempts with different ports.
+
+ """
# linux doesn't have a mechanism to allow the kernel to dynamically
# assign ports in the "privileged" range (i.e. 600 - 1024) so we
# loop through and call bind() on a privileged port explicitly since
@@ -125,6 +210,11 @@ def _bind_to_reserved_port(self):
raise ReserveFDException()
def _on_open(self, app):
+ """Callback passed to the `WebSocketApp` to execute when `run_forever` is called.
+
+ Configure the `socket` and call `client.on_open()`.
+
+ """
# TCP keepalive settings don't apply to local unix sockets
if 'ws+unix' not in self.url:
# enable keepalives on the socket
@@ -150,29 +240,77 @@ def _on_open(self, app):
self.client.on_open()
def _on_message(self, app, data):
+ """Callback passed to the `WebSocketApp` to execute when data is received.
+
+ Pass the received data to the `JSONRPCClient`.
+
+ """
self.client._recv(json.loads(data))
def _on_error(self, app, e):
+ """Callback passed to the `WebSocketApp` to execute when an error occurs.
+
+ Log the error.
+
+ """
logger.warning("Websocket client error: %r", e)
self.client._ws_connection_error = e
def _on_close(self, app, code, reason):
+ """Callback passed to the `WebSocketApp` to execute when it closes.
+
+ Close the `JSONRPCClient`.
+
+ """
self.client.on_close(code, reason)
class Call:
- def __init__(self, method, params):
+ """An encapsulation of the data from a single request-response pair."""
+
+ def __init__(self, method: str, params: tuple):
+ """Initialize a `Call` object with an automatically-assigned id.
+
+ Args:
+ method: The API method being called.
+ params: Arguments passed to the method.
+
+ """
self.id = str(uuid.uuid4())
self.method = method
self.params = params
self.returned = Event()
- self.result = None
- self.error = None
- self.py_exception = None
+ self.result: Any = None
+ self.error: ClientException | None = None
+ self.py_exception: BaseException | None = None
+
+
+class _JobDict(JobFields):
+ """Contains data received from the server for a particular running job."""
+ __ready: Event
+ """Is set when the job returns or ends in error."""
+ __callback: '_JobCallback | None'
+ """Procedure to execute each time a job update is received."""
+
+
+_JobCallback: TypeAlias = Callable[[_JobDict], None]
class Job:
- def __init__(self, client, job_id, callback=None):
+ """A long-running background process on the server initiated by an API call.
+
+ Every `Job` is responsible for a corresponding `_JobDict` in the client's list of jobs.
+
+ """
+ def __init__(self, client: 'JSONRPCClient', job_id: str, callback: _JobCallback | None=None):
+ """Initialize `Job`.
+
+ Args:
+ client: Reference to the client that created this `Job` and receives updates on its progress.
+ job_id: The job id returned by the server. Index of this `Job` in the `client._jobs` dictionary.
+ callback: A procedure to be called every time a job event is received.
+
+ """
self.client = client
self.job_id = job_id
# If a job event has been received already then we must set an Event
@@ -190,6 +328,16 @@ def __repr__(self):
return f''
def result(self):
+ """Wait for the job to finish and return its result.
+
+ Returns:
+ Any: The job's result.
+
+ Raises:
+ ValidationErrors: The job failed due to one or more validation errors.
+ ClientException: No job event was received or it did not succeed.
+
+ """
# Wait indefinitely for the job event with state SUCCESS/FAILED/ABORTED
self.event.wait()
job = self.client._jobs.pop(self.job_id, None)
@@ -197,7 +345,7 @@ def result(self):
raise ClientException('No job event was received.')
if job['state'] != 'SUCCESS':
if job['exc_info'] and job['exc_info']['type'] == 'VALIDATION':
- raise ValidationErrors(job['exc_info']['extra'])
+ raise ValidationErrors(job['exc_info']['extra'] or [])
raise ClientException(
job['error'],
trace={
@@ -210,12 +358,55 @@ def result(self):
return job['result']
+class _EventCallbackProtocol(Protocol):
+ """Specifies how event callbacks should be defined."""
+ def __call__(self, mtype: str, **message: Any) -> None: ...
+
+
+class _Payload(TypedDict):
+ """Contains data for managing a subscription.
+
+ Attributes:
+ callback: Procedure to call when the event is triggered.
+ sync: If `True`, main client thread blocks until `callback` finishes each time it is invoked. Otherwise, run
+ `callback` in the background as a daemon `Thread`.
+ event: `Event` that is set when the subscription should end.
+ error: Information included in the Notification if the subscription ended in error.
+ id: Random UUID assigned by `core.subscribe`.
+ ready: For backwards compatibility with `LegacyClient`.
+
+ """
+ callback: _EventCallbackProtocol | None
+ sync: bool
+ event: Event
+ error: NotRequired[str | TruenasError | None]
+ id: NotRequired[str]
+ ready: NotRequired[Event]
+
+
class JSONRPCClient:
- def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False,
- call_timeout=undefined, verify_ssl=True):
- """
- Arguments:
- :reserved_ports(bool): should the local socket used a reserved port
+ """The object used to interface with the TrueNAS API.
+
+ Keeps track of the calls made, jobs submitted, and callbacks. Maintains a websocket connection using a `WSClient`.
+
+ """
+ def __init__(self, uri: str | None=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False,
+ call_timeout: float | UndefinedType=undefined, verify_ssl=True):
+ """Initialize a `JSONRPCClient`.
+
+ Args:
+ uri: The address to connect to. Defaults to the local middlewared socket.
+ reserved_ports: `True` if the local socket should use a reserved port.
+ py_exceptions: `True` if the server should include exception objects in
+ `message['error']['data']['py_exception']`.
+ log_py_exceptions: `True` if exception tracebacks from API calls should be logged.
+ call_timeout: Number of seconds to allow an API call before timing out. Can be overridden on a per-call
+ basis. Defaults to `CALL_TIMEOUT`.
+ verify_ssl: `True` if SSL certificate should be verified before connecting.
+
+ Raises:
+ ClientException: `WSClient` timed out or some other connection error occurred.
+
"""
if uri is None:
uri = f'ws+unix://{MIDDLEWARE_RUN_DIR}/middlewared.sock'
@@ -223,19 +414,19 @@ def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_e
if call_timeout is undefined:
call_timeout = CALL_TIMEOUT
- self._calls = {}
- self._jobs = defaultdict(dict)
+ self._calls: dict[str, Call] = {}
+ self._jobs: defaultdict[str, _JobDict] = defaultdict(dict) # type: ignore
self._jobs_lock = Lock()
self._jobs_watching = False
self._py_exceptions = py_exceptions
self._log_py_exceptions = log_py_exceptions
self._call_timeout = call_timeout
- self._event_callbacks = defaultdict(list)
+ self._event_callbacks: defaultdict[str, list[_Payload]] = defaultdict(list)
self._set_options_call: Call | None = None
self._closed = Event()
self._connected = Event()
- self._connection_error = None
- self._ws_connection_error = None
+ self._connection_error: str | None = None
+ self._ws_connection_error: WebSocketException
self._ws = WSClient(
uri,
client=self,
@@ -246,7 +437,7 @@ def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_e
self._connected.wait(10)
if not self._connected.is_set():
raise ClientException('Failed connection handshake')
- if self._ws_connection_error is not None:
+ if hasattr(self, '_ws_connection_error'):
if isinstance(self._ws_connection_error, WebSocketException):
raise self._ws_connection_error
if self._connection_error is not None:
@@ -257,10 +448,17 @@ def __enter__(self):
def __exit__(self, typ, value, traceback):
self.close()
- if typ is not None:
- raise
def _send(self, data):
+ """Send data to the server using `WSClient`.
+
+ Args:
+ data: Object serializable with `ejson`.
+
+ Raises:
+ ClientException: Connection to the server closed prematurely.
+
+ """
try:
self._ws.send(json.dumps(data))
except (AttributeError, WebSocketConnectionClosedException):
@@ -268,13 +466,29 @@ def _send(self, data):
# running tasks in the event loop (i.e. failover.call_remote failover.get_disks_local)
raise ClientException('Unexpected closure of remote connection', errno.ECONNABORTED)
- def _recv(self, message):
+ def _recv(self, message: JSONRPCMessage):
+ """Process a deserialized JSON-RPC v2.0 message from the server.
+
+ The TrueNAS websocket `JSONRPCClient` receives data from the server in two standard forms: Notifications and
+ Responses. These are defined in the JSON-RPC v2.0 protocol at https://www.jsonrpc.org/specification.
+
+ In the TrueNAS websocket client, Notifications are used to communicate subscription updates including when a
+ subscription is terminated. These subscription updates also include updates about jobs submitted by the client
+ via `core.get_jobs`.
+
+ A Response is the server's answer to a Request sent by the client which may or may not come back immediately
+ depending on the Request sent.
+
+ Args:
+ message: Deserialized JSON-RPC v2.0 data from the server.
+
+ """
try:
if 'method' in message:
- params = message['params']
match message['method']:
case 'collection_update':
if self._event_callbacks:
+ params = message['params']
if '*' in self._event_callbacks:
for event in self._event_callbacks['*']:
self._run_callback(event, [params['msg'].upper()], params)
@@ -282,6 +496,7 @@ def _recv(self, message):
for event in self._event_callbacks[params['collection']]:
self._run_callback(event, [params['msg'].upper()], params)
case 'notify_unsubscribed':
+ params = message['params']
if params['collection'] in self._event_callbacks:
for event in self._event_callbacks[params['collection']]:
if 'error' in params:
@@ -295,7 +510,7 @@ def _recv(self, message):
try:
self._parse_error(message['error'], self._set_options_call)
except Exception:
- logger.error('Unhandled exception in Client._parse_error', exc_info=True)
+ logger.error('Unhandled exception in JSONRPCClient._parse_error', exc_info=True)
else:
logger.error('Error setting client options: %r', self._set_options_call.error)
self._connected.set()
@@ -306,7 +521,7 @@ def _recv(self, message):
try:
self._parse_error(message['error'], call)
except Exception:
- logger.error('Unhandled exception in Client._parse_error', exc_info=True)
+ logger.error('Unhandled exception in JSONRPCClient._parse_error', exc_info=True)
call.returned.set()
self._unregister_call(call)
else:
@@ -314,14 +529,21 @@ def _recv(self, message):
else:
logger.error('Received unknown message %r', message)
except Exception:
- logger.error('Unhandled exception in Client._recv', exc_info=True)
+ logger.error('Unhandled exception in JSONRPCClient._recv', exc_info=True)
- def _parse_error(self, error: dict, call: Call):
+ def _parse_error(self, error: ErrorObj, call: Call):
+ """Convert an error received from the server into a `ClientException` and store it.
+
+ Args:
+ error: The JSON object received in an error Response.
+ call: The associated `Call` object with which to store the `ClientException`.
+
+ """
code = JSONRPCError(error['code'])
if self._py_exceptions and code in [JSONRPCError.INVALID_PARAMS, JSONRPCError.TRUENAS_CALL_ERROR]:
data = error['data']
call.error = ClientException(data['reason'], data['error'], data['trace'], data['extra'])
- if data.get('py_exception'):
+ if 'py_exception' in data:
try:
call.py_exception = pickle.loads(b64decode(data['py_exception']))
except Exception as e:
@@ -334,16 +556,38 @@ def _parse_error(self, error: dict, call: Call):
else:
call.error = ClientException(code.name)
- def _run_callback(self, event, args, kwargs):
+ def _run_callback(self, event: _Payload, args: Iterable[str], kwargs: CollectionUpdateParams):
+ """Call the passed `_Payload`'s callback function.
+
+ Block until the callback returns if `event['sync']` is set. Otherwise, run in a separate daemon `Thread`.
+
+ Args:
+ event: The `_Payload` whose callback to run.
+ args: Positional arguments to the callback.
+ kwargs: Keyword arguments to the callback.
+
+ """
+ if event['callback'] is None:
+ return
if event['sync']:
event['callback'](*args, **kwargs)
else:
Thread(target=event['callback'], args=args, kwargs=kwargs, daemon=True).start()
def on_open(self):
+ """Make an API call to `core.set_options` to configure how middlewared sends its responses."""
self._set_options_call = self.call("core.set_options", {"py_exceptions": self._py_exceptions}, background=True)
- def on_close(self, code, reason=None):
+ def on_close(self, code: int, reason: str | None=None):
+ """Close this `JSONRPCClient` in response to the `WebSocketApp` closing.
+
+ End all unanswered calls and unreturned jobs with an error.
+
+ Args:
+ code: One of several closing frame status codes defined in `websocket._abnf`.
+ reason: A message to accompany the closing code and provide more information.
+
+ """
error = f'WebSocket connection closed with code={code!r}, reason={reason!r}'
self._connection_error = error
@@ -371,22 +615,33 @@ def on_close(self, code, reason=None):
self._closed.set()
- def _register_call(self, call):
+ def _register_call(self, call: Call):
+ """Save a `Call` and index it by its id."""
self._calls[call.id] = call
- def _unregister_call(self, call):
+ def _unregister_call(self, call: Call):
+ """Remove a `Call` after it has returned."""
self._calls.pop(call.id, None)
- def _jobs_callback(self, mtype, **message):
- """
- Method to process the received job events.
+ def _jobs_callback(self, mtype: str, *, fields: JobFields, **message):
+ """Process a received job event.
+
+ Update the saved job info, execute its saved callback in the background, and set its "__ready" flag if its
+ "state" is received.
+
+ Args:
+ mtype: Indicates if the job state has changed.
+ **message: The members contained in `CollectionUpdateParams`.
+
+ Keyword Args:
+ fields (JobFields): Contains job id and other information about the job from the server.
+
"""
- fields = message.get('fields')
job_id = fields['id']
with self._jobs_lock:
if fields:
job = self._jobs[job_id]
- job.update(fields)
+ job.update(**fields)
if callable(job.get('__callback')):
Thread(target=job['__callback'], args=(job,), daemon=True).start()
if mtype == 'CHANGED' and job['state'] in ('SUCCESS', 'FAILED', 'ABORTED'):
@@ -400,13 +655,37 @@ def _jobs_callback(self, mtype, **message):
event.set()
def _jobs_subscribe(self):
- """
- Subscribe to job updates, calling `_jobs_callback` on every new event.
- """
+ """Subscribe to job updates, calling `_jobs_callback` on every new event."""
self._jobs_watching = True
self.subscribe('core.get_jobs', self._jobs_callback, sync=True)
- def call(self, method, *params, background=False, callback=None, job=False, register_call=None, timeout=undefined):
+ def call(self, method: str, *params, background=False, callback: _JobCallback | None=None,
+ job: Literal['RETURN'] | bool=False, register_call: bool | None=None,
+ timeout: float | UndefinedType=undefined) -> Any:
+ """The primary way to send call requests to the API.
+
+ Send a JSON-RPC v2.0 Request to the server.
+
+ Args:
+ method: An API endpoint to call.
+ *params: Arguments to pass to the endpoint.
+ background: If `background=True`, send the request and return a `Call` object before receiving a response.
+ By default, wait for the call to return instead.
+ callback: The callback to pass to the job if `job` is set.
+ job: If set, subscribe to job updates and if `background=False`, create a `Job`. If `job='RETURN'`, return
+ the `Job` object rather than just its result.
+ timeout: Number of seconds to allow the call before timing out if `background=False`.
+
+ Returns:
+ Call: If `background` is set, return an object representing the request-response pair.
+ Job: If `job='RETURN'`, return the `Job` object.
+ Any: Otherwise, return the result of the call.
+
+ Raises:
+ ClientException: Connection to the server closed prematurely or the call ended in error.
+ CallTimeout: The call took longer than `timeout` seconds to return.
+
+ """
if register_call is None:
register_call = not background
@@ -436,17 +715,36 @@ def call(self, method, *params, background=False, callback=None, job=False, regi
if not background:
self._unregister_call(c)
- def wait(self, c, *, callback=None, job=False, timeout=undefined):
+ def wait(self, c: Call, *, callback: _JobCallback | None=None, job: Literal['RETURN'] | bool=False,
+ timeout: float | UndefinedType=undefined) -> Any:
+ """Wait for an API call to return and return its result.
+
+ Args:
+ c: The `Call` object containing the data that was sent.
+ callback: The callback to pass to the job if `job` is set.
+ job: If set, create a `Job`. If `job='RETURN'`, return the `Job` object rather than just its result.
+ timeout: Override the default number of seconds until a timeout exception occurs.
+
+ Returns:
+ Job: If `job='RETURN'`, return the `Job` object.
+ Any: If `job=True`, return the job's result. Otherwise, return the call's result.
+
+ Raises:
+ CallTimeout: The call took longer than `timeout` seconds to return.
+ ClientException: The call ended in error and `py_exception` was not enabled for `c`.
+ BaseException: The call ended in error and `py_exception` was enabled for `c`.
+
+ """
if timeout is undefined:
timeout = self._call_timeout
try:
- if not c.returned.wait(timeout):
+ if not c.returned.wait(timeout): # type: ignore
raise CallTimeout()
if c.error:
if c.py_exception:
- if self._log_py_exceptions:
+ if self._log_py_exceptions and c.error.trace:
logger.error(c.error.trace["formatted"])
raise c.py_exception
else:
@@ -463,14 +761,34 @@ def wait(self, c, *, callback=None, job=False, timeout=undefined):
self._unregister_call(c)
@staticmethod
- def event_payload():
+ def event_payload() -> _Payload:
+ """Create an empty payload.
+
+ Returns:
+ _Payload: Empty `_Payload`.
+
+ """
return {
'callback': None,
'sync': False,
'event': Event(),
}
- def subscribe(self, name, callback, payload=None, sync=False):
+ def subscribe(self, name: str, callback: _EventCallbackProtocol, payload: _Payload | None=None,
+ sync: bool=False) -> str:
+ """Subscribe to an event by calling `core.subscribe`.
+
+ Args:
+ name: The name of the event to subscribe to.
+ callback: A procedure to call when an event is triggered.
+ payload: Dictionary containing subscription information.
+ sync: If `True`, main client thread blocks until `callback` finishes each time it is invoked. Otherwise,
+ run `callback` in the background as a daemon `Thread`.
+
+ Returns:
+ str: The `_Payload` id assigned by `core.subscribe`.
+
+ """
payload = payload or self.event_payload()
payload.update({
'callback': callback,
@@ -480,27 +798,56 @@ def subscribe(self, name, callback, payload=None, sync=False):
payload['id'] = self.call('core.subscribe', name, timeout=10)
return payload['id']
- def unsubscribe(self, id_):
+ def unsubscribe(self, id_: str):
+ """Call `core.unsubscribe` and remove all associated `_Payload`s
+
+ Args:
+ id_: `id` of the `_Payload` to remove.
+
+ """
self.call('core.unsubscribe', id_)
for k, events in list(self._event_callbacks.items()):
- events = [v for v in events if v['id'] != id_]
+ events = [v for v in events if v.get('id') != id_]
if events:
self._event_callbacks[k] = events
else:
self._event_callbacks.pop(k)
- def ping(self, timeout=10):
+ def ping(self, timeout: float=10) -> Literal['pong']:
+ """Call `core.ping` to verify connection to the server.
+
+ Args:
+ timeout: Number of seconds to allow before raising `CallTimeout`.
+
+ Raises:
+ ClientException: Connection to the server closed prematurely or the call ended in error.
+ CallTimeout: The call took longer than `timeout` seconds to return.
+
+ """
c = self.call('core.ping', background=True, register_call=True)
return self.wait(c, timeout=timeout)
def close(self):
+ """Allow one second for the `WSClient` to close."""
self._ws.close()
# Wait for websocketclient thread to close
self._closed.wait(1)
- self._ws = None
+ del self._ws
def main():
+ """The entry point for midclt. Run `midclt -h` to see usage.
+
+ Sub-commands:
+ call, ping, subscribe
+
+ Options:
+ -h, -q, -u URI, -U USERNAME, -P PASSWORD, -K API_KEY, -t TIMEOUT
+
+ Raises:
+ ValueError: Login failed (`midclt call`) or a subscription terminated with an error (`midclt subscribe`).
+
+ """
parser = argparse.ArgumentParser()
parser.add_argument('-q', '--quiet', action='store_true')
parser.add_argument('-u', '--uri')
@@ -557,7 +904,8 @@ def from_json(args):
if args.job_print == 'progressbar':
# display the job progress and status message while we wait
- def callback(progress_bar, job):
+ def pb_callback(progress_bar: ProgressBar, job: _JobDict):
+ """Update `progress_bar` with information in `job['progress']`."""
try:
progress_bar.update(
job['progress']['percent'], job['progress']['description']
@@ -566,16 +914,17 @@ def callback(progress_bar, job):
print(f'Failed to update progress bar: {e!s}', file=sys.stderr)
with ProgressBar() as progress_bar:
- kwargs.update({
- 'job': True,
- 'callback': lambda job: callback(progress_bar, job)
- })
+ kwargs.update(
+ job=True,
+ callback=lambda job: pb_callback(progress_bar, job)
+ )
rv = c.call(args.method[0], *list(from_json(args.method[1:])), **kwargs)
progress_bar.finish()
else:
lastdesc = ''
- def callback(job):
+ def callback(job: _JobDict):
+ """Print `job`'s description to `stderr` if it has changed."""
nonlocal lastdesc
desc = job['progress']['description']
if desc is not None and desc != lastdesc:
@@ -616,19 +965,20 @@ def callback(job):
event = subscribe_payload['event']
number = 0
- def cb(mtype, **message):
+ def cb(mtype: str, **message):
+ """Print the event message and unsubscribe if the maximum number of events is reached."""
nonlocal number
print(json.dumps(message))
number += 1
if args.number and number >= args.number:
event.set()
- c.subscribe(args.event, cb, subscribe_payload)
+ c.subscribe(args.event, cb, subscribe_payload) # type: ignore (`LegacyClient` does not return `_Payload`)
if not event.wait(timeout=args.timeout):
sys.exit(1)
- if subscribe_payload['error']:
+ if 'error' in subscribe_payload and subscribe_payload['error']:
raise ValueError(subscribe_payload['error'])
sys.exit(0)
diff --git a/truenas_api_client/config.py b/truenas_api_client/config.py
index bcc066a..997f24b 100644
--- a/truenas_api_client/config.py
+++ b/truenas_api_client/config.py
@@ -1,3 +1,4 @@
import os
CALL_TIMEOUT = int(os.environ.get("CALL_TIMEOUT", 60))
+"""Default number of seconds to allow an API call until timing out."""
diff --git a/truenas_api_client/ejson.py b/truenas_api_client/ejson.py
index 2044c3b..721afa4 100644
--- a/truenas_api_client/ejson.py
+++ b/truenas_api_client/ejson.py
@@ -1,9 +1,42 @@
+"""Provides wrappers of the `json` module for handling Python sets and common objects of the `datetime` module.
+
+Specifically, this module allows `datetime.date`, `datetime.time`,
+`datetime.datetime`, and `set` objects to be serialized and deserialized in
+addition to the types handled by the `json` module (those types are listed
+[here](https://docs.python.org/3.11/library/json.html#json.JSONDecoder)).
+
+Example::
+
+ >>> from ejson import dumps, loads
+ >>> obj = {'string', 4, date.today(), time(16, 22, 6)}
+ >>> serialized = dumps(obj)
+ >>> serialized
+ {"$set": [4, {"$type": "date", "$value": "2024-07-03"}, "string", {"$time": "16:22:06"}]}
+ >>> deserialized = loads(serialized)
+
+"""
import calendar
from datetime import date, datetime, time, timedelta, timezone
import json
class JSONEncoder(json.JSONEncoder):
+ """Custom JSON encoder that extends the default encoder to handle more types.
+
+ In addition to the types already supported by `json.JSONEncoder`, this
+ encoder adds support for the following types:
+
+ | Python | JSON |
+ | ----------------- | ------------------------------------------------- |
+ | datetime.date | {"$type": "date", "$value": string[YYYY-MM-DD]} |
+ | datetime.datetime | {"$date": number[Total milliseconds since EPOCH]} |
+ | datetime.time | {"$time": string[HH:MM:SS]} |
+ | set | {"$set": array[items...]} |
+
+ Note: When serializing Python sets, the order that the elements appear in
+ the JSON array is undefined.
+
+ """
def default(self, obj):
if type(obj) is date:
return {'$type': 'date', '$value': obj.isoformat()}
@@ -19,13 +52,18 @@ def default(self, obj):
return super(JSONEncoder, self).default(obj)
-def object_hook(obj):
+def object_hook(obj: dict):
+ """Used when deserializing `date`, `time`, `datetime`, and `set` objects.
+
+ Passed as a kwarg to a JSON deserialization function like `json.dump()`.
+
+ """
obj_len = len(obj)
if obj_len == 1:
if '$date' in obj:
return datetime.fromtimestamp(obj['$date'] / 1000, tz=timezone.utc) + timedelta(milliseconds=obj['$date'] % 1000)
if '$time' in obj:
- return time(*[int(i) for i in obj['$time'].split(':')])
+ return time(*[int(i) for i in obj['$time'].split(':')[:4]]) # type: ignore
if '$set' in obj:
return set(obj['$set'])
if obj_len == 2 and '$type' in obj and '$value' in obj:
@@ -35,12 +73,28 @@ def object_hook(obj):
def dump(obj, fp, **kwargs):
+ """Wraps `json.dump()` and uses the custom `JSONEncoder`.
+
+ Can serialize `date`, `time`, `datetime`, and `set` objects
+ to a file-like object.
+
+ """
return json.dump(obj, fp, cls=JSONEncoder, **kwargs)
-def dumps(obj, **kwargs):
+def dumps(obj, **kwargs) -> str:
+ """Wraps `json.dumps()` and uses the custom `JSONEncoder`.
+
+ Can serialize `date`, `time`, `datetime`, and `set` objects.
+
+ """
return json.dumps(obj, cls=JSONEncoder, **kwargs)
-def loads(obj, **kwargs):
+def loads(obj: str | bytes | bytearray, **kwargs):
+ """Wraps `json.loads()` and uses a custom `object_hook` argument.
+
+ Can deserialize `date`, `time`, `datetime`, and `set` objects.
+
+ """
return json.loads(obj, object_hook=object_hook, **kwargs)
diff --git a/truenas_api_client/exc.py b/truenas_api_client/exc.py
index d6e3cf0..1bd27b0 100644
--- a/truenas_api_client/exc.py
+++ b/truenas_api_client/exc.py
@@ -1,8 +1,11 @@
-from collections import namedtuple
+"""Defines general classes for handling exceptions which may be raised through the client."""
+
+from collections.abc import Iterable
import errno
+from .jsonrpc import ErrorExtra, TruenasTraceback
try:
- from libzfs import Error as ZFSError
+ from libzfs import Error as ZFSError # pytype: disable=import-error
except ImportError:
# this happens on our CI/CD runners as they do not install the py-libzfs module to run our api integration tests
LIBZFS = False
@@ -11,21 +14,42 @@
class ReserveFDException(Exception):
+ """A `WSClient` instance failed to bind to a reserved port."""
pass
class ErrnoMixin:
+ """Provides custom error codes and a function to get the name of an error code."""
+
ENOMETHOD = 201
+ """Service not found or method not found in service."""
ESERVICESTARTFAILURE = 202
+ """Service failed to start."""
EALERTCHECKERUNAVAILABLE = 203
+ """Alert checker unavailable."""
EREMOTENODEERROR = 204
+ """Remote node responded with an error."""
EDATASETISLOCKED = 205
+ """Locked datasets."""
EINVALIDRRDTIMESTAMP = 206
+ """Invalid RRD timestamp."""
ENOTAUTHENTICATED = 207
+ """Client not authenticated."""
ESSLCERTVERIFICATIONERROR = 208
+ """SSL certificate/host key could not be verified."""
@classmethod
- def _get_errname(cls, code):
+ def _get_errname(cls, code: int) -> str | None:
+ """Get the name of an error given its error code.
+
+ Args:
+ code: An error code for either a ZFSError or a custom error defined in this class.
+
+ Returns:
+ str: The name of the associated error.
+ None: `code` does not match any known errors.
+
+ """
if LIBZFS and 2000 <= code <= 2100:
return 'EZFS_' + ZFSError(code).name
for k, v in cls.__dict__.items():
@@ -34,7 +58,19 @@ def _get_errname(cls, code):
class ClientException(ErrnoMixin, Exception):
- def __init__(self, error, errno=None, trace=None, extra=None):
+ """Represents any exception that might arise from a `Client`."""
+
+ def __init__(self, error: str, errno: int | None=None, trace: TruenasTraceback | None=None,
+ extra: list[ErrorExtra] | None=None):
+ """Initialize `ClientException`.
+
+ Args:
+ error: An error message offering a reason for the exception.
+ errno: An error code to classify the error.
+ trace: Traceback information from the server.
+ extra: Any other errors pertaining to the exception.
+
+ """
self.errno = errno
self.error = error
self.trace = trace
@@ -44,14 +80,19 @@ def __str__(self):
return self.error
-Error = namedtuple('Error', ['attribute', 'errmsg', 'errcode'])
+class ValidationErrors(ClientException):
+ """A raisable collection of `ErrorExtra`s that indicates a validation error occurred on the server."""
+ def __init__(self, errors: Iterable[ErrorExtra]):
+ """Initialize `ValidationErrors`.
-class ValidationErrors(ClientException):
- def __init__(self, errors):
+ Args:
+ errors: List of error codes and messages from the server.
+
+ """
self.errors = []
for e in errors:
- self.errors.append(Error(e[0], e[1], e[2]))
+ self.errors.append(ErrorExtra(e[0], e[1], e[2]))
super().__init__(str(self))
@@ -64,5 +105,7 @@ def __str__(self):
class CallTimeout(ClientException):
+ """A special `ClientException` raised when a `Call` times out before it can return a result."""
def __init__(self):
+ """Initiate a `ClientException` with message `"Call timeout"`."""
super().__init__("Call timeout", errno.ETIMEDOUT)
diff --git a/truenas_api_client/jsonrpc.py b/truenas_api_client/jsonrpc.py
index c48f61f..cf11f87 100644
--- a/truenas_api_client/jsonrpc.py
+++ b/truenas_api_client/jsonrpc.py
@@ -1,4 +1,10 @@
+"""Collection of types used to reference the structure of JSONRPC-2.0 messages received from the server.
+
+https://www.jsonrpc.org/specification
+
+"""
import enum
+from typing import Any, Literal, NamedTuple, NotRequired, TypeAlias, TypedDict
class JSONRPCError(enum.Enum):
@@ -11,3 +17,96 @@ class JSONRPCError(enum.Enum):
# Custom error codes from -32000 to -32099 as allowed by the specification above
TRUENAS_TOO_MANY_CONCURRENT_CALLS = -32000
TRUENAS_CALL_ERROR = -32001
+
+
+class JobProgress(TypedDict):
+ percent: float
+ description: str
+
+
+class ErrorExtra(NamedTuple):
+ attribute: str
+ errmsg: str
+ errcode: int
+
+
+class ExcInfo(TypedDict):
+ type: str
+ extra: list[ErrorExtra] | None
+ repr: str
+
+
+class JobFields(TypedDict):
+ id: str
+ state: str
+ progress: JobProgress
+ result: Any
+ exc_info: ExcInfo
+ error: str
+ exception: str
+
+
+class CollectionUpdateParams(TypedDict):
+ msg: str
+ collection: str
+ id: NotRequired[Any]
+ fields: NotRequired[JobFields]
+ extra: NotRequired[dict]
+
+
+class CollectionUpdate(TypedDict):
+ jsonrpc: Literal['2.0']
+ method: Literal['collection_update']
+ params: CollectionUpdateParams
+
+
+TruenasTraceback = TypedDict('TruenasTraceback', {
+ 'class': str,
+ 'frames': NotRequired[list[dict[str, Any]]],
+ 'formatted': str,
+ 'repr': str,
+})
+# Has to be defined this way because `class` is a keyword.
+
+
+class TruenasError(TypedDict):
+ error: int
+ errname: str
+ reason: str
+ trace: TruenasTraceback | None
+ extra: list[ErrorExtra]
+ py_exception: NotRequired[str]
+
+
+class NotifyUnsubscribedParams(TypedDict):
+ collection: str
+ error: TruenasError
+
+
+class NotifyUnsubscribed(TypedDict):
+ jsonrpc: Literal['2.0']
+ method: Literal['notify_unsubscribed']
+ params: NotifyUnsubscribedParams
+
+
+class SuccessResponse(TypedDict):
+ jsonrpc: Literal['2.0']
+ result: Any
+ id: str
+
+
+class ErrorObj(TypedDict):
+ code: int
+ message: str | None
+ data: TruenasError
+
+
+class ErrorResponse(TypedDict):
+ jsonrpc: Literal['2.0']
+ error: ErrorObj
+ id: str
+
+
+Notification: TypeAlias = CollectionUpdate | NotifyUnsubscribed
+Response: TypeAlias = SuccessResponse | ErrorResponse
+JSONRPCMessage: TypeAlias = Notification | Response
diff --git a/truenas_api_client/legacy.py b/truenas_api_client/legacy.py
index 1377590..60ab72b 100644
--- a/truenas_api_client/legacy.py
+++ b/truenas_api_client/legacy.py
@@ -1,3 +1,5 @@
+"""The websocket client prior to implementing JSONRPC-2.0 protocol. Used for backwards compatibility."""
+
from base64 import b64decode
from collections import defaultdict
import errno
@@ -20,7 +22,7 @@
from . import ejson as json
from .config import CALL_TIMEOUT
from .exc import ReserveFDException, ClientException, ValidationErrors, CallTimeout
-from .utils import MIDDLEWARE_RUN_DIR, undefined
+from .utils import MIDDLEWARE_RUN_DIR, undefined, UndefinedType
logger = logging.getLogger(__name__)
@@ -190,7 +192,7 @@ def result(self):
class LegacyClient:
def __init__(self, uri=None, reserved_ports=False, py_exceptions=False, log_py_exceptions=False,
- call_timeout=undefined, verify_ssl=True):
+ call_timeout: float | UndefinedType=undefined, verify_ssl=True):
"""
Arguments:
:reserved_ports(bool): should the local socket used a reserved port
diff --git a/truenas_api_client/utils.py b/truenas_api_client/utils.py
index 55e18c6..3793bcf 100644
--- a/truenas_api_client/utils.py
+++ b/truenas_api_client/utils.py
@@ -1,18 +1,52 @@
+"""Utility classes for use in the TrueNAS API client.
+
+Includes `Struct` for creating regular objects out of `Mapping`s with string
+keys and `ProgressBar` for displaying the progress of a task in the CLI.
+
+Attributes:
+ MIDDLEWARE_RUN_DIR: Directory containing the middlewared Unix domain socket.
+ undefined: A dummy object similar in purpose to `None` that indicates an unset variable.
+
+"""
import sys
+from typing import Any, final, Mapping
+
MIDDLEWARE_RUN_DIR = '/var/run/middleware'
-undefined = object()
+
+@final
+class UndefinedType:
+ def __new__(cls):
+ if not hasattr(cls, '_instance'):
+ cls._instance = super().__new__(cls)
+ return cls._instance
+undefined = UndefinedType()
class Struct:
- """
- Simpler wrapper to access using object attributes instead of keys.
+ """Simpler wrapper to access using object attributes instead of keys.
+
This is meant for compatibility when switch scripts to use middleware
client instead of django directly.
+
+ Example::
+
+ >>> d = {'a':1, 'b':'2', 'c': [3, '4', {'d':5}], 'e':{'f':{'g':6}}}
+ >>> s = Struct(d)
+ >>> s.c
+ [3, '4', {'d':5}]
+ >>> s.e.f.g
+ 6
+
"""
+ def __init__(self, mapping: Mapping[str, Any]):
+ """Initialize a `Struct` with a `Mapping`.
- def __init__(self, mapping):
+ Args:
+ mapping: Contains string keys that will become the `Struct`'s attribute names.
+
+ """
for k, v in mapping.items():
if isinstance(v, dict):
setattr(self, k, Struct(v))
@@ -21,6 +55,26 @@ def __init__(self, mapping):
class ProgressBar(object):
+ """A simple text-based progress bar that writes to `sys.stderr`.
+
+ Status: (message)
+ Total Progress: [#####################___________________] 53.00%
+
+ Example:
+ ```
+ with ProgressBar() as pb:
+ for step in range(1, 101):
+ pb.update(step)
+ ```
+
+ Attributes:
+ message: String to display next to "Status".
+ percentage: A float from `0.0` to `100.0` representing the total progress.
+ write_stream: This is `sys.stderr` by default but can be any `TextIO`.
+ used_flag: Indicates whether `update()` has been called.
+ extra: A string or printable object to display after the status message.
+
+ """
def __init__(self):
self.message = None
self.percentage = 0
@@ -32,6 +86,11 @@ def __enter__(self):
return self
def draw(self):
+ """Erase the previous progress bar and draw an updated one.
+
+ If `self.extra` is set, will display "Status: (message) Extra: (extra)".
+
+ """
progress_width = 40
filled_width = int(self.percentage * progress_width)
self.write_stream.write('\033[2K\033[A\033[2K\r')
@@ -47,7 +106,14 @@ def draw(self):
)
self.write_stream.flush()
- def update(self, percentage=None, message=None):
+ def update(self, percentage: float | None=None, message: str | None=None):
+ """Update the progress bar with a new percentage and/or message, redrawing it.
+
+ Args:
+ percentage: The new percentage to display. A value of `100.0` represents full.
+ message: The "Status" message to display above the progress bar.
+
+ """
if not self.used_flag:
self.write_stream.write('\n')
self.used_flag = True
@@ -58,6 +124,7 @@ def update(self, percentage=None, message=None):
self.draw()
def finish(self):
+ """Fill the progress bar to 100%."""
self.percentage = 1
def __exit__(self, typ, value, traceback):