diff --git a/fedn/fedn/network/api/client.py b/fedn/fedn/network/api/client.py index 16429825f..00532d2da 100644 --- a/fedn/fedn/network/api/client.py +++ b/fedn/fedn/network/api/client.py @@ -1,4 +1,5 @@ import json +import os import requests @@ -18,18 +19,31 @@ class APIClient: :type verify: bool """ - def __init__(self, host, port, secure=False, verify=False): + def __init__(self, host, port=None, secure=False, verify=False, token=None, auth_scheme=None): self.host = host self.port = port self.secure = secure self.verify = verify + self.header = {} + # Auth scheme passed as argument overrides environment variable. + # "Token" is the default auth scheme. + if not auth_scheme: + auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Token") + # Override potential env variable if token is passed as argument. + if not token: + token = os.environ.get("FEDN_AUTH_TOKEN", False) + + if token: + self.header = {"Authorization": f"{auth_scheme} {token}"} def _get_url(self, endpoint): if self.secure: protocol = 'https' else: protocol = 'http' - return f'{protocol}://{self.host}:{self.port}/{endpoint}' + if self.port: + return f'{protocol}://{self.host}:{self.port}/{endpoint}' + return f'{protocol}://{self.host}/{endpoint}' def get_model_trail(self): """ Get the model trail. @@ -37,7 +51,7 @@ def get_model_trail(self): :return: The model trail as dict including commit timestamp. :rtype: dict """ - response = requests.get(self._get_url('get_model_trail'), verify=self.verify) + response = requests.get(self._get_url('get_model_trail'), verify=self.verify, headers=self.header) return response.json() def list_models(self, session_id=None): @@ -46,7 +60,7 @@ def list_models(self, session_id=None): :return: All models. :rtype: dict """ - response = requests.get(self._get_url('list_models'), params={'session_id': session_id}, verify=self.verify) + response = requests.get(self._get_url('list_models'), params={'session_id': session_id}, verify=self.verify, headers=self.header) return response.json() def list_clients(self): @@ -55,7 +69,7 @@ def list_clients(self): return: All clients. rtype: dict """ - response = requests.get(self._get_url('list_clients')) + response = requests.get(self._get_url('list_clients'), verify=self.verify, headers=self.header) return response.json() def get_active_clients(self, combiner_id): @@ -66,7 +80,7 @@ def get_active_clients(self, combiner_id): :return: All active clients. :rtype: dict """ - response = requests.get(self._get_url('get_active_clients'), params={'combiner': combiner_id}, verify=self.verify) + response = requests.get(self._get_url('get_active_clients'), params={'combiner': combiner_id}, verify=self.verify, headers=self.header) return response.json() def get_client_config(self, checksum=True): @@ -78,7 +92,7 @@ def get_client_config(self, checksum=True): :return: The client configuration. :rtype: dict """ - response = requests.get(self._get_url('get_client_config'), params={'checksum': checksum}, verify=self.verify) + response = requests.get(self._get_url('get_client_config'), params={'checksum': checksum}, verify=self.verify, headers=self.header) return response.json() def list_combiners(self): @@ -87,7 +101,7 @@ def list_combiners(self): :return: All combiners with info. :rtype: dict """ - response = requests.get(self._get_url('list_combiners')) + response = requests.get(self._get_url('list_combiners'), verify=self.verify, headers=self.header) return response.json() def get_combiner(self, combiner_id): @@ -98,7 +112,7 @@ def get_combiner(self, combiner_id): :return: The combiner info. :rtype: dict """ - response = requests.get(self._get_url(f'get_combiner?combiner={combiner_id}'), verify=self.verify) + response = requests.get(self._get_url(f'get_combiner?combiner={combiner_id}'), verify=self.verify, headers=self.header) return response.json() def list_rounds(self): @@ -107,7 +121,7 @@ def list_rounds(self): :return: All rounds with config and metrics. :rtype: dict """ - response = requests.get(self._get_url('list_rounds')) + response = requests.get(self._get_url('list_rounds'), verify=self.verify, headers=self.header) return response.json() def get_round(self, round_id): @@ -118,7 +132,7 @@ def get_round(self, round_id): :return: The round config and metrics. :rtype: dict """ - response = requests.get(self._get_url(f'get_round?round_id={round_id}'), verify=self.verify) + response = requests.get(self._get_url(f'get_round?round_id={round_id}'), verify=self.verify, headers=self.header) return response.json() def start_session(self, session_id=None, aggregator='fedavg', model_id=None, round_timeout=180, rounds=5, round_buffer_size=-1, delete_models=True, @@ -162,7 +176,7 @@ def start_session(self, session_id=None, aggregator='fedavg', model_id=None, rou 'helper': helper, 'min_clients': min_clients, 'requested_clients': requested_clients - }, verify=self.verify + }, verify=self.verify, headers=self.header ) return response.json() @@ -172,7 +186,7 @@ def list_sessions(self): :return: All sessions in dict. :rtype: dict """ - response = requests.get(self._get_url('list_sessions'), verify=self.verify) + response = requests.get(self._get_url('list_sessions'), verify=self.verify, headers=self.header) return response.json() def get_session(self, session_id): @@ -183,7 +197,7 @@ def get_session(self, session_id): :return: The session as a json object. :rtype: dict """ - response = requests.get(self._get_url(f'get_session?session_id={session_id}'), self.verify) + response = requests.get(self._get_url(f'get_session?session_id={session_id}'), self.verify, headers=self.header) return response.json() def session_is_finished(self, session_id): @@ -218,7 +232,7 @@ def set_package(self, path: str, helper: str, name: str = None, description: str """ with open(path, 'rb') as file: response = requests.post(self._get_url('set_package'), files={'file': file}, data={ - 'helper': helper, 'name': name, 'description': description}, verify=self.verify) + 'helper': helper, 'name': name, 'description': description}, verify=self.verify, headers=self.header) return response.json() def get_package(self): @@ -227,7 +241,7 @@ def get_package(self): :return: The compute package with info. :rtype: dict """ - response = requests.get(self._get_url('get_package'), verify=self.verify) + response = requests.get(self._get_url('get_package'), verify=self.verify, headers=self.header) return response.json() def list_compute_packages(self): @@ -236,7 +250,7 @@ def list_compute_packages(self): :return: All compute packages with info. :rtype: dict """ - response = requests.get(self._get_url('list_compute_packages'), verify=self.verify) + response = requests.get(self._get_url('list_compute_packages'), verify=self.verify, headers=self.header) return response.json() def download_package(self, path): @@ -247,7 +261,7 @@ def download_package(self, path): :return: Message with success or failure. :rtype: dict """ - response = requests.get(self._get_url('download_package'), verify=self.verify) + response = requests.get(self._get_url('download_package'), verify=self.verify, headers=self.header) if response.status_code == 200: with open(path, 'wb') as file: file.write(response.content) @@ -261,7 +275,7 @@ def get_package_checksum(self): :return: The checksum. :rtype: dict """ - response = requests.get(self._get_url('get_package_checksum'), verify=self.verify) + response = requests.get(self._get_url('get_package_checksum'), verify=self.verify, headers=self.header) return response.json() def get_latest_model(self): @@ -270,7 +284,7 @@ def get_latest_model(self): :return: The latest model id. :rtype: dict """ - response = requests.get(self._get_url('get_latest_model'), verify=self.verify) + response = requests.get(self._get_url('get_latest_model'), verify=self.verify, headers=self.header) return response.json() def get_initial_model(self): @@ -279,7 +293,7 @@ def get_initial_model(self): :return: The initial model id. :rtype: dict """ - response = requests.get(self._get_url('get_initial_model'), verify=self.verify) + response = requests.get(self._get_url('get_initial_model'), verify=self.verify, headers=self.header) return response.json() def set_initial_model(self, path): @@ -291,7 +305,7 @@ def set_initial_model(self, path): :rtype: dict """ with open(path, 'rb') as file: - response = requests.post(self._get_url('set_initial_model'), files={'file': file}, verify=self.verify) + response = requests.post(self._get_url('set_initial_model'), files={'file': file}, verify=self.verify, headers=self.header) return response.json() def get_controller_status(self): @@ -300,7 +314,7 @@ def get_controller_status(self): :return: The status of the controller. :rtype: dict """ - response = requests.get(self._get_url('get_controller_status'), verify=self.verify) + response = requests.get(self._get_url('get_controller_status'), verify=self.verify, headers=self.header) return response.json() def get_events(self, **kwargs): @@ -309,7 +323,7 @@ def get_events(self, **kwargs): :return: The events in dict :rtype: dict """ - response = requests.get(self._get_url('get_events'), params=kwargs, verify=self.verify) + response = requests.get(self._get_url('get_events'), params=kwargs, verify=self.verify, headers=self.header) return response.json() def list_validations(self, **kwargs): @@ -318,5 +332,5 @@ def list_validations(self, **kwargs): :return: All validations in dict. :rtype: dict """ - response = requests.get(self._get_url('list_validations'), params=kwargs, verify=self.verify) + response = requests.get(self._get_url('list_validations'), params=kwargs, verify=self.verify, headers=self.header) return response.json()