diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..bac7ec4 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,48 @@ +name: tests + +# This action is triggered: +# 1. when someone creates a pull request (to any branch) +# 2. when changes are merged into the main branch (via a pull request) +on: + push: + branches: [ main ] + pull_request: + branches: [ '*' ] + +jobs: + test: + runs-on: ${{ matrix.os }} + container: ${{ matrix.container }} + + # we support Linux and macOS + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + + # Steps for running tests and analysis. + steps: + - name: Checking out repository (${{ matrix.os }}) + uses: actions/checkout@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + submodules: recursive + + - name: Setting up Python 3.12 (${{ matrix.os }}) + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Installing dtspy dependencies (${{ matrix.os }}) + run: python3 -m pip install -r requirements.txt + + - name: Running tests (${{ matrix.os }}) + run: coverage run -m unittest discover + env: + DTS_KBASE_DEV_TOKEN: ${{ secrets.DTS_KBASE_DEV_TOKEN }} + + # add this when ready + #- if: ${{ matrix.os == 'ubuntu-latest' }} + # name: Uploading coverage report to codecov.io + # uses: codecov/codecov-action@v4.0.1 + # with: + # token: ${{ secrets.CODECOV_TOKEN }} diff --git a/README.md b/README.md index 88a7374..a77e47c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ # dtspy + +![Tests](https://github.com/kbase/dtspy/actions/workflows/tests.yml/badge.svg) + Python client for the Data Transfer Service diff --git a/dts/client.py b/dts/client.py index f38bddf..7a9dcd4 100644 --- a/dts/client.py +++ b/dts/client.py @@ -1,6 +1,5 @@ import base64 from frictionless.resources import JsonResource -import io import requests from requests.auth import AuthBase import logging @@ -21,12 +20,13 @@ class KBaseAuth(AuthBase): def __init__(self, api_key): self.api_key = api_key - def __call__(self, r): - token = base64.b64encode(bytes(self.api_key + '\n', 'utf-8')) - r.headers['Authorization'] = f'Bearer {token.decode('utf-8')}' - return r + def __call__(self, request): + b64_token = base64.b64encode(bytes(self.api_key + '\n', 'utf-8')) + token = b64_token.decode('utf-8') + request.headers['Authorization'] = f'Bearer {token}' + return request -class Client(object): +class Client: """`Client`: A client for performing file transfers with the Data Transfer System""" def __init__(self, api_key = None, @@ -39,6 +39,8 @@ def __init__(self, self.connect(server = server, port = port, api_key = api_key) else: self.uri = None + self.name = None + self.version = None def connect(self, api_key = None, @@ -48,11 +50,11 @@ def connect(self, * Connects the client to the given DTS `server` via the given `port` using the given (unencoded) `api_key`.""" - if type(api_key) != str: + if not isinstance(api_key, str): raise TypeError('api_key must be an unencoded API key.') - if type(server) != str: + if not isinstance(server, str): raise TypeError('server must be a URI for a DTS server.') - if port and type(port) != int: + if port and not isinstance(port, int): raise TypeError('port must be an integer') self.auth = KBaseAuth(api_key) if port: @@ -91,12 +93,11 @@ def databases(self): except Exception as err: logger.error(f'Other error occurred: {err}') return None - else: - results = response.json() - return [Database(id = r['id'], - name = r['name'], - organization = r['organization'], - url = r['url']) for r in results] + results = response.json() + return [Database(id = r['id'], + name = r['name'], + organization = r['organization'], + url = r['url']) for r in results] def search(self, database = None, @@ -104,92 +105,121 @@ def search(self, status = None, offset = 0, limit = None, + specific = None, ): """ `client.search(database = None, query = None, status = None, offset = 0, - limit = None) -> `list` of `frictionless.DataResource` objects + limit = None, + specific = None) -> `list` of `frictionless.DataResource` objects * Performs a synchronous search of the database with the given name using the given query string. Optional arguments: + * query: a search string that is directly interpreted by the database * status: filters for files based on their status: * `"staged"` means "search only for files that are already in the source database staging area" - * `"archived"` means "search only for files that are archived and not staged" + * `"unstaged"` means "search only for files that are not staged" * offset: a 0-based index from which to start retrieving results (default: 0) * limit: if given, the maximum number of results to retrieve + * specific: a dictionary mapping database-specific search parameters to their values """ + params = { + 'database': database, + 'query': query, + } if not self.uri: raise RuntimeError('dts.Client: not connected.') - if type(database) != str: + if query: + if not isinstance(query, str): + # we also accept numeric values + if isinstance(query, int) or isinstance(query, float): + query = str(query) + else: + raise RuntimeError('search: query must be a string or a number.') + else: + raise RuntimeError('search: missing query.') + if not isinstance(database, str): raise TypeError('search: database must be a string.') - if type(offset) != int or offset < 0: - raise TypeError('search: invalid offset: %s.'%offset) + if status: + if status not in ['staged', 'unstaged']: + raise TypeError(f'search: invalid status: {status}.') + params['status'] = status + if offset: + if not str(offset).isdigit(): + raise TypeError('search: offset must be numeric') + if int(offset) < 0: + raise ValueError(f'search: offset must be non-negative') + params['offset'] = int(offset) if limit: - if type(limit) != int: - raise TypeError('search: limit must be an int.') - elif limit < 1: - raise TypeError(f'search: invalid number of retrieved results: {N}') + if not str(limit).isdigit(): + raise TypeError('search: limit must be numeric') + if int(limit) < 1: + raise ValueError(f'search: limit must be greater than 1') + params['limit'] = int(limit) + if specific: + if not isinstance(specific, dict): + raise TypeError('search: specific must be a dict.') + params['specific'] = specific try: - params = { - 'database': database, - 'query': query, - 'status': status, - 'offset': offset, - 'limit': limit, - } - response = requests.get(url=f'{self.uri}/files', params=params, auth=self.auth) + response = requests.post(url=f'{self.uri}/files', + json=params, + auth=self.auth) response.raise_for_status() - except HTTPError as http_err: - logger.error(f'HTTP error occurred: {http_err}') + except (HTTPError, requests.exceptions.HTTPError) as err: + logger.error(f'HTTP error occurred: {err}') return None except Exception as err: logger.error(f'Other error occurred: {err}') return None - else: - return [JsonResource(r) for r in response.json()['resources']] + return [JsonResource(r) for r in response.json()['resources']] def transfer(self, file_ids = None, source = None, - destination = None): + destination = None, + timeout = None): """ `client.transfer(file_ids = None, source = None, - destination = None) -> UUID + destination = None, + timeout = None) -> UUID * Submits a request to transfer files from a source to a destination database. the files in the source database are identified by a list of string file_ids. """ if not self.uri: raise RuntimeError('dts.Client: not connected.') - if type(source) != str: + if not isinstance(source, str): raise TypeError('transfer: source database name must be a string.') - if type(destination) != str: + if not isinstance(destination, str): raise TypeError('transfer: destination database name must be a string.') - if type(file_ids) != list: - raise TypeError('batch: sequences must be a list of string file IDs.') + if not isinstance(file_ids, list): + raise TypeError('transfer: file_ids must be a list of string file IDs.') + if timeout and not isinstance(timeout, int) and not isinstance(timeout, float): + raise TypeError('transfer: timeout must be a number of seconds.') try: - response = requests.post(f'{self.uri}/transfers', - data={ - source: source, - destination: destination, - file_ids: file_ids, - }) + response = requests.post(url=f'{self.uri}/transfers', + json={ + 'source': source, + 'destination': destination, + 'file_ids': file_ids, + }, + auth=self.auth, + timeout=timeout) response.raise_for_status() - except HTTPError as http_err: - logger.error(f'HTTP error occurred: {http_err}') + except (HTTPError, requests.exceptions.HTTPError) as err: + logger.error(f'HTTP error occurred: {err}') return None except Exception as err: logger.error(f'Other error occurred: {err}') return None - else: - return uuid.UUID(response.json()["id"]) + return uuid.UUID(response.json()["id"]) - def transferStatus(self, id): - """`client.transferStatus(id)` -> TransferStatus + def transfer_status(self, id): + """`client.transfer_status(id)` -> TransferStatus * Returns status information for the transfer with the given identifier. Possible statuses are: @@ -205,43 +235,43 @@ def transferStatus(self, id): if not self.uri: raise RuntimeError('dts.Client: not connected.') try: - response = requests.get(f'{self.uri}/transfers/{str(id)}') + response = requests.get(url=f'{self.uri}/transfers/{id}', + auth=self.auth) response.raise_for_status() - except HTTPError as http_err: + except (HTTPError, requests.exceptions.HTTPError) as err: logger.error(f'HTTP error occurred: {http_err}') return None except Exception as err: logger.error(f'Other error occurred: {err}') return None - else: - results = response.json() - return TransferStatus( - id = response['id'], - status = response['status'], - message = response['message'] if 'message' in response else None, - num_files = response['num_files'], - num_files_transferred = response['num_files_transferred'], - ) - - def deleteTransfer(self, id): + results = response.json() + return TransferStatus( + id = results.get('id'), + status = results.get('status'), + message = results.get('message'), + num_files = results.get('num_files'), + num_files_transferred = results.get('num_files_transferred'), + ) + + def cancel_transfer(self, id): """ -`client.deleteTransfer(id) -> None +`client.cancel_transfer(id) -> None * Deletes a file transfer, canceling """ if not self.uri: raise RuntimeError('dts.Client: not connected.') try: - response = requests.delete(f'{self.uri}/transfers/{str(id)}') + response = requests.delete(url=f'{self.uri}/transfers/{id}', + auth=self.auth) response.raise_for_status() - except HTTPError as http_err: + except (HTTPError, requests.exceptions.HTTPError) as err: logger.error(f'HTTP error occurred: {http_err}') return None except Exception as err: logger.error(f'Other error occurred: {err}') return None - else: - return None + return None def __repr__(self): if self.uri: diff --git a/requirements.txt b/requirements.txt index f618595..bc090c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ certifi==2024.7.4 chardet==5.2.0 charset-normalizer==3.3.2 click==8.1.7 +coverage==7.6.0 frictionless==5.17.0 humanize==4.9.0 idna==3.7 diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_client.py b/test/test_client.py new file mode 100644 index 0000000..55987e3 --- /dev/null +++ b/test/test_client.py @@ -0,0 +1,65 @@ +# unit tests for the dts.client package + +import dts +import os +import unittest + +class TestClient(unittest.TestCase): + """Unit tests for dts.client.Client""" + + def setUp(self): + self.token = os.getenv('DTS_KBASE_DEV_TOKEN') + if not self.token: + raise ValueError('Environment variable DTS_KBASE_DEV_TOKEN must be set!') + self.server = "https://lb-dts.staging.kbase.us" + + def test_ctor(self): + client = dts.Client(api_key = self.token, server = self.server) + self.assertTrue(client.uri) + self.assertTrue(client.name) + self.assertTrue(client.version) + + def test_connect(self): + client = dts.Client() + self.assertFalse(client.uri) + self.assertFalse(client.name) + self.assertFalse(client.version) + client.connect(api_key = self.token, server = self.server) + self.assertTrue(client.uri) + self.assertTrue(client.name) + self.assertTrue(client.version) + client.disconnect() + self.assertFalse(client.uri) + self.assertFalse(client.name) + self.assertFalse(client.version) + + def test_databases(self): + client = dts.Client(api_key = self.token, server = self.server) + dbs = client.databases() + self.assertTrue(isinstance(dbs, list)) + self.assertEqual(2, len(dbs)) + self.assertTrue(any([db.id == 'jdp' for db in dbs])) + self.assertTrue(any([db.id == 'kbase' for db in dbs])) + + def test_basic_jdp_search(self): + client = dts.Client(api_key = self.token, server = self.server) + results = client.search(database = 'jdp', query = '3300047546') + self.assertTrue(isinstance(results, list)) + self.assertTrue(len(results) > 0) + self.assertTrue(all([result.to_dict()['id'].startswith('JDP:') + for result in results])) + + def test_jdp_search_for_taxon_oid(self): + client = dts.Client(api_key = self.token, server = self.server) + taxon_oid = '2582580701' + params = {'f': 'img_taxon_oid', 'extra': 'img_taxon_oid'} + results = client.search(database = 'jdp', + query = taxon_oid, + specific = params) + self.assertTrue(isinstance(results, list)) + self.assertTrue(len(results) > 0) + self.assertTrue(any([result.to_dict()['extra']['img_taxon_oid'] == int(taxon_oid) + for result in results])) + +if __name__ == '__main__': + unittest.main()