Skip to content

Commit

Permalink
chore: several refactors to fix bugs and lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
0x6f677548 committed Oct 25, 2023
1 parent 4f28955 commit 602071a
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 134 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ jobs:
ruff src --ignore=E501,I001
- name: Lint with mypy
run: |
mypy src
mypy src --ignore-mission-imports
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ dependencies = [
"ruff>=0.0.243",
]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/ca_pwt tests}"
typing = "mypy --ignore-missing-imports --install-types --non-interactive {args:src/ca_pwt tests}"
style = [
"ruff {args:.}",
"black --check --diff {args:.}",
Expand Down
4 changes: 2 additions & 2 deletions src/ca_pwt/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def _exit_with_exception(exception: Exception, exit_code: int = 1, fg: str = "re
def _get_from_ctx_if_none(
ctx: click.Context,
ctx_key: str,
value: str | None = None,
invoke_func: Callable[..., str] | None = None,
value: str | None,
invoke_func: Callable[..., str],
**kwargs: Any,
) -> str:
"""Get a value from the context if it is None,
Expand Down
4 changes: 2 additions & 2 deletions src/ca_pwt/directory_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def get_by_display_name(self, display_name: str) -> APIResponse:

if response.success:
# move the value property to the response property
response.response = response.json()["value"]
for entity in response.response:
results = response.json()["value"]
for entity in results:
if entity["displayName"] == display_name:
response.response = entity
return response
Expand Down
19 changes: 9 additions & 10 deletions src/ca_pwt/groups.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import requests
import logging
from ca_pwt.helpers.graph_api import APIResponse, EntityAPI, _REQUEST_TIMEOUT
from ca_pwt.helpers.dict import cleanup_odata_dict, remove_element_from_dict
from ca_pwt.helpers.utils import assert_condition
from ca_pwt.helpers.utils import assert_condition, cleanup_odata_dict, remove_element_from_dict, ensure_list

_logger = logging.getLogger(__name__)

Expand All @@ -29,7 +28,7 @@ def add_user_to_group(self, user_id: str, group_id: str) -> APIResponse:
)


def load_groups(input_file: str) -> dict:
def load_groups(input_file: str) -> list[dict]:
"""Loads groups from the specified file.
It also cleans up the dictionary to remove unnecessary elements."""
import json
Expand All @@ -38,10 +37,11 @@ def load_groups(input_file: str) -> dict:
_logger.info(f"Reading groups from file {input_file}...")

groups = json.load(f)
return cleanup_odata_dict(groups, ensure_list=True)
groups = cleanup_odata_dict(groups)
return ensure_list(groups)


def save_groups(groups: dict, output_file: str):
def save_groups(groups: list[dict], output_file: str):
"""Saves groups to the specified file."""
import json

Expand All @@ -50,8 +50,8 @@ def save_groups(groups: dict, output_file: str):
f.write(json.dumps(groups, indent=4))


def cleanup_groups(source: dict) -> dict:
"""Cleans up the groups dictionary for import by
def cleanup_groups(source: list[dict]) -> list[dict]:
"""Cleans up the groups list and dictionary for import by
removing disallowed elements while importing. (e.g. id, createdDateTime,
modifiedDateTime, templateId, deletedDateTime, renewedDateTime)"""
_logger.info("Cleaning up groups...")
Expand Down Expand Up @@ -89,13 +89,12 @@ def get_groups_by_ids(access_token: str, group_ids: list[str], *, ignore_not_fou
continue
else:
group_response.assert_success()
group_detail = group_response.json()
group_detail = cleanup_odata_dict(group_detail, ensure_list=False)
group_detail = cleanup_odata_dict(group_response.json())
result.append(group_detail)
return result


def import_groups(access_token: str, groups: dict, *, allow_duplicates: bool = False) -> list[tuple[str, str]]:
def import_groups(access_token: str, groups: list[dict], *, allow_duplicates: bool = False) -> list[tuple[str, str]]:
"""Imports groups from the specified dictionary.
Returns a list of tuples with the group id and name of the imported groups.
"""
Expand Down
92 changes: 0 additions & 92 deletions src/ca_pwt/helpers/dict.py

This file was deleted.

36 changes: 36 additions & 0 deletions src/ca_pwt/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,42 @@
from typing import Any


def ensure_list(source: list[dict] | dict) -> list[dict]:
"""Ensures that the source is a list.
If it is not a list, it will be wrapped in a list.
If it is None, an empty list will be returned."""
if source is None:
return []
elif isinstance(source, list):
return source
elif "value" in source and isinstance(source["value"], list):
return source["value"]
else:
return [source]


def cleanup_odata_dict(source: dict) -> dict:
"""Cleans up the dictionary returned by the graph api by removing
not needed elements like @odata.context and @microsoft.graph.tips
"""

if not source:
exception: Exception = ValueError("The dictionary is None or empty.")
raise exception

remove_element_from_dict(source, "@odata.context")
remove_element_from_dict(source, "@microsoft.graph.tips")
return source


def remove_element_from_dict(source: dict, element: str) -> bool:
"""Remove an element from a dictionary if it exists."""
if element in source:
source.pop(element)
return True
return False


def assert_condition(condition: Any, message: str):
"""Asserts a condition and raises an AssertionError if it is False"""
if not condition:
Expand Down
30 changes: 15 additions & 15 deletions src/ca_pwt/policies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from ca_pwt.helpers.dict import remove_element_from_dict, cleanup_odata_dict
from ca_pwt.helpers.utils import remove_element_from_dict, cleanup_odata_dict, ensure_list
from ca_pwt.helpers.graph_api import EntityAPI
from ca_pwt.policies_mappings import replace_values_by_keys_in_policies
from ca_pwt.groups import get_groups_by_ids
Expand All @@ -12,20 +12,19 @@ def _get_entity_path(self) -> str:
return "identity/conditionalAccess/policies"


def load_policies(input_file: str) -> dict:
def load_policies(input_file: str) -> list[dict]:
"""Loads policies from the specified file.
It also cleans up the dictionary to remove unnecessary elements."""
import json

with open(input_file) as f:
_logger.info(f"Reading policies from file {input_file}...")

policies = json.load(f)
policies = cleanup_odata_dict(policies, ensure_list=True)
return policies
policies = cleanup_odata_dict(json.load(f))
return ensure_list(policies)


def save_policies(policies: dict, output_file: str):
def save_policies(policies: list[dict], output_file: str):
"""Saves policies to the specified file."""
import json

Expand All @@ -34,26 +33,26 @@ def save_policies(policies: dict, output_file: str):
f.write(json.dumps(policies, indent=4))


def cleanup_policies(source: dict) -> dict:
def cleanup_policies(policies: list[dict]) -> list[dict]:
"""Cleans up the policies dictionary for import by
removing disallowed elements while importing. (e.g. id, createdDateTime,
modifiedDateTime, templateId, id"""
_logger.info("Cleaning up policies...")

# exclude some elements, namely createdDateTime,
# modifiedDateTime, id, templateId, [email protected]
for policy in source:
for policy in policies:
remove_element_from_dict(policy, "createdDateTime")
remove_element_from_dict(policy, "modifiedDateTime")
remove_element_from_dict(policy, "id")
remove_element_from_dict(policy, "templateId")
grant_controls = policy["grantControls"]
if grant_controls is not None:
remove_element_from_dict(grant_controls, "[email protected]")
return source
return policies


def export_policies(access_token: str, odata_filter: str | None = None) -> dict:
def export_policies(access_token: str, odata_filter: str | None = None) -> list[dict]:
"""Exports all policies with the specified filter. Filter is
an OData filter string."""
policies_api = PoliciesAPI(access_token=access_token)
Expand All @@ -62,14 +61,15 @@ def export_policies(access_token: str, odata_filter: str | None = None) -> dict:
policies = response.json()

_logger.debug(f"Obtained policies: {policies}")
policies = cleanup_odata_dict(policies, ensure_list=True)
policies = cleanup_odata_dict(policies)
policies = ensure_list(policies)
_logger.debug(f"Formatted policies: {policies}")
return policies


def import_policies(
access_token: str,
policies: dict,
policies: list[dict],
*,
allow_duplicates: bool = False,
) -> list[tuple[str, str]]:
Expand All @@ -86,7 +86,7 @@ def import_policies(
policies = cleanup_policies(policies)
created_policies: list[tuple[str, str]] = []
for policy in policies:
display_name = policy.get("displayName")
display_name: str = str(policy.get("displayName"))

# check if the policy already exists
if not allow_duplicates:
Expand All @@ -108,10 +108,10 @@ def import_policies(

def get_groups_in_policies(
access_token: str,
policies: dict,
policies: list[dict],
*,
ignore_not_found: bool = False,
) -> dict:
) -> list[dict]:
"""Obtains all groups referenced by the policies in the policies dict.
If ignore_not_found is True, groups that are not found are ignored.
Returns a dictionary with the groups."""
Expand Down
Loading

0 comments on commit 602071a

Please sign in to comment.