From aa353b00d73c00364452d1002ce6219da39d13f6 Mon Sep 17 00:00:00 2001 From: Zane Clark Date: Thu, 15 Aug 2024 08:52:28 -0700 Subject: [PATCH] feat: simplify configuration merging --- schemachange/config/BaseConfig.py | 54 +----- schemachange/config/ChangeHistoryTable.py | 37 ++-- schemachange/config/DeployConfig.py | 49 ++--- schemachange/config/get_merged_config.py | 90 ++++----- schemachange/config/parse_cli_args.py | 12 +- ...g-complete.yml => schemachange-config.yml} | 1 + tests/config/test_Config.py | 175 +++--------------- tests/config/test_get_merged_config.py | 163 ++++++++++++++++ tests/config/test_get_yaml_config.py | 46 ++--- tests/config/test_parse_cli_args.py | 29 ++- tests/test_cli_misc.py | 12 +- tests/test_main.py | 6 +- 12 files changed, 311 insertions(+), 363 deletions(-) rename tests/config/{schemachange-config-complete.yml => schemachange-config.yml} (97%) create mode 100644 tests/config/test_get_merged_config.py diff --git a/schemachange/config/BaseConfig.py b/schemachange/config/BaseConfig.py index bb129057..f07df6d1 100644 --- a/schemachange/config/BaseConfig.py +++ b/schemachange/config/BaseConfig.py @@ -24,7 +24,6 @@ class BaseConfig(ABC): subcommand: Literal["deploy", "render"] config_version: int | None = None - config_folder: Path = Path(".") config_file_path: Path | None = None root_folder: Path | None = Path(".") modules_folder: Path | None = None @@ -36,7 +35,7 @@ class BaseConfig(ABC): def factory( cls, subcommand: Literal["deploy", "render"], - config_folder: Path | str | None = Path("."), + config_file_path: Path, root_folder: Path | str | None = Path("."), modules_folder: Path | str | None = None, config_vars: str | dict | None = None, @@ -50,16 +49,9 @@ def factory( "config_vars did not parse correctly, please check its configuration" ) from e - if "secrets" in kwargs: - secrets.update(kwargs.pop("secrets")) - - if config_folder is None: - config_folder = "." - return cls( subcommand=subcommand, - config_folder=validate_directory(path=config_folder), - config_file_path=Path(config_folder) / cls.default_config_file_name, + config_file_path=config_file_path, root_folder=validate_directory(path=root_folder), modules_folder=validate_directory(path=modules_folder), config_vars=validate_config_vars(config_vars=config_vars), @@ -68,48 +60,6 @@ def factory( **kwargs, ) - def asdict_exclude_defaults(self) -> dict: - retain = {} - for field in dataclasses.fields(self): - val = getattr(self, field.name) - if ( - val is None - and field.default == dataclasses.MISSING - or val == field.default - ): - continue - if ( - field.default_factory != dataclasses.MISSING - and val == field.default_factory() - ): - continue - if field.name in [ - "_create_change_history_table", - "_autocommit", - "_dry_run", - "_raise_exception_on_ignored_versioned_script", - ]: - retain[field.name.lstrip("_")] = val - else: - retain[field.name] = val - - return retain - - def merge_exclude_defaults(self: T, other: T) -> T: - self_kwargs = self.asdict_exclude_defaults() - self_kwargs.pop("config_file_path") - - other_kwargs = other.asdict_exclude_defaults() - other_kwargs.pop("config_file_path") - if "change_history_table" in other_kwargs: - other_kwargs["change_history_table"] = other.change_history_table - - if "secrets" in other_kwargs and len(other_kwargs["secrets"]) == 0: - other_kwargs.pop("secrets") - - kwargs = {**self_kwargs, **other_kwargs} - return self.__class__.factory(**kwargs) - def log_details(self): logger.info("Using root folder", root_folder=str(self.root_folder)) if self.modules_folder: diff --git a/schemachange/config/ChangeHistoryTable.py b/schemachange/config/ChangeHistoryTable.py index 6246c7ac..ba7645a5 100644 --- a/schemachange/config/ChangeHistoryTable.py +++ b/schemachange/config/ChangeHistoryTable.py @@ -1,6 +1,8 @@ import dataclasses from typing import ClassVar +from schemachange.config.utils import get_snowflake_identifier_string + @dataclasses.dataclass(frozen=True) class ChangeHistoryTable: @@ -18,29 +20,32 @@ def fully_qualified(self) -> str: @classmethod def from_str(cls, table_str: str): - details: dict[str, str] = { - "database_name": cls._default_database_name, - "schema_name": cls._default_schema_name, - "table_name": cls._default_table_name, - } + database_name = cls._default_database_name + schema_name = cls._default_schema_name + table_name = cls._default_table_name + if table_str is not None: table_name_parts = table_str.strip().split(".") if len(table_name_parts) == 1: - details["table_name"] = table_name_parts[0] + table_name = table_name_parts[0] elif len(table_name_parts) == 2: - details["table_name"] = table_name_parts[1] - details["schema_name"] = table_name_parts[0] + table_name = table_name_parts[1] + schema_name = table_name_parts[0] elif len(table_name_parts) == 3: - details["table_name"] = table_name_parts[2] - details["schema_name"] = table_name_parts[1] - details["database_name"] = table_name_parts[0] + table_name = table_name_parts[2] + schema_name = table_name_parts[1] + database_name = table_name_parts[0] else: raise ValueError(f"Invalid change history table name: {table_str}") - # if the name element does not include '"' raise to upper case on return return cls( - **{ - attr_name: attr_val if '"' in attr_val else attr_val.upper() - for (attr_name, attr_val) in details.items() - } + table_name=get_snowflake_identifier_string( + input_value=table_name, input_type="table_name" + ), + schema_name=get_snowflake_identifier_string( + input_value=schema_name, input_type="schema_name" + ), + database_name=get_snowflake_identifier_string( + input_value=database_name, input_type="database_name" + ), ) diff --git a/schemachange/config/DeployConfig.py b/schemachange/config/DeployConfig.py index 178e850f..b7988d23 100644 --- a/schemachange/config/DeployConfig.py +++ b/schemachange/config/DeployConfig.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +from pathlib import Path from typing import Literal from schemachange.config.BaseConfig import BaseConfig @@ -21,56 +22,35 @@ class DeployConfig(BaseConfig): change_history_table: ChangeHistoryTable | None = dataclasses.field( default_factory=ChangeHistoryTable ) - _create_change_history_table: bool | None - _autocommit: bool | None - _dry_run: bool | None + create_change_history_table: bool = False + autocommit: bool = False + dry_run: bool = False query_tag: str | None = None oauth_config: dict | None = None version_number_validation_regex: str | None = None - _raise_exception_on_ignored_versioned_script: bool | None - - @property - def create_change_history_table(self) -> bool: - if self._create_change_history_table is None: - return False - return self._create_change_history_table - - @property - def autocommit(self) -> bool: - if self._autocommit is None: - return False - return self._autocommit - - @property - def dry_run(self) -> bool: - if self._dry_run is None: - return False - return self._dry_run - - @property - def raise_exception_on_ignored_versioned_script(self) -> bool: - if self._raise_exception_on_ignored_versioned_script is None: - return False - return self._raise_exception_on_ignored_versioned_script + raise_exception_on_ignored_versioned_script: bool = False @classmethod def factory( cls, + config_file_path: Path, snowflake_role: str | None = None, snowflake_warehouse: str | None = None, snowflake_database: str | None = None, snowflake_schema: str | None = None, - create_change_history_table: bool | None = None, - autocommit: bool | None = None, - dry_run: bool | None = None, - raise_exception_on_ignored_versioned_script: bool | None = None, + change_history_table: str | None = None, **kwargs, ): if "subcommand" in kwargs: kwargs.pop("subcommand") + change_history_table = ChangeHistoryTable.from_str( + table_str=change_history_table + ) + return super().factory( subcommand="deploy", + config_file_path=config_file_path, snowflake_role=get_snowflake_identifier_string( snowflake_role, "snowflake_role" ), @@ -83,10 +63,7 @@ def factory( snowflake_schema=get_snowflake_identifier_string( snowflake_schema, "snowflake_schema" ), - _create_change_history_table=create_change_history_table, - _autocommit=autocommit, - _dry_run=dry_run, - _raise_exception_on_ignored_versioned_script=raise_exception_on_ignored_versioned_script, + change_history_table=change_history_table, **kwargs, ) diff --git a/schemachange/config/get_merged_config.py b/schemachange/config/get_merged_config.py index 8ddfde63..f34bbe93 100644 --- a/schemachange/config/get_merged_config.py +++ b/schemachange/config/get_merged_config.py @@ -1,48 +1,15 @@ import logging import sys -from argparse import Namespace -from enum import Enum from pathlib import Path from typing import Union, Optional -from schemachange.config.ChangeHistoryTable import ChangeHistoryTable from schemachange.config.DeployConfig import DeployConfig from schemachange.config.RenderConfig import RenderConfig from schemachange.config.parse_cli_args import parse_cli_args -from schemachange.config.utils import load_yaml_config +from schemachange.config.utils import load_yaml_config, validate_directory -def config_factory(args: Union[Namespace, dict]) -> Union[DeployConfig, RenderConfig]: - if isinstance(args, Namespace): - subcommand = args.subcommand - kwargs = args.__dict__ - else: - subcommand = args.get("subcommand") - kwargs = args - - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - if "change_history_table" in kwargs and kwargs["change_history_table"] is not None: - kwargs["change_history_table"] = ChangeHistoryTable.from_str( - table_str=kwargs["change_history_table"] - ) - - if "vars" in kwargs: - kwargs["config_vars"] = kwargs.pop("vars") - - if subcommand == "deploy": - return DeployConfig.factory(**kwargs) - elif subcommand == "render": - return RenderConfig.factory(**kwargs) - else: - raise Exception(f"unhandled subcommand: {subcommand}") - - -def get_yaml_config( - subcommand: str, - config_file_path: Optional[Path], - script_path: Optional[Path] = None, -) -> Union[DeployConfig, RenderConfig]: +def get_yaml_config_kwargs(config_file_path: Optional[Path]) -> dict: # TODO: I think the configuration key for oauthconfig should be oauth-config. # This looks like a bug in the current state of the repo to me @@ -51,32 +18,49 @@ def get_yaml_config( k.replace("-", "_").replace("oauthconfig", "oauth_config"): v for (k, v) in load_yaml_config(config_file_path).items() } - kwargs["subcommand"] = subcommand - if script_path: - kwargs["script_path"] = script_path if "verbose" in kwargs: - kwargs["log_level"] = logging.DEBUG if kwargs["verbose"] else logging.ERROR + kwargs["log_level"] = logging.DEBUG kwargs.pop("verbose") - return config_factory(kwargs) + if "vars" in kwargs: + kwargs["config_vars"] = kwargs.pop("vars") + + return kwargs def get_merged_config() -> Union[DeployConfig, RenderConfig]: - args = parse_cli_args(sys.argv[1:]) - if ( - "log_level" in args - and args.log_level is not None - and isinstance(args.log_level, Enum) - ): - args.log_level = args.log_level.value + cli_kwargs = parse_cli_args(sys.argv[1:]) + cli_config_vars = cli_kwargs.pop("config_vars", None) + if cli_config_vars is None: + cli_config_vars = {} - cli_config = config_factory(args=args) - yaml_config = get_yaml_config( - subcommand=cli_config.subcommand, - config_file_path=cli_config.config_file_path, - script_path=getattr(cli_config, "script_path", None), + config_folder = validate_directory(path=cli_kwargs.pop("config_folder", ".")) + config_file_path = Path(config_folder) / "schemachange-config.yml" + + yaml_kwargs = get_yaml_config_kwargs( + config_file_path=config_file_path, ) + yaml_config_vars = yaml_kwargs.pop("config_vars", None) + if yaml_config_vars is None: + yaml_config_vars = {} + + config_vars = { + **yaml_config_vars, + **cli_config_vars, + } # override the YAML config with the CLI configuration - return yaml_config.merge_exclude_defaults(other=cli_config) + kwargs = { + "config_file_path": config_file_path, + "config_vars": config_vars, + **{k: v for k, v in yaml_kwargs.items() if v is not None}, + **{k: v for k, v in cli_kwargs.items() if v is not None}, + } + + if cli_kwargs["subcommand"] == "deploy": + return DeployConfig.factory(**kwargs) + elif cli_kwargs["subcommand"] == "render": + return RenderConfig.factory(**kwargs) + else: + raise Exception(f"unhandled subcommand: {cli_kwargs['subcommand'] }") diff --git a/schemachange/config/parse_cli_args.py b/schemachange/config/parse_cli_args.py index 500cc785..e57f6dd6 100644 --- a/schemachange/config/parse_cli_args.py +++ b/schemachange/config/parse_cli_args.py @@ -71,7 +71,7 @@ def deprecate_verbose( del parsed_args.verbose -def parse_cli_args(args) -> argparse.Namespace: +def parse_cli_args(args) -> dict: parser = argparse.ArgumentParser( prog="schemachange", description="Apply schema changes to a Snowflake account. Full readme at " @@ -252,4 +252,12 @@ def parse_cli_args(args) -> argparse.Namespace: deprecate_verbose(args=args, verbose=verbose, parsed_args=parsed_args) - return parsed_args + parsed_kwargs = parsed_args.__dict__ + + if "log_level" in parsed_kwargs and isinstance(parsed_kwargs["log_level"], Enum): + parsed_kwargs["log_level"] = parsed_kwargs["log_level"].value + + if "vars" in parsed_kwargs: + parsed_kwargs["config_vars"] = parsed_kwargs.pop("vars") + + return parsed_kwargs diff --git a/tests/config/schemachange-config-complete.yml b/tests/config/schemachange-config.yml similarity index 97% rename from tests/config/schemachange-config-complete.yml rename to tests/config/schemachange-config.yml index 8977a15d..73a5e56a 100644 --- a/tests/config/schemachange-config-complete.yml +++ b/tests/config/schemachange-config.yml @@ -10,6 +10,7 @@ snowflake-schema: 'snowflake-schema-from-yaml' change-history-table: 'change-history-table-from-yaml' vars: var1: 'from_yaml' + var2: 'also_from_yaml' create-change-history-table: false autocommit: false verbose: false diff --git a/tests/config/test_Config.py b/tests/config/test_Config.py index e7fe3a51..f92454de 100644 --- a/tests/config/test_Config.py +++ b/tests/config/test_Config.py @@ -1,7 +1,5 @@ from __future__ import annotations -import dataclasses -from argparse import Namespace from pathlib import Path from unittest import mock @@ -11,7 +9,6 @@ from schemachange.config.ChangeHistoryTable import ChangeHistoryTable from schemachange.config.DeployConfig import DeployConfig from schemachange.config.RenderConfig import RenderConfig -from schemachange.config.get_merged_config import config_factory from schemachange.config.utils import get_config_secrets @@ -19,7 +16,10 @@ @mock.patch("pathlib.Path.is_dir", return_value=True) def yaml_config(_) -> DeployConfig: return DeployConfig.factory( - config_folder=Path(__file__).parent.parent.parent / "demo" / "basics_demo", + config_file_path=Path(__file__).parent.parent.parent + / "demo" + / "basics_demo" + / "schemachange-config.yml", root_folder=Path(__file__).parent.parent.parent / "demo" / "basics_demo", modules_folder=Path(__file__).parent.parent.parent / "demo" / "basics_demo", config_vars={"var1": "yaml_vars"}, @@ -29,7 +29,7 @@ def yaml_config(_) -> DeployConfig: snowflake_warehouse="yaml_snowflake_warehouse", snowflake_database="yaml_snowflake_database", snowflake_schema="yaml_snowflake_schema", - change_history_table=ChangeHistoryTable(table_name="yaml_change_history_table"), + change_history_table="yaml_change_history_table", create_change_history_table=True, autocommit=True, dry_run=True, @@ -189,49 +189,11 @@ def test_fully_qualified(self, table: ChangeHistoryTable, expected: str): class TestConfig: - @mock.patch("pathlib.Path.is_dir", return_value=True) - def test_inject_config_file_path_no_config_folder(self, _): - config = BaseConfig.factory(subcommand="deploy") - assert config.config_file_path == Path(".") / config.default_config_file_name - - @mock.patch("pathlib.Path.is_dir", return_value=True) - def test_inject_config_file_path_with_config_folder(self, _): - config_folder = "some_config_folder" - config = BaseConfig.factory(subcommand="deploy", config_folder=config_folder) - assert ( - config.config_file_path - == Path(config_folder) / config.default_config_file_name - ) - - @mock.patch("pathlib.Path.is_dir", return_value=False) - def test_invalid_config_folder(self, _): - with pytest.raises(Exception) as e_info: - DeployConfig.factory( - config_folder="some_config_folder_name", - root_folder="some_root_folder_name", - modules_folder="some_modules_folder_name", - config_vars={"some": "config_vars"}, - snowflake_account="some_snowflake_account", - snowflake_user="some_snowflake_user", - snowflake_role="some_snowflake_role", - snowflake_warehouse="some_snowflake_warehouse", - snowflake_database="some_snowflake_database", - snowflake_schema="some_snowflake_schema", - change_history_table=ChangeHistoryTable( - table_name="some_history_table" - ), - query_tag="some_query_tag", - oauth_config={"some": "values"}, - version_number_validation_regex="some_regex", - ) - e_info_value = str(e_info.value) - assert "Path is not valid directory: some_config_folder_name" in e_info_value - - @mock.patch("pathlib.Path.is_dir", side_effect=[True, False, False]) + @mock.patch("pathlib.Path.is_dir", side_effect=[False]) def test_invalid_root_folder(self, _): with pytest.raises(Exception) as e_info: DeployConfig.factory( - config_folder="some_config_folder_name", + config_file_path=Path("some_config_file_name"), root_folder="some_root_folder_name", modules_folder="some_modules_folder_name", config_vars={"some": "config_vars"}, @@ -241,9 +203,7 @@ def test_invalid_root_folder(self, _): snowflake_warehouse="some_snowflake_warehouse", snowflake_database="some_snowflake_database", snowflake_schema="some_snowflake_schema", - change_history_table=ChangeHistoryTable( - table_name="some_history_table" - ), + change_history_table="some_history_table", query_tag="some_query_tag", oauth_config={"some": "values"}, version_number_validation_regex="some_regex", @@ -251,11 +211,11 @@ def test_invalid_root_folder(self, _): e_info_value = str(e_info.value) assert "Path is not valid directory: some_root_folder_name" in e_info_value - @mock.patch("pathlib.Path.is_dir", side_effect=[True, True, False]) + @mock.patch("pathlib.Path.is_dir", side_effect=[True, False]) def test_invalid_modules_folder(self, _): with pytest.raises(Exception) as e_info: DeployConfig.factory( - config_folder="some_config_folder_name", + config_file_path=Path("some_config_file_name"), root_folder="some_root_folder_name", modules_folder="some_modules_folder_name", config_vars={"some": "config_vars"}, @@ -265,9 +225,7 @@ def test_invalid_modules_folder(self, _): snowflake_warehouse="some_snowflake_warehouse", snowflake_database="some_snowflake_database", snowflake_schema="some_snowflake_schema", - change_history_table=ChangeHistoryTable( - table_name="some_history_table" - ), + change_history_table="some_history_table", query_tag="some_query_tag", oauth_config={"some": "values"}, version_number_validation_regex="some_regex", @@ -277,7 +235,11 @@ def test_invalid_modules_folder(self, _): def test_config_vars_not_a_dict(self): with pytest.raises(Exception) as e_info: - BaseConfig.factory(subcommand="deploy", config_vars="a string") + BaseConfig.factory( + subcommand="deploy", + config_vars="a string", + config_file_path=Path("."), + ) assert ( "config_vars did not parse correctly, please check its configuration" in str(e_info.value) @@ -286,100 +248,27 @@ def test_config_vars_not_a_dict(self): def test_config_vars_reserved_word(self): with pytest.raises(Exception) as e_info: BaseConfig.factory( - subcommand="deploy", config_vars={"schemachange": "not allowed"} + subcommand="deploy", + config_vars={"schemachange": "not allowed"}, + config_file_path=Path("."), ) assert ( "The variable 'schemachange' has been reserved for use by schemachange, please use a different name" in str(e_info.value) ) - @pytest.mark.parametrize( - "cli_config, cli_overrides", - [ - (DeployConfig.factory(), []), - ( - DeployConfig.factory( - config_folder=Path(__file__).parent.parent.parent - / "demo" - / "citibike_demo", - root_folder=Path(__file__).parent.parent.parent - / "demo" - / "citibike_demo", - modules_folder=Path(__file__).parent.parent.parent - / "demo" - / "citibike_demo", - config_vars={"var1": "cli_vars"}, - snowflake_account="cli_snowflake_account", - snowflake_user="cli_snowflake_user", - snowflake_role="cli_snowflake_role", - snowflake_warehouse="cli_snowflake_warehouse", - snowflake_database="cli_snowflake_database", - snowflake_schema="cli_snowflake_schema", - change_history_table=ChangeHistoryTable( - table_name="cli_change_history_table" - ), - create_change_history_table=False, - autocommit=False, - dry_run=False, - query_tag="cli_query_tag", - oauth_config={"oauth": "cli_oauth"}, - version_number_validation_regex="cli_version_number_validation_regex", - raise_exception_on_ignored_versioned_script=False, - ), - [ - "config_folder", - "root_folder", - "modules_folder", - "config_vars", - "snowflake_account", - "snowflake_user", - "snowflake_role", - "snowflake_warehouse", - "snowflake_database", - "snowflake_schema", - "change_history_table", - "_create_change_history_table", - "_autocommit", - "_dry_run", - "query_tag", - "oauth_config", - "version_number_validation_regex", - "_raise_exception_on_ignored_versioned_script", - "config_file_path", - ], - ), - ], - ) - def test_merge_exclude_unset( - self, - yaml_config: DeployConfig, - cli_config: DeployConfig, - cli_overrides: list[str], - ): - merged_config = yaml_config.merge_exclude_defaults(other=cli_config) - - for name, field in dataclasses.asdict(merged_config).items(): - if name not in cli_overrides: - expected = getattr(yaml_config, name) - else: - expected = getattr(cli_config, name) - - if isinstance(expected, Path): - assert str(getattr(merged_config, name)) == str(expected) - else: - assert getattr(merged_config, name) == expected - def test_check_for_deploy_args_happy_path(self): config = DeployConfig.factory( snowflake_account="account", snowflake_user="user", snowflake_role="role", snowflake_warehouse="warehouse", + config_file_path=Path("."), ) config.check_for_deploy_args() def test_check_for_deploy_args_exception(self): - config = DeployConfig.factory() + config = DeployConfig.factory(config_file_path=Path(".")) with pytest.raises(ValueError) as e: config.check_for_deploy_args() @@ -388,28 +277,6 @@ def test_check_for_deploy_args_exception(self): ) -@pytest.mark.parametrize( - "args, expected_class", - [ - (Namespace(subcommand="deploy"), DeployConfig), - (Namespace(subcommand="render", script_path="some script"), RenderConfig), - ], -) -@mock.patch("pathlib.Path.is_file", return_value=True) -def test_config_factory( - _, args: Namespace, expected_class: DeployConfig | RenderConfig -): - result = config_factory(args) - # noinspection PyTypeChecker - assert isinstance(result, expected_class) - - -def test_config_factory_unhandled_subcommand(): - with pytest.raises(Exception) as e_info: - config_factory(Namespace(subcommand="unhandled")) - assert "unhandled subcommand" in str(e_info) - - @mock.patch("pathlib.Path.is_file", return_value=False) def test_render_config_invalid_path(_): with pytest.raises(Exception) as e_info: diff --git a/tests/config/test_get_merged_config.py b/tests/config/test_get_merged_config.py new file mode 100644 index 00000000..be1b6513 --- /dev/null +++ b/tests/config/test_get_merged_config.py @@ -0,0 +1,163 @@ +from pathlib import Path +from unittest import mock + +import pytest + +from schemachange.config.ChangeHistoryTable import ChangeHistoryTable +from schemachange.config.get_merged_config import get_merged_config + +required_args = [ + "--snowflake-account", + "account", + "--snowflake-user", + "user", + "--snowflake-warehouse", + "warehouse", + "--snowflake-role", + "role", +] + + +class TestGetMergedConfig: + @mock.patch("pathlib.Path.is_dir", return_value=True) + def test_default_config_folder(self, _): + with mock.patch("sys.argv", ["schemachange", *required_args]): + config = get_merged_config() + assert ( + config.config_file_path == Path(".") / config.default_config_file_name + ) + + @mock.patch("pathlib.Path.is_dir", return_value=True) + def test_config_folder(self, _): + with mock.patch( + "sys.argv", ["schemachange", "--config-folder", "DUMMY", *required_args] + ): + config = get_merged_config() + assert ( + config.config_file_path + == Path("DUMMY") / config.default_config_file_name + ) + + @mock.patch("pathlib.Path.is_dir", return_value=False) + def test_invalid_config_folder(self, _): + with pytest.raises(Exception) as e_info: + with mock.patch( + "sys.argv", ["schemachange", "--config-folder", "DUMMY", *required_args] + ): + config = get_merged_config() + assert ( + config.config_file_path + == Path("DUMMY") / config.default_config_file_name + ) + e_info_value = str(e_info.value) + assert "Path is not valid directory: DUMMY" in e_info_value + + @mock.patch("pathlib.Path.is_dir", return_value=True) + def test_no_cli_args(self, _): + with mock.patch( + "sys.argv", ["schemachange", "--config-folder", str(Path(__file__).parent)] + ): + config = get_merged_config() + + assert config.snowflake_account == "snowflake-account-from-yaml" + assert config.snowflake_user == "snowflake-user-from-yaml" + assert config.snowflake_warehouse == '"snowflake-warehouse-from-yaml"' + assert config.snowflake_role == '"snowflake-role-from-yaml"' + assert str(config.root_folder) == "root-folder-from-yaml" + assert str(config.modules_folder) == "modules-folder-from-yaml" + assert config.snowflake_database == '"snowflake-database-from-yaml"' + assert config.snowflake_schema == '"snowflake-schema-from-yaml"' + assert config.change_history_table == ChangeHistoryTable( + table_name='"change-history-table-from-yaml"', + schema_name="SCHEMACHANGE", + database_name="METADATA", + ) + assert config.config_vars == {"var1": "from_yaml", "var2": "also_from_yaml"} + assert config.create_change_history_table is False + assert config.autocommit is False + assert config.dry_run is False + assert config.query_tag == "query-tag-from-yaml" + assert config.oauth_config == { + "token-provider-url": "token-provider-url-from-yaml", + "token-response-name": "token-response-name-from-yaml", + "token-request-headers": { + "Content-Type": "Content-Type-from-yaml", + "User-Agent": "User-Agent-from-yaml", + }, + "token-request-payload": { + "client_id": "id-from-yaml", + "username": "username-from-yaml", + "password": "password-from-yaml", + "grant_type": "type-from-yaml", + "scope": "scope-from-yaml", + }, + } + + @mock.patch("pathlib.Path.is_dir", return_value=True) + def test_all_cli_args(self, _): + with mock.patch( + "sys.argv", + [ + "schemachange", + "--config-folder", + str(Path(__file__).parent), + "--root-folder", + "root-folder-from-cli", + "--modules-folder", + "modules-folder-from-cli", + "--vars", + '{"var1": "from_cli", "var3": "also_from_cli"}', + "--snowflake-account", + "snowflake-account-from-cli", + "--snowflake-user", + "snowflake-user-from-cli", + "--snowflake-role", + "snowflake-role-from-cli", + "--snowflake-warehouse", + "snowflake-warehouse-from-cli", + "--snowflake-database", + "snowflake-database-from-cli", + "--snowflake-schema", + "snowflake-schema-from-cli", + "--change-history-table", + "change-history-table-from-cli", + "--create-change-history-table", + "--autocommit", + "--dry-run", + "--query-tag", + "query-tag-from-cli", + "--oauth-config", + '{"token-provider-url": "https//...", "token-request-payload": {"client_id": "GUID_xyz"} }', + "--version_number_validation_regex", + "version_number_validation_regex-from-cli", + "--raise-exception-on-ignored-versioned-script", + ], + ): + config = get_merged_config() + + assert config.snowflake_account == "snowflake-account-from-cli" + assert config.snowflake_user == "snowflake-user-from-cli" + assert config.snowflake_warehouse == '"snowflake-warehouse-from-cli"' + assert config.snowflake_role == '"snowflake-role-from-cli"' + assert str(config.root_folder) == "root-folder-from-cli" + assert str(config.modules_folder) == "modules-folder-from-cli" + assert config.snowflake_database == '"snowflake-database-from-cli"' + assert config.snowflake_schema == '"snowflake-schema-from-cli"' + assert config.change_history_table == ChangeHistoryTable( + table_name='"change-history-table-from-cli"', + schema_name="SCHEMACHANGE", + database_name="METADATA", + ) + assert config.config_vars == { + "var1": "from_cli", + "var2": "also_from_yaml", + "var3": "also_from_cli", + } + assert config.create_change_history_table is True + assert config.autocommit is True + assert config.dry_run is True + assert config.query_tag == "query-tag-from-cli" + assert config.oauth_config == { + "token-provider-url": "https//...", + "token-request-payload": {"client_id": "GUID_xyz"}, + } diff --git a/tests/config/test_get_yaml_config.py b/tests/config/test_get_yaml_config.py index 96b473f1..9748fff9 100644 --- a/tests/config/test_get_yaml_config.py +++ b/tests/config/test_get_yaml_config.py @@ -6,8 +6,7 @@ import pytest -from schemachange.config.ChangeHistoryTable import ChangeHistoryTable -from schemachange.config.get_merged_config import get_yaml_config +from schemachange.config.get_merged_config import get_yaml_config_kwargs from schemachange.config.utils import load_yaml_config @@ -77,30 +76,25 @@ def test_load_yaml_config__requiring_env_var_but_env_var_not_set_should_raise_ex @mock.patch("pathlib.Path.is_dir", return_value=True) def test_get_yaml_config(_): - config_file_path = Path(__file__).parent / "schemachange-config-complete.yml" - yaml_config = get_yaml_config( - subcommand="deploy", config_file_path=config_file_path - ) - assert yaml_config.root_folder == Path("root-folder-from-yaml") - assert yaml_config.modules_folder == Path("modules-folder-from-yaml") - assert yaml_config.snowflake_account == "snowflake-account-from-yaml" - assert yaml_config.snowflake_user == "snowflake-user-from-yaml" - assert yaml_config.snowflake_role == '"snowflake-role-from-yaml"' - assert yaml_config.snowflake_warehouse == '"snowflake-warehouse-from-yaml"' - assert yaml_config.snowflake_database == '"snowflake-database-from-yaml"' - assert yaml_config.snowflake_schema == '"snowflake-schema-from-yaml"' - assert yaml_config.change_history_table == ChangeHistoryTable( - database_name="METADATA", - table_name="CHANGE-HISTORY-TABLE-FROM-YAML", - ) - assert yaml_config.query_tag == "query-tag-from-yaml" - - assert yaml_config.create_change_history_table is False - assert yaml_config.autocommit is False - assert yaml_config.dry_run is False - - assert yaml_config.config_vars == {"var1": "from_yaml"} - assert yaml_config.oauth_config == { + config_file_path = Path(__file__).parent / "schemachange-config.yml" + yaml_config = get_yaml_config_kwargs(config_file_path=config_file_path) + assert str(yaml_config["root_folder"]) == "root-folder-from-yaml" + assert str(yaml_config["modules_folder"]) == "modules-folder-from-yaml" + assert yaml_config["snowflake_account"] == "snowflake-account-from-yaml" + assert yaml_config["snowflake_user"] == "snowflake-user-from-yaml" + assert yaml_config["snowflake_role"] == "snowflake-role-from-yaml" + assert yaml_config["snowflake_warehouse"] == "snowflake-warehouse-from-yaml" + assert yaml_config["snowflake_database"] == "snowflake-database-from-yaml" + assert yaml_config["snowflake_schema"] == "snowflake-schema-from-yaml" + assert yaml_config["change_history_table"] == "change-history-table-from-yaml" + assert yaml_config["query_tag"] == "query-tag-from-yaml" + + assert yaml_config["create_change_history_table"] is False + assert yaml_config["autocommit"] is False + assert yaml_config["dry_run"] is False + + assert yaml_config["config_vars"] == {"var1": "from_yaml", "var2": "also_from_yaml"} + assert yaml_config["oauth_config"] == { "token-provider-url": "token-provider-url-from-yaml", "token-request-headers": { "Content-Type": "Content-Type-from-yaml", diff --git a/tests/config/test_parse_cli_args.py b/tests/config/test_parse_cli_args.py index d5555c6f..545caa59 100644 --- a/tests/config/test_parse_cli_args.py +++ b/tests/config/test_parse_cli_args.py @@ -20,13 +20,12 @@ def test_parse_args_defaults(): parsed_args = parse_cli_args(args) for expected_arg, expected_value in expected.items(): - parsed_arg = getattr(parsed_args, expected_arg) - assert parsed_arg == expected_value - assert parsed_args.create_change_history_table is None - assert parsed_args.autocommit is None - assert parsed_args.dry_run is None - assert parsed_args.raise_exception_on_ignored_versioned_script is None - assert parsed_args.subcommand == "deploy" + assert parsed_args[expected_arg] == expected_value + assert parsed_args["create_change_history_table"] is None + assert parsed_args["autocommit"] is None + assert parsed_args["dry_run"] is None + assert parsed_args["raise_exception_on_ignored_versioned_script"] is None + assert parsed_args["subcommand"] == "deploy" def test_parse_args_deploy_names(): @@ -73,10 +72,11 @@ def test_parse_args_deploy_names(): expected[expected_arg] = expected_value parsed_args = parse_cli_args(args) - assert parsed_args.subcommand == "deploy" + assert parsed_args["subcommand"] == "deploy" for expected_arg, expected_value in expected.items(): - parsed_arg = getattr(parsed_args, expected_arg) - assert parsed_arg == expected_value + if expected_arg == "vars": + expected_arg = "config_vars" + assert parsed_args[expected_arg] == expected_value def test_parse_args_deploy_flags(): @@ -124,18 +124,17 @@ def test_parse_args_deploy_flags(): expected[expected_arg] = expected_value parsed_args = parse_cli_args(args) - assert parsed_args.subcommand == "deploy" + assert parsed_args["subcommand"] == "deploy" for expected_arg, expected_value in expected.items(): - parsed_arg = getattr(parsed_args, expected_arg) - assert parsed_arg == expected_value + assert parsed_args[expected_arg] == expected_value def test_parse_args_verbose_deprecation(): args: list[str] = ["--verbose"] with pytest.warns(UserWarning) as warning: parsed_args = parse_cli_args(args) - assert getattr(parsed_args, "verbose", None) is None - assert parsed_args.log_level is logging.DEBUG + assert parsed_args.get("verbose", None) is None + assert parsed_args["log_level"] is logging.DEBUG assert ( str(warning[0].message) == "Argument ['-v', '--verbose'] is deprecated and will be interpreted as a DEBUG log level." diff --git a/tests/test_cli_misc.py b/tests/test_cli_misc.py index 57aa5419..d5c0ef87 100644 --- a/tests/test_cli_misc.py +++ b/tests/test_cli_misc.py @@ -80,23 +80,23 @@ def test_sorted_alphanumeric_mixed_string(): { "database_name": ChangeHistoryTable._default_database_name, "schema_name": ChangeHistoryTable._default_schema_name, - "table_name": "change_history_table".upper(), + "table_name": "change_history_table", }, ), ( "myschema.change_history_table", { "database_name": ChangeHistoryTable._default_database_name, - "schema_name": "myschema".upper(), - "table_name": "change_history_table".upper(), + "schema_name": "myschema", + "table_name": "change_history_table", }, ), ( "mydb.myschema.change_history_table", { - "database_name": "mydb".upper(), - "schema_name": "myschema".upper(), - "table_name": "change_history_table".upper(), + "database_name": "mydb", + "schema_name": "myschema", + "table_name": "change_history_table", }, ), ( diff --git a/tests/test_main.py b/tests/test_main.py index 398ec675..dc214108 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -15,7 +15,7 @@ default_base_config = { # Shared configuration options - "config_folder": Path("."), + "config_file_path": Path(".") / "schemachange-config.yml", "root_folder": Path("."), "modules_folder": None, "config_vars": {}, @@ -120,7 +120,7 @@ **default_deploy_config, **required_config, "change_history_table": ChangeHistoryTable( - database_name="DB", schema_name="SCHEMA", table_name="TABLE" + database_name="db", schema_name="schema", table_name="table" ), }, None, @@ -329,7 +329,7 @@ def test_main_deploy_config_folder( ) args[args.index("DUMMY")] = d - expected_config["config_folder"] = Path(d) + expected_config["config_file_path"] = Path(d) / "schemachange-config.yml" with mock.patch(to_mock) as mock_command: with mock.patch("sys.argv", args):