Skip to content

Commit

Permalink
Merge pull request #16 from salesforce/refactor/improve-unit-tests
Browse files Browse the repository at this point in the history
Improve unit test coverage
  • Loading branch information
kmcquade authored Mar 12, 2021
2 parents d0b2a9f + 4e9693a commit 80ba8ff
Show file tree
Hide file tree
Showing 20 changed files with 674 additions and 145 deletions.
37 changes: 16 additions & 21 deletions azure_guardrails/command/generate_terraform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import click
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
from azure_guardrails import set_log_level, set_stream_logger
from azure_guardrails.terraform.terraform import get_terraform_template, TerraformTemplate
from azure_guardrails.terraform.terraform import TerraformTemplateNoParams, TerraformTemplateWithParams
from azure_guardrails.shared import utils, validate
from azure_guardrails.scrapers.compliance_data import ComplianceCoverage
from azure_guardrails.shared.config import get_default_config, get_config_from_file
Expand Down Expand Up @@ -76,16 +76,6 @@
default=False,
help="Only generate policies where parameters are REQUIRED",
)
# @optgroup.option(
# "--parameter-options",
# "-o",
# type=click.Choice(["defaults", "empty"], case_sensitive=True),
# multiple=True,
# required=False,
# default=None,
# help="Include Policies with Parameters that have default values (defaults) and/or Policies that have empty defaults that you must fill in (empty).",
# # callback=validate.click_validate_supported_azure_service, # TODO: Write this validation
# )
# Mutually exclusive option groups
# https://github.com/click-contrib/click-option-group
# https://stackoverflow.com/questions/37310718/mutually-exclusive-option-groups-in-python-click
Expand Down Expand Up @@ -168,25 +158,30 @@ def generate_terraform(
if no_params:
include_empty_defaults = False
with_parameters = False
elif params_required:
if params_required:
include_empty_defaults = True
elif params_optional:
with_parameters = True
if params_optional:
with_parameters = True
include_empty_defaults = False

if service == "all":
services = Services(config=config)
else:
services = Services(service_names=[service], config=config)
if with_parameters:
display_names = services.get_display_names_by_service_with_parameters(include_empty_defaults=include_empty_defaults)
terraform_template = TerraformTemplate(parameters=display_names,
subscription_name=subscription,
management_group=management_group,
enforcement_mode=enforcement_mode)
display_names = services.get_display_names_by_service_with_parameters(
include_empty_defaults=include_empty_defaults)
terraform_template = TerraformTemplateWithParams(parameters=display_names,
subscription_name=subscription,
management_group=management_group,
enforcement_mode=enforcement_mode)
result = terraform_template.rendered()
else:
display_names = services.get_display_names_sorted_by_service(with_parameters=with_parameters)
result = get_terraform_template(policy_names=display_names,
subscription_name=subscription,
management_group=management_group, enforcement_mode=enforcement_mode)
terraform_template = TerraformTemplateNoParams(policy_names=display_names,
subscription_name=subscription,
management_group=management_group,
enforcement_mode=enforcement_mode)
result = terraform_template.rendered()
print(result)
48 changes: 35 additions & 13 deletions azure_guardrails/command/list_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"--format",
"-f",
"fmt",
type=click.Choice(["stdout", "yaml", "terraform"]),
type=click.Choice(["stdout", "yaml"]),
required=False,
default="stdout",
help="Output format",
Expand All @@ -54,24 +54,46 @@ def list_policies(service: str, with_parameters: bool, fmt: str, verbosity: int)
service_names.append("all")
if service not in service_names:
raise Exception(f"Please provide a valid service name. Valid service names are {service_names}")
print("Getting policy names according to service\n")
if verbosity >= 1:
utils.print_grey("Getting policy names according to service\n")
if fmt == "yaml":
print_policies_in_yaml(service=service, with_parameters=with_parameters)
print_policies_in_yaml(service=service, with_parameters=with_parameters, verbosity=verbosity)
else:
print_policies_in_stdout(service=service, with_parameters=with_parameters, verbosity=verbosity)


def print_policies_in_yaml(service: str, with_parameters: bool):
def get_display_names_sorted_by_service(service: str, with_parameters: bool) -> dict:
if service == "all":
services = Services()
display_names = services.get_display_names_sorted_by_service(with_parameters=with_parameters)
result = yaml.dump(display_names)
total_policies = 0
for service_name in display_names.keys():
total_policies += len(display_names[service_name])
else:
service = Service(service_name=service)
display_names = service.get_display_names(with_parameters=with_parameters)
result = yaml.dump(display_names)
total_policies = len(display_names)
services = Services(service_names=[service])
display_names = services.get_display_names_sorted_by_service(with_parameters=with_parameters)
return display_names


def print_policies_in_yaml(service: str, with_parameters: bool, verbosity: int):
display_names = get_display_names_sorted_by_service(service=service, with_parameters=with_parameters)
result = yaml.dump(display_names)
total_policies = 0
for service_name in display_names.keys():
total_policies += len(display_names[service_name])
print(result)
print(f"total policies: {str(total_policies)}")
if verbosity >= 1:
print(f"total policies: {str(total_policies)}")


def print_policies_in_stdout(service: str, with_parameters: bool, verbosity: int):
# TODO: Figure out if I should just print all of the policies as a list or if they should be indented. If indented, uncomment the commented lines below.
display_names = get_display_names_sorted_by_service(service=service, with_parameters=with_parameters)
total_policies = 0
for service_name in display_names.keys():
# print(f"{service_name}:")
total_policies += len(display_names[service_name])
for policy_name in display_names.get(service_name):
print(policy_name)
# print(f"\t{policy_name}")
# print("\n")

if verbosity >= 1:
print(f"total policies: {str(total_policies)}")
11 changes: 11 additions & 0 deletions azure_guardrails/guardrails/policy_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def includes_parameters(self) -> bool:
@property
def parameters_have_defaults(self) -> bool:
"""Determines if the policy requires parameters that do not have defaultValues"""
# TODO: Rename this property to parameters_required, it is confusing
result = True
for parameter in self.properties.parameters:
if parameter.name == "effect":
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(self, name, parameter_json):
self.metadata_json = parameter_json.get("metadata")
self.description = self.metadata_json.get("description")
self.display_name = self.metadata_json.get("displayName")
self.schema = self.metadata_json.get("schema", None)
self.category = self.metadata_json.get("category", None)
self.strong_type = self.metadata_json.get("strongType", None)
self.assign_permissions = self.metadata_json.get("assignPermissions", None)
Expand Down Expand Up @@ -206,6 +208,7 @@ def __repr__(self):
return self.properties_json

def _parameters(self) -> List[Optional[Parameter]]:
# TODO: Parameters should be a dict, not a list. These methods are silly
parameters = []
parameter_json = self.properties_json.get("parameters")
if parameter_json:
Expand All @@ -224,6 +227,14 @@ def parameter_name_exists(self, parameter_name) -> bool:
except:
return False

@property
def parameter_names(self) -> list:
"""Return the list of parameter names"""
parameters = []
for parameter in self.parameters:
parameters.append(parameter.name)
return parameters

def get_parameter_by_name(self, parameter_name) -> Parameter:
try:
parameter_json = self.properties_json.get("parameters").get(parameter_name)
Expand Down
11 changes: 11 additions & 0 deletions azure_guardrails/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import json
import csv
from pathlib import Path
from colorama import Fore
END = "\033[0m"
GREY = "\33[90m"

AZURE_POLICY_SERVICE_DIRECTORY = os.path.abspath(
os.path.join(
Expand Down Expand Up @@ -63,3 +65,12 @@ def get_compliance_table() -> list:
for row in csv_reader:
results.append(row)
return results


def print_red(string):
print(f"{Fore.RED}{string}{END}")


def print_grey(string):
print(f"{GREY}{string}{END}")
# Color code from here: https://stackoverflow.com/a/39452138
94 changes: 54 additions & 40 deletions azure_guardrails/terraform/terraform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,74 @@
logger = logging.getLogger(__name__)


def get_terraform_template(policy_names: dict, subscription_name: str = "",
management_group: str = "", enforcement_mode: bool = False) -> str:
if subscription_name == "" and management_group == "":
raise Exception("Please supply a value for the subscription name or the management group")
if enforcement_mode:
enforcement_string = "true"
else:
enforcement_string = "false"
# TODO: Shorten the subscription name if it is over X characters
if subscription_name:
name = f"{subscription_name}-noparams"
# TODO: Shorten the management group name if it is over X characters
else:
name = f"{management_group}-noparams"
name = name.replace("-", "_")
name = name.lower()
template_contents = dict(
name=name,
policy_names=policy_names,
subscription_name=subscription_name,
management_group=management_group,
enforcement_mode=enforcement_string,
)
template_path = os.path.join(os.path.dirname(__file__), "no-parameters")
env = Environment(loader=FileSystemLoader(template_path)) # nosec
template = env.get_template("policy-set-with-builtins-v2.tf")
return template.render(t=template_contents)


class TerraformTemplate:
class TerraformTemplateNoParams:
"""Terraform Template for when there are no parameters"""
def __init__(self, policy_names: dict, subscription_name: str = "", management_group: str = "",
enforcement_mode: bool = False):
self.name = self._name(subscription_name=subscription_name, management_group=management_group)
self.subscription_name = subscription_name
self.management_group = management_group
self.policy_names = policy_names
if enforcement_mode:
self.enforcement_string = "true"
else:
self.enforcement_string = "false"

def _name(self, subscription_name: str, management_group: str) -> str:
if subscription_name == "" and management_group == "":
raise Exception("Please supply a value for the subscription name or the management group")
# TODO: Shorten the subscription name if it is over X characters
if subscription_name:
name = f"{subscription_name}-noparams"
# TODO: Shorten the management group name if it is over X characters
else:
name = f"{management_group}-noparams"
name = name.replace("-", "_")
name = name.lower()
return name

def rendered(self) -> str:
template_contents = dict(
name=self.name,
policy_names=self.policy_names,
subscription_name=self.subscription_name,
management_group=self.management_group,
enforcement_mode=self.enforcement_string,
)
template_path = os.path.join(os.path.dirname(__file__), "no-parameters")
env = Environment(loader=FileSystemLoader(template_path)) # nosec
template = env.get_template("policy-set-with-builtins-v2.tf")
return template.render(t=template_contents)


class TerraformTemplateWithParams:
"""Terraform Template with Parameters"""

def __init__(self,
parameters: dict,
subscription_name: str = "",
management_group: str = "", enforcement_mode: bool = False):
# TODO: Shorten the subscription name if it is over X characters
if subscription_name:
self.name = f"{subscription_name}-params"
# TODO: Shorten the management group name if it is over X characters
else:
self.name = f"{management_group}-params"
# self.name = name
self.service_parameters = self._parameters(parameters)

if subscription_name == "" and management_group == "":
raise Exception("Please supply a value for the subscription name or the management group")
self.name = self._name(subscription_name=subscription_name, management_group=management_group)
self.service_parameters = self._parameters(parameters)
self.subscription_name = subscription_name
self.management_group = management_group
if enforcement_mode:
self.enforcement_string = "true"
else:
self.enforcement_string = "false"

def _name(self, subscription_name: str, management_group: str) -> str:
if subscription_name == "" and management_group == "":
raise Exception("Please supply a value for the subscription name or the management group")
# TODO: Shorten the subscription name if it is over X characters
if subscription_name:
name = f"{subscription_name}-params"
# TODO: Shorten the management group name if it is over X characters
else:
name = f"{management_group}-params"
return name

def _parameters(self, parameters) -> dict:
"""Separated this out just in case we need to do more processing"""
results = {}
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ PyYAML==5.4.1
# Required for printing things
jinja2==2.11.3
tabulate==0.8.9
ruamel.yaml
ruamel.yaml==0.16.13
colorama==0.4.4
# Scrapers
beautifulsoup4==4.9.3
requests==2.25.1
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[nosetests]
exe = True
tests = test/, test/command/, test/logic, test/shared
tests = test/, test/command/, test/guardrails, test/shared, test/terraform
verbosity=2

[tool:pytest]
testpaths = test test/command test/logic test/shared
testpaths = test test/command test/guardrails test/shared test/terraform
python_files=test/*/test_*.py
norecursedirs = .svn _build tmp* __pycache__

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"pyyaml",
"jinja2",
"tabulate",
"colorama",
"ruamel.yaml",
"beautifulsoup4",
"requests",
Expand Down
Loading

0 comments on commit 80ba8ff

Please sign in to comment.