Skip to content

Commit

Permalink
feat: simplify configuration merging
Browse files Browse the repository at this point in the history
  • Loading branch information
Zane Clark authored and Zane Clark committed Aug 16, 2024
1 parent ee104d3 commit aa353b0
Show file tree
Hide file tree
Showing 12 changed files with 311 additions and 363 deletions.
54 changes: 2 additions & 52 deletions schemachange/config/BaseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class BaseConfig(ABC):

subcommand: Literal["deploy", "render"]
config_version: int | None = None
config_folder: Path = Path(".")
config_file_path: Path | None = None
root_folder: Path | None = Path(".")
modules_folder: Path | None = None
Expand All @@ -36,7 +35,7 @@ class BaseConfig(ABC):
def factory(
cls,
subcommand: Literal["deploy", "render"],
config_folder: Path | str | None = Path("."),
config_file_path: Path,
root_folder: Path | str | None = Path("."),
modules_folder: Path | str | None = None,
config_vars: str | dict | None = None,
Expand All @@ -50,16 +49,9 @@ def factory(
"config_vars did not parse correctly, please check its configuration"
) from e

if "secrets" in kwargs:
secrets.update(kwargs.pop("secrets"))

if config_folder is None:
config_folder = "."

return cls(
subcommand=subcommand,
config_folder=validate_directory(path=config_folder),
config_file_path=Path(config_folder) / cls.default_config_file_name,
config_file_path=config_file_path,
root_folder=validate_directory(path=root_folder),
modules_folder=validate_directory(path=modules_folder),
config_vars=validate_config_vars(config_vars=config_vars),
Expand All @@ -68,48 +60,6 @@ def factory(
**kwargs,
)

def asdict_exclude_defaults(self) -> dict:
retain = {}
for field in dataclasses.fields(self):
val = getattr(self, field.name)
if (
val is None
and field.default == dataclasses.MISSING
or val == field.default
):
continue
if (
field.default_factory != dataclasses.MISSING
and val == field.default_factory()
):
continue
if field.name in [
"_create_change_history_table",
"_autocommit",
"_dry_run",
"_raise_exception_on_ignored_versioned_script",
]:
retain[field.name.lstrip("_")] = val
else:
retain[field.name] = val

return retain

def merge_exclude_defaults(self: T, other: T) -> T:
self_kwargs = self.asdict_exclude_defaults()
self_kwargs.pop("config_file_path")

other_kwargs = other.asdict_exclude_defaults()
other_kwargs.pop("config_file_path")
if "change_history_table" in other_kwargs:
other_kwargs["change_history_table"] = other.change_history_table

if "secrets" in other_kwargs and len(other_kwargs["secrets"]) == 0:
other_kwargs.pop("secrets")

kwargs = {**self_kwargs, **other_kwargs}
return self.__class__.factory(**kwargs)

def log_details(self):
logger.info("Using root folder", root_folder=str(self.root_folder))
if self.modules_folder:
Expand Down
37 changes: 21 additions & 16 deletions schemachange/config/ChangeHistoryTable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import dataclasses
from typing import ClassVar

from schemachange.config.utils import get_snowflake_identifier_string


@dataclasses.dataclass(frozen=True)
class ChangeHistoryTable:
Expand All @@ -18,29 +20,32 @@ def fully_qualified(self) -> str:

@classmethod
def from_str(cls, table_str: str):
details: dict[str, str] = {
"database_name": cls._default_database_name,
"schema_name": cls._default_schema_name,
"table_name": cls._default_table_name,
}
database_name = cls._default_database_name
schema_name = cls._default_schema_name
table_name = cls._default_table_name

if table_str is not None:
table_name_parts = table_str.strip().split(".")
if len(table_name_parts) == 1:
details["table_name"] = table_name_parts[0]
table_name = table_name_parts[0]
elif len(table_name_parts) == 2:
details["table_name"] = table_name_parts[1]
details["schema_name"] = table_name_parts[0]
table_name = table_name_parts[1]
schema_name = table_name_parts[0]
elif len(table_name_parts) == 3:
details["table_name"] = table_name_parts[2]
details["schema_name"] = table_name_parts[1]
details["database_name"] = table_name_parts[0]
table_name = table_name_parts[2]
schema_name = table_name_parts[1]
database_name = table_name_parts[0]
else:
raise ValueError(f"Invalid change history table name: {table_str}")

# if the name element does not include '"' raise to upper case on return
return cls(
**{
attr_name: attr_val if '"' in attr_val else attr_val.upper()
for (attr_name, attr_val) in details.items()
}
table_name=get_snowflake_identifier_string(
input_value=table_name, input_type="table_name"
),
schema_name=get_snowflake_identifier_string(
input_value=schema_name, input_type="schema_name"
),
database_name=get_snowflake_identifier_string(
input_value=database_name, input_type="database_name"
),
)
49 changes: 13 additions & 36 deletions schemachange/config/DeployConfig.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
from pathlib import Path
from typing import Literal

from schemachange.config.BaseConfig import BaseConfig
Expand All @@ -21,56 +22,35 @@ class DeployConfig(BaseConfig):
change_history_table: ChangeHistoryTable | None = dataclasses.field(
default_factory=ChangeHistoryTable
)
_create_change_history_table: bool | None
_autocommit: bool | None
_dry_run: bool | None
create_change_history_table: bool = False
autocommit: bool = False
dry_run: bool = False
query_tag: str | None = None
oauth_config: dict | None = None
version_number_validation_regex: str | None = None
_raise_exception_on_ignored_versioned_script: bool | None

@property
def create_change_history_table(self) -> bool:
if self._create_change_history_table is None:
return False
return self._create_change_history_table

@property
def autocommit(self) -> bool:
if self._autocommit is None:
return False
return self._autocommit

@property
def dry_run(self) -> bool:
if self._dry_run is None:
return False
return self._dry_run

@property
def raise_exception_on_ignored_versioned_script(self) -> bool:
if self._raise_exception_on_ignored_versioned_script is None:
return False
return self._raise_exception_on_ignored_versioned_script
raise_exception_on_ignored_versioned_script: bool = False

@classmethod
def factory(
cls,
config_file_path: Path,
snowflake_role: str | None = None,
snowflake_warehouse: str | None = None,
snowflake_database: str | None = None,
snowflake_schema: str | None = None,
create_change_history_table: bool | None = None,
autocommit: bool | None = None,
dry_run: bool | None = None,
raise_exception_on_ignored_versioned_script: bool | None = None,
change_history_table: str | None = None,
**kwargs,
):
if "subcommand" in kwargs:
kwargs.pop("subcommand")

change_history_table = ChangeHistoryTable.from_str(
table_str=change_history_table
)

return super().factory(
subcommand="deploy",
config_file_path=config_file_path,
snowflake_role=get_snowflake_identifier_string(
snowflake_role, "snowflake_role"
),
Expand All @@ -83,10 +63,7 @@ def factory(
snowflake_schema=get_snowflake_identifier_string(
snowflake_schema, "snowflake_schema"
),
_create_change_history_table=create_change_history_table,
_autocommit=autocommit,
_dry_run=dry_run,
_raise_exception_on_ignored_versioned_script=raise_exception_on_ignored_versioned_script,
change_history_table=change_history_table,
**kwargs,
)

Expand Down
90 changes: 37 additions & 53 deletions schemachange/config/get_merged_config.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,15 @@
import logging
import sys
from argparse import Namespace
from enum import Enum
from pathlib import Path
from typing import Union, Optional

from schemachange.config.ChangeHistoryTable import ChangeHistoryTable
from schemachange.config.DeployConfig import DeployConfig
from schemachange.config.RenderConfig import RenderConfig
from schemachange.config.parse_cli_args import parse_cli_args
from schemachange.config.utils import load_yaml_config
from schemachange.config.utils import load_yaml_config, validate_directory


def config_factory(args: Union[Namespace, dict]) -> Union[DeployConfig, RenderConfig]:
if isinstance(args, Namespace):
subcommand = args.subcommand
kwargs = args.__dict__
else:
subcommand = args.get("subcommand")
kwargs = args

kwargs = {k: v for k, v in kwargs.items() if v is not None}

if "change_history_table" in kwargs and kwargs["change_history_table"] is not None:
kwargs["change_history_table"] = ChangeHistoryTable.from_str(
table_str=kwargs["change_history_table"]
)

if "vars" in kwargs:
kwargs["config_vars"] = kwargs.pop("vars")

if subcommand == "deploy":
return DeployConfig.factory(**kwargs)
elif subcommand == "render":
return RenderConfig.factory(**kwargs)
else:
raise Exception(f"unhandled subcommand: {subcommand}")


def get_yaml_config(
subcommand: str,
config_file_path: Optional[Path],
script_path: Optional[Path] = None,
) -> Union[DeployConfig, RenderConfig]:
def get_yaml_config_kwargs(config_file_path: Optional[Path]) -> dict:
# TODO: I think the configuration key for oauthconfig should be oauth-config.
# This looks like a bug in the current state of the repo to me

Expand All @@ -51,32 +18,49 @@ def get_yaml_config(
k.replace("-", "_").replace("oauthconfig", "oauth_config"): v
for (k, v) in load_yaml_config(config_file_path).items()
}
kwargs["subcommand"] = subcommand
if script_path:
kwargs["script_path"] = script_path

if "verbose" in kwargs:
kwargs["log_level"] = logging.DEBUG if kwargs["verbose"] else logging.ERROR
kwargs["log_level"] = logging.DEBUG
kwargs.pop("verbose")

return config_factory(kwargs)
if "vars" in kwargs:
kwargs["config_vars"] = kwargs.pop("vars")

return kwargs


def get_merged_config() -> Union[DeployConfig, RenderConfig]:
args = parse_cli_args(sys.argv[1:])
if (
"log_level" in args
and args.log_level is not None
and isinstance(args.log_level, Enum)
):
args.log_level = args.log_level.value
cli_kwargs = parse_cli_args(sys.argv[1:])
cli_config_vars = cli_kwargs.pop("config_vars", None)
if cli_config_vars is None:
cli_config_vars = {}

cli_config = config_factory(args=args)
yaml_config = get_yaml_config(
subcommand=cli_config.subcommand,
config_file_path=cli_config.config_file_path,
script_path=getattr(cli_config, "script_path", None),
config_folder = validate_directory(path=cli_kwargs.pop("config_folder", "."))
config_file_path = Path(config_folder) / "schemachange-config.yml"

yaml_kwargs = get_yaml_config_kwargs(
config_file_path=config_file_path,
)
yaml_config_vars = yaml_kwargs.pop("config_vars", None)
if yaml_config_vars is None:
yaml_config_vars = {}

config_vars = {
**yaml_config_vars,
**cli_config_vars,
}

# override the YAML config with the CLI configuration
return yaml_config.merge_exclude_defaults(other=cli_config)
kwargs = {
"config_file_path": config_file_path,
"config_vars": config_vars,
**{k: v for k, v in yaml_kwargs.items() if v is not None},
**{k: v for k, v in cli_kwargs.items() if v is not None},
}

if cli_kwargs["subcommand"] == "deploy":
return DeployConfig.factory(**kwargs)
elif cli_kwargs["subcommand"] == "render":
return RenderConfig.factory(**kwargs)
else:
raise Exception(f"unhandled subcommand: {cli_kwargs['subcommand'] }")
12 changes: 10 additions & 2 deletions schemachange/config/parse_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def deprecate_verbose(
del parsed_args.verbose


def parse_cli_args(args) -> argparse.Namespace:
def parse_cli_args(args) -> dict:
parser = argparse.ArgumentParser(
prog="schemachange",
description="Apply schema changes to a Snowflake account. Full readme at "
Expand Down Expand Up @@ -252,4 +252,12 @@ def parse_cli_args(args) -> argparse.Namespace:

deprecate_verbose(args=args, verbose=verbose, parsed_args=parsed_args)

return parsed_args
parsed_kwargs = parsed_args.__dict__

if "log_level" in parsed_kwargs and isinstance(parsed_kwargs["log_level"], Enum):
parsed_kwargs["log_level"] = parsed_kwargs["log_level"].value

if "vars" in parsed_kwargs:
parsed_kwargs["config_vars"] = parsed_kwargs.pop("vars")

return parsed_kwargs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ snowflake-schema: 'snowflake-schema-from-yaml'
change-history-table: 'change-history-table-from-yaml'
vars:
var1: 'from_yaml'
var2: 'also_from_yaml'
create-change-history-table: false
autocommit: false
verbose: false
Expand Down
Loading

0 comments on commit aa353b0

Please sign in to comment.