diff --git a/src/ca_pwt/commands.py b/src/ca_pwt/commands.py index 98a583a..6cbc545 100644 --- a/src/ca_pwt/commands.py +++ b/src/ca_pwt/commands.py @@ -65,7 +65,7 @@ "--duplicate_action", type=click.Choice([action.value for action in DuplicateActionEnum], case_sensitive=True), help="The action to take when a duplicate is found (default is ignore). ", - default=DuplicateActionEnum.ignore.value, + default=DuplicateActionEnum.IGNORE.value, ) @@ -369,7 +369,7 @@ def import_policies_cmd( ctx: click.Context, input_file: str, access_token: str | None = None, - duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore, + duplicate_action: DuplicateActionEnum = DuplicateActionEnum.IGNORE, ): """Imports CA policies from a file""" try: @@ -443,7 +443,7 @@ def import_groups_cmd( ctx: click.Context, input_file: str, access_token: str | None = None, - duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore, + duplicate_action: DuplicateActionEnum = DuplicateActionEnum.IGNORE, ): """Imports groups from a file""" try: diff --git a/src/ca_pwt/groups.py b/src/ca_pwt/groups.py index d60acca..e5733bb 100644 --- a/src/ca_pwt/groups.py +++ b/src/ca_pwt/groups.py @@ -95,7 +95,7 @@ def get_groups_by_ids(access_token: str, group_ids: list[str], *, ignore_not_fou def import_groups( - access_token: str, groups: list[dict], duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore + access_token: str, groups: list[dict], duplicate_action: DuplicateActionEnum = DuplicateActionEnum.IGNORE ) -> list[tuple[str, str]]: """Imports groups from the specified dictionary. Returns a list of tuples with the group id and name of the imported groups. diff --git a/src/ca_pwt/helpers/graph_api.py b/src/ca_pwt/helpers/graph_api.py index 82f8643..b80d018 100644 --- a/src/ca_pwt/helpers/graph_api.py +++ b/src/ca_pwt/helpers/graph_api.py @@ -1,17 +1,18 @@ import requests import logging -from ca_pwt.helpers.utils import assert_condition from abc import ABC, abstractmethod from enum import StrEnum +from typing import Any, Callable +from ca_pwt.helpers.utils import assert_condition _REQUEST_TIMEOUT = 500 class DuplicateActionEnum(StrEnum): - ignore = "ignore" - replace = "replace" - duplicate = "duplicate" - fail = "fail" + IGNORE = "ignore" + REPLACE = "replace" + DUPLICATE = "duplicate" + FAIL = "fail" class APIResponse: @@ -27,7 +28,7 @@ def __init__(self, request_response: requests.Response, expected_status_code: in the success property will be set to True """ self.status_code = request_response.status_code - self.response: requests.Response | str = request_response + self.response: requests.Response | str | dict[str, Any] = request_response self.expected_status_code = expected_status_code self.success = self.status_code == self.expected_status_code if self._logger.isEnabledFor(logging.DEBUG): @@ -68,11 +69,14 @@ def _get_entity_path(self) -> str: """Returns the path to the entity in the Microsoft Graph API""" pass + + def _request_get(self, url: str) -> APIResponse: """Sends a GET request to the API""" self._logger.debug(f"GET {url}") return APIResponse( - requests.get(url, headers=self.request_headers, timeout=_REQUEST_TIMEOUT), expected_status_code=200 + requests.get(url, headers=self.request_headers, timeout=_REQUEST_TIMEOUT), + expected_status_code=200 ) def _request_post(self, url: str, entity: dict) -> APIResponse: @@ -155,7 +159,7 @@ def create(self, entity: dict) -> APIResponse: return self._request_post(self.entity_url, entity) def create_checking_duplicates( - self, entity: dict, odata_filter: str, duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore + self, entity: dict, odata_filter: str, duplicate_action: DuplicateActionEnum = DuplicateActionEnum.IGNORE ) -> APIResponse: """Creates an entity checking for duplicates first and taking the specified action if a duplicate is found A duplicate is determined by the odata_filter parameter, getting the top entity with the specified filter""" @@ -163,15 +167,15 @@ def create_checking_duplicates( assert_condition(odata_filter, "odata_filter cannot be None") # if duplicate_action is not duplicate, check if the entity already exists - if duplicate_action != DuplicateActionEnum.duplicate: + if duplicate_action != DuplicateActionEnum.DUPLICATE: existing_entity = self.get_top_entity(odata_filter) if existing_entity.success: - if duplicate_action == DuplicateActionEnum.ignore: + if duplicate_action == DuplicateActionEnum.IGNORE: self._logger.warning( f"Entity {self._get_entity_path()} with filter {odata_filter} already exists. Skipping..." ) return existing_entity - elif duplicate_action == DuplicateActionEnum.replace: + elif duplicate_action == DuplicateActionEnum.REPLACE: existing_entity_id = existing_entity.json()["id"] self._logger.warning(f"Replacing entity {self._get_entity_path()} with id {existing_entity_id}...") response = self.update(existing_entity_id, entity) @@ -180,7 +184,7 @@ def create_checking_duplicates( # we need to return the existing_entity_id in the response body response.response = {"id": existing_entity_id} return response - elif duplicate_action == DuplicateActionEnum.fail: + elif duplicate_action == DuplicateActionEnum.FAIL: msg = f"Entity {self._get_entity_path()} with filter {odata_filter} already exists." raise ValueError(msg) else: diff --git a/src/ca_pwt/policies.py b/src/ca_pwt/policies.py index d6b6205..11baa2c 100644 --- a/src/ca_pwt/policies.py +++ b/src/ca_pwt/policies.py @@ -70,7 +70,7 @@ def export_policies(access_token: str, odata_filter: str | None = None) -> list[ def import_policies( access_token: str, policies: list[dict], - duplicate_action: DuplicateActionEnum = DuplicateActionEnum.ignore, + duplicate_action: DuplicateActionEnum = DuplicateActionEnum.IGNORE, ) -> list[tuple[str, str]]: """Imports the specified policies. If allow_duplicates is False, it will skip policies that already exist (using the display name as