diff --git a/.changes/header.tpl.md b/.changes/header.tpl.md index df8faa7b..8bee7b4a 100644 --- a/.changes/header.tpl.md +++ b/.changes/header.tpl.md @@ -1,6 +1,6 @@ -# Changelog -All notable changes to this project will be documented in this file. +# dbt Core Changelog -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), -and is generated by [Changie](https://github.com/miniscruff/changie). +- This file provides a full account of all changes to `dbt-common` +- Changes are listed under the (pre)release in which they first appear. Subsequent releases include changes from previous releases. +- "Breaking changes" listed under a version may require action from end users or external maintainers when upgrading to that version. +- Do not edit this file directly. This file is auto-generated using [changie](https://github.com/miniscruff/changie). For details on how to document a change, see [the contributing guide](https://github.com/dbt-labs/dbt-common/blob/main/CONTRIBUTING.md#adding-changelog-entry) diff --git a/.changie.yaml b/.changie.yaml index b1bd5eeb..dd452733 100644 --- a/.changie.yaml +++ b/.changie.yaml @@ -1,12 +1,21 @@ changesDir: .changes unreleasedDir: unreleased headerPath: header.tpl.md +versionHeaderPath: "" changelogPath: CHANGELOG.md versionExt: md -envPrefix: CHANGIE_ -versionFormat: '## dbt-oss-template {{.Version}} - {{.Time.Format "January 02, 2006"}}' +envPrefix: "CHANGIE_" +versionFormat: '## dbt-common {{.Version}} - {{.Time.Format "January 02, 2006"}}' kindFormat: '### {{.Kind}}' -changeFormat: '* {{.Body}}' +changeFormat: |- + {{- $IssueList := list }} + {{- $changes := splitList " " $.Custom.Issue }} + {{- range $issueNbr := $changes }} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-common/issues/nbr)" | replace "nbr" $issueNbr }} + {{- $IssueList = append $IssueList $changeLink }} + {{- end -}} + - {{.Body}} ({{ range $index, $element := $IssueList }}{{if $index}}, {{end}}{{$element}}{{end}}) + kinds: - label: Breaking Changes - label: Features @@ -14,7 +23,44 @@ kinds: - label: Docs - label: Under the Hood - label: Dependencies + changeFormat: |- + {{- $PRList := list }} + {{- $changes := splitList " " $.Custom.PR }} + {{- range $pullrequest := $changes }} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-common/pull/nbr)" | replace "nbr" $pullrequest }} + {{- $PRList = append $PRList $changeLink }} + {{- end -}} + - {{.Body}} ({{ range $index, $element := $PRList }}{{if $index}}, {{end}}{{$element}}{{end}}) + skipGlobalChoices: true + additionalChoices: + - key: Author + label: GitHub Username(s) (separated by a single space if multiple) + type: string + minLength: 3 + - key: PR + label: GitHub Pull Request Number (separated by a single space if multiple) + type: string + minLength: 1 - label: Security + changeFormat: |- + {{- $PRList := list }} + {{- $changes := splitList " " $.Custom.PR }} + {{- range $pullrequest := $changes }} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-common/pull/nbr)" | replace "nbr" $pullrequest }} + {{- $PRList = append $PRList $changeLink }} + {{- end -}} + - {{.Body}} ({{ range $index, $element := $PRList }}{{if $index}}, {{end}}{{$element}}{{end}}) + skipGlobalChoices: true + additionalChoices: + - key: Author + label: GitHub Username(s) (separated by a single space if multiple) + type: string + minLength: 3 + - key: PR + label: GitHub Pull Request Number (separated by a single space if multiple) + type: string + minLength: 1 + newlines: afterChangelogHeader: 1 afterKind: 1 @@ -31,3 +77,56 @@ custom: label: GitHub Issue Number (separated by a single space if multiple) type: string minLength: 1 + +footerFormat: | + {{- $contributorDict := dict }} + {{- /* ensure all names in this list are all lowercase for later matching purposes */}} + {{- $core_team := splitList " " .Env.CORE_TEAM }} + {{- /* ensure we always skip snyk and dependabot in addition to the core team */}} + {{- $maintainers := list "dependabot[bot]" "snyk-bot"}} + {{- range $team_member := $core_team }} + {{- $team_member_lower := lower $team_member }} + {{- $maintainers = append $maintainers $team_member_lower }} + {{- end }} + {{- range $change := .Changes }} + {{- $authorList := splitList " " $change.Custom.Author }} + {{- /* loop through all authors for a single changelog */}} + {{- range $author := $authorList }} + {{- $authorLower := lower $author }} + {{- /* we only want to include non-core team contributors */}} + {{- if not (has $authorLower $maintainers)}} + {{- $changeList := splitList " " $change.Custom.Author }} + {{- $IssueList := list }} + {{- $changeLink := $change.Kind }} + {{- if or (eq $change.Kind "Dependencies") (eq $change.Kind "Security") }} + {{- $changes := splitList " " $change.Custom.PR }} + {{- range $issueNbr := $changes }} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-common/pull/nbr)" | replace "nbr" $issueNbr }} + {{- $IssueList = append $IssueList $changeLink }} + {{- end -}} + {{- else }} + {{- $changes := splitList " " $change.Custom.Issue }} + {{- range $issueNbr := $changes }} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-common/issues/nbr)" | replace "nbr" $issueNbr }} + {{- $IssueList = append $IssueList $changeLink }} + {{- end -}} + {{- end }} + {{- /* check if this contributor has other changes associated with them already */}} + {{- if hasKey $contributorDict $author }} + {{- $contributionList := get $contributorDict $author }} + {{- $contributionList = concat $contributionList $IssueList }} + {{- $contributorDict := set $contributorDict $author $contributionList }} + {{- else }} + {{- $contributionList := $IssueList }} + {{- $contributorDict := set $contributorDict $author $contributionList }} + {{- end }} + {{- end}} + {{- end}} + {{- end }} + {{- /* no indentation here for formatting so the final markdown doesn't have unneeded indentations */}} + {{- if $contributorDict}} + ### Contributors + {{- range $k,$v := $contributorDict }} + - [@{{$k}}](https://github.com/{{$k}}) ({{ range $index, $element := $v }}{{if $index}}, {{end}}{{$element}}{{end}}) + {{- end }} + {{- end }} diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..26e20a5d --- /dev/null +++ b/.flake8 @@ -0,0 +1,14 @@ +[flake8] +select = + E + W + F +ignore = + W503 # makes Flake8 work like black + W504 + E203 # makes Flake8 work like black + E741 + E501 # long line checking is done in black +exclude = test/ +per-file-ignores = + */__init__.py: F401 diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 00000000..6cf657a2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,65 @@ +name: 🐞 Bug +description: Report a bug or an issue you've found with dbt-common +title: "[Bug]
(['"]))""") + + +class TagIterator: + def __init__(self, data): + self.data = data + self.blocks = [] + self._parenthesis_stack = [] + self.pos = 0 + + def linepos(self, end=None) -> str: + """Given an absolute position in the input data, return a pair of + line number + relative position to the start of the line. + """ + end_val: int = self.pos if end is None else end + data = self.data[:end_val] + # if not found, rfind returns -1, and -1+1=0, which is perfect! + last_line_start = data.rfind("\n") + 1 + # it's easy to forget this, but line numbers are 1-indexed + line_number = data.count("\n") + 1 + return f"{line_number}:{end_val - last_line_start}" + + def advance(self, new_position): + self.pos = new_position + + def rewind(self, amount=1): + self.pos -= amount + + def _search(self, pattern): + return pattern.search(self.data, self.pos) + + def _match(self, pattern): + return pattern.match(self.data, self.pos) + + def _first_match(self, *patterns, **kwargs): + matches = [] + for pattern in patterns: + # default to 'search', but sometimes we want to 'match'. + if kwargs.get("method", "search") == "search": + match = self._search(pattern) + else: + match = self._match(pattern) + if match: + matches.append(match) + if not matches: + return None + # if there are multiple matches, pick the least greedy match + # TODO: do I need to account for m.start(), or is this ok? + return min(matches, key=lambda m: m.end()) + + def _expect_match(self, expected_name, *patterns, **kwargs): + match = self._first_match(*patterns, **kwargs) + if match is None: + raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :]) + return match + + def handle_expr(self, match): + """Handle an expression. At this point we're at a string like: + {{ 1 + 2 }} + ^ right here + + And the match contains "{{ " + + We expect to find a `}}`, but we might find one in a string before + that. Imagine the case of `{{ 2 * "}}" }}`... + + You're not allowed to have blocks or comments inside an expr so it is + pretty straightforward, I hope: only strings can get in the way. + """ + self.advance(match.end()) + while True: + match = self._expect_match("}}", EXPR_END_PATTERN, QUOTE_START_PATTERN) + if match.groupdict().get("expr_end") is not None: + break + else: + # it's a quote. we haven't advanced for this match yet, so + # just slurp up the whole string, no need to rewind. + match = self._expect_match("string", STRING_PATTERN) + self.advance(match.end()) + + self.advance(match.end()) + + def handle_comment(self, match): + self.advance(match.end()) + match = self._expect_match("#}", COMMENT_END_PATTERN) + self.advance(match.end()) + + def _expect_block_close(self): + """Search for the tag close marker. + To the right of the type name, there are a few possiblities: + - a name (handled by the regex's 'block_name') + - any number of: `=`, `(`, `)`, strings, etc (arguments) + - nothing + + followed eventually by a %} + + So the only characters we actually have to worry about in this context + are quote and `%}` - nothing else can hide the %} and be valid jinja. + """ + while True: + end_match = self._expect_match('tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN) + self.advance(end_match.end()) + if end_match.groupdict().get("tag_close") is not None: + return + # must be a string. Rewind to its start and advance past it. + self.rewind() + string_match = self._expect_match("string", STRING_PATTERN) + self.advance(string_match.end()) + + def handle_raw(self): + # raw blocks are super special, they are a single complete regex + match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN) + self.advance(match.end()) + return match.end() + + def handle_tag(self, match): + """The tag could be one of a few things: + + {% mytag %} + {% mytag x = y %} + {% mytag x = "y" %} + {% mytag x.y() %} + {% mytag foo("a", "b", c="d") %} + + But the key here is that it's always going to be `{% mytag`! + """ + groups = match.groupdict() + # always a value + block_type_name = groups["block_type_name"] + # might be None + block_name = groups.get("block_name") + start_pos = self.pos + if block_type_name == "raw": + match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN) + self.advance(match.end()) + else: + self.advance(match.end()) + self._expect_block_close() + return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos) + + def find_tags(self): + while True: + match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN) + if match is None: + break + + self.advance(match.start()) + # start = self.pos + + groups = match.groupdict() + comment_start = groups.get("comment_start") + expr_start = groups.get("expr_start") + block_type_name = groups.get("block_type_name") + + if comment_start is not None: + self.handle_comment(match) + elif expr_start is not None: + self.handle_expr(match) + elif block_type_name is not None: + yield self.handle_tag(match) + else: + raise DbtInternalError( + "Invalid regex match in next_block, expected block start, " "expr start, or comment start" + ) + + def __iter__(self): + return self.find_tags() + + +_CONTROL_FLOW_TAGS = { + "if": "endif", + "for": "endfor", +} + +_CONTROL_FLOW_END_TAGS = {v: k for k, v in _CONTROL_FLOW_TAGS.items()} + + +class BlockIterator: + def __init__(self, data): + self.tag_parser = TagIterator(data) + self.current = None + self.stack = [] + self.last_position = 0 + + @property + def current_end(self): + if self.current is None: + return 0 + else: + return self.current.end + + @property + def data(self): + return self.tag_parser.data + + def is_current_end(self, tag): + return ( + tag.block_type_name.startswith("end") + and self.current is not None + and tag.block_type_name[3:] == self.current.block_type_name + ) + + def find_blocks(self, allowed_blocks=None, collect_raw_data=True): + """Find all top-level blocks in the data.""" + if allowed_blocks is None: + allowed_blocks = {"snapshot", "macro", "materialization", "docs"} + + for tag in self.tag_parser.find_tags(): + if tag.block_type_name in _CONTROL_FLOW_TAGS: + self.stack.append(tag.block_type_name) + elif tag.block_type_name in _CONTROL_FLOW_END_TAGS: + found = None + if self.stack: + found = self.stack.pop() + else: + expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name] + raise UnexpectedControlFlowEndTagError(tag, expected, self.tag_parser) + expected = _CONTROL_FLOW_TAGS[found] + if expected != tag.block_type_name: + raise MissingControlFlowStartTagError(tag, expected, self.tag_parser) + + if tag.block_type_name in allowed_blocks: + if self.stack: + raise BlockDefinitionNotAtTopError(self.tag_parser, tag.start) + if self.current is not None: + raise NestedTagsError(outer=self.current, inner=tag) + if collect_raw_data: + raw_data = self.data[self.last_position : tag.start] + self.last_position = tag.start + if raw_data: + yield BlockData(raw_data) + self.current = tag + + elif self.is_current_end(tag): + self.last_position = tag.end + assert self.current is not None + yield BlockTag( + block_type_name=self.current.block_type_name, + block_name=self.current.block_name, + contents=self.data[self.current.end : tag.start], + full_block=self.data[self.current.start : tag.end], + ) + self.current = None + + if self.current: + linecount = self.data[: self.current.end].count("\n") + 1 + raise MissingCloseTagError(self.current.block_type_name, linecount) + + if collect_raw_data: + raw_data = self.data[self.last_position :] + if raw_data: + yield BlockData(raw_data) + + def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True): + return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)) diff --git a/dbt_common/clients/agate_helper.py b/dbt_common/clients/agate_helper.py new file mode 100644 index 00000000..4f937c2c --- /dev/null +++ b/dbt_common/clients/agate_helper.py @@ -0,0 +1,247 @@ +from codecs import BOM_UTF8 + +import agate +import datetime +import isodate +import json +from typing import Iterable, List, Dict, Union, Optional, Any + +from dbt_common.exceptions import DbtRuntimeError +from dbt_common.utils import ForgivingJSONEncoder + +BOM = BOM_UTF8.decode("utf-8") # '\ufeff' + + +class Integer(agate.data_types.DataType): + def cast(self, d): + # by default agate will cast none as a Number + # but we need to cast it as an Integer to preserve + # the type when merging and unioning tables + if type(d) == int or d is None: + return d + else: + raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d) + + def jsonify(self, d): + return d + + +class Number(agate.data_types.Number): + # undo the change in https://github.com/wireservice/agate/pull/733 + # i.e. do not cast True and False to numeric 1 and 0 + def cast(self, d): + if type(d) == bool: + raise agate.exceptions.CastError("Do not cast True to 1 or False to 0.") + else: + return super().cast(d) + + +class ISODateTime(agate.data_types.DateTime): + def cast(self, d): + # this is agate.data_types.DateTime.cast with the "clever" bits removed + # so we only handle ISO8601 stuff + if isinstance(d, datetime.datetime) or d is None: + return d + elif isinstance(d, datetime.date): + return datetime.datetime.combine(d, datetime.time(0, 0, 0)) + elif isinstance(d, str): + d = d.strip() + if d.lower() in self.null_values: + return None + try: + return isodate.parse_datetime(d) + except: # noqa + pass + + raise agate.exceptions.CastError('Can not parse value "%s" as datetime.' % d) + + +def build_type_tester( + text_columns: Iterable[str], string_null_values: Optional[Iterable[str]] = ("null", "") +) -> agate.TypeTester: + + types = [ + Integer(null_values=("null", "")), + Number(null_values=("null", "")), + agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"), + agate.data_types.DateTime(null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"), + ISODateTime(null_values=("null", "")), + agate.data_types.Boolean(true_values=("true",), false_values=("false",), null_values=("null", "")), + agate.data_types.Text(null_values=string_null_values), + ] + force = {k: agate.data_types.Text(null_values=string_null_values) for k in text_columns} + return agate.TypeTester(force=force, types=types) + + +DEFAULT_TYPE_TESTER = build_type_tester(()) + + +def table_from_rows( + rows: List[Any], + column_names: Iterable[str], + text_only_columns: Optional[Iterable[str]] = None, +) -> agate.Table: + if text_only_columns is None: + column_types = DEFAULT_TYPE_TESTER + else: + # If text_only_columns are present, prevent coercing empty string or + # literal 'null' strings to a None representation. + column_types = build_type_tester(text_only_columns, string_null_values=()) + + return agate.Table(rows, column_names, column_types=column_types) + + +def table_from_data(data, column_names: Iterable[str]) -> agate.Table: + "Convert a list of dictionaries into an Agate table" + + # The agate table is generated from a list of dicts, so the column order + # from `data` is not preserved. We can use `select` to reorder the columns + # + # If there is no data, create an empty table with the specified columns + + if len(data) == 0: + return agate.Table([], column_names=column_names) + else: + table = agate.Table.from_object(data, column_types=DEFAULT_TYPE_TESTER) + return table.select(column_names) + + +def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table: + """ + Convert a list of dictionaries into an Agate table. This method does not + coerce string values into more specific types (eg. '005' will not be + coerced to '5'). Additionally, this method does not coerce values to + None (eg. '' or 'null' will retain their string literal representations). + """ + + rows = [] + text_only_columns = set() + for _row in data: + row = [] + for col_name in column_names: + value = _row[col_name] + if isinstance(value, (dict, list, tuple)): + # Represent container types as json strings + value = json.dumps(value, cls=ForgivingJSONEncoder) + text_only_columns.add(col_name) + elif isinstance(value, str): + text_only_columns.add(col_name) + row.append(value) + + rows.append(row) + + return table_from_rows(rows=rows, column_names=column_names, text_only_columns=text_only_columns) + + +def empty_table(): + "Returns an empty Agate table. To be used in place of None" + + return agate.Table(rows=[]) + + +def as_matrix(table): + "Return an agate table as a matrix of data sans columns" + + return [r.values() for r in table.rows.values()] + + +def from_csv(abspath, text_columns, delimiter=","): + type_tester = build_type_tester(text_columns=text_columns) + with open(abspath, encoding="utf-8") as fp: + if fp.read(1) != BOM: + fp.seek(0) + return agate.Table.from_csv(fp, column_types=type_tester, delimiter=delimiter) + + +class _NullMarker: + pass + + +NullableAgateType = Union[agate.data_types.DataType, _NullMarker] + + +class ColumnTypeBuilder(Dict[str, NullableAgateType]): + def __init__(self) -> None: + super().__init__() + + def __setitem__(self, key, value): + if key not in self: + super().__setitem__(key, value) + return + + existing_type = self[key] + if isinstance(existing_type, _NullMarker): + # overwrite + super().__setitem__(key, value) + elif isinstance(value, _NullMarker): + # use the existing value + return + # when one table column is Number while another is Integer, force the column to Number on merge + elif isinstance(value, Integer) and isinstance(existing_type, agate.data_types.Number): + # use the existing value + return + elif isinstance(existing_type, Integer) and isinstance(value, agate.data_types.Number): + # overwrite + super().__setitem__(key, value) + elif not isinstance(value, type(existing_type)): + # actual type mismatch! + raise DbtRuntimeError( + f"Tables contain columns with the same names ({key}), " + f"but different types ({value} vs {existing_type})" + ) + + def finalize(self) -> Dict[str, agate.data_types.DataType]: + result: Dict[str, agate.data_types.DataType] = {} + for key, value in self.items(): + if isinstance(value, _NullMarker): + # agate would make it a Number but we'll make it Integer so that if this column + # gets merged with another Integer column, it won't get forced to a Number + result[key] = Integer() + else: + result[key] = value + return result + + +def _merged_column_types(tables: List[agate.Table]) -> Dict[str, agate.data_types.DataType]: + # this is a lot like agate.Table.merge, but with handling for all-null + # rows being "any type". + new_columns: ColumnTypeBuilder = ColumnTypeBuilder() + for table in tables: + for i in range(len(table.columns)): + column_name: str = table.column_names[i] + column_type: NullableAgateType = table.column_types[i] + # avoid over-sensitive type inference + if all(x is None for x in table.columns[column_name]): + column_type = _NullMarker() + new_columns[column_name] = column_type + + return new_columns.finalize() + + +def merge_tables(tables: List[agate.Table]) -> agate.Table: + """This is similar to agate.Table.merge, but it handles rows of all 'null' + values more gracefully during merges. + """ + new_columns = _merged_column_types(tables) + column_names = tuple(new_columns.keys()) + column_types = tuple(new_columns.values()) + + rows: List[agate.Row] = [] + for table in tables: + if table.column_names == column_names and table.column_types == column_types: + rows.extend(table.rows) + else: + for row in table.rows: + data = [row.get(name, None) for name in column_names] + rows.append(agate.Row(data, column_names)) + # _is_fork to tell agate that we already made things into `Row`s. + return agate.Table(rows, column_names, column_types, _is_fork=True) + + +def get_column_value_uncased(column_name: str, row: agate.Row) -> Any: + """Get the value of a column in this row, ignoring the casing of the column name.""" + for key, value in row.items(): + if key.casefold() == column_name.casefold(): + return value + + raise KeyError diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py new file mode 100644 index 00000000..1b6de92b --- /dev/null +++ b/dbt_common/clients/jinja.py @@ -0,0 +1,501 @@ +import codecs +import linecache +import os +import tempfile +from ast import literal_eval +from contextlib import contextmanager +from itertools import chain, islice +from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, Callable +from typing_extensions import Protocol + +import jinja2 +import jinja2.ext +import jinja2.nativetypes # type: ignore +import jinja2.nodes +import jinja2.parser +import jinja2.sandbox + +from dbt_common.utils import ( + get_dbt_macro_name, + get_docs_macro_name, + get_materialization_macro_name, + get_test_macro_name, +) +from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag + +from dbt_common.exceptions import ( + CompilationError, + DbtInternalError, + CaughtMacroErrorWithNodeError, + MaterializationArgError, + JinjaRenderingError, + UndefinedCompilationError, +) +from dbt_common.exceptions.macros import MacroReturn, UndefinedMacroError, CaughtMacroError + + +SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param") + +# Global which can be set by dependents of dbt-common (e.g. core via flag parsing) +MACRO_DEBUGGING = False + + +def _linecache_inject(source, write): + if write: + # this is the only reliable way to accomplish this. Obviously, it's + # really darn noisy and will fill your temporary directory + tmp_file = tempfile.NamedTemporaryFile( + prefix="dbt-macro-compiled-", + suffix=".py", + delete=False, + mode="w+", + encoding="utf-8", + ) + tmp_file.write(source) + filename = tmp_file.name + else: + # `codecs.encode` actually takes a `bytes` as the first argument if + # the second argument is 'hex' - mypy does not know this. + rnd = codecs.encode(os.urandom(12), "hex") # type: ignore + filename = rnd.decode("ascii") + + # put ourselves in the cache + cache_entry = (len(source), None, [line + "\n" for line in source.splitlines()], filename) + # linecache does in fact have an attribute `cache`, thanks + linecache.cache[filename] = cache_entry # type: ignore + return filename + + +class MacroFuzzParser(jinja2.parser.Parser): + def parse_macro(self): + node = jinja2.nodes.Macro(lineno=next(self.stream).lineno) + + # modified to fuzz macros defined in the same file. this way + # dbt can understand the stack of macros being called. + # - @cmcarthur + node.name = get_dbt_macro_name(self.parse_assign_target(name_only=True).name) + + self.parse_signature(node) + node.body = self.parse_statements(("name:endmacro",), drop_needle=True) + return node + + +class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment): + def _parse(self, source, name, filename): + return MacroFuzzParser(self, source, name, filename).parse() + + def _compile(self, source, filename): + """Override jinja's compilation to stash the rendered source inside + the python linecache for debugging when the appropriate environment + variable is set. + + If the value is 'write', also write the files to disk. + WARNING: This can write a ton of data if you aren't careful. + """ + if filename == "" and MACRO_DEBUGGING: + write = MACRO_DEBUGGING == "write" + filename = _linecache_inject(source, write) + + return super()._compile(source, filename) # type: ignore + + +class NativeSandboxEnvironment(MacroFuzzEnvironment): + code_generator_class = jinja2.nativetypes.NativeCodeGenerator + + +class TextMarker(str): + """A special native-env marker that indicates a value is text and is + not to be evaluated. Use this to prevent your numbery-strings from becoming + numbers! + """ + + +class NativeMarker(str): + """A special native-env marker that indicates the field should be passed to + literal_eval. + """ + + +class BoolMarker(NativeMarker): + pass + + +class NumberMarker(NativeMarker): + pass + + +def _is_number(value) -> bool: + return isinstance(value, (int, float)) and not isinstance(value, bool) + + +def quoted_native_concat(nodes): + """This is almost native_concat from the NativeTemplate, except in the + special case of a single argument that is a quoted string and returns a + string, the quotes are re-inserted. + """ + head = list(islice(nodes, 2)) + + if not head: + return "" + + if len(head) == 1: + raw = head[0] + if isinstance(raw, TextMarker): + return str(raw) + elif not isinstance(raw, NativeMarker): + # return non-strings as-is + return raw + else: + # multiple nodes become a string. + return "".join([str(v) for v in chain(head, nodes)]) + + try: + result = literal_eval(raw) + except (ValueError, SyntaxError, MemoryError): + result = raw + if isinstance(raw, BoolMarker) and not isinstance(result, bool): + raise JinjaRenderingError(f"Could not convert value '{raw!s}' into type 'bool'") + if isinstance(raw, NumberMarker) and not _is_number(result): + raise JinjaRenderingError(f"Could not convert value '{raw!s}' into type 'number'") + + return result + + +class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore + environment_class = NativeSandboxEnvironment # type: ignore + + def render(self, *args, **kwargs): + """Render the template to produce a native Python type. If the + result is a single node, its value is returned. Otherwise, the + nodes are concatenated as strings. If the result can be parsed + with :func:`ast.literal_eval`, the parsed value is returned. + Otherwise, the string is returned. + """ + vars = dict(*args, **kwargs) + + try: + return quoted_native_concat(self.root_render_func(self.new_context(vars))) + except Exception: + return self.environment.handle_exception() + + +NativeSandboxEnvironment.template_class = NativeSandboxTemplate # type: ignore + + +class TemplateCache: + def __init__(self) -> None: + self.file_cache: Dict[str, jinja2.Template] = {} + + def get_node_template(self, node) -> jinja2.Template: + key = node.macro_sql + + if key in self.file_cache: + return self.file_cache[key] + + template = get_template( + string=node.macro_sql, + ctx={}, + node=node, + ) + + self.file_cache[key] = template + return template + + def clear(self): + self.file_cache.clear() + + +template_cache = TemplateCache() + + +class BaseMacroGenerator: + def __init__(self, context: Optional[Dict[str, Any]] = None) -> None: + self.context: Optional[Dict[str, Any]] = context + + def get_template(self): + raise NotImplementedError("get_template not implemented!") + + def get_name(self) -> str: + raise NotImplementedError("get_name not implemented!") + + def get_macro(self): + name = self.get_name() + template = self.get_template() + # make the module. previously we set both vars and local, but that's + # redundant: They both end up in the same place + # make_module is in jinja2.environment. It returns a TemplateModule + module = template.make_module(vars=self.context, shared=False) + macro = module.__dict__[get_dbt_macro_name(name)] + module.__dict__.update(self.context) + return macro + + @contextmanager + def exception_handler(self) -> Iterator[None]: + try: + yield + except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e: + raise CaughtMacroError(e) + + def call_macro(self, *args, **kwargs): + # called from __call__ methods + if self.context is None: + raise DbtInternalError("Context is still None in call_macro!") + assert self.context is not None + + macro = self.get_macro() + + with self.exception_handler(): + try: + return macro(*args, **kwargs) + except MacroReturn as e: + return e.value + + +class MacroProtocol(Protocol): + name: str + macro_sql: str + + +class CallableMacroGenerator(BaseMacroGenerator): + def __init__( + self, + macro: MacroProtocol, + context: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(context) + self.macro = macro + + def get_template(self): + return template_cache.get_node_template(self.macro) + + def get_name(self) -> str: + return self.macro.name + + @contextmanager + def exception_handler(self) -> Iterator[None]: + try: + yield + except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e: + raise CaughtMacroErrorWithNodeError(exc=e, node=self.macro) + except CompilationError as e: + e.stack.append(self.macro) + raise e + + # this makes MacroGenerator objects callable like functions + def __call__(self, *args, **kwargs): + return self.call_macro(*args, **kwargs) + + +class MaterializationExtension(jinja2.ext.Extension): + tags = ["materialization"] + + def parse(self, parser): + node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno) + materialization_name = parser.parse_assign_target(name_only=True).name + + adapter_name = "default" + node.args = [] + node.defaults = [] + + while parser.stream.skip_if("comma"): + target = parser.parse_assign_target(name_only=True) + + if target.name == "default": + pass + + elif target.name == "adapter": + parser.stream.expect("assign") + value = parser.parse_expression() + adapter_name = value.value + + elif target.name == "supported_languages": + target.set_ctx("param") + node.args.append(target) + parser.stream.expect("assign") + languages = parser.parse_expression() + node.defaults.append(languages) + + else: + raise MaterializationArgError(materialization_name, target.name) + + if SUPPORTED_LANG_ARG not in node.args: + node.args.append(SUPPORTED_LANG_ARG) + node.defaults.append(jinja2.nodes.List([jinja2.nodes.Const("sql")])) + + node.name = get_materialization_macro_name(materialization_name, adapter_name) + + node.body = parser.parse_statements(("name:endmaterialization",), drop_needle=True) + + return node + + +class DocumentationExtension(jinja2.ext.Extension): + tags = ["docs"] + + def parse(self, parser): + node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno) + docs_name = parser.parse_assign_target(name_only=True).name + + node.args = [] + node.defaults = [] + node.name = get_docs_macro_name(docs_name) + node.body = parser.parse_statements(("name:enddocs",), drop_needle=True) + return node + + +class TestExtension(jinja2.ext.Extension): + tags = ["test"] + + def parse(self, parser): + node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno) + test_name = parser.parse_assign_target(name_only=True).name + + parser.parse_signature(node) + node.name = get_test_macro_name(test_name) + node.body = parser.parse_statements(("name:endtest",), drop_needle=True) + return node + + +def _is_dunder_name(name): + return name.startswith("__") and name.endswith("__") + + +def create_undefined(node=None): + class Undefined(jinja2.Undefined): + def __init__(self, hint=None, obj=None, name=None, exc=None): + super().__init__(hint=hint, name=name) + self.node = node + self.name = name + self.hint = hint + # jinja uses these for safety, so we have to override them. + # see https://github.com/pallets/jinja/blob/master/jinja2/sandbox.py#L332-L339 # noqa + self.unsafe_callable = False + self.alters_data = False + + def __getitem__(self, name): + # Propagate the undefined value if a caller accesses this as if it + # were a dictionary + return self + + def __getattr__(self, name): + if name == "name" or _is_dunder_name(name): + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name)) + + self.name = name + + return self.__class__(hint=self.hint, name=self.name) + + def __call__(self, *args, **kwargs): + return self + + def __reduce__(self): + raise UndefinedCompilationError(name=self.name, node=node) + + return Undefined + + +NATIVE_FILTERS: Dict[str, Callable[[Any], Any]] = { + "as_text": TextMarker, + "as_bool": BoolMarker, + "as_native": NativeMarker, + "as_number": NumberMarker, +} + + +TEXT_FILTERS: Dict[str, Callable[[Any], Any]] = { + "as_text": lambda x: x, + "as_bool": lambda x: x, + "as_native": lambda x: x, + "as_number": lambda x: x, +} + + +def get_environment( + node=None, + capture_macros: bool = False, + native: bool = False, +) -> jinja2.Environment: + args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = { + "extensions": ["jinja2.ext.do", "jinja2.ext.loopcontrols"] + } + + if capture_macros: + args["undefined"] = create_undefined(node) + + args["extensions"].append(MaterializationExtension) + args["extensions"].append(DocumentationExtension) + args["extensions"].append(TestExtension) + + env_cls: Type[jinja2.Environment] + text_filter: Type + if native: + env_cls = NativeSandboxEnvironment + filters = NATIVE_FILTERS + else: + env_cls = MacroFuzzEnvironment + filters = TEXT_FILTERS + + env = env_cls(**args) + env.filters.update(filters) + + return env + + +@contextmanager +def catch_jinja(node=None) -> Iterator[None]: + try: + yield + except jinja2.exceptions.TemplateSyntaxError as e: + e.translated = False + raise CompilationError(str(e), node) from e + except jinja2.exceptions.UndefinedError as e: + raise UndefinedMacroError(str(e), node) from e + except CompilationError as exc: + exc.add_node(node) + raise + + +def parse(string): + with catch_jinja(): + return get_environment().parse(str(string)) + + +def get_template( + string: str, + ctx: Dict[str, Any], + node=None, + capture_macros: bool = False, + native: bool = False, +): + with catch_jinja(node): + env = get_environment(node, capture_macros, native=native) + + template_source = str(string) + return env.from_string(template_source, globals=ctx) + + +def render_template(template, ctx: Dict[str, Any], node=None) -> str: + with catch_jinja(node): + return template.render(ctx) + + +def extract_toplevel_blocks( + data: str, + allowed_blocks: Optional[Set[str]] = None, + collect_raw_data: bool = True, +) -> List[Union[BlockData, BlockTag]]: + """Extract the top-level blocks with matching block types from a jinja + file, with some special handling for block nesting. + + :param data: The data to extract blocks from. + :param allowed_blocks: The names of the blocks to extract from the file. + They may not be nested within if/for blocks. If None, use the default + values. + :param collect_raw_data: If set, raw data between matched blocks will also + be part of the results, as `BlockData` objects. They have a + `block_type_name` field of `'__dbt_data'` and will never have a + `block_name`. + :return: A list of `BlockTag`s matching the allowed block types and (if + `collect_raw_data` is `True`) `BlockData` objects. + """ + return BlockIterator(data).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data) diff --git a/dbt_common/clients/system.py b/dbt_common/clients/system.py new file mode 100644 index 00000000..f637af68 --- /dev/null +++ b/dbt_common/clients/system.py @@ -0,0 +1,557 @@ +import dbt_common.exceptions.base +import errno +import fnmatch +import functools +import json +import os +import os.path +import re +import shutil +import stat +import subprocess +import sys +import tarfile +from pathlib import Path +from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Type, Union + +import dbt_common.exceptions +import requests +from dbt_common.events.functions import fire_event +from dbt_common.events.types import ( + SystemCouldNotWrite, + SystemExecutingCmd, + SystemStdOut, + SystemStdErr, + SystemReportReturnCode, +) +from dbt_common.exceptions import DbtInternalError +from dbt_common.utils.connection import connection_exception_retry +from pathspec import PathSpec # type: ignore + +if sys.platform == "win32": + from ctypes import WinDLL, c_bool +else: + WinDLL = None + c_bool = None + + +def find_matching( + root_path: str, + relative_paths_to_search: List[str], + file_pattern: str, + ignore_spec: Optional[PathSpec] = None, +) -> List[Dict[str, Any]]: + """ + Given an absolute `root_path`, a list of relative paths to that + absolute root path (`relative_paths_to_search`), and a `file_pattern` + like '*.sql', returns information about the files. For example: + + > find_matching('/root/path', ['models'], '*.sql') + + [ { 'absolute_path': '/root/path/models/model_one.sql', + 'relative_path': 'model_one.sql', + 'searched_path': 'models' }, + { 'absolute_path': '/root/path/models/subdirectory/model_two.sql', + 'relative_path': 'subdirectory/model_two.sql', + 'searched_path': 'models' } ] + """ + matching = [] + root_path = os.path.normpath(root_path) + regex = fnmatch.translate(file_pattern) + reobj = re.compile(regex, re.IGNORECASE) + + for relative_path_to_search in relative_paths_to_search: + # potential speedup for ignore_spec + # if ignore_spec.matches(relative_path_to_search): + # continue + absolute_path_to_search = os.path.join(root_path, relative_path_to_search) + walk_results = os.walk(absolute_path_to_search) + + for current_path, subdirectories, local_files in walk_results: + # potential speedup for ignore_spec + # relative_dir = os.path.relpath(current_path, root_path) + os.sep + # if ignore_spec.match(relative_dir): + # continue + for local_file in local_files: + absolute_path = os.path.join(current_path, local_file) + relative_path = os.path.relpath(absolute_path, absolute_path_to_search) + relative_path_to_root = os.path.join(relative_path_to_search, relative_path) + + modification_time = os.path.getmtime(absolute_path) + if reobj.match(local_file) and (not ignore_spec or not ignore_spec.match_file(relative_path_to_root)): + matching.append( + { + "searched_path": relative_path_to_search, + "absolute_path": absolute_path, + "relative_path": relative_path, + "modification_time": modification_time, + } + ) + + return matching + + +def load_file_contents(path: str, strip: bool = True) -> str: + path = convert_path(path) + with open(path, "rb") as handle: + to_return = handle.read().decode("utf-8") + + if strip: + to_return = to_return.strip() + + return to_return + + +@functools.singledispatch +def make_directory(path=None) -> None: + """ + Make a directory and any intermediate directories that don't already + exist. This function handles the case where two threads try to create + a directory at once. + """ + raise DbtInternalError(f"Can not create directory from {type(path)} ") + + +@make_directory.register +def _(path: str) -> None: + path = convert_path(path) + if not os.path.exists(path): + # concurrent writes that try to create the same dir can fail + try: + os.makedirs(path) + + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise e + + +@make_directory.register +def _(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + +def make_file(path: str, contents: str = "", overwrite: bool = False) -> bool: + """ + Make a file at `path` assuming that the directory it resides in already + exists. The file is saved with contents `contents` + """ + if overwrite or not os.path.exists(path): + path = convert_path(path) + with open(path, "w") as fh: + fh.write(contents) + return True + + return False + + +def make_symlink(source: str, link_path: str) -> None: + """ + Create a symlink at `link_path` referring to `source`. + """ + if not supports_symlinks(): + # TODO: why not import these at top? + raise dbt_common.exceptions.SymbolicLinkError() + + os.symlink(source, link_path) + + +def supports_symlinks() -> bool: + return getattr(os, "symlink", None) is not None + + +def write_file(path: str, contents: str = "") -> bool: + path = convert_path(path) + try: + make_directory(os.path.dirname(path)) + with open(path, "w", encoding="utf-8") as f: + f.write(str(contents)) + except Exception as exc: + # note that you can't just catch FileNotFound, because sometimes + # windows apparently raises something else. + # It's also not sufficient to look at the path length, because + # sometimes windows fails to write paths that are less than the length + # limit. So on windows, suppress all errors that happen from writing + # to disk. + if os.name == "nt": + # sometimes we get a winerror of 3 which means the path was + # definitely too long, but other times we don't and it means the + # path was just probably too long. This is probably based on the + # windows/python version. + if getattr(exc, "winerror", 0) == 3: + reason = "Path was too long" + else: + reason = "Path was possibly too long" + # all our hard work and the path was still too long. Log and + # continue. + fire_event(SystemCouldNotWrite(path=path, reason=reason, exc=str(exc))) + else: + raise + return True + + +def read_json(path: str) -> Dict[str, Any]: + return json.loads(load_file_contents(path)) + + +def write_json(path: str, data: Dict[str, Any]) -> bool: + return write_file(path, json.dumps(data, cls=dbt_common.utils.encoding.JSONEncoder)) + + +def _windows_rmdir_readonly(func: Callable[[str], Any], path: str, exc: Tuple[Any, OSError, Any]): + exception_val = exc[1] + if exception_val.errno == errno.EACCES: + os.chmod(path, stat.S_IWUSR) + func(path) + else: + raise + + +def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str: + """ + If path_to_resolve is a relative path, create an absolute path + with base_path as the base. + + If path_to_resolve is an absolute path or a user path (~), just + resolve it to an absolute path and return. + """ + return os.path.abspath(os.path.join(base_path, os.path.expanduser(path_to_resolve))) + + +def rmdir(path: str) -> None: + """ + Recursively deletes a directory. Includes an error handler to retry with + different permissions on Windows. Otherwise, removing directories (eg. + cloned via git) can cause rmtree to throw a PermissionError exception + """ + path = convert_path(path) + if sys.platform == "win32": + onerror = _windows_rmdir_readonly + else: + onerror = None + + shutil.rmtree(path, onerror=onerror) + + +def _win_prepare_path(path: str) -> str: + """Given a windows path, prepare it for use by making sure it is absolute + and normalized. + """ + path = os.path.normpath(path) + + # if a path starts with '\', splitdrive() on it will return '' for the + # drive, but the prefix requires a drive letter. So let's add the drive + # letter back in. + # Unless it starts with '\\'. In that case, the path is a UNC mount point + # and splitdrive will be fine. + if not path.startswith("\\\\") and path.startswith("\\"): + curdrive = os.path.splitdrive(os.getcwd())[0] + path = curdrive + path + + # now our path is either an absolute UNC path or relative to the current + # directory. If it's relative, we need to make it absolute or the prefix + # won't work. `ntpath.abspath` allegedly doesn't always play nice with long + # paths, so do this instead. + if not os.path.splitdrive(path)[0]: + path = os.path.join(os.getcwd(), path) + + return path + + +def _supports_long_paths() -> bool: + if sys.platform != "win32": + return True + # Eryk Sun says to use `WinDLL('ntdll')` instead of `windll.ntdll` because + # of pointer caching in a comment here: + # https://stackoverflow.com/a/35097999/11262881 + # I don't know exaclty what he means, but I am inclined to believe him as + # he's pretty active on Python windows bugs! + else: + try: + dll = WinDLL("ntdll") + except OSError: # I don't think this happens? you need ntdll to run python + return False + # not all windows versions have it at all + if not hasattr(dll, "RtlAreLongPathsEnabled"): + return False + # tell windows we want to get back a single unsigned byte (a bool). + dll.RtlAreLongPathsEnabled.restype = c_bool + return dll.RtlAreLongPathsEnabled() + + +def convert_path(path: str) -> str: + """Convert a path that dbt has, which might be >260 characters long, to one + that will be writable/readable on Windows. + + On other platforms, this is a no-op. + """ + # some parts of python seem to append '\*.*' to strings, better safe than + # sorry. + if len(path) < 250: + return path + if _supports_long_paths(): + return path + + prefix = "\\\\?\\" + # Nothing to do + if path.startswith(prefix): + return path + + path = _win_prepare_path(path) + + # add the prefix. The check is just in case os.getcwd() does something + # unexpected - I believe this if-state should always be True though! + if not path.startswith(prefix): + path = prefix + path + return path + + +def remove_file(path: str) -> None: + path = convert_path(path) + os.remove(path) + + +def path_exists(path: str) -> bool: + path = convert_path(path) + return os.path.lexists(path) + + +def path_is_symlink(path: str) -> bool: + path = convert_path(path) + return os.path.islink(path) + + +def open_dir_cmd() -> str: + # https://docs.python.org/2/library/sys.html#sys.platform + if sys.platform == "win32": + return "start" + + elif sys.platform == "darwin": + return "open" + + else: + return "xdg-open" + + +def _handle_posix_cwd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn: + if exc.errno == errno.ENOENT: + message = "Directory does not exist" + elif exc.errno == errno.EACCES: + message = "Current user cannot access directory, check permissions" + elif exc.errno == errno.ENOTDIR: + message = "Not a directory" + else: + message = "Unknown OSError: {} - cwd".format(str(exc)) + raise dbt_common.exceptions.WorkingDirectoryError(cwd, cmd, message) + + +def _handle_posix_cmd_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn: + if exc.errno == errno.ENOENT: + message = "Could not find command, ensure it is in the user's PATH" + elif exc.errno == errno.EACCES: + message = "User does not have permissions for this command" + else: + message = "Unknown OSError: {} - cmd".format(str(exc)) + raise dbt_common.exceptions.ExecutableError(cwd, cmd, message) + + +def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn: + """OSError handling for POSIX systems. + + Some things that could happen to trigger an OSError: + - cwd could not exist + - exc.errno == ENOENT + - exc.filename == cwd + - cwd could have permissions that prevent the current user moving to it + - exc.errno == EACCES + - exc.filename == cwd + - cwd could exist but not be a directory + - exc.errno == ENOTDIR + - exc.filename == cwd + - cmd[0] could not exist + - exc.errno == ENOENT + - exc.filename == None(?) + - cmd[0] could exist but have permissions that prevents the current + user from executing it (executable bit not set for the user) + - exc.errno == EACCES + - exc.filename == None(?) + """ + if getattr(exc, "filename", None) == cwd: + _handle_posix_cwd_error(exc, cwd, cmd) + else: + _handle_posix_cmd_error(exc, cwd, cmd) + + +def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn: + cls: Type[dbt_common.exceptions.DbtBaseException] = dbt_common.exceptions.base.CommandError + if exc.errno == errno.ENOENT: + message = ( + "Could not find command, ensure it is in the user's PATH " "and that the user has permissions to run it" + ) + cls = dbt_common.exceptions.ExecutableError + elif exc.errno == errno.ENOEXEC: + message = "Command was not executable, ensure it is valid" + cls = dbt_common.exceptions.ExecutableError + elif exc.errno == errno.ENOTDIR: + message = "Unable to cd: path does not exist, user does not have" " permissions, or not a directory" + cls = dbt_common.exceptions.WorkingDirectoryError + else: + message = 'Unknown error: {} (errno={}: "{}")'.format( + str(exc), exc.errno, errno.errorcode.get(exc.errno, "") + ) + raise cls(cwd, cmd, message) + + +def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn: + """Interpret an OSError exception and raise the appropriate dbt exception.""" + if len(cmd) == 0: + raise dbt_common.exceptions.base.CommandError(cwd, cmd) + + # all of these functions raise unconditionally + if os.name == "nt": + _handle_windows_error(exc, cwd, cmd) + else: + _handle_posix_error(exc, cwd, cmd) + + # this should not be reachable, raise _something_ at least! + raise dbt_common.exceptions.DbtInternalError("Unhandled exception in _interpret_oserror: {}".format(exc)) + + +def run_cmd(cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None) -> Tuple[bytes, bytes]: + fire_event(SystemExecutingCmd(cmd=cmd)) + if len(cmd) == 0: + raise dbt_common.exceptions.base.CommandError(cwd, cmd) + + # the env argument replaces the environment entirely, which has exciting + # consequences on Windows! Do an update instead. + full_env = env + if env is not None: + full_env = os.environ.copy() + full_env.update(env) + + try: + exe_pth = shutil.which(cmd[0]) + if exe_pth: + cmd = [os.path.abspath(exe_pth)] + list(cmd[1:]) + proc = subprocess.Popen(cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=full_env) + + out, err = proc.communicate() + except OSError as exc: + _interpret_oserror(exc, cwd, cmd) + + fire_event(SystemStdOut(bmsg=str(out))) + fire_event(SystemStdErr(bmsg=str(err))) + + if proc.returncode != 0: + fire_event(SystemReportReturnCode(returncode=proc.returncode)) + raise dbt_common.exceptions.CommandResultError(cwd, cmd, proc.returncode, out, err) + + return out, err + + +def download_with_retries(url: str, path: str, timeout: Optional[Union[float, tuple]] = None) -> None: + download_fn = functools.partial(download, url, path, timeout) + connection_exception_retry(download_fn, 5) + + +def download( + url: str, + path: str, + timeout: Optional[Union[float, Tuple[float, float], Tuple[float, None]]] = None, +) -> None: + path = convert_path(path) + connection_timeout = timeout or float(os.getenv("DBT_HTTP_TIMEOUT", 10)) + response = requests.get(url, timeout=connection_timeout) + with open(path, "wb") as handle: + for block in response.iter_content(1024 * 64): + handle.write(block) + + +def rename(from_path: str, to_path: str, force: bool = False) -> None: + from_path = convert_path(from_path) + to_path = convert_path(to_path) + is_symlink = path_is_symlink(to_path) + + if os.path.exists(to_path) and force: + if is_symlink: + remove_file(to_path) + else: + rmdir(to_path) + + shutil.move(from_path, to_path) + + +def untar_package(tar_path: str, dest_dir: str, rename_to: Optional[str] = None) -> None: + tar_path = convert_path(tar_path) + tar_dir_name = None + with tarfile.open(tar_path, "r:gz") as tarball: + tarball.extractall(dest_dir) + tar_dir_name = os.path.commonprefix(tarball.getnames()) + if rename_to: + downloaded_path = os.path.join(dest_dir, tar_dir_name) + desired_path = os.path.join(dest_dir, rename_to) + dbt_common.clients.system.rename(downloaded_path, desired_path, force=True) + + +def chmod_and_retry(func, path, exc_info): + """Define an error handler to pass to shutil.rmtree. + On Windows, when a file is marked read-only as git likes to do, rmtree will + fail. To handle that, on errors try to make the file writable. + We want to retry most operations here, but listdir is one that we know will + be useless. + """ + if func is os.listdir or os.name != "nt": + raise + os.chmod(path, stat.S_IREAD | stat.S_IWRITE) + # on error,this will raise. + func(path) + + +def _absnorm(path): + return os.path.normcase(os.path.abspath(path)) + + +def move(src, dst): + """A re-implementation of shutil.move that properly removes the source + directory on windows when it has read-only files in it and the move is + between two drives. + + This is almost identical to the real shutil.move, except it, uses our rmtree + and skips handling non-windows OSes since the existing one works ok there. + """ + src = convert_path(src) + dst = convert_path(dst) + if os.name != "nt": + return shutil.move(src, dst) + + if os.path.isdir(dst): + if _absnorm(src) == _absnorm(dst): + os.rename(src, dst) + return + + dst = os.path.join(dst, os.path.basename(src.rstrip("/\\"))) + if os.path.exists(dst): + raise EnvironmentError("Path '{}' already exists".format(dst)) + + try: + os.rename(src, dst) + except OSError: + # probably different drives + if os.path.isdir(src): + if _absnorm(dst + "\\").startswith(_absnorm(src + "\\")): + # dst is inside src + raise EnvironmentError("Cannot move a directory '{}' into itself '{}'".format(src, dst)) + shutil.copytree(src, dst, symlinks=True) + rmtree(src) + else: + shutil.copy2(src, dst) + os.unlink(src) + + +def rmtree(path): + """Recursively remove the path. On permissions errors on windows, try to remove + the read-only flag and try again. + """ + path = convert_path(path) + return shutil.rmtree(path, onerror=chmod_and_retry) diff --git a/dbt_common/constants.py b/dbt_common/constants.py new file mode 100644 index 00000000..ca591e05 --- /dev/null +++ b/dbt_common/constants.py @@ -0,0 +1 @@ +SECRET_ENV_PREFIX = "DBT_ENV_SECRET_" diff --git a/dbt_common/contracts/__init__.py b/dbt_common/contracts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt_common/contracts/config/__init__.py b/dbt_common/contracts/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt_common/contracts/config/base.py b/dbt_common/contracts/config/base.py new file mode 100644 index 00000000..899c7643 --- /dev/null +++ b/dbt_common/contracts/config/base.py @@ -0,0 +1,255 @@ +# necessary for annotating constructors +from __future__ import annotations + +from dataclasses import dataclass, Field + +from itertools import chain +from typing import Callable, Dict, Any, List, TypeVar, Type + +from dbt_common.contracts.config.metadata import Metadata +from dbt_common.exceptions import CompilationError, DbtInternalError +from dbt_common.contracts.config.properties import AdditionalPropertiesAllowed +from dbt_common.contracts.util import Replaceable + +T = TypeVar("T", bound="BaseConfig") + + +@dataclass +class BaseConfig(AdditionalPropertiesAllowed, Replaceable): + # enable syntax like: config['key'] + def __getitem__(self, key): + return self.get(key) + + # like doing 'get' on a dictionary + def get(self, key, default=None): + if hasattr(self, key): + return getattr(self, key) + elif key in self._extra: + return self._extra[key] + else: + return default + + # enable syntax like: config['key'] = value + def __setitem__(self, key, value): + if hasattr(self, key): + setattr(self, key, value) + else: + self._extra[key] = value + + def __delitem__(self, key): + if hasattr(self, key): + msg = ('Error, tried to delete config key "{}": Cannot delete ' "built-in keys").format(key) + raise CompilationError(msg) + else: + del self._extra[key] + + def _content_iterator(self, include_condition: Callable[[Field], bool]): + seen = set() + for fld, _ in self._get_fields(): + seen.add(fld.name) + if include_condition(fld): + yield fld.name + + for key in self._extra: + if key not in seen: + seen.add(key) + yield key + + def __iter__(self): + yield from self._content_iterator(include_condition=lambda f: True) + + def __len__(self): + return len(self._get_fields()) + len(self._extra) + + @staticmethod + def compare_key( + unrendered: Dict[str, Any], + other: Dict[str, Any], + key: str, + ) -> bool: + if key not in unrendered and key not in other: + return True + elif key not in unrendered and key in other: + return False + elif key in unrendered and key not in other: + return False + else: + return unrendered[key] == other[key] + + @classmethod + def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool: + """This is like __eq__, except it ignores some fields.""" + seen = set() + for fld, target_name in cls._get_fields(): + key = target_name + seen.add(key) + if CompareBehavior.should_include(fld): + if not cls.compare_key(unrendered, other, key): + return False + + for key in chain(unrendered, other): + if key not in seen: + seen.add(key) + if not cls.compare_key(unrendered, other, key): + return False + return True + + # This is used in 'add_config_call' to create the combined config_call_dict. + # 'meta' moved here from node + mergebehavior = { + "append": ["pre-hook", "pre_hook", "post-hook", "post_hook", "tags"], + "update": [ + "quoting", + "column_types", + "meta", + "docs", + "contract", + ], + "dict_key_append": ["grants"], + } + + @classmethod + def _merge_dicts(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]: + """Find all the items in data that match a target_field on this class, + and merge them with the data found in `src` for target_field, using the + field's specified merge behavior. Matching items will be removed from + `data` (but _not_ `src`!). + + Returns a dict with the merge results. + + That means this method mutates its input! Any remaining values in data + were not merged. + """ + result = {} + + for fld, target_field in cls._get_fields(): + if target_field not in data: + continue + + data_attr = data.pop(target_field) + if target_field not in src: + result[target_field] = data_attr + continue + + merge_behavior = MergeBehavior.from_field(fld) + self_attr = src[target_field] + + result[target_field] = _merge_field_value( + merge_behavior=merge_behavior, + self_value=self_attr, + other_value=data_attr, + ) + return result + + def update_from(self: T, data: Dict[str, Any], config_cls: Type[BaseConfig], validate: bool = True) -> T: + """Given a dict of keys, update the current config from them, validate + it, and return a new config with the updated values + """ + dct = self.to_dict(omit_none=False) + + self_merged = self._merge_dicts(dct, data) + dct.update(self_merged) + + adapter_merged = config_cls._merge_dicts(dct, data) + dct.update(adapter_merged) + + # any remaining fields must be "clobber" + dct.update(data) + + # any validation failures must have come from the update + if validate: + self.validate(dct) + return self.from_dict(dct) + + def finalize_and_validate(self: T) -> T: + dct = self.to_dict(omit_none=False) + self.validate(dct) + return self.from_dict(dct) + + +class MergeBehavior(Metadata): + Append = 1 + Update = 2 + Clobber = 3 + DictKeyAppend = 4 + + @classmethod + def default_field(cls) -> "MergeBehavior": + return cls.Clobber + + @classmethod + def metadata_key(cls) -> str: + return "merge" + + +class CompareBehavior(Metadata): + Include = 1 + Exclude = 2 + + @classmethod + def default_field(cls) -> "CompareBehavior": + return cls.Include + + @classmethod + def metadata_key(cls) -> str: + return "compare" + + @classmethod + def should_include(cls, fld: Field) -> bool: + return cls.from_field(fld) == cls.Include + + +def _listify(value: Any) -> List: + if isinstance(value, list): + return value[:] + else: + return [value] + + +# There are two versions of this code. The one here is for config +# objects, the one in _add_config_call in core context_config.py is for +# config_call_dict dictionaries. +def _merge_field_value( + merge_behavior: MergeBehavior, + self_value: Any, + other_value: Any, +): + if merge_behavior == MergeBehavior.Clobber: + return other_value + elif merge_behavior == MergeBehavior.Append: + return _listify(self_value) + _listify(other_value) + elif merge_behavior == MergeBehavior.Update: + if not isinstance(self_value, dict): + raise DbtInternalError(f"expected dict, got {self_value}") + if not isinstance(other_value, dict): + raise DbtInternalError(f"expected dict, got {other_value}") + value = self_value.copy() + value.update(other_value) + return value + elif merge_behavior == MergeBehavior.DictKeyAppend: + if not isinstance(self_value, dict): + raise DbtInternalError(f"expected dict, got {self_value}") + if not isinstance(other_value, dict): + raise DbtInternalError(f"expected dict, got {other_value}") + new_dict = {} + for key in self_value.keys(): + new_dict[key] = _listify(self_value[key]) + for key in other_value.keys(): + extend = False + new_key = key + # This might start with a +, to indicate we should extend the list + # instead of just clobbering it + if new_key.startswith("+"): + new_key = key.lstrip("+") + extend = True + if new_key in new_dict and extend: + # extend the list + value = other_value[key] + new_dict[new_key].extend(_listify(value)) + else: + # clobber the list + new_dict[new_key] = _listify(other_value[key]) + return new_dict + + else: + raise DbtInternalError(f"Got an invalid merge_behavior: {merge_behavior}") diff --git a/dbt_common/contracts/config/materialization.py b/dbt_common/contracts/config/materialization.py new file mode 100644 index 00000000..5f7f536b --- /dev/null +++ b/dbt_common/contracts/config/materialization.py @@ -0,0 +1,11 @@ +from dbt_common.dataclass_schema import StrEnum + + +class OnConfigurationChangeOption(StrEnum): + Apply = "apply" + Continue = "continue" + Fail = "fail" + + @classmethod + def default(cls) -> "OnConfigurationChangeOption": + return cls.Apply diff --git a/dbt_common/contracts/config/metadata.py b/dbt_common/contracts/config/metadata.py new file mode 100644 index 00000000..83f3457e --- /dev/null +++ b/dbt_common/contracts/config/metadata.py @@ -0,0 +1,69 @@ +from dataclasses import Field +from enum import Enum +from typing import TypeVar, Type, Optional, Dict, Any + +from dbt_common.exceptions import DbtInternalError + +M = TypeVar("M", bound="Metadata") + + +class Metadata(Enum): + @classmethod + def from_field(cls: Type[M], fld: Field) -> M: + default = cls.default_field() + key = cls.metadata_key() + + return _get_meta_value(cls, fld, key, default) + + def meta(self, existing: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + key = self.metadata_key() + return _set_meta_value(self, key, existing) + + @classmethod + def default_field(cls) -> "Metadata": + raise NotImplementedError("Not implemented") + + @classmethod + def metadata_key(cls) -> str: + raise NotImplementedError("Not implemented") + + +def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M: + # a metadata field might exist. If it does, it might have a matching key. + # If it has both, make sure the value is valid and return it. If it + # doesn't, return the default. + if fld.metadata: + value = fld.metadata.get(key, default) + else: + value = default + + try: + return cls(value) + except ValueError as exc: + raise DbtInternalError(f"Invalid {cls} value: {value}") from exc + + +def _set_meta_value(obj: M, key: str, existing: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + if existing is None: + result = {} + else: + result = existing.copy() + result.update({key: obj}) + return result + + +class ShowBehavior(Metadata): + Show = 1 + Hide = 2 + + @classmethod + def default_field(cls) -> "ShowBehavior": + return cls.Show + + @classmethod + def metadata_key(cls) -> str: + return "show_hide" + + @classmethod + def should_show(cls, fld: Field) -> bool: + return cls.from_field(fld) == cls.Show diff --git a/dbt_common/contracts/config/properties.py b/dbt_common/contracts/config/properties.py new file mode 100644 index 00000000..ce669623 --- /dev/null +++ b/dbt_common/contracts/config/properties.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass, field +from typing import Dict, Any + +from dbt_common.dataclass_schema import ExtensibleDbtClassMixin + + +class AdditionalPropertiesMixin: + """Make this class an extensible property. + + The underlying class definition must include a type definition for a field + named '_extra' that is of type `Dict[str, Any]`. + """ + + ADDITIONAL_PROPERTIES = True + + # This takes attributes in the dictionary that are + # not in the class definitions and puts them in an + # _extra dict in the class + @classmethod + def __pre_deserialize__(cls, data): + # dir() did not work because fields with + # metadata settings are not found + # The original version of this would create the + # object first and then update extra with the + # extra keys, but that won't work here, so + # we're copying the dict so we don't insert the + # _extra in the original data. This also requires + # that Mashumaro actually build the '_extra' field + cls_keys = cls._get_field_names() + new_dict = {} + for key, value in data.items(): + # The pre-hook/post-hook mess hasn't been converted yet... That happens in + # the super().__pre_deserialize__ below... + if key not in cls_keys and key not in ["_extra", "pre-hook", "post-hook"]: + if "_extra" not in new_dict: + new_dict["_extra"] = {} + new_dict["_extra"][key] = value + else: + new_dict[key] = value + data = new_dict + data = super().__pre_deserialize__(data) + return data + + def __post_serialize__(self, dct): + data = super().__post_serialize__(dct) + data.update(self.extra) + if "_extra" in data: + del data["_extra"] + return data + + def replace(self, **kwargs): + dct = self.to_dict(omit_none=False) + dct.update(kwargs) + return self.from_dict(dct) + + @property + def extra(self): + return self._extra + + +@dataclass +class AdditionalPropertiesAllowed(AdditionalPropertiesMixin, ExtensibleDbtClassMixin): + _extra: Dict[str, Any] = field(default_factory=dict) diff --git a/dbt_common/contracts/constraints.py b/dbt_common/contracts/constraints.py new file mode 100644 index 00000000..ce3d1513 --- /dev/null +++ b/dbt_common/contracts/constraints.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional, List + +from dbt_common.dataclass_schema import dbtClassMixin + + +class ConstraintType(str, Enum): + check = "check" + not_null = "not_null" + unique = "unique" + primary_key = "primary_key" + foreign_key = "foreign_key" + custom = "custom" + + @classmethod + def is_valid(cls, item) -> bool: + try: + cls(item) + except ValueError: + return False + return True + + +@dataclass +class ColumnLevelConstraint(dbtClassMixin): + type: ConstraintType + name: Optional[str] = None + # expression is a user-provided field that will depend on the constraint type. + # It could be a predicate (check type), or a sequence sql keywords (e.g. unique type), + # so the vague naming of 'expression' is intended to capture this range. + expression: Optional[str] = None + warn_unenforced: bool = True # Warn if constraint cannot be enforced by platform but will be in DDL + warn_unsupported: bool = True # Warn if constraint is not supported by the platform and won't be in DDL + + +@dataclass +class ModelLevelConstraint(ColumnLevelConstraint): + columns: List[str] = field(default_factory=list) diff --git a/dbt_common/contracts/util.py b/dbt_common/contracts/util.py new file mode 100644 index 00000000..1467e4d8 --- /dev/null +++ b/dbt_common/contracts/util.py @@ -0,0 +1,7 @@ +import dataclasses + + +# TODO: remove from dbt_common.contracts.util:: Replaceable + references +class Replaceable: + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py new file mode 100644 index 00000000..d718604b --- /dev/null +++ b/dbt_common/dataclass_schema.py @@ -0,0 +1,165 @@ +from typing import ClassVar, cast, get_type_hints, List, Tuple, Dict, Any, Optional +import re +import jsonschema +from dataclasses import fields, Field +from enum import Enum +from datetime import datetime +from dateutil.parser import parse + +# type: ignore +from mashumaro import DataClassDictMixin +from mashumaro.config import TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig +from mashumaro.types import SerializableType, SerializationStrategy +from mashumaro.jsonschema import build_json_schema + +import functools + + +class ValidationError(jsonschema.ValidationError): + pass + + +class DateTimeSerialization(SerializationStrategy): + def serialize(self, value) -> str: + out = value.isoformat() + # Assume UTC if timezone is missing + if value.tzinfo is None: + out += "Z" + return out + + def deserialize(self, value) -> datetime: + return value if isinstance(value, datetime) else parse(cast(str, value)) + + +class dbtMashConfig(MashBaseConfig): + code_generation_options = [ + TO_DICT_ADD_OMIT_NONE_FLAG, + ] + serialization_strategy = { + datetime: DateTimeSerialization(), + } + json_schema = { + "additionalProperties": False, + } + serialize_by_alias = True + + +# This class pulls in DataClassDictMixin from Mashumaro. The 'to_dict' +# and 'from_dict' methods come from Mashumaro. +class dbtClassMixin(DataClassDictMixin): + """The Mixin adds methods to generate a JSON schema and + convert to and from JSON encodable dicts with validation + against the schema + """ + + _mapped_fields: ClassVar[Optional[Dict[Any, List[Tuple[Field, str]]]]] = None + + # Config class used by Mashumaro + class Config(dbtMashConfig): + pass + + ADDITIONAL_PROPERTIES: ClassVar[bool] = False + + # This is called by the mashumaro from_dict in order to handle + # nested classes. We no longer do any munging here, but leaving here + # so that subclasses can leave super() in place for possible future needs. + @classmethod + def __pre_deserialize__(cls, data): + return data + + # This is called by the mashumaro to_dict in order to handle + # nested classes. We no longer do any munging here, but leaving here + # so that subclasses can leave super() in place for possible future needs. + def __post_serialize__(self, data): + return data + + @classmethod + @functools.lru_cache + def json_schema(cls): + json_schema_obj = build_json_schema(cls) + json_schema = json_schema_obj.to_dict() + return json_schema + + @classmethod + def validate(cls, data): + json_schema = cls.json_schema() + validator = jsonschema.Draft7Validator(json_schema) + error = next(iter(validator.iter_errors(data)), None) + if error is not None: + raise ValidationError.create_from(error) from error + + # This method was copied from hologram. Used in model_config.py and relation.py + @classmethod + def _get_fields(cls) -> List[Tuple[Field, str]]: + if cls._mapped_fields is None: + cls._mapped_fields = {} + if cls.__name__ not in cls._mapped_fields: + mapped_fields = [] + type_hints = get_type_hints(cls) + + for f in fields(cls): # type: ignore + # Skip internal fields + if f.name.startswith("_"): + continue + + # Note fields() doesn't resolve forward refs + f.type = type_hints[f.name] + + # hologram used the "field_mapping" here, but we use the + # the field's metadata "alias". Since this method is mainly + # just used in merging config dicts, it mostly applies to + # pre-hook and post-hook. + field_name = f.metadata.get("alias", f.name) + mapped_fields.append((f, field_name)) + cls._mapped_fields[cls.__name__] = mapped_fields + return cls._mapped_fields[cls.__name__] + + # copied from hologram. Used in tests + @classmethod + def _get_field_names(cls): + return [element[1] for element in cls._get_fields()] + + +class ValidatedStringMixin(str, SerializableType): + ValidationRegex = "" + + @classmethod + def _deserialize(cls, value: str) -> "ValidatedStringMixin": + cls.validate(value) + return ValidatedStringMixin(value) + + def _serialize(self) -> str: + return str(self) + + @classmethod + def validate(cls, value): + res = re.match(cls.ValidationRegex, value) + + if res is None: + raise ValidationError(f"Invalid value: {value}") # TODO + + +# These classes must be in this order or it doesn't work +class StrEnum(str, SerializableType, Enum): + def __str__(self): + return self.value + + # https://docs.python.org/3.6/library/enum.html#using-automatic-values + def _generate_next_value_(name, *_): + return name + + def _serialize(self) -> str: + return self.value + + @classmethod + def _deserialize(cls, value: str): + return cls(value) + + +class ExtensibleDbtClassMixin(dbtClassMixin): + ADDITIONAL_PROPERTIES = True + + class Config(dbtMashConfig): + json_schema = { + "additionalProperties": True, + } diff --git a/dbt_common/events/README.md b/dbt_common/events/README.md new file mode 100644 index 00000000..a857508d --- /dev/null +++ b/dbt_common/events/README.md @@ -0,0 +1,41 @@ +# Events Module +The Events module is responsible for communicating internal dbt structures into a consumable interface. Because the "event" classes are based entirely on protobuf definitions, the interface is really clearly defined, whether or not protobufs are used to consume it. We use Betterproto for compiling the protobuf message definitions into Python classes. + +# Using the Events Module +The event module provides types that represent what is happening in dbt in `events.types`. These types are intended to represent an exhaustive list of all things happening within dbt that will need to be logged, streamed, or printed. To fire an event, `events.functions::fire_event` is the entry point to the module from everywhere in dbt. + +# Logging +When events are processed via `fire_event`, nearly everything is logged. Whether or not the user has enabled the debug flag, all debug messages are still logged to the file. However, some events are particularly time consuming to construct because they return a huge amount of data. Today, the only messages in this category are cache events and are only logged if the `--log-cache-events` flag is on. This is important because these messages should not be created unless they are going to be logged, because they cause a noticable performance degredation. These events use a "fire_event_if" functions. + +# Adding a New Event +* Add a new message in types.proto, and a second message with the same name + "Msg". The "Msg" message should have two fields, an "info" field of EventInfo, and a "data" field referring to the message name without "Msg" +* run the protoc compiler to update types_pb2.py: make proto_types +* Add a wrapping class in dbt_common/event/types.py with a Level superclass plus code and message methods +* Add the class to tests/unit/test_events.py + +We have switched from using betterproto to using google protobuf, because of a lack of support for Struct fields in betterproto. + +The google protobuf interface is janky and very much non-Pythonic. The "generated" classes in types_pb2.py do not resemble regular Python classes. They do not have normal constructors; they can only be constructed empty. They can be "filled" by setting fields individually or using a json_format method like ParseDict. We have wrapped the logging events with a class (in types.py) which allows using a constructor -- keywords only, no positional parameters. + +## Required for Every Event + +- a method `code`, that's unique across events +- assign a log level by using the Level mixin: `DebugLevel`, `InfoLevel`, `WarnLevel`, or `ErrorLevel` +- a message() + +Example +``` +class PartialParsingDeletedExposure(DebugLevel): + def code(self): + return "I049" + + def message(self) -> str: + return f"Partial parsing: deleted exposure {self.unique_id}" + +``` + +## Compiling types.proto + +After adding a new message in `types.proto`, either: +- In the repository root directory: `make proto_types` +- In the `dbt_common/events` directory: `protoc -I=. --python_out=. types.proto` diff --git a/dbt_common/events/__init__.py b/dbt_common/events/__init__.py new file mode 100644 index 00000000..b200d081 --- /dev/null +++ b/dbt_common/events/__init__.py @@ -0,0 +1,7 @@ +from dbt_common.events.base_types import EventLevel +from dbt_common.events.event_manager_client import get_event_manager +from dbt_common.events.functions import get_stdout_config +from dbt_common.events.logger import LineFormat + +# make sure event manager starts with a logger +get_event_manager().add_logger(get_stdout_config(LineFormat.PlainText, True, EventLevel.INFO, False)) diff --git a/dbt_common/events/base_types.py b/dbt_common/events/base_types.py new file mode 100644 index 00000000..98b15738 --- /dev/null +++ b/dbt_common/events/base_types.py @@ -0,0 +1,181 @@ +from enum import Enum +import os +import threading +from dbt_common.events import types_pb2 +import sys +from google.protobuf.json_format import ParseDict, MessageToDict, MessageToJson +from google.protobuf.message import Message +from dbt_common.events.helpers import get_json_string_utcnow +from typing import Optional + +from dbt_common.invocation import get_invocation_id + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + + +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# These base types define the _required structure_ for the concrete event # +# types defined in types.py # +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # + + +def get_global_metadata_vars() -> dict: + from dbt_common.events.functions import get_metadata_vars + + return get_metadata_vars() + + +# exactly one pid per concrete event +def get_pid() -> int: + return os.getpid() + + +# in theory threads can change, so we don't cache them. +def get_thread_name() -> str: + return threading.current_thread().name + + +# EventLevel is an Enum, but mixing in the 'str' type is suggested in the Python +# documentation, and provides support for json conversion, which fails otherwise. +class EventLevel(str, Enum): + DEBUG = "debug" + TEST = "test" + INFO = "info" + WARN = "warn" + ERROR = "error" + + +class BaseEvent: + """BaseEvent for proto message generated python events""" + + PROTO_TYPES_MODULE = types_pb2 + + def __init__(self, *args, **kwargs) -> None: + class_name = type(self).__name__ + msg_cls = getattr(self.PROTO_TYPES_MODULE, class_name) + if class_name == "Formatting" and len(args) > 0: + kwargs["msg"] = args[0] + args = () + assert len(args) == 0, f"[{class_name}] Don't use positional arguments when constructing logging events" + if "base_msg" in kwargs: + kwargs["base_msg"] = str(kwargs["base_msg"]) + if "msg" in kwargs: + kwargs["msg"] = str(kwargs["msg"]) + try: + self.pb_msg = ParseDict(kwargs, msg_cls()) + except Exception: + # Imports need to be here to avoid circular imports + from dbt_common.events.types import Note + from dbt_common.events.functions import fire_event + + error_msg = f"[{class_name}]: Unable to parse dict {kwargs}" + # If we're testing throw an error so that we notice failures + if "pytest" in sys.modules: + raise Exception(error_msg) + else: + fire_event(Note(msg=error_msg), level=EventLevel.WARN) + self.pb_msg = msg_cls() + + def __setattr__(self, key, value): + if key == "pb_msg": + super().__setattr__(key, value) + else: + super().__getattribute__("pb_msg").__setattr__(key, value) + + def __getattr__(self, key): + if key == "pb_msg": + return super().__getattribute__(key) + else: + return super().__getattribute__("pb_msg").__getattribute__(key) + + def to_dict(self): + return MessageToDict(self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True) + + def to_json(self) -> str: + return MessageToJson( + self.pb_msg, + preserving_proto_field_name=True, + including_default_value_fields=True, + indent=None, + ) + + def level_tag(self) -> EventLevel: + return EventLevel.DEBUG + + def message(self) -> str: + raise Exception("message() not implemented for event") + + def code(self) -> str: + raise Exception("code() not implemented for event") + + +class EventInfo(Protocol): + level: str + name: str + ts: str + code: str + + +class EventMsg(Protocol): + info: EventInfo + data: Message + + +def msg_from_base_event(event: BaseEvent, level: Optional[EventLevel] = None): + + msg_class_name = f"{type(event).__name__}Msg" + msg_cls = getattr(event.PROTO_TYPES_MODULE, msg_class_name) + + # level in EventInfo must be a string, not an EventLevel + msg_level: str = level.value if level else event.level_tag().value + assert msg_level is not None + event_info = { + "level": msg_level, + "msg": event.message(), + "invocation_id": get_invocation_id(), + "extra": get_global_metadata_vars(), + "ts": get_json_string_utcnow(), + "pid": get_pid(), + "thread": get_thread_name(), + "code": event.code(), + "name": type(event).__name__, + } + new_event = ParseDict({"info": event_info}, msg_cls()) + new_event.data.CopyFrom(event.pb_msg) + return new_event + + +# DynamicLevel requires that the level be supplied on the +# event construction call using the "info" function from functions.py +class DynamicLevel(BaseEvent): + pass + + +class TestLevel(BaseEvent): + __test__ = False + + def level_tag(self) -> EventLevel: + return EventLevel.TEST + + +class DebugLevel(BaseEvent): + def level_tag(self) -> EventLevel: + return EventLevel.DEBUG + + +class InfoLevel(BaseEvent): + def level_tag(self) -> EventLevel: + return EventLevel.INFO + + +class WarnLevel(BaseEvent): + def level_tag(self) -> EventLevel: + return EventLevel.WARN + + +class ErrorLevel(BaseEvent): + def level_tag(self) -> EventLevel: + return EventLevel.ERROR diff --git a/dbt_common/events/contextvars.py b/dbt_common/events/contextvars.py new file mode 100644 index 00000000..5bdb78fe --- /dev/null +++ b/dbt_common/events/contextvars.py @@ -0,0 +1,114 @@ +import contextlib +import contextvars + +from typing import Any, Generator, Mapping, Dict + + +LOG_PREFIX = "log_" +TASK_PREFIX = "task_" + +_context_vars: Dict[str, contextvars.ContextVar] = {} + + +def get_contextvars(prefix: str) -> Dict[str, Any]: + rv = {} + ctx = contextvars.copy_context() + + prefix_len = len(prefix) + for k in ctx: + if k.name.startswith(prefix) and ctx[k] is not Ellipsis: + rv[k.name[prefix_len:]] = ctx[k] + + return rv + + +def get_node_info(): + cvars = get_contextvars(LOG_PREFIX) + if "node_info" in cvars: + return cvars["node_info"] + else: + return {} + + +def get_project_root(): + cvars = get_contextvars(TASK_PREFIX) + if "project_root" in cvars: + return cvars["project_root"] + else: + return None + + +def clear_contextvars(prefix: str) -> None: + ctx = contextvars.copy_context() + for k in ctx: + if k.name.startswith(prefix): + k.set(Ellipsis) + + +def set_log_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]: + return set_contextvars(LOG_PREFIX, **kwargs) + + +def set_task_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]: + return set_contextvars(TASK_PREFIX, **kwargs) + + +# put keys and values into context. Returns the contextvar.Token mapping +# Save and pass to reset_contextvars +def set_contextvars(prefix: str, **kwargs: Any) -> Mapping[str, contextvars.Token]: + cvar_tokens = {} + for k, v in kwargs.items(): + log_key = f"{prefix}{k}" + try: + var = _context_vars[log_key] + except KeyError: + var = contextvars.ContextVar(log_key, default=Ellipsis) + _context_vars[log_key] = var + + cvar_tokens[k] = var.set(v) + + return cvar_tokens + + +# reset by Tokens +def reset_contextvars(prefix: str, **kwargs: contextvars.Token) -> None: + for k, v in kwargs.items(): + log_key = f"{prefix}{k}" + var = _context_vars[log_key] + var.reset(v) + + +# remove from contextvars +def unset_contextvars(prefix: str, *keys: str) -> None: + for k in keys: + if k in _context_vars: + log_key = f"{prefix}{k}" + _context_vars[log_key].set(Ellipsis) + + +# Context manager or decorator to set and unset the context vars +@contextlib.contextmanager +def log_contextvars(**kwargs: Any) -> Generator[None, None, None]: + context = get_contextvars(LOG_PREFIX) + saved = {k: context[k] for k in context.keys() & kwargs.keys()} + + set_contextvars(LOG_PREFIX, **kwargs) + try: + yield + finally: + unset_contextvars(LOG_PREFIX, *kwargs.keys()) + set_contextvars(LOG_PREFIX, **saved) + + +# Context manager for earlier in task.run +@contextlib.contextmanager +def task_contextvars(**kwargs: Any) -> Generator[None, None, None]: + context = get_contextvars(TASK_PREFIX) + saved = {k: context[k] for k in context.keys() & kwargs.keys()} + + set_contextvars(TASK_PREFIX, **kwargs) + try: + yield + finally: + unset_contextvars(TASK_PREFIX, *kwargs.keys()) + set_contextvars(TASK_PREFIX, **saved) diff --git a/dbt_common/events/event_handler.py b/dbt_common/events/event_handler.py new file mode 100644 index 00000000..58e23a13 --- /dev/null +++ b/dbt_common/events/event_handler.py @@ -0,0 +1,40 @@ +import logging +from typing import Union + +from dbt_common.events.base_types import EventLevel +from dbt_common.events.types import Note +from dbt_common.events.event_manager import IEventManager + + +_log_level_to_event_level_map = { + logging.DEBUG: EventLevel.DEBUG, + logging.INFO: EventLevel.INFO, + logging.WARN: EventLevel.WARN, + logging.WARNING: EventLevel.WARN, + logging.ERROR: EventLevel.ERROR, + logging.CRITICAL: EventLevel.ERROR, +} + + +class DbtEventLoggingHandler(logging.Handler): + """A logging handler that wraps the EventManager + This allows non-dbt packages to log to the dbt event stream. + All logs are generated as "Note" events. + """ + + def __init__(self, event_manager: IEventManager, level): + super().__init__(level) + self.event_manager = event_manager + + def emit(self, record: logging.LogRecord): + note = Note(msg=record.getMessage()) + level = _log_level_to_event_level_map[record.levelno] + self.event_manager.fire_event(e=note, level=level) + + +def set_package_logging(package_name: str, log_level: Union[str, int], event_mgr: IEventManager): + """Attach dbt's custom logging handler to the package's logger.""" + log = logging.getLogger(package_name) + log.setLevel(log_level) + event_handler = DbtEventLoggingHandler(event_manager=event_mgr, level=log_level) + log.addHandler(event_handler) diff --git a/dbt_common/events/event_manager.py b/dbt_common/events/event_manager.py new file mode 100644 index 00000000..c41b0983 --- /dev/null +++ b/dbt_common/events/event_manager.py @@ -0,0 +1,64 @@ +import os +import traceback +from typing import Callable, List, Optional, Protocol, Tuple + +from dbt_common.events.base_types import BaseEvent, EventLevel, msg_from_base_event, EventMsg +from dbt_common.events.logger import LoggerConfig, _Logger, _TextLogger, _JsonLogger, LineFormat + + +class EventManager: + def __init__(self) -> None: + self.loggers: List[_Logger] = [] + self.callbacks: List[Callable[[EventMsg], None]] = [] + + def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: + msg = msg_from_base_event(e, level=level) + + if os.environ.get("DBT_TEST_BINARY_SERIALIZATION"): + print(f"--- {msg.info.name}") + try: + msg.SerializeToString() + except Exception as exc: + raise Exception( + f"{msg.info.name} is not serializable to binary. Originating exception: {exc}, {traceback.format_exc()}" + ) + + for logger in self.loggers: + if logger.filter(msg): # type: ignore + logger.write_line(msg) + + for callback in self.callbacks: + callback(msg) + + def add_logger(self, config: LoggerConfig) -> None: + logger = _JsonLogger(config) if config.line_format == LineFormat.Json else _TextLogger(config) + self.loggers.append(logger) + + def flush(self) -> None: + for logger in self.loggers: + logger.flush() + + +class IEventManager(Protocol): + callbacks: List[Callable[[EventMsg], None]] + loggers: List[_Logger] + + def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: + ... + + def add_logger(self, config: LoggerConfig) -> None: + ... + + +class TestEventManager(IEventManager): + __test__ = False + + def __init__(self) -> None: + self.event_history: List[Tuple[BaseEvent, Optional[EventLevel]]] = [] + self.loggers = [] + + def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: + self.event_history.append((e, level)) + + def add_logger(self, config: LoggerConfig) -> None: + raise NotImplementedError() diff --git a/dbt_common/events/event_manager_client.py b/dbt_common/events/event_manager_client.py new file mode 100644 index 00000000..1b674f6e --- /dev/null +++ b/dbt_common/events/event_manager_client.py @@ -0,0 +1,29 @@ +# Since dbt-rpc does not do its own log setup, and since some events can +# currently fire before logs can be configured by setup_event_logger(), we +# create a default configuration with default settings and no file output. +from dbt_common.events.event_manager import IEventManager, EventManager + +_EVENT_MANAGER: IEventManager = EventManager() + + +def get_event_manager() -> IEventManager: + global _EVENT_MANAGER + return _EVENT_MANAGER + + +def add_logger_to_manager(logger) -> None: + global _EVENT_MANAGER + _EVENT_MANAGER.add_logger(logger) + + +def ctx_set_event_manager(event_manager: IEventManager) -> None: + global _EVENT_MANAGER + _EVENT_MANAGER = event_manager + + +def cleanup_event_logger() -> None: + # Reset to a no-op manager to release streams associated with logs. This is + # especially important for tests, since pytest replaces the stdout stream + # during test runs, and closes the stream after the test is over. + _EVENT_MANAGER.loggers.clear() + _EVENT_MANAGER.callbacks.clear() diff --git a/dbt_common/events/format.py b/dbt_common/events/format.py new file mode 100644 index 00000000..f87e464b --- /dev/null +++ b/dbt_common/events/format.py @@ -0,0 +1,54 @@ +from dbt_common import ui + +from typing import Optional, Union +from datetime import datetime + +from dbt_common.events.interfaces import LoggableDbtObject + + +def format_fancy_output_line( + msg: str, + status: str, + index: Optional[int], + total: Optional[int], + execution_time: Optional[float] = None, + truncate: bool = False, +) -> str: + if index is None or total is None: + progress = "" + else: + progress = "{} of {} ".format(index, total) + prefix = "{progress}{message} ".format(progress=progress, message=msg) + + truncate_width = ui.printer_width() - 3 + justified = prefix.ljust(ui.printer_width(), ".") + if truncate and len(justified) > truncate_width: + justified = justified[:truncate_width] + "..." + + if execution_time is None: + status_time = "" + else: + status_time = " in {execution_time:0.2f}s".format(execution_time=execution_time) + + output = "{justified} [{status}{status_time}]".format(justified=justified, status=status, status_time=status_time) + + return output + + +def _pluralize(string: Union[str, LoggableDbtObject]) -> str: + if isinstance(string, LoggableDbtObject): + return string.pluralize() + else: + return f"{string}s" + + +def pluralize(count, string: Union[str, LoggableDbtObject]) -> str: + pluralized: str = str(string) + if count != 1: + pluralized = _pluralize(string) + return f"{count} {pluralized}" + + +def timestamp_to_datetime_string(ts) -> str: + timestamp_dt = datetime.fromtimestamp(ts.seconds + ts.nanos / 1e9) + return timestamp_dt.strftime("%H:%M:%S.%f") diff --git a/dbt_common/events/functions.py b/dbt_common/events/functions.py new file mode 100644 index 00000000..60ef4d27 --- /dev/null +++ b/dbt_common/events/functions.py @@ -0,0 +1,150 @@ +from pathlib import Path + +from dbt_common.events.event_manager_client import get_event_manager +from dbt_common.invocation import get_invocation_id +from dbt_common.helper_types import WarnErrorOptions +from dbt_common.utils import ForgivingJSONEncoder +from dbt_common.events.base_types import BaseEvent, EventLevel, EventMsg +from dbt_common.events.logger import LoggerConfig, LineFormat +from dbt_common.exceptions import scrub_secrets, env_secrets +from dbt_common.events.types import Note +from functools import partial +import json +import os +import sys +from typing import Callable, Dict, Optional, TextIO, Union +from google.protobuf.json_format import MessageToDict + +LOG_VERSION = 3 +metadata_vars: Optional[Dict[str, str]] = None +_METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_" +WARN_ERROR_OPTIONS = WarnErrorOptions(include=[], exclude=[]) +WARN_ERROR = False + +# This global, and the following two functions for capturing stdout logs are +# an unpleasant hack we intend to remove as part of API-ification. The GitHub +# issue #6350 was opened for that work. +CAPTURE_STREAM: Optional[TextIO] = None + + +def stdout_filter( + log_cache_events: bool, + line_format: LineFormat, + msg: EventMsg, +) -> bool: + return msg.info.name not in ["CacheAction", "CacheDumpGraph"] or log_cache_events + + +def get_stdout_config( + line_format: LineFormat, + use_colors: bool, + level: EventLevel, + log_cache_events: bool, +) -> LoggerConfig: + return LoggerConfig( + name="stdout_log", + level=level, + use_colors=use_colors, + line_format=line_format, + scrubber=env_scrubber, + filter=partial( + stdout_filter, + log_cache_events, + line_format, + ), + invocation_id=get_invocation_id(), + output_stream=sys.stdout, + ) + + +def make_log_dir_if_missing(log_path: Union[Path, str]) -> None: + if isinstance(log_path, str): + log_path = Path(log_path) + log_path.mkdir(parents=True, exist_ok=True) + + +def env_scrubber(msg: str) -> str: + return scrub_secrets(msg, env_secrets()) + + +# used for integration tests +def capture_stdout_logs(stream: TextIO) -> None: + global CAPTURE_STREAM + CAPTURE_STREAM = stream + + +def stop_capture_stdout_logs() -> None: + global CAPTURE_STREAM + CAPTURE_STREAM = None + + +def get_capture_stream() -> Optional[TextIO]: + return CAPTURE_STREAM + + +# returns a dictionary representation of the event fields. +# the message may contain secrets which must be scrubbed at the usage site. +def msg_to_json(msg: EventMsg) -> str: + msg_dict = msg_to_dict(msg) + raw_log_line = json.dumps(msg_dict, sort_keys=True, cls=ForgivingJSONEncoder) + return raw_log_line + + +def msg_to_dict(msg: EventMsg) -> dict: + msg_dict = dict() + try: + msg_dict = MessageToDict( + msg, preserving_proto_field_name=True, including_default_value_fields=True # type: ignore + ) + except Exception as exc: + event_type = type(msg).__name__ + fire_event(Note(msg=f"type {event_type} is not serializable. {str(exc)}"), level=EventLevel.WARN) + # We don't want an empty NodeInfo in output + if "data" in msg_dict and "node_info" in msg_dict["data"] and msg_dict["data"]["node_info"]["node_name"] == "": + del msg_dict["data"]["node_info"] + return msg_dict + + +def warn_or_error(event, node=None) -> None: + if WARN_ERROR or WARN_ERROR_OPTIONS.includes(type(event).__name__): + + # TODO: resolve this circular import when at top + from dbt_common.exceptions import EventCompilationError + + raise EventCompilationError(event.message(), node) + else: + fire_event(event) + + +# an alternative to fire_event which only creates and logs the event value +# if the condition is met. Does nothing otherwise. +def fire_event_if(conditional: bool, lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None) -> None: + if conditional: + fire_event(lazy_e(), level=level) + + +# a special case of fire_event_if, to only fire events in our unit/functional tests +def fire_event_if_test(lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None) -> None: + fire_event_if(conditional=("pytest" in sys.modules), lazy_e=lazy_e, level=level) + + +# top-level method for accessing the new eventing system +# this is where all the side effects happen branched by event type +# (i.e. - mutating the event history, printing to stdout, logging +# to files, etc.) +def fire_event(e: BaseEvent, level: Optional[EventLevel] = None) -> None: + get_event_manager().fire_event(e, level=level) + + +def get_metadata_vars() -> Dict[str, str]: + global metadata_vars + if not metadata_vars: + metadata_vars = { + k[len(_METADATA_ENV_PREFIX) :]: v for k, v in os.environ.items() if k.startswith(_METADATA_ENV_PREFIX) + } + return metadata_vars + + +def reset_metadata_vars() -> None: + global metadata_vars + metadata_vars = None diff --git a/dbt_common/events/helpers.py b/dbt_common/events/helpers.py new file mode 100644 index 00000000..25ff6bde --- /dev/null +++ b/dbt_common/events/helpers.py @@ -0,0 +1,14 @@ +from datetime import datetime + + +# This converts a datetime to a json format datetime string which +# is used in constructing protobuf message timestamps. +def datetime_to_json_string(dt: datetime) -> str: + return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + + +# preformatted time stamp +def get_json_string_utcnow() -> str: + ts = datetime.utcnow() + ts_rfc3339 = datetime_to_json_string(ts) + return ts_rfc3339 diff --git a/dbt_common/events/interfaces.py b/dbt_common/events/interfaces.py new file mode 100644 index 00000000..13c7df9d --- /dev/null +++ b/dbt_common/events/interfaces.py @@ -0,0 +1,7 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class LoggableDbtObject(Protocol): + def pluralize(self) -> str: + ... diff --git a/dbt_common/events/logger.py b/dbt_common/events/logger.py new file mode 100644 index 00000000..ece7f283 --- /dev/null +++ b/dbt_common/events/logger.py @@ -0,0 +1,176 @@ +import json +import logging +import threading +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from logging.handlers import RotatingFileHandler +from typing import Optional, TextIO, Any, Callable + +from colorama import Style + +from dbt_common.events.base_types import EventLevel, EventMsg +from dbt_common.events.format import timestamp_to_datetime_string +from dbt_common.utils import ForgivingJSONEncoder + +# A Filter is a function which takes a BaseEvent and returns True if the event +# should be logged, False otherwise. +Filter = Callable[[EventMsg], bool] + + +# Default filter which logs every event +def NoFilter(_: EventMsg) -> bool: + return True + + +# A Scrubber removes secrets from an input string, returning a sanitized string. +Scrubber = Callable[[str], str] + + +# Provide a pass-through scrubber implementation, also used as a default +def NoScrubber(s: str) -> str: + return s + + +class LineFormat(Enum): + PlainText = 1 + DebugText = 2 + Json = 3 + + +# Map from dbt event levels to python log levels +_log_level_map = { + EventLevel.DEBUG: 10, + EventLevel.TEST: 10, + EventLevel.INFO: 20, + EventLevel.WARN: 30, + EventLevel.ERROR: 40, +} + + +# We need this function for now because the numeric log severity levels in +# Python do not match those for logbook, so we have to explicitly call the +# correct function by name. +def send_to_logger(l, level: str, log_line: str): + if level == "test": + l.debug(log_line) + elif level == "debug": + l.debug(log_line) + elif level == "info": + l.info(log_line) + elif level == "warn": + l.warning(log_line) + elif level == "error": + l.error(log_line) + else: + raise AssertionError(f"While attempting to log {log_line}, encountered the unhandled level: {level}") + + +@dataclass +class LoggerConfig: + name: str + filter: Filter = NoFilter + scrubber: Scrubber = NoScrubber + line_format: LineFormat = LineFormat.PlainText + level: EventLevel = EventLevel.WARN + invocation_id: Optional[str] = None + use_colors: bool = False + output_stream: Optional[TextIO] = None + output_file_name: Optional[str] = None + output_file_max_bytes: Optional[int] = 10 * 1024 * 1024 # 10 mb + logger: Optional[Any] = None + + +class _Logger: + def __init__(self, config: LoggerConfig) -> None: + self.name: str = config.name + self.filter: Filter = config.filter + self.scrubber: Scrubber = config.scrubber + self.level: EventLevel = config.level + self.invocation_id: Optional[str] = config.invocation_id + self._python_logger: Optional[logging.Logger] = config.logger + + if config.output_stream is not None: + stream_handler = logging.StreamHandler(config.output_stream) + self._python_logger = self._get_python_log_for_handler(stream_handler) + + if config.output_file_name: + file_handler = RotatingFileHandler( + filename=str(config.output_file_name), + encoding="utf8", + maxBytes=config.output_file_max_bytes, # type: ignore + backupCount=5, + ) + self._python_logger = self._get_python_log_for_handler(file_handler) + + def _get_python_log_for_handler(self, handler: logging.Handler): + log = logging.getLogger(self.name) + log.setLevel(_log_level_map[self.level]) + handler.setFormatter(logging.Formatter(fmt="%(message)s")) + log.handlers.clear() + log.propagate = False + log.addHandler(handler) + return log + + def create_line(self, msg: EventMsg) -> str: + raise NotImplementedError() + + def write_line(self, msg: EventMsg): + line = self.create_line(msg) + if self._python_logger is not None: + send_to_logger(self._python_logger, msg.info.level, line) + + def flush(self): + if self._python_logger is not None: + for handler in self._python_logger.handlers: + handler.flush() + + +class _TextLogger(_Logger): + def __init__(self, config: LoggerConfig) -> None: + super().__init__(config) + self.use_colors = config.use_colors + self.use_debug_format = config.line_format == LineFormat.DebugText + + def create_line(self, msg: EventMsg) -> str: + return self.create_debug_line(msg) if self.use_debug_format else self.create_info_line(msg) + + def create_info_line(self, msg: EventMsg) -> str: + ts: str = datetime.utcnow().strftime("%H:%M:%S") + scrubbed_msg: str = self.scrubber(msg.info.msg) # type: ignore + return f"{self._get_color_tag()}{ts} {scrubbed_msg}" + + def create_debug_line(self, msg: EventMsg) -> str: + log_line: str = "" + # Create a separator if this is the beginning of an invocation + # TODO: This is an ugly hack, get rid of it if we can + ts: str = timestamp_to_datetime_string(msg.info.ts) + if msg.info.name == "MainReportVersion": + separator = 30 * "=" + log_line = f"\n\n{separator} {ts} | {self.invocation_id} {separator}\n" + scrubbed_msg: str = self.scrubber(msg.info.msg) # type: ignore + level = msg.info.level + log_line += f"{self._get_color_tag()}{ts} [{level:<5}]{self._get_thread_name()} {scrubbed_msg}" + return log_line + + def _get_color_tag(self) -> str: + return "" if not self.use_colors else Style.RESET_ALL + + def _get_thread_name(self) -> str: + thread_name = "" + if threading.current_thread().name: + thread_name = threading.current_thread().name + thread_name = thread_name[:10] + thread_name = thread_name.ljust(10, " ") + thread_name = f" [{thread_name}]:" + return thread_name + + +class _JsonLogger(_Logger): + def create_line(self, msg: EventMsg) -> str: + from dbt_common.events.functions import msg_to_dict + + msg_dict = msg_to_dict(msg) + raw_log_line = json.dumps(msg_dict, sort_keys=True, cls=ForgivingJSONEncoder) + line = self.scrubber(raw_log_line) # type: ignore + return line diff --git a/dbt_common/events/types.proto b/dbt_common/events/types.proto new file mode 100644 index 00000000..ad791315 --- /dev/null +++ b/dbt_common/events/types.proto @@ -0,0 +1,121 @@ +syntax = "proto3"; + +package proto_types; + +import "google/protobuf/timestamp.proto"; + +// Common event info +message EventInfo { + string name = 1; + string code = 2; + string msg = 3; + string level = 4; + string invocation_id = 5; + int32 pid = 6; + string thread = 7; + google.protobuf.Timestamp ts = 8; + map extra = 9; + string category = 10; +} + +// GenericMessage, used for deserializing only +message GenericMessage { + EventInfo info = 1; +} + +// M - Deps generation + +// M020 +message RetryExternalCall { + int32 attempt = 1; + int32 max = 2; +} + +message RetryExternalCallMsg { + EventInfo info = 1; + RetryExternalCall data = 2; +} + +// M021 +message RecordRetryException { + string exc = 1; +} + +message RecordRetryExceptionMsg { + EventInfo info = 1; + RecordRetryException data = 2; +} + +// Z - Misc + +// Z005 +message SystemCouldNotWrite { + string path = 1; + string reason = 2; + string exc = 3; +} + +message SystemCouldNotWriteMsg { + EventInfo info = 1; + SystemCouldNotWrite data = 2; +} + +// Z006 +message SystemExecutingCmd { + repeated string cmd = 1; +} + +message SystemExecutingCmdMsg { + EventInfo info = 1; + SystemExecutingCmd data = 2; +} + +// Z007 +message SystemStdOut{ + string bmsg = 1; +} + +message SystemStdOutMsg { + EventInfo info = 1; + SystemStdOut data = 2; +} + +// Z008 +message SystemStdErr { + string bmsg = 1; +} + +message SystemStdErrMsg { + EventInfo info = 1; + SystemStdErr data = 2; +} + +// Z009 +message SystemReportReturnCode { + int32 returncode = 1; +} + +message SystemReportReturnCodeMsg { + EventInfo info = 1; + SystemReportReturnCode data = 2; +} + +// Z017 +message Formatting { + string msg = 1; +} + +message FormattingMsg { + EventInfo info = 1; + Formatting data = 2; +} + +// Z050 +message Note { + string msg = 1; +} + +message NoteMsg { + EventInfo info = 1; + Note data = 2; +} diff --git a/dbt_common/events/types.py b/dbt_common/events/types.py new file mode 100644 index 00000000..0ee5cd00 --- /dev/null +++ b/dbt_common/events/types.py @@ -0,0 +1,124 @@ +from dbt_common.events.base_types import ( + DebugLevel, + InfoLevel, +) + + +# The classes in this file represent the data necessary to describe a +# particular event to both human readable logs, and machine reliable +# event streams. classes extend superclasses that indicate what +# destinations they are intended for, which mypy uses to enforce +# that the necessary methods are defined. + + +# Event codes have prefixes which follow this table +# +# | Code | Description | +# |:----:|:-------------------:| +# | A | Pre-project loading | +# | D | Deprecations | +# | E | DB adapter | +# | I | Project parsing | +# | M | Deps generation | +# | P | Artifacts | +# | Q | Node execution | +# | W | Node testing | +# | Z | Misc | +# | T | Test only | +# +# The basic idea is that event codes roughly translate to the natural order of running a dbt task + +# ======================================================= +# M - Deps generation +# ======================================================= + + +class RetryExternalCall(DebugLevel): + def code(self) -> str: + return "M020" + + def message(self) -> str: + return f"Retrying external call. Attempt: {self.attempt} Max attempts: {self.max}" + + +class RecordRetryException(DebugLevel): + def code(self) -> str: + return "M021" + + def message(self) -> str: + return f"External call exception: {self.exc}" + + +# ======================================================= +# Z - Misc +# ======================================================= + + +class SystemCouldNotWrite(DebugLevel): + def code(self) -> str: + return "Z005" + + def message(self) -> str: + return ( + f"Could not write to path {self.path}({len(self.path)} characters): " + f"{self.reason}\nexception: {self.exc}" + ) + + +class SystemExecutingCmd(DebugLevel): + def code(self) -> str: + return "Z006" + + def message(self) -> str: + return f'Executing "{" ".join(self.cmd)}"' + + +class SystemStdOut(DebugLevel): + def code(self) -> str: + return "Z007" + + def message(self) -> str: + return f'STDOUT: "{str(self.bmsg)}"' + + +class SystemStdErr(DebugLevel): + def code(self) -> str: + return "Z008" + + def message(self) -> str: + return f'STDERR: "{str(self.bmsg)}"' + + +class SystemReportReturnCode(DebugLevel): + def code(self) -> str: + return "Z009" + + def message(self) -> str: + return f"command return code={self.returncode}" + + +# We use events to create console output, but also think of them as a sequence of important and +# meaningful occurrences to be used for debugging and monitoring. The Formatting event helps eases +# the tension between these two goals by allowing empty lines, heading separators, and other +# formatting to be written to the console, while they can be ignored for other purposes. For +# general information that isn't simple formatting, the Note event should be used instead. + + +class Formatting(InfoLevel): + def code(self) -> str: + return "Z017" + + def message(self) -> str: + return self.msg + + +class Note(InfoLevel): + """The Note event provides a way to log messages which aren't likely to be + useful as more structured events. For console formatting text like empty + lines and separator bars, use the Formatting event instead.""" + + def code(self) -> str: + return "Z050" + + def message(self) -> str: + return self.msg diff --git a/dbt_common/events/types_pb2.py b/dbt_common/events/types_pb2.py new file mode 100644 index 00000000..3d7a33e3 --- /dev/null +++ b/dbt_common/events/types_pb2.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: types.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0btypes.proto\x12\x0bproto_types\x1a\x1fgoogle/protobuf/timestamp.proto"\x91\x02\n\tEventInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0b\n\x03msg\x18\x03 \x01(\t\x12\r\n\x05level\x18\x04 \x01(\t\x12\x15\n\rinvocation_id\x18\x05 \x01(\t\x12\x0b\n\x03pid\x18\x06 \x01(\x05\x12\x0e\n\x06thread\x18\x07 \x01(\t\x12&\n\x02ts\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x30\n\x05\x65xtra\x18\t \x03(\x0b\x32!.proto_types.EventInfo.ExtraEntry\x12\x10\n\x08\x63\x61tegory\x18\n \x01(\t\x1a,\n\nExtraEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"6\n\x0eGenericMessage\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo"1\n\x11RetryExternalCall\x12\x0f\n\x07\x61ttempt\x18\x01 \x01(\x05\x12\x0b\n\x03max\x18\x02 \x01(\x05"j\n\x14RetryExternalCallMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12,\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x1e.proto_types.RetryExternalCall"#\n\x14RecordRetryException\x12\x0b\n\x03\x65xc\x18\x01 \x01(\t"p\n\x17RecordRetryExceptionMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12/\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32!.proto_types.RecordRetryException"@\n\x13SystemCouldNotWrite\x12\x0c\n\x04path\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\x12\x0b\n\x03\x65xc\x18\x03 \x01(\t"n\n\x16SystemCouldNotWriteMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12.\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32 .proto_types.SystemCouldNotWrite"!\n\x12SystemExecutingCmd\x12\x0b\n\x03\x63md\x18\x01 \x03(\t"l\n\x15SystemExecutingCmdMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12-\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x1f.proto_types.SystemExecutingCmd"\x1c\n\x0cSystemStdOut\x12\x0c\n\x04\x62msg\x18\x01 \x01(\t"`\n\x0fSystemStdOutMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12\'\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x19.proto_types.SystemStdOut"\x1c\n\x0cSystemStdErr\x12\x0c\n\x04\x62msg\x18\x01 \x01(\t"`\n\x0fSystemStdErrMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12\'\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x19.proto_types.SystemStdErr",\n\x16SystemReportReturnCode\x12\x12\n\nreturncode\x18\x01 \x01(\x05"t\n\x19SystemReportReturnCodeMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12\x31\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32#.proto_types.SystemReportReturnCode"\x19\n\nFormatting\x12\x0b\n\x03msg\x18\x01 \x01(\t"\\\n\rFormattingMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12%\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x17.proto_types.Formatting"\x13\n\x04Note\x12\x0b\n\x03msg\x18\x01 \x01(\t"P\n\x07NoteMsg\x12$\n\x04info\x18\x01 \x01(\x0b\x32\x16.proto_types.EventInfo\x12\x1f\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x11.proto_types.Noteb\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "types_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _EVENTINFO_EXTRAENTRY._options = None + _EVENTINFO_EXTRAENTRY._serialized_options = b"8\001" + _globals["_EVENTINFO"]._serialized_start = 62 + _globals["_EVENTINFO"]._serialized_end = 335 + _globals["_EVENTINFO_EXTRAENTRY"]._serialized_start = 291 + _globals["_EVENTINFO_EXTRAENTRY"]._serialized_end = 335 + _globals["_GENERICMESSAGE"]._serialized_start = 337 + _globals["_GENERICMESSAGE"]._serialized_end = 391 + _globals["_RETRYEXTERNALCALL"]._serialized_start = 393 + _globals["_RETRYEXTERNALCALL"]._serialized_end = 442 + _globals["_RETRYEXTERNALCALLMSG"]._serialized_start = 444 + _globals["_RETRYEXTERNALCALLMSG"]._serialized_end = 550 + _globals["_RECORDRETRYEXCEPTION"]._serialized_start = 552 + _globals["_RECORDRETRYEXCEPTION"]._serialized_end = 587 + _globals["_RECORDRETRYEXCEPTIONMSG"]._serialized_start = 589 + _globals["_RECORDRETRYEXCEPTIONMSG"]._serialized_end = 701 + _globals["_SYSTEMCOULDNOTWRITE"]._serialized_start = 703 + _globals["_SYSTEMCOULDNOTWRITE"]._serialized_end = 767 + _globals["_SYSTEMCOULDNOTWRITEMSG"]._serialized_start = 769 + _globals["_SYSTEMCOULDNOTWRITEMSG"]._serialized_end = 879 + _globals["_SYSTEMEXECUTINGCMD"]._serialized_start = 881 + _globals["_SYSTEMEXECUTINGCMD"]._serialized_end = 914 + _globals["_SYSTEMEXECUTINGCMDMSG"]._serialized_start = 916 + _globals["_SYSTEMEXECUTINGCMDMSG"]._serialized_end = 1024 + _globals["_SYSTEMSTDOUT"]._serialized_start = 1026 + _globals["_SYSTEMSTDOUT"]._serialized_end = 1054 + _globals["_SYSTEMSTDOUTMSG"]._serialized_start = 1056 + _globals["_SYSTEMSTDOUTMSG"]._serialized_end = 1152 + _globals["_SYSTEMSTDERR"]._serialized_start = 1154 + _globals["_SYSTEMSTDERR"]._serialized_end = 1182 + _globals["_SYSTEMSTDERRMSG"]._serialized_start = 1184 + _globals["_SYSTEMSTDERRMSG"]._serialized_end = 1280 + _globals["_SYSTEMREPORTRETURNCODE"]._serialized_start = 1282 + _globals["_SYSTEMREPORTRETURNCODE"]._serialized_end = 1326 + _globals["_SYSTEMREPORTRETURNCODEMSG"]._serialized_start = 1328 + _globals["_SYSTEMREPORTRETURNCODEMSG"]._serialized_end = 1444 + _globals["_FORMATTING"]._serialized_start = 1446 + _globals["_FORMATTING"]._serialized_end = 1471 + _globals["_FORMATTINGMSG"]._serialized_start = 1473 + _globals["_FORMATTINGMSG"]._serialized_end = 1565 + _globals["_NOTE"]._serialized_start = 1567 + _globals["_NOTE"]._serialized_end = 1586 + _globals["_NOTEMSG"]._serialized_start = 1588 + _globals["_NOTEMSG"]._serialized_end = 1668 +# @@protoc_insertion_point(module_scope) diff --git a/dbt_common/exceptions/__init__.py b/dbt_common/exceptions/__init__.py new file mode 100644 index 00000000..437ef6c0 --- /dev/null +++ b/dbt_common/exceptions/__init__.py @@ -0,0 +1,7 @@ +from dbt_common.exceptions.base import * # noqa +from dbt_common.exceptions.events import * # noqa +from dbt_common.exceptions.macros import * # noqa +from dbt_common.exceptions.contracts import * # noqa +from dbt_common.exceptions.connection import * # noqa +from dbt_common.exceptions.system import * # noqa +from dbt_common.exceptions.jinja import * # noqa diff --git a/dbt_common/exceptions/base.py b/dbt_common/exceptions/base.py new file mode 100644 index 00000000..0ec53d2a --- /dev/null +++ b/dbt_common/exceptions/base.py @@ -0,0 +1,270 @@ +import builtins +from typing import List, Any, Optional +import os + +from dbt_common.constants import SECRET_ENV_PREFIX +from dbt_common.dataclass_schema import ValidationError + + +def env_secrets() -> List[str]: + return [v for k, v in os.environ.items() if k.startswith(SECRET_ENV_PREFIX) and v.strip()] + + +def scrub_secrets(msg: str, secrets: List[str]) -> str: + scrubbed = str(msg) + + for secret in secrets: + scrubbed = scrubbed.replace(secret, "*****") + + return scrubbed + + +class DbtBaseException(Exception): + CODE = -32000 + MESSAGE = "Server Error" + + def data(self): + # if overriding, make sure the result is json-serializable. + return { + "type": self.__class__.__name__, + "message": str(self), + } + + +class DbtInternalError(DbtBaseException): + def __init__(self, msg: str): + self.stack: List = [] + self.msg = scrub_secrets(msg, env_secrets()) + + @property + def type(self): + return "Internal" + + def process_stack(self): + lines = [] + stack = self.stack + first = True + + if len(stack) > 1: + lines.append("") + + for item in stack: + msg = "called by" + + if first: + msg = "in" + first = False + + lines.append(f"> {msg}") + + return lines + + def __str__(self): + if hasattr(self.msg, "split"): + split_msg = self.msg.split("\n") + else: + split_msg = str(self.msg).split("\n") + + lines = ["{}".format(self.type + " Error")] + split_msg + + lines += self.process_stack() + + return lines[0] + "\n" + "\n".join([" " + line for line in lines[1:]]) + + +class DbtRuntimeError(RuntimeError, DbtBaseException): + CODE = 10001 + MESSAGE = "Runtime error" + + def __init__(self, msg: str, node=None) -> None: + self.stack: List = [] + self.node = node + self.msg = scrub_secrets(msg, env_secrets()) + + def add_node(self, node=None): + if node is not None and node is not self.node: + if self.node is not None: + self.stack.append(self.node) + self.node = node + + @property + def type(self): + return "Runtime" + + def node_to_string(self, node: Any): + """ + Given a node-like object we attempt to create the best identifier we can + """ + result = "" + if hasattr(node, "resource_type"): + result += node.resource_type + if hasattr(node, "name"): + result += f" {node.name}" + if hasattr(node, "original_file_path"): + result += f" ({node.original_file_path})" + + return result.strip() if result != "" else " " + + def process_stack(self): + lines = [] + stack = self.stack + [self.node] + first = True + + if len(stack) > 1: + lines.append("") + + for item in stack: + msg = "called by" + + if first: + msg = "in" + first = False + + lines.append(f"> {msg} {self.node_to_string(item)}") + + return lines + + def validator_error_message(self, exc: builtins.Exception): + """Given a dbt.dataclass_schema.ValidationError (which is basically a + jsonschema.ValidationError), return the relevant parts as a string + """ + if not isinstance(exc, ValidationError): + return str(exc) + path = "[%s]" % "][".join(map(repr, exc.relative_path)) + return f"at path {path}: {exc.message}" + + def __str__(self, prefix: str = "! "): + node_string = "" + + if self.node is not None: + node_string = f" in {self.node_to_string(self.node)}" + + if hasattr(self.msg, "split"): + split_msg = self.msg.split("\n") + else: + split_msg = str(self.msg).split("\n") + + lines = ["{}{}".format(self.type + " Error", node_string)] + split_msg + + lines += self.process_stack() + + return lines[0] + "\n" + "\n".join([" " + line for line in lines[1:]]) + + def data(self): + result = DbtBaseException.data(self) + if self.node is None: + return result + + result.update( + { + "raw_code": self.node.raw_code, + # the node isn't always compiled, but if it is, include that! + "compiled_code": getattr(self.node, "compiled_code", None), + } + ) + return result + + +class CompilationError(DbtRuntimeError): + CODE = 10004 + MESSAGE = "Compilation Error" + + @property + def type(self): + return "Compilation" + + def _fix_dupe_msg(self, path_1: str, path_2: str, name: str, type_name: str) -> str: + if path_1 == path_2: + return f"remove one of the {type_name} entries for {name} in this file:\n - {path_1!s}\n" + else: + return f"remove the {type_name} entry for {name} in one of these files:\n" f" - {path_1!s}\n{path_2!s}" + + +class RecursionError(DbtRuntimeError): + pass + + +class DbtConfigError(DbtRuntimeError): + CODE = 10007 + MESSAGE = "DBT Configuration Error" + + # ToDo: Can we remove project? + def __init__(self, msg: str, project=None, result_type="invalid_project", path=None) -> None: + self.project = project + super().__init__(msg) + self.result_type = result_type + self.path = path + + def __str__(self, prefix="! ") -> str: + msg = super().__str__(prefix) + if self.path is None: + return msg + else: + return f"{msg}\n\nError encountered in {self.path}" + + +class NotImplementedError(DbtBaseException): + def __init__(self, msg: str) -> None: + self.msg = msg + self.formatted_msg = f"ERROR: {self.msg}" + super().__init__(self.formatted_msg) + + +class SemverError(Exception): + def __init__(self, msg: Optional[str] = None) -> None: + self.msg = msg + if msg is not None: + super().__init__(msg) + else: + super().__init__() + + +class VersionsNotCompatibleError(SemverError): + pass + + +class DbtValidationError(DbtRuntimeError): + CODE = 10005 + MESSAGE = "Validation Error" + + +class DbtDatabaseError(DbtRuntimeError): + CODE = 10003 + MESSAGE = "Database Error" + + def process_stack(self): + lines = [] + + if hasattr(self.node, "build_path") and self.node.build_path: + lines.append(f"compiled Code at {self.node.build_path}") + + return lines + DbtRuntimeError.process_stack(self) + + @property + def type(self): + return "Database" + + +class UnexpectedNullError(DbtDatabaseError): + def __init__(self, field_name: str, source): + self.field_name = field_name + self.source = source + msg = ( + f"Expected a non-null value when querying field '{self.field_name}' of table " + f" {self.source} but received value 'null' instead" + ) + super().__init__(msg) + + +class CommandError(DbtRuntimeError): + def __init__(self, cwd: str, cmd: List[str], msg: str = "Error running command") -> None: + cmd_scrubbed = list(scrub_secrets(cmd_txt, env_secrets()) for cmd_txt in cmd) + super().__init__(msg) + self.cwd = cwd + self.cmd = cmd_scrubbed + self.args = (cwd, cmd_scrubbed, msg) + + def __str__(self): + if len(self.cmd) == 0: + return f"{self.msg}: No arguments given" + return f'{self.msg}: "{self.cmd[0]}"' diff --git a/dbt_common/exceptions/cache.py b/dbt_common/exceptions/cache.py new file mode 100644 index 00000000..6dc21539 --- /dev/null +++ b/dbt_common/exceptions/cache.py @@ -0,0 +1,66 @@ +import re +from typing import Dict + +from dbt_common.exceptions import DbtInternalError + + +class CacheInconsistencyError(DbtInternalError): + def __init__(self, msg: str): + self.msg = msg + formatted_msg = f"Cache inconsistency detected: {self.msg}" + super().__init__(msg=formatted_msg) + + +class NewNameAlreadyInCacheError(CacheInconsistencyError): + def __init__(self, old_key: str, new_key: str): + self.old_key = old_key + self.new_key = new_key + msg = f'in rename of "{self.old_key}" -> "{self.new_key}", new name is in the cache already' + super().__init__(msg) + + +class ReferencedLinkNotCachedError(CacheInconsistencyError): + def __init__(self, referenced_key: str): + self.referenced_key = referenced_key + msg = f"in add_link, referenced link key {self.referenced_key} not in cache!" + super().__init__(msg) + + +class DependentLinkNotCachedError(CacheInconsistencyError): + def __init__(self, dependent_key: str): + self.dependent_key = dependent_key + msg = f"in add_link, dependent link key {self.dependent_key} not in cache!" + super().__init__(msg) + + +class TruncatedModelNameCausedCollisionError(CacheInconsistencyError): + def __init__(self, new_key, relations: Dict): + self.new_key = new_key + self.relations = relations + super().__init__(self.get_message()) + + def get_message(self) -> str: + # Tell user when collision caused by model names truncated during + # materialization. + match = re.search("__dbt_backup|__dbt_tmp$", self.new_key.identifier) + if match: + truncated_model_name_prefix = self.new_key.identifier[: match.start()] + message_addendum = ( + "\n\nName collisions can occur when the length of two " + "models' names approach your database's builtin limit. " + "Try restructuring your project such that no two models " + f"share the prefix '{truncated_model_name_prefix}'. " + "Then, clean your warehouse of any removed models." + ) + else: + message_addendum = "" + + msg = f"in rename, new key {self.new_key} already in cache: {list(self.relations.keys())}{message_addendum}" + + return msg + + +class NoneRelationFoundError(CacheInconsistencyError): + def __init__(self): + msg = "in get_relations, a None relation was found in the cache!" + super().__init__(msg) diff --git a/dbt_common/exceptions/connection.py b/dbt_common/exceptions/connection.py new file mode 100644 index 00000000..2638f32b --- /dev/null +++ b/dbt_common/exceptions/connection.py @@ -0,0 +1,7 @@ +class ConnectionError(Exception): + """ + There was a problem with the connection that returned a bad response, + timed out, or resulted in a file that is corrupt. + """ + + pass diff --git a/dbt_common/exceptions/contracts.py b/dbt_common/exceptions/contracts.py new file mode 100644 index 00000000..2ea0053c --- /dev/null +++ b/dbt_common/exceptions/contracts.py @@ -0,0 +1,17 @@ +from typing import Any +from dbt_common.exceptions import CompilationError + + +# this is part of the context and also raised in dbt.contracts.relation.py +class DataclassNotDictError(CompilationError): + def __init__(self, obj: Any): + self.obj = obj + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = ( + f'The object ("{self.obj}") was used as a dictionary. This ' + "capability has been removed from objects of this type." + ) + + return msg diff --git a/dbt_common/exceptions/events.py b/dbt_common/exceptions/events.py new file mode 100644 index 00000000..862922a1 --- /dev/null +++ b/dbt_common/exceptions/events.py @@ -0,0 +1,9 @@ +from dbt_common.exceptions import CompilationError, scrub_secrets, env_secrets + + +# event level exception +class EventCompilationError(CompilationError): + def __init__(self, msg: str, node) -> None: + self.msg = scrub_secrets(msg, env_secrets()) + self.node = node + super().__init__(msg=self.msg) diff --git a/dbt_common/exceptions/jinja.py b/dbt_common/exceptions/jinja.py new file mode 100644 index 00000000..f689b2ad --- /dev/null +++ b/dbt_common/exceptions/jinja.py @@ -0,0 +1,87 @@ +from dbt_common.exceptions import CompilationError + + +class BlockDefinitionNotAtTopError(CompilationError): + def __init__(self, tag_parser, tag_start) -> None: + self.tag_parser = tag_parser + self.tag_start = tag_start + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + position = self.tag_parser.linepos(self.tag_start) + msg = ( + f"Got a block definition inside control flow at {position}. " + "All dbt block definitions must be at the top level" + ) + return msg + + +class MissingCloseTagError(CompilationError): + def __init__(self, block_type_name: str, linecount: int) -> None: + self.block_type_name = block_type_name + self.linecount = linecount + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = ( + f"Reached EOF without finding a close tag for {self.block_type_name} (searched from line {self.linecount})" + ) + return msg + + +class MissingControlFlowStartTagError(CompilationError): + def __init__(self, tag, expected_tag: str, tag_parser) -> None: + self.tag = tag + self.expected_tag = expected_tag + self.tag_parser = tag_parser + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + linepos = self.tag_parser.linepos(self.tag.start) + msg = ( + f"Got an unexpected control flow end tag, got {self.tag.block_type_name} but " + f"expected {self.expected_tag} next (@ {linepos})" + ) + return msg + + +class NestedTagsError(CompilationError): + def __init__(self, outer, inner) -> None: + self.outer = outer + self.inner = inner + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = ( + f"Got nested tags: {self.outer.block_type_name} (started at {self.outer.start}) did " + f"not have a matching {{{{% end{self.outer.block_type_name} %}}}} before a " + f"subsequent {self.inner.block_type_name} was found (started at {self.inner.start})" + ) + return msg + + +class UnexpectedControlFlowEndTagError(CompilationError): + def __init__(self, tag, expected_tag: str, tag_parser) -> None: + self.tag = tag + self.expected_tag = expected_tag + self.tag_parser = tag_parser + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + linepos = self.tag_parser.linepos(self.tag.start) + msg = ( + f"Got an unexpected control flow end tag, got {self.tag.block_type_name} but " + f"never saw a preceeding {self.expected_tag} (@ {linepos})" + ) + return msg + + +class UnexpectedMacroEOFError(CompilationError): + def __init__(self, expected_name: str, actual_name: str) -> None: + self.expected_name = expected_name + self.actual_name = actual_name + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = f'unexpected EOF, expected {self.expected_name}, got "{self.actual_name}"' + return msg diff --git a/dbt_common/exceptions/macros.py b/dbt_common/exceptions/macros.py new file mode 100644 index 00000000..5fbefce3 --- /dev/null +++ b/dbt_common/exceptions/macros.py @@ -0,0 +1,107 @@ +from typing import Any + +from dbt_common.exceptions import CompilationError, DbtBaseException + + +class MacroReturn(DbtBaseException): + """ + Hack of all hacks + This is not actually an exception. + It's how we return a value from a macro. + """ + + def __init__(self, value) -> None: + self.value = value + + +class UndefinedMacroError(CompilationError): + def __str__(self, prefix: str = "! ") -> str: + msg = super().__str__(prefix) + return ( + f"{msg}. This can happen when calling a macro that does " + "not exist. Check for typos and/or install package dependencies " + 'with "dbt deps".' + ) + + +class UndefinedCompilationError(CompilationError): + def __init__(self, name: str, node) -> None: + self.name = name + self.node = node + self.msg = f"{self.name} is undefined" + super().__init__(msg=self.msg) + + +class CaughtMacroError(CompilationError): + def __init__(self, exc) -> None: + self.exc = exc + super().__init__(msg=str(exc)) + + +class CaughtMacroErrorWithNodeError(CompilationError): + def __init__(self, exc, node) -> None: + self.exc = exc + self.node = node + super().__init__(msg=str(exc)) + + +class JinjaRenderingError(CompilationError): + pass + + +class MaterializationArgError(CompilationError): + def __init__(self, name: str, argument: str) -> None: + self.name = name + self.argument = argument + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = f"materialization '{self.name}' received unknown argument '{self.argument}'." + return msg + + +class MacroNameNotStringError(CompilationError): + def __init__(self, kwarg_value) -> None: + self.kwarg_value = kwarg_value + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = f"The macro_name parameter ({self.kwarg_value}) " "to adapter.dispatch was not a string" + return msg + + +class MacrosSourcesUnWriteableError(CompilationError): + def __init__(self, node) -> None: + self.node = node + msg = 'cannot "write" macros or sources' + super().__init__(msg=msg) + + +class MacroArgTypeError(CompilationError): + def __init__(self, method_name: str, arg_name: str, got_value: Any, expected_type) -> None: + self.method_name = method_name + self.arg_name = arg_name + self.got_value = got_value + self.expected_type = expected_type + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + got_type = type(self.got_value) + msg = ( + f"'adapter.{self.method_name}' expects argument " + f"'{self.arg_name}' to be of type '{self.expected_type}', instead got " + f"{self.got_value} ({got_type})" + ) + return msg + + +class MacroResultError(CompilationError): + def __init__(self, freshness_macro_name: str, table): + self.freshness_macro_name = freshness_macro_name + self.table = table + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = f'Got an invalid result from "{self.freshness_macro_name}" macro: {[tuple(r) for r in self.table]}' + + return msg diff --git a/dbt_common/exceptions/system.py b/dbt_common/exceptions/system.py new file mode 100644 index 00000000..b0062f63 --- /dev/null +++ b/dbt_common/exceptions/system.py @@ -0,0 +1,50 @@ +from typing import List, Union, Any + +from dbt_common.exceptions import CompilationError, CommandError, scrub_secrets, env_secrets + + +class SymbolicLinkError(CompilationError): + def __init__(self) -> None: + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = ( + "dbt encountered an error when attempting to create a symbolic link. " + "If this error persists, please create an issue at: \n\n" + "https://github.com/dbt-labs/dbt-core" + ) + + return msg + + +class ExecutableError(CommandError): + def __init__(self, cwd: str, cmd: List[str], msg: str) -> None: + super().__init__(cwd, cmd, msg) + + +class WorkingDirectoryError(CommandError): + def __init__(self, cwd: str, cmd: List[str], msg: str) -> None: + super().__init__(cwd, cmd, msg) + + def __str__(self): + return f'{self.msg}: "{self.cwd}"' + + +class CommandResultError(CommandError): + def __init__( + self, + cwd: str, + cmd: List[str], + returncode: Union[int, Any], + stdout: bytes, + stderr: bytes, + msg: str = "Got a non-zero returncode", + ) -> None: + super().__init__(cwd, cmd, msg) + self.returncode = returncode + self.stdout = scrub_secrets(stdout.decode("utf-8"), env_secrets()) + self.stderr = scrub_secrets(stderr.decode("utf-8"), env_secrets()) + self.args = (cwd, self.cmd, returncode, self.stdout, self.stderr, msg) + + def __str__(self): + return f"{self.msg} running: {self.cmd}" diff --git a/dbt_common/helper_types.py b/dbt_common/helper_types.py new file mode 100644 index 00000000..d8631f38 --- /dev/null +++ b/dbt_common/helper_types.py @@ -0,0 +1,116 @@ +# never name this package "types", or mypy will crash in ugly ways + +# necessary for annotating constructors +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Tuple, AbstractSet, Union +from typing import Callable, cast, Generic, Optional, TypeVar, List, NewType, Set + +from dbt_common.dataclass_schema import ( + dbtClassMixin, + ValidationError, + StrEnum, +) + +Port = NewType("Port", int) + + +class NVEnum(StrEnum): + novalue = "novalue" + + def __eq__(self, other): + return isinstance(other, NVEnum) + + +@dataclass +class NoValue(dbtClassMixin): + """Sometimes, you want a way to say none that isn't None""" + + novalue: NVEnum = field(default_factory=lambda: NVEnum.novalue) + + +@dataclass +class IncludeExclude(dbtClassMixin): + INCLUDE_ALL = ("all", "*") + + include: Union[str, List[str]] + exclude: List[str] = field(default_factory=list) + + def __post_init__(self): + if isinstance(self.include, str) and self.include not in self.INCLUDE_ALL: + raise ValidationError(f"include must be one of {self.INCLUDE_ALL} or a list of strings") + + if self.exclude and self.include not in self.INCLUDE_ALL: + raise ValidationError(f"exclude can only be specified if include is one of {self.INCLUDE_ALL}") + + if isinstance(self.include, list): + self._validate_items(self.include) + + if isinstance(self.exclude, list): + self._validate_items(self.exclude) + + def includes(self, item_name: str): + return (item_name in self.include or self.include in self.INCLUDE_ALL) and item_name not in self.exclude + + def _validate_items(self, items: List[str]): + pass + + +class WarnErrorOptions(IncludeExclude): + def __init__( + self, + include: Union[str, List[str]], + exclude: Optional[List[str]] = None, + valid_error_names: Optional[Set[str]] = None, + ): + self._valid_error_names: Set[str] = valid_error_names or set() + super().__init__(include=include, exclude=(exclude or [])) + + def _validate_items(self, items: List[str]): + for item in items: + if item not in self._valid_error_names: + raise ValidationError(f"{item} is not a valid dbt error name.") + + +FQNPath = Tuple[str, ...] +PathSet = AbstractSet[FQNPath] + +T = TypeVar("T") + + +# A data type for representing lazily evaluated values. +# +# usage: +# x = Lazy.defer(lambda: expensive_fn()) +# y = x.force() +# +# inspired by the purescript data type +# https://pursuit.purescript.org/packages/purescript-lazy/5.0.0/docs/Data.Lazy +@dataclass +class Lazy(Generic[T]): + _f: Callable[[], T] + memo: Optional[T] = None + + # constructor for lazy values + @classmethod + def defer(cls, f: Callable[[], T]) -> Lazy[T]: + return Lazy(f) + + # workaround for open mypy issue: + # https://github.com/python/mypy/issues/6910 + def _typed_eval_f(self) -> T: + return cast(Callable[[], T], getattr(self, "_f"))() + + # evaluates the function if the value has not been memoized already + def force(self) -> T: + if self.memo is None: + self.memo = self._typed_eval_f() + return self.memo + + +# This class is used in to_target_dict, so that accesses to missing keys +# will return an empty string instead of Undefined +class DictDefaultEmptyStr(dict): + def __getitem__(self, key): + return dict.get(self, key, "") diff --git a/dbt_common/invocation.py b/dbt_common/invocation.py new file mode 100644 index 00000000..0e5d3206 --- /dev/null +++ b/dbt_common/invocation.py @@ -0,0 +1,12 @@ +import uuid + +_INVOCATION_ID = str(uuid.uuid4()) + + +def get_invocation_id() -> str: + return _INVOCATION_ID + + +def reset_invocation_id(): + global _INVOCATION_ID + _INVOCATION_ID = str(uuid.uuid4()) diff --git a/dbt_common/py.typed b/dbt_common/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/dbt_common/semver.py b/dbt_common/semver.py new file mode 100644 index 00000000..64620c53 --- /dev/null +++ b/dbt_common/semver.py @@ -0,0 +1,455 @@ +from dataclasses import dataclass +import re +from typing import List + +import dbt_common.exceptions.base +from dbt_common.exceptions import VersionsNotCompatibleError + +from dbt_common.dataclass_schema import dbtClassMixin, StrEnum +from typing import Optional + + +class Matchers(StrEnum): + GREATER_THAN = ">" + GREATER_THAN_OR_EQUAL = ">=" + LESS_THAN = "<" + LESS_THAN_OR_EQUAL = "<=" + EXACT = "=" + + +@dataclass +class VersionSpecification(dbtClassMixin): + major: Optional[str] = None + minor: Optional[str] = None + patch: Optional[str] = None + prerelease: Optional[str] = None + build: Optional[str] = None + matcher: Matchers = Matchers.EXACT + + +_MATCHERS = r"(?P \>=|\>|\<|\<=|=)?" +_NUM_NO_LEADING_ZEROS = r"(0|[1-9]\d*)" +_ALPHA = r"[0-9A-Za-z-]*" +_ALPHA_NO_LEADING_ZEROS = r"(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)" + +_BASE_VERSION_REGEX = r""" +(?P {num_no_leading_zeros})\. +(?P {num_no_leading_zeros})\. +(?P {num_no_leading_zeros}) +""".format( + num_no_leading_zeros=_NUM_NO_LEADING_ZEROS +) + +_VERSION_EXTRA_REGEX = r""" +(\-? + (?P + {alpha_no_leading_zeros}(\.{alpha_no_leading_zeros})*))? +(\+ + (?P + {alpha}(\.{alpha})*))? +""".format( + alpha_no_leading_zeros=_ALPHA_NO_LEADING_ZEROS, alpha=_ALPHA +) + + +_VERSION_REGEX_PAT_STR = r""" +^ +{matchers} +{base_version_regex} +{version_extra_regex} +$ +""".format( + matchers=_MATCHERS, + base_version_regex=_BASE_VERSION_REGEX, + version_extra_regex=_VERSION_EXTRA_REGEX, +) + +_VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE) + + +def _cmp(a, b): + """Return negative if ab.""" + return (a > b) - (a < b) + + +@dataclass +class VersionSpecifier(VersionSpecification): + def to_version_string(self, skip_matcher=False): + prerelease = "" + build = "" + matcher = "" + + if self.prerelease: + prerelease = "-" + self.prerelease + + if self.build: + build = "+" + self.build + + if not skip_matcher: + matcher = self.matcher + return "{}{}.{}.{}{}{}".format(matcher, self.major, self.minor, self.patch, prerelease, build) + + @classmethod + def from_version_string(cls, version_string): + match = _VERSION_REGEX.match(version_string) + + if not match: + raise dbt_common.exceptions.base.SemverError(f'"{version_string}" is not a valid semantic version.') + + matched = {k: v for k, v in match.groupdict().items() if v is not None} + + return cls.from_dict(matched) + + def __str__(self): + return self.to_version_string() + + def to_range(self) -> "VersionRange": + range_start: VersionSpecifier = UnboundedVersionSpecifier() + range_end: VersionSpecifier = UnboundedVersionSpecifier() + + if self.matcher == Matchers.EXACT: + range_start = self + range_end = self + + elif self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL]: + range_start = self + + elif self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL]: + range_end = self + + return VersionRange(start=range_start, end=range_end) + + def compare(self, other): + if self.is_unbounded or other.is_unbounded: + return 0 + + for key in ["major", "minor", "patch", "prerelease"]: + (a, b) = (getattr(self, key), getattr(other, key)) + if key == "prerelease": + if a is None and b is None: + continue + if a is None: + if self.matcher == Matchers.LESS_THAN: + # If 'a' is not a pre-release but 'b' is, and b must be + # less than a, return -1 to prevent installations of + # pre-releases with greater base version than a + # maximum specified non-pre-release version. + return -1 + # Otherwise, stable releases are considered greater than + # pre-release + return 1 + if b is None: + return -1 + + # Check the prerelease component only + prcmp = self._nat_cmp(a, b) + if prcmp != 0: # either -1 or 1 + return prcmp + # else is equal and will fall through + + else: # major/minor/patch, should all be numbers + if int(a) > int(b): + return 1 + elif int(a) < int(b): + return -1 + # else is equal and will fall through + + equal = (self.matcher == Matchers.GREATER_THAN_OR_EQUAL and other.matcher == Matchers.LESS_THAN_OR_EQUAL) or ( + self.matcher == Matchers.LESS_THAN_OR_EQUAL and other.matcher == Matchers.GREATER_THAN_OR_EQUAL + ) + if equal: + return 0 + + lt = ( + (self.matcher == Matchers.LESS_THAN and other.matcher == Matchers.LESS_THAN_OR_EQUAL) + or (other.matcher == Matchers.GREATER_THAN and self.matcher == Matchers.GREATER_THAN_OR_EQUAL) + or (self.is_upper_bound and other.is_lower_bound) + ) + if lt: + return -1 + + gt = ( + (other.matcher == Matchers.LESS_THAN and self.matcher == Matchers.LESS_THAN_OR_EQUAL) + or (self.matcher == Matchers.GREATER_THAN and other.matcher == Matchers.GREATER_THAN_OR_EQUAL) + or (self.is_lower_bound and other.is_upper_bound) + ) + if gt: + return 1 + + return 0 + + def __lt__(self, other): + return self.compare(other) == -1 + + def __gt__(self, other): + return self.compare(other) == 1 + + def __eq___(self, other): + return self.compare(other) == 0 + + def __cmp___(self, other): + return self.compare(other) + + @property + def is_unbounded(self): + return False + + @property + def is_lower_bound(self): + return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL] + + @property + def is_upper_bound(self): + return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL] + + @property + def is_exact(self): + return self.matcher == Matchers.EXACT + + @classmethod + def _nat_cmp(cls, a, b): + def cmp_prerelease_tag(a, b): + if isinstance(a, int) and isinstance(b, int): + return _cmp(a, b) + elif isinstance(a, int): + return -1 + elif isinstance(b, int): + return 1 + else: + return _cmp(a, b) + + a, b = a or "", b or "" + a_parts, b_parts = a.split("."), b.split(".") + a_parts = [int(x) if re.match(r"^\d+$", x) else x for x in a_parts] + b_parts = [int(x) if re.match(r"^\d+$", x) else x for x in b_parts] + for sub_a, sub_b in zip(a_parts, b_parts): + cmp_result = cmp_prerelease_tag(sub_a, sub_b) + if cmp_result != 0: + return cmp_result + else: + return _cmp(len(a), len(b)) + + +@dataclass +class VersionRange: + start: VersionSpecifier + end: VersionSpecifier + + def _try_combine_exact(self, a, b): + if a.compare(b) == 0: + return a + else: + raise VersionsNotCompatibleError() + + def _try_combine_lower_bound_with_exact(self, lower, exact): + comparison = lower.compare(exact) + + if comparison < 0 or (comparison == 0 and lower.matcher == Matchers.GREATER_THAN_OR_EQUAL): + return exact + + raise VersionsNotCompatibleError() + + def _try_combine_lower_bound(self, a, b): + if b.is_unbounded: + return a + elif a.is_unbounded: + return b + + if not (a.is_exact or b.is_exact): + comparison = a.compare(b) < 0 + + if comparison: + return b + else: + return a + + elif a.is_exact: + return self._try_combine_lower_bound_with_exact(b, a) + + elif b.is_exact: + return self._try_combine_lower_bound_with_exact(a, b) + + def _try_combine_upper_bound_with_exact(self, upper, exact): + comparison = upper.compare(exact) + + if comparison > 0 or (comparison == 0 and upper.matcher == Matchers.LESS_THAN_OR_EQUAL): + return exact + + raise VersionsNotCompatibleError() + + def _try_combine_upper_bound(self, a, b): + if b.is_unbounded: + return a + elif a.is_unbounded: + return b + + if not (a.is_exact or b.is_exact): + comparison = a.compare(b) > 0 + + if comparison: + return b + else: + return a + + elif a.is_exact: + return self._try_combine_upper_bound_with_exact(b, a) + + elif b.is_exact: + return self._try_combine_upper_bound_with_exact(a, b) + + def reduce(self, other): + start = None + + if self.start.is_exact and other.start.is_exact: + start = end = self._try_combine_exact(self.start, other.start) + + else: + start = self._try_combine_lower_bound(self.start, other.start) + end = self._try_combine_upper_bound(self.end, other.end) + + if start.compare(end) > 0: + raise VersionsNotCompatibleError() + + return VersionRange(start=start, end=end) + + def __str__(self): + result = [] + + if self.start.is_unbounded and self.end.is_unbounded: + return "ANY" + + if not self.start.is_unbounded: + result.append(self.start.to_version_string()) + + if not self.end.is_unbounded: + result.append(self.end.to_version_string()) + + return ", ".join(result) + + def to_version_string_pair(self): + to_return = [] + + if not self.start.is_unbounded: + to_return.append(self.start.to_version_string()) + + if not self.end.is_unbounded: + to_return.append(self.end.to_version_string()) + + return to_return + + +class UnboundedVersionSpecifier(VersionSpecifier): + def __init__(self, *args, **kwargs) -> None: + super().__init__(matcher=Matchers.EXACT, major=None, minor=None, patch=None, prerelease=None, build=None) + + def __str__(self): + return "*" + + @property + def is_unbounded(self): + return True + + @property + def is_lower_bound(self): + return False + + @property + def is_upper_bound(self): + return False + + @property + def is_exact(self): + return False + + +def reduce_versions(*args): + version_specifiers = [] + + for version in args: + if isinstance(version, UnboundedVersionSpecifier) or version is None: + continue + + elif isinstance(version, VersionSpecifier): + version_specifiers.append(version) + + elif isinstance(version, VersionRange): + if not isinstance(version.start, UnboundedVersionSpecifier): + version_specifiers.append(version.start) + + if not isinstance(version.end, UnboundedVersionSpecifier): + version_specifiers.append(version.end) + + else: + version_specifiers.append(VersionSpecifier.from_version_string(version)) + + for version_specifier in version_specifiers: + if not isinstance(version_specifier, VersionSpecifier): + raise Exception(version_specifier) + + if not version_specifiers: + return VersionRange(start=UnboundedVersionSpecifier(), end=UnboundedVersionSpecifier()) + + try: + to_return = version_specifiers.pop().to_range() + + for version_specifier in version_specifiers: + to_return = to_return.reduce(version_specifier.to_range()) + except VersionsNotCompatibleError: + raise VersionsNotCompatibleError( + "Could not find a satisfactory version from options: {}".format([str(a) for a in args]) + ) + + return to_return + + +def versions_compatible(*args): + if len(args) == 1: + return True + + try: + reduce_versions(*args) + return True + except VersionsNotCompatibleError: + return False + + +def find_possible_versions(requested_range, available_versions): + possible_versions = [] + + for version_string in available_versions: + version = VersionSpecifier.from_version_string(version_string) + + if versions_compatible(version, requested_range.start, requested_range.end): + possible_versions.append(version) + + sorted_versions = sorted(possible_versions, reverse=True) + return [v.to_version_string(skip_matcher=True) for v in sorted_versions] + + +def resolve_to_specific_version(requested_range, available_versions): + max_version = None + max_version_string = None + + for version_string in available_versions: + version = VersionSpecifier.from_version_string(version_string) + + if versions_compatible(version, requested_range.start, requested_range.end) and ( + max_version is None or max_version.compare(version) < 0 + ): + max_version = version + max_version_string = version_string + + return max_version_string + + +def filter_installable(versions: List[str], install_prerelease: bool) -> List[str]: + installable = [] + installable_dict = {} + for version_string in versions: + version = VersionSpecifier.from_version_string(version_string) + if install_prerelease or not version.prerelease: + installable.append(version) + installable_dict[str(version)] = version_string + sorted_installable = sorted(installable) + sorted_installable_original_versions = [str(installable_dict.get(str(version))) for version in sorted_installable] + return sorted_installable_original_versions diff --git a/dbt_common/ui.py b/dbt_common/ui.py new file mode 100644 index 00000000..2cc7c5ef --- /dev/null +++ b/dbt_common/ui.py @@ -0,0 +1,81 @@ +from os import getenv as os_getenv +import sys +import textwrap +from typing import Dict + +import colorama + +# Colorama is needed for colored logs on Windows because we're using logger.info +# intead of print(). If the Windows env doesn't have a TERM var set or it is set to None +# (i.e. in the case of Git Bash on Windows- this emulates Unix), then it's safe to initialize +# Colorama with wrapping turned on which allows us to strip ANSI sequences from stdout. +# You can safely initialize Colorama for any OS and the coloring stays the same except +# when piped to another process for Linux and MacOS, then it loses the coloring. To combat +# that, we will just initialize Colorama when needed on Windows using a non-Unix terminal. + +if sys.platform == "win32" and (not os_getenv("TERM") or os_getenv("TERM") == "None"): + colorama.init(wrap=True) + +COLORS: Dict[str, str] = { + "red": colorama.Fore.RED, + "green": colorama.Fore.GREEN, + "yellow": colorama.Fore.YELLOW, + "reset_all": colorama.Style.RESET_ALL, +} + + +COLOR_FG_RED = COLORS["red"] +COLOR_FG_GREEN = COLORS["green"] +COLOR_FG_YELLOW = COLORS["yellow"] +COLOR_RESET_ALL = COLORS["reset_all"] + + +USE_COLOR = True +PRINTER_WIDTH = 80 + + +def color(text: str, color_code: str) -> str: + if USE_COLOR: + return "{}{}{}".format(color_code, text, COLOR_RESET_ALL) + else: + return text + + +def printer_width() -> int: + return PRINTER_WIDTH + + +def green(text: str) -> str: + return color(text, COLOR_FG_GREEN) + + +def yellow(text: str) -> str: + return color(text, COLOR_FG_YELLOW) + + +def red(text: str) -> str: + return color(text, COLOR_FG_RED) + + +def line_wrap_message(msg: str, subtract: int = 0, dedent: bool = True, prefix: str = "") -> str: + """ + Line wrap the given message to PRINTER_WIDTH - {subtract}. Convert double + newlines to newlines and avoid calling textwrap.fill() on them (like + markdown) + """ + width = printer_width() - subtract + if dedent: + msg = textwrap.dedent(msg) + + if prefix: + msg = f"{prefix}{msg}" + + # If the input had an explicit double newline, we want to preserve that + # (we'll turn it into a single line soon). Support windows, too. + splitter = "\r\n\r\n" if "\r\n\r\n" in msg else "\n\n" + chunks = msg.split(splitter) + return "\n".join(textwrap.fill(chunk, width=width, break_on_hyphens=False) for chunk in chunks) + + +def warning_tag(msg: str) -> str: + return f'[{yellow("WARNING")}]: {msg}' diff --git a/dbt_common/utils/__init__.py b/dbt_common/utils/__init__.py new file mode 100644 index 00000000..16523f5c --- /dev/null +++ b/dbt_common/utils/__init__.py @@ -0,0 +1,26 @@ +from dbt_common.utils.encoding import md5, JSONEncoder, ForgivingJSONEncoder + +from dbt_common.utils.casting import ( + cast_to_str, + cast_to_int, + cast_dict_to_dict_of_strings, +) + +from dbt_common.utils.dict import ( + AttrDict, + filter_null_values, + merge, + deep_merge, + deep_merge_item, + deep_map_render, +) + +from dbt_common.utils.executor import executor + +from dbt_common.utils.jinja import ( + get_dbt_macro_name, + get_docs_macro_name, + get_materialization_macro_name, + get_test_macro_name, + MACRO_PREFIX, +) diff --git a/dbt_common/utils/casting.py b/dbt_common/utils/casting.py new file mode 100644 index 00000000..811ea376 --- /dev/null +++ b/dbt_common/utils/casting.py @@ -0,0 +1,25 @@ +# This is useful for proto generated classes in particular, since +# the default for protobuf for strings is the empty string, so +# Optional[str] types don't work for generated Python classes. +from typing import Optional + + +def cast_to_str(string: Optional[str]) -> str: + if string is None: + return "" + else: + return string + + +def cast_to_int(integer: Optional[int]) -> int: + if integer is None: + return 0 + else: + return integer + + +def cast_dict_to_dict_of_strings(dct): + new_dct = {} + for k, v in dct.items(): + new_dct[str(k)] = str(v) + return new_dct diff --git a/dbt_common/utils/connection.py b/dbt_common/utils/connection.py new file mode 100644 index 00000000..890c3e99 --- /dev/null +++ b/dbt_common/utils/connection.py @@ -0,0 +1,33 @@ +import time + +from dbt_common.events.types import RecordRetryException, RetryExternalCall +from dbt_common.exceptions import ConnectionError +from tarfile import ReadError + +import requests + + +def connection_exception_retry(fn, max_attempts: int, attempt: int = 0): + """Attempts to run a function that makes an external call, if the call fails + on a Requests exception or decompression issue (ReadError), it will be tried + up to 5 more times. All exceptions that Requests explicitly raises inherit from + requests.exceptions.RequestException. See https://github.com/dbt-labs/dbt-core/issues/4579 + for context on this decompression issues specifically. + """ + try: + return fn() + except ( + requests.exceptions.RequestException, + ReadError, + EOFError, + ) as exc: + if attempt <= max_attempts - 1: + # This import needs to be inline to avoid circular dependency + from dbt_common.events.functions import fire_event + + fire_event(RecordRetryException(exc=str(exc))) + fire_event(RetryExternalCall(attempt=attempt, max=max_attempts)) + time.sleep(1) + return connection_exception_retry(fn, max_attempts, attempt + 1) + else: + raise ConnectionError("External connection exception occurred: " + str(exc)) diff --git a/dbt_common/utils/dict.py b/dbt_common/utils/dict.py new file mode 100644 index 00000000..ff97d185 --- /dev/null +++ b/dbt_common/utils/dict.py @@ -0,0 +1,126 @@ +import copy +import datetime +from typing import Dict, Optional, TypeVar, Callable, Any, Tuple, Union, Type + +from dbt_common.exceptions import DbtConfigError, RecursionError + +K_T = TypeVar("K_T") +V_T = TypeVar("V_T") + + +def filter_null_values(input: Dict[K_T, Optional[V_T]]) -> Dict[K_T, V_T]: + return {k: v for k, v in input.items() if v is not None} + + +class AttrDict(dict): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def merge(*args): + if len(args) == 0: + return None + + if len(args) == 1: + return args[0] + + lst = list(args) + last = lst.pop(len(lst) - 1) + + return _merge(merge(*lst), last) + + +def _merge(a, b): + to_return = a.copy() + to_return.update(b) + return to_return + + +# http://stackoverflow.com/questions/20656135/python-deep-merge-dictionary-data +def deep_merge(*args): + """ + >>> dbt_common.utils.deep_merge({'a': 1, 'b': 2, 'c': 3}, {'a': 2}, {'a': 3, 'b': 1}) # noqa + {'a': 3, 'b': 1, 'c': 3} + """ + if len(args) == 0: + return None + + if len(args) == 1: + return copy.deepcopy(args[0]) + + lst = list(args) + last = copy.deepcopy(lst.pop(len(lst) - 1)) + + return _deep_merge(deep_merge(*lst), last) + + +def _deep_merge(destination, source): + if isinstance(source, dict): + for key, value in source.items(): + deep_merge_item(destination, key, value) + return destination + + +def deep_merge_item(destination, key, value): + if isinstance(value, dict): + node = destination.setdefault(key, {}) + destination[key] = deep_merge(node, value) + elif isinstance(value, tuple) or isinstance(value, list): + if key in destination: + destination[key] = list(value) + list(destination[key]) + else: + destination[key] = value + else: + destination[key] = value + + +def _deep_map_render( + func: Callable[[Any, Tuple[Union[str, int], ...]], Any], + value: Any, + keypath: Tuple[Union[str, int], ...], +) -> Any: + atomic_types: Tuple[Type[Any], ...] = (int, float, str, type(None), bool, datetime.date) + + ret: Any + + if isinstance(value, list): + ret = [_deep_map_render(func, v, (keypath + (idx,))) for idx, v in enumerate(value)] + elif isinstance(value, dict): + ret = {k: _deep_map_render(func, v, (keypath + (str(k),))) for k, v in value.items()} + elif isinstance(value, atomic_types): + ret = func(value, keypath) + else: + container_types: Tuple[Type[Any], ...] = (list, dict) + ok_types = container_types + atomic_types + raise DbtConfigError("in _deep_map_render, expected one of {!r}, got {!r}".format(ok_types, type(value))) + + return ret + + +def deep_map_render(func: Callable[[Any, Tuple[Union[str, int], ...]], Any], value: Any) -> Any: + """This function renders a nested dictionary derived from a yaml + file. It is used to render dbt_project.yml, profiles.yml, and + schema files. + + It maps the function func() onto each non-container value in 'value' + recursively, returning a new value. As long as func does not manipulate + the value, then deep_map_render will also not manipulate it. + + value should be a value returned by `yaml.safe_load` or `json.load` - the + only expected types are list, dict, native python number, str, NoneType, + and bool. + + func() will be called on numbers, strings, Nones, and booleans. Its first + parameter will be the value, and the second will be its keypath, an + iterable over the __getitem__ keys needed to get to it. + + :raises: If there are cycles in the value, raises a + dbt_common.exceptions.RecursionError + """ + try: + return _deep_map_render(func, value, ()) + except RuntimeError as exc: + if "maximum recursion depth exceeded" in str(exc): + raise RecursionError("Cycle detected in deep_map_render") + raise diff --git a/dbt_common/utils/encoding.py b/dbt_common/utils/encoding.py new file mode 100644 index 00000000..c741e52f --- /dev/null +++ b/dbt_common/utils/encoding.py @@ -0,0 +1,56 @@ +import datetime +import decimal +import hashlib +import json +from typing import Tuple, Type, Any + +import jinja2 +import sys + +DECIMALS: Tuple[Type[Any], ...] +try: + import cdecimal # typing: ignore +except ImportError: + DECIMALS = (decimal.Decimal,) +else: + DECIMALS = (decimal.Decimal, cdecimal.Decimal) + + +def md5(string, charset="utf-8"): + if sys.version_info >= (3, 9): + return hashlib.md5(string.encode(charset), usedforsecurity=False).hexdigest() + else: + return hashlib.md5(string.encode(charset)).hexdigest() + + +class JSONEncoder(json.JSONEncoder): + """A 'custom' json encoder that does normal json encoder things, but also + handles `Decimal`s and `Undefined`s. Decimals can lose precision because + they get converted to floats. Undefined's are serialized to an empty string + """ + + def default(self, obj): + if isinstance(obj, DECIMALS): + return float(obj) + elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): + return obj.isoformat() + elif isinstance(obj, jinja2.Undefined): + return "" + elif isinstance(obj, Exception): + return repr(obj) + elif hasattr(obj, "to_dict"): + # if we have a to_dict we should try to serialize the result of + # that! + return obj.to_dict(omit_none=True) + else: + return super().default(obj) + + +class ForgivingJSONEncoder(JSONEncoder): + def default(self, obj): + # let dbt's default JSON encoder handle it if possible, fallback to + # str() + try: + return super().default(obj) + except TypeError: + return str(obj) diff --git a/dbt_common/utils/executor.py b/dbt_common/utils/executor.py new file mode 100644 index 00000000..afe5d6da --- /dev/null +++ b/dbt_common/utils/executor.py @@ -0,0 +1,63 @@ +import concurrent.futures +from contextlib import contextmanager +from typing import Protocol, Optional + + +class ConnectingExecutor(concurrent.futures.Executor): + def submit_connected(self, adapter, conn_name, func, *args, **kwargs): + def connected(conn_name, func, *args, **kwargs): + with self.connection_named(adapter, conn_name): + return func(*args, **kwargs) + + return self.submit(connected, conn_name, func, *args, **kwargs) + + +# a little concurrent.futures.Executor for single-threaded mode +class SingleThreadedExecutor(ConnectingExecutor): + def submit(*args, **kwargs): + # this basic pattern comes from concurrent.futures.Executor itself, + # but without handling the `fn=` form. + if len(args) >= 2: + self, fn, *args = args + elif not args: + raise TypeError("descriptor 'submit' of 'SingleThreadedExecutor' object needs an argument") + else: + raise TypeError("submit expected at least 1 positional argument, got %d" % (len(args) - 1)) + fut = concurrent.futures.Future() + try: + result = fn(*args, **kwargs) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(result) + return fut + + @contextmanager + def connection_named(self, adapter, name): + yield + + +class MultiThreadedExecutor( + ConnectingExecutor, + concurrent.futures.ThreadPoolExecutor, +): + @contextmanager + def connection_named(self, adapter, name): + with adapter.connection_named(name): + yield + + +class ThreadedArgs(Protocol): + single_threaded: bool + + +class HasThreadingConfig(Protocol): + args: ThreadedArgs + threads: Optional[int] + + +def executor(config: HasThreadingConfig) -> ConnectingExecutor: + if config.args.single_threaded: + return SingleThreadedExecutor() + else: + return MultiThreadedExecutor(max_workers=config.threads) diff --git a/dbt_common/utils/formatting.py b/dbt_common/utils/formatting.py new file mode 100644 index 00000000..08354b23 --- /dev/null +++ b/dbt_common/utils/formatting.py @@ -0,0 +1,8 @@ +from typing import Optional + + +def lowercase(value: Optional[str]) -> Optional[str]: + if value is None: + return None + else: + return value.lower() diff --git a/dbt_common/utils/jinja.py b/dbt_common/utils/jinja.py new file mode 100644 index 00000000..36464cbe --- /dev/null +++ b/dbt_common/utils/jinja.py @@ -0,0 +1,33 @@ +from dbt_common.exceptions import DbtInternalError + + +MACRO_PREFIX = "dbt_macro__" +DOCS_PREFIX = "dbt_docs__" + + +def get_dbt_macro_name(name): + if name is None: + raise DbtInternalError("Got None for a macro name!") + return f"{MACRO_PREFIX}{name}" + + +def get_dbt_docs_name(name): + if name is None: + raise DbtInternalError("Got None for a doc name!") + return f"{DOCS_PREFIX}{name}" + + +def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True): + if adapter_type is None: + adapter_type = "default" + name = f"materialization_{materialization_name}_{adapter_type}" + return get_dbt_macro_name(name) if with_prefix else name + + +def get_docs_macro_name(docs_name, with_prefix=True): + return get_dbt_docs_name(docs_name) if with_prefix else docs_name + + +def get_test_macro_name(test_name, with_prefix=True): + name = f"test_{test_name}" + return get_dbt_macro_name(name) if with_prefix else name diff --git a/dev-requirements.txt b/dev-requirements.txt deleted file mode 100644 index 79cfec94..00000000 --- a/dev-requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -black==23.11.0 -bumpversion -flake8 -flaky -freezegun==1.3.1 -hypothesis -ipdb -mypy==1.7.1 -pip-tools -pre-commit diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..4a54150c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,114 @@ +[project] +name = "dbt-common" +version = "0.0.1" +description = "The shared common utilities that dbt-core and adapter implementations use" +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "dbt Labs", email = "info@dbtlabs.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "agate~=1.7.0", + "colorama>=0.3.9,<0.5", + "isodate>=0.6,<0.7", + "jsonschema~=4.0", + "Jinja2~=3.0", + "mashumaro[msgpack]~=3.9", + "pathspec>=0.9,<0.12", # TODO: I'm not sure this is needed. check search.py? + "protobuf>=4.0.0", + "python-dateutil~=2.0", + "requests<3.0.0", + "typing-extensions~=4.4", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.sdist] +exclude = [ + "/.github", + "/.changes", + ".changie.yaml", + ".gitignore", + ".pre-commit-config.yaml", + "CONTRIBUTING.md", + "MAKEFILE", + "/tests", +] + +[tool.hatch.build.targets.wheel] +packages = ["dbt_common"] + +[tool.hatch.envs.dev-env.scripts] +all = ["pre-commit run --all-files"] + +[tool.hatch.envs.dev-env] +description = "Env for running development commands like pytest / pre-commit" +dependencies = [ + "pytest~=7.3", + "pytest-xdist~=3.2", + "httpx~=0.24", + "hypothesis~=6.87", + "pre-commit~=3.2", + "isort~=5.12", + "black~=23.3", + "ruff==0.0.260", + "mypy~=1.3", + "pytest~=7.3", + "types-Jinja2~=2.11", + "types-jsonschema~=4.17", + "types-python-dateutil~=2.8", + "types-PyYAML~=6.0", +] + +[tool.ruff] +line-length = 120 +select = [ + "E", # Pycodestyle + "F", # Pyflakes + "W", # Whitespace + "D", # Pydocs +] +ignore = [ + # Missing docstring in public module -- often docs handled within classes + "D100", + # Missing docstring in public package -- often docs handled within files not __init__.py + "D104" +] +# Let ruff autofix these errors. +# F401 - Unused imports. +fixable = ["F401"] + +[tool.ruff.pydocstyle] +convention = "google" + +[tool.mypy] +mypy_path = "third-party-stubs/" +namespace_packages = true +warn_unused_configs = true +disallow_untyped_defs = true +warn_redundant_casts = true + +# Don't run the extensive mypy checks on custom stubs +[[tool.mypy.overrides]] +module = ["logbook.*"] +disallow_untyped_defs = false + +[tool.isort] +profile = "black" + +[tool.black] +line-length = 120 diff --git a/tests/unit/test_agate_helper.py b/tests/unit/test_agate_helper.py new file mode 100644 index 00000000..2e3595a1 --- /dev/null +++ b/tests/unit/test_agate_helper.py @@ -0,0 +1,221 @@ +import unittest + +import agate + +from datetime import datetime +from decimal import Decimal +from isodate import tzinfo +import os +from shutil import rmtree +from tempfile import mkdtemp +from dbt_common.clients import agate_helper + +SAMPLE_CSV_DATA = """a,b,c,d,e,f,g +1,n,test,3.2,20180806T11:33:29.320Z,True,NULL +2,y,asdf,900,20180806T11:35:29.320Z,False,a string""" + +SAMPLE_CSV_BOM_DATA = "\ufeff" + SAMPLE_CSV_DATA + + +EXPECTED = [ + [ + 1, + "n", + "test", + Decimal("3.2"), + datetime(2018, 8, 6, 11, 33, 29, 320000, tzinfo=tzinfo.Utc()), + True, + None, + ], + [ + 2, + "y", + "asdf", + 900, + datetime(2018, 8, 6, 11, 35, 29, 320000, tzinfo=tzinfo.Utc()), + False, + "a string", + ], +] + + +EXPECTED_STRINGS = [ + ["1", "n", "test", "3.2", "20180806T11:33:29.320Z", "True", None], + ["2", "y", "asdf", "900", "20180806T11:35:29.320Z", "False", "a string"], +] + + +class TestAgateHelper(unittest.TestCase): + def setUp(self): + self.tempdir = mkdtemp() + + def tearDown(self): + rmtree(self.tempdir) + + def test_from_csv(self): + path = os.path.join(self.tempdir, "input.csv") + with open(path, "wb") as fp: + fp.write(SAMPLE_CSV_DATA.encode("utf-8")) + tbl = agate_helper.from_csv(path, ()) + self.assertEqual(len(tbl), len(EXPECTED)) + for idx, row in enumerate(tbl): + self.assertEqual(list(row), EXPECTED[idx]) + + def test_bom_from_csv(self): + path = os.path.join(self.tempdir, "input.csv") + with open(path, "wb") as fp: + fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8")) + tbl = agate_helper.from_csv(path, ()) + self.assertEqual(len(tbl), len(EXPECTED)) + for idx, row in enumerate(tbl): + self.assertEqual(list(row), EXPECTED[idx]) + + def test_from_csv_all_reserved(self): + path = os.path.join(self.tempdir, "input.csv") + with open(path, "wb") as fp: + fp.write(SAMPLE_CSV_DATA.encode("utf-8")) + tbl = agate_helper.from_csv(path, tuple("abcdefg")) + self.assertEqual(len(tbl), len(EXPECTED_STRINGS)) + for expected, row in zip(EXPECTED_STRINGS, tbl): + self.assertEqual(list(row), expected) + + def test_from_data(self): + column_names = ["a", "b", "c", "d", "e", "f", "g"] + data = [ + { + "a": "1", + "b": "n", + "c": "test", + "d": "3.2", + "e": "20180806T11:33:29.320Z", + "f": "True", + "g": "NULL", + }, + { + "a": "2", + "b": "y", + "c": "asdf", + "d": "900", + "e": "20180806T11:35:29.320Z", + "f": "False", + "g": "a string", + }, + ] + tbl = agate_helper.table_from_data(data, column_names) + self.assertEqual(len(tbl), len(EXPECTED)) + for idx, row in enumerate(tbl): + self.assertEqual(list(row), EXPECTED[idx]) + + def test_datetime_formats(self): + path = os.path.join(self.tempdir, "input.csv") + datetimes = [ + "20180806T11:33:29.000Z", + "20180806T11:33:29Z", + "20180806T113329Z", + ] + expected = datetime(2018, 8, 6, 11, 33, 29, 0, tzinfo=tzinfo.Utc()) + for dt in datetimes: + with open(path, "wb") as fp: + fp.write("a\n{}".format(dt).encode("utf-8")) + tbl = agate_helper.from_csv(path, ()) + self.assertEqual(tbl[0][0], expected) + + def test_merge_allnull(self): + t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c")) + t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c")) + result = agate_helper.merge_tables([t1, t2]) + self.assertEqual(result.column_names, ("a", "b", "c")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate_helper.Integer) + self.assertEqual(len(result), 4) + + def test_merge_mixed(self): + t1 = agate_helper.table_from_rows([(1, "a", None, None), (2, "b", None, None)], ("a", "b", "c", "d")) + t2 = agate_helper.table_from_rows([(3, "c", "dog", 1), (4, "d", "cat", 5)], ("a", "b", "c", "d")) + t3 = agate_helper.table_from_rows([(3, "c", None, 1.5), (4, "d", None, 3.5)], ("a", "b", "c", "d")) + + result = agate_helper.merge_tables([t1, t2]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + assert isinstance(result.column_types[3], agate_helper.Integer) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t1, t3]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate_helper.Integer) + assert isinstance(result.column_types[3], agate.data_types.Number) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t2, t3]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + assert isinstance(result.column_types[3], agate.data_types.Number) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t3, t2]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + assert isinstance(result.column_types[3], agate.data_types.Number) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t1, t2, t3]) + self.assertEqual(result.column_names, ("a", "b", "c", "d")) + assert isinstance(result.column_types[0], agate_helper.Integer) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + assert isinstance(result.column_types[3], agate.data_types.Number) + self.assertEqual(len(result), 6) + + def test_nocast_string_types(self): + # String fields should not be coerced into a representative type + # See: https://github.com/dbt-labs/dbt-core/issues/2984 + + column_names = ["a", "b", "c", "d", "e"] + result_set = [ + {"a": "0005", "b": "01T00000aabbccdd", "c": "true", "d": 10, "e": False}, + {"a": "0006", "b": "01T00000aabbccde", "c": "false", "d": 11, "e": True}, + ] + + tbl = agate_helper.table_from_data_flat(data=result_set, column_names=column_names) + self.assertEqual(len(tbl), len(result_set)) + + expected = [ + ["0005", "01T00000aabbccdd", "true", Decimal(10), False], + ["0006", "01T00000aabbccde", "false", Decimal(11), True], + ] + + for i, row in enumerate(tbl): + self.assertEqual(list(row), expected[i]) + + def test_nocast_bool_01(self): + # True and False values should not be cast to 1 and 0, and vice versa + # See: https://github.com/dbt-labs/dbt-core/issues/4511 + + column_names = ["a", "b"] + result_set = [ + {"a": True, "b": 1}, + {"a": False, "b": 0}, + ] + + tbl = agate_helper.table_from_data_flat(data=result_set, column_names=column_names) + self.assertEqual(len(tbl), len(result_set)) + + assert isinstance(tbl.column_types[0], agate.data_types.Boolean) + assert isinstance(tbl.column_types[1], agate_helper.Integer) + + expected = [ + [True, Decimal(1)], + [False, Decimal(0)], + ] + + for i, row in enumerate(tbl): + self.assertEqual(list(row), expected[i]) diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py new file mode 100644 index 00000000..817af7a2 --- /dev/null +++ b/tests/unit/test_connection_retries.py @@ -0,0 +1,59 @@ +import functools +import pytest +from requests.exceptions import RequestException +from dbt_common.exceptions import ConnectionError +from dbt_common.utils.connection import connection_exception_retry + + +def no_retry_fn(): + return "success" + + +class TestNoRetries: + def test_no_retry(self): + fn_to_retry = functools.partial(no_retry_fn) + result = connection_exception_retry(fn_to_retry, 3) + + expected = "success" + + assert result == expected + + +def no_success_fn(): + raise RequestException("You'll never pass") + return "failure" + + +class TestMaxRetries: + def test_no_retry(self): + fn_to_retry = functools.partial(no_success_fn) + + with pytest.raises(ConnectionError): + connection_exception_retry(fn_to_retry, 3) + + +def single_retry_fn(): + global counter + if counter == 0: + counter += 1 + raise RequestException("You won't pass this one time") + elif counter == 1: + counter += 1 + return "success on 2" + + return "How did we get here?" + + +class TestSingleRetry: + def test_no_retry(self): + global counter + counter = 0 + + fn_to_retry = functools.partial(single_retry_fn) + result = connection_exception_retry(fn_to_retry, 3) + expected = "success on 2" + + # We need to test the return value here, not just that it did not throw an error. + # If the value is not being passed it causes cryptic errors + assert result == expected + assert counter == 2 diff --git a/tests/unit/test_core_dbt_utils.py b/tests/unit/test_core_dbt_utils.py new file mode 100644 index 00000000..3a31c60d --- /dev/null +++ b/tests/unit/test_core_dbt_utils.py @@ -0,0 +1,73 @@ +import requests +import tarfile +import unittest + +from dbt_common.exceptions import ConnectionError +from dbt_common.utils.connection import connection_exception_retry + + +class TestCommonDbtUtils(unittest.TestCase): + def test_connection_exception_retry_none(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add(), 5) + self.assertEqual(1, counter) + + def test_connection_exception_retry_success_requests_exception(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_requests_exception(), 5) + self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry + + def test_connection_exception_retry_max(self): + Counter._reset() + with self.assertRaises(ConnectionError): + connection_exception_retry(lambda: Counter._add_with_exception(), 5) + self.assertEqual(6, counter) # 6 = original attempt plus 5 retries + + def test_connection_exception_retry_success_failed_untar(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_untar_exception(), 5) + self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry + + def test_connection_exception_retry_success_failed_eofexception(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_eof_exception(), 5) + self.assertEqual(2, counter) # 2 = original attempt returned EOFError, plus 1 retry + + +counter: int = 0 + + +class Counter: + def _add(): + global counter + counter += 1 + + # All exceptions that Requests explicitly raises inherit from + # requests.exceptions.RequestException so we want to make sure that raises plus one exception + # that inherit from it for sanity + def _add_with_requests_exception(): + global counter + counter += 1 + if counter < 2: + raise requests.exceptions.RequestException + + def _add_with_exception(): + global counter + counter += 1 + raise requests.exceptions.ConnectionError + + def _add_with_untar_exception(): + global counter + counter += 1 + if counter < 2: + raise tarfile.ReadError + + def _add_with_eof_exception(): + global counter + counter += 1 + if counter < 2: + raise EOFError + + def _reset(): + global counter + counter = 0 diff --git a/tests/unit/test_event_handler.py b/tests/unit/test_event_handler.py new file mode 100644 index 00000000..80d5ae2b --- /dev/null +++ b/tests/unit/test_event_handler.py @@ -0,0 +1,40 @@ +import logging + +from dbt_common.events.base_types import EventLevel +from dbt_common.events.event_handler import DbtEventLoggingHandler, set_package_logging +from dbt_common.events.event_manager import TestEventManager + + +def test_event_logging_handler_emits_records_correctly(): + event_manager = TestEventManager() + handler = DbtEventLoggingHandler(event_manager=event_manager, level=logging.DEBUG) + log = logging.getLogger("test") + log.setLevel(logging.DEBUG) + log.addHandler(handler) + + log.debug("test") + log.info("test") + log.warning("test") + log.error("test") + log.exception("test") + log.critical("test") + assert len(event_manager.event_history) == 6 + assert event_manager.event_history[0][1] == EventLevel.DEBUG + assert event_manager.event_history[1][1] == EventLevel.INFO + assert event_manager.event_history[2][1] == EventLevel.WARN + assert event_manager.event_history[3][1] == EventLevel.ERROR + assert event_manager.event_history[4][1] == EventLevel.ERROR + assert event_manager.event_history[5][1] == EventLevel.ERROR + + +def test_set_package_logging_sets_level_correctly(): + event_manager = TestEventManager() + log = logging.getLogger("test") + set_package_logging("test", logging.DEBUG, event_manager) + log.debug("debug") + assert len(event_manager.event_history) == 1 + set_package_logging("test", logging.WARN, event_manager) + log.debug("debug 2") + assert len(event_manager.event_history) == 1 + log.warning("warning") + assert len(event_manager.event_history) == 3 # warning logs create two events diff --git a/tests/unit/test_helper_types.py b/tests/unit/test_helper_types.py new file mode 100644 index 00000000..f3337478 --- /dev/null +++ b/tests/unit/test_helper_types.py @@ -0,0 +1,55 @@ +import pytest + +from dbt_common.helper_types import IncludeExclude, WarnErrorOptions +from dbt_common.dataclass_schema import ValidationError + + +class TestIncludeExclude: + def test_init_invalid(self): + with pytest.raises(ValidationError): + IncludeExclude(include="invalid") + + with pytest.raises(ValidationError): + IncludeExclude(include=["ItemA"], exclude=["ItemB"]) + + @pytest.mark.parametrize( + "include,exclude,expected_includes", + [ + ("all", [], True), + ("*", [], True), + ("*", ["ItemA"], False), + (["ItemA"], [], True), + (["ItemA", "ItemB"], [], True), + ], + ) + def test_includes(self, include, exclude, expected_includes): + include_exclude = IncludeExclude(include=include, exclude=exclude) + + assert include_exclude.includes("ItemA") == expected_includes + + +class TestWarnErrorOptions: + def test_init_invalid_error(self): + with pytest.raises(ValidationError): + WarnErrorOptions(include=["InvalidError"], valid_error_names=set(["ValidError"])) + + with pytest.raises(ValidationError): + WarnErrorOptions(include="*", exclude=["InvalidError"], valid_error_names=set(["ValidError"])) + + def test_init_invalid_error_default_valid_error_names(self): + with pytest.raises(ValidationError): + WarnErrorOptions(include=["InvalidError"]) + + with pytest.raises(ValidationError): + WarnErrorOptions(include="*", exclude=["InvalidError"]) + + def test_init_valid_error(self): + warn_error_options = WarnErrorOptions(include=["ValidError"], valid_error_names=set(["ValidError"])) + assert warn_error_options.include == ["ValidError"] + assert warn_error_options.exclude == [] + + warn_error_options = WarnErrorOptions( + include="*", exclude=["ValidError"], valid_error_names=set(["ValidError"]) + ) + assert warn_error_options.include == "*" + assert warn_error_options.exclude == ["ValidError"] diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py new file mode 100644 index 00000000..1fafa4a8 --- /dev/null +++ b/tests/unit/test_jinja.py @@ -0,0 +1,398 @@ +import unittest + +from dbt_common.clients.jinja import extract_toplevel_blocks +from dbt_common.exceptions import CompilationError + + +class TestBlockLexer(unittest.TestCase): + def test_basic(self): + body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' + block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" + blocks = extract_toplevel_blocks(block_data, allowed_blocks={"mytype"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, body) + self.assertEqual(blocks[0].full_block, block_data) + + def test_multiple(self): + body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' + body_two = "{{ config(bar=1)}}\r\nselect * from {% if foo %} thing " "{% else %} other_thing {% endif %}" + + block_data = ( + " {% mytype foo %}" + + body_one + + "{% endmytype %}" + + "\r\n{% othertype bar %}" + + body_two + + "{% endothertype %}" + ) + blocks = extract_toplevel_blocks(block_data, allowed_blocks={"mytype", "othertype"}, collect_raw_data=False) + self.assertEqual(len(blocks), 2) + + def test_comments(self): + body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' + comment = "{# my comment #}" + block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" + blocks = extract_toplevel_blocks(comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, body) + self.assertEqual(blocks[0].full_block, block_data) + + def test_evil_comments(self): + body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' + comment = "{# external comment {% othertype bar %} select * from thing.other_thing{% endothertype %} #}" + block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" + blocks = extract_toplevel_blocks(comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, body) + self.assertEqual(blocks[0].full_block, block_data) + + def test_nested_comments(self): + body = '{# my comment #} {{ config(foo="bar") }}\r\nselect * from {# my other comment embedding {% endmytype %} #} this.that\r\n' + block_data = " \n\r\t{%- mytype foo %}" + body + "{% endmytype -%}" + comment = "{# external comment {% othertype bar %} select * from thing.other_thing{% endothertype %} #}" + blocks = extract_toplevel_blocks(comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, body) + self.assertEqual(blocks[0].full_block, block_data) + + def test_complex_file(self): + blocks = extract_toplevel_blocks( + complex_snapshot_file, allowed_blocks={"mytype", "myothertype"}, collect_raw_data=False + ) + self.assertEqual(len(blocks), 3) + self.assertEqual(blocks[0].block_type_name, "mytype") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].full_block, "{% mytype foo %} some stuff {% endmytype %}") + self.assertEqual(blocks[0].contents, " some stuff ") + self.assertEqual(blocks[1].block_type_name, "mytype") + self.assertEqual(blocks[1].block_name, "bar") + self.assertEqual(blocks[1].full_block, bar_block) + self.assertEqual(blocks[1].contents, bar_block[16:-15].rstrip()) + self.assertEqual(blocks[2].block_type_name, "myothertype") + self.assertEqual(blocks[2].block_name, "x") + self.assertEqual(blocks[2].full_block, x_block.strip()) + self.assertEqual( + blocks[2].contents, + x_block[len("\n{% myothertype x %}") : -len("{% endmyothertype %}\n")], + ) + + def test_peaceful_macro_coexistence(self): + body = ( + "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} {# my model #} {% a b %} test {% enda %}" + ) + blocks = extract_toplevel_blocks(body, allowed_blocks={"macro", "a"}, collect_raw_data=True) + self.assertEqual(len(blocks), 4) + self.assertEqual(blocks[0].full_block, "{# my macro #} ") + self.assertEqual(blocks[1].block_type_name, "macro") + self.assertEqual(blocks[1].block_name, "foo") + self.assertEqual(blocks[1].contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") + self.assertEqual(blocks[3].block_type_name, "a") + self.assertEqual(blocks[3].block_name, "b") + self.assertEqual(blocks[3].contents, " test ") + + def test_macro_with_trailing_data(self): + body = "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} {# my model #} {% a b %} test {% enda %} raw data so cool" + blocks = extract_toplevel_blocks(body, allowed_blocks={"macro", "a"}, collect_raw_data=True) + self.assertEqual(len(blocks), 5) + self.assertEqual(blocks[0].full_block, "{# my macro #} ") + self.assertEqual(blocks[1].block_type_name, "macro") + self.assertEqual(blocks[1].block_name, "foo") + self.assertEqual(blocks[1].contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") + self.assertEqual(blocks[3].block_type_name, "a") + self.assertEqual(blocks[3].block_name, "b") + self.assertEqual(blocks[3].contents, " test ") + self.assertEqual(blocks[4].full_block, " raw data so cool") + + def test_macro_with_crazy_args(self): + body = """{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}cool{# block comment with {% endmacro %} in it #} stuff here {% endmacro %}""" + blocks = extract_toplevel_blocks(body, allowed_blocks={"macro"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "macro") + self.assertEqual(blocks[0].block_name, "foo") + self.assertEqual(blocks[0].contents, "cool{# block comment with {% endmacro %} in it #} stuff here ") + + def test_materialization_parse(self): + body = "{% materialization xxx, default %} ... {% endmaterialization %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"materialization"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "materialization") + self.assertEqual(blocks[0].block_name, "xxx") + self.assertEqual(blocks[0].full_block, body) + + body = '{% materialization xxx, adapter="other" %} ... {% endmaterialization %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"materialization"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "materialization") + self.assertEqual(blocks[0].block_name, "xxx") + self.assertEqual(blocks[0].full_block, body) + + def test_nested_not_ok(self): + # we don't allow nesting same blocks + body = "{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}" + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body, allowed_blocks={"myblock"}) + + def test_incomplete_block_failure(self): + fullbody = "{% myblock foo %} {% endmyblock %}" + for length in range(len("{% myblock foo %}"), len(fullbody) - 1): + body = fullbody[:length] + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body, allowed_blocks={"myblock"}) + + def test_wrong_end_failure(self): + body = "{% myblock foo %} {% endotherblock %}" + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) + + def test_comment_no_end_failure(self): + body = "{# " + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body) + + def test_comment_only(self): + body = "{# myblock #}" + blocks = extract_toplevel_blocks(body) + self.assertEqual(len(blocks), 1) + blocks = extract_toplevel_blocks(body, collect_raw_data=False) + self.assertEqual(len(blocks), 0) + + def test_comment_block_self_closing(self): + # test the case where a comment start looks a lot like it closes itself + # (but it doesn't in jinja!) + body = "{#} {% myblock foo %} {#}" + blocks = extract_toplevel_blocks(body, collect_raw_data=False) + self.assertEqual(len(blocks), 0) + + def test_embedded_self_closing_comment_block(self): + body = "{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, body) + self.assertEqual(blocks[0].contents, " {#}{% endmyblock %} {#}") + + def test_set_statement(self): + body = "{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_set_block(self): + body = "{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_crazy_set_statement(self): + body = '{% set x = (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}{% set y = otherthing("{% myblock foo %}") %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"otherblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") + self.assertEqual(blocks[0].block_type_name, "otherblock") + + def test_do_statement(self): + body = "{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_deceptive_do_statement(self): + body = "{% do thing %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_do_block(self): + body = "{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}" + blocks = extract_toplevel_blocks(body, allowed_blocks={"do", "myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 2) + self.assertEqual(blocks[0].contents, "thing.update()") + self.assertEqual(blocks[0].block_type_name, "do") + self.assertEqual(blocks[1].full_block, "{% myblock foo %}hi{% endmyblock %}") + + def test_crazy_do_statement(self): + body = '{% do (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}{% do otherthing("{% myblock foo %}") %}{% myblock x %}hi{% endmyblock %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 2) + self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") + self.assertEqual(blocks[0].block_type_name, "otherblock") + self.assertEqual(blocks[1].full_block, "{% myblock x %}hi{% endmyblock %}") + self.assertEqual(blocks[1].block_type_name, "myblock") + + def test_awful_jinja(self): + blocks = extract_toplevel_blocks( + if_you_do_this_you_are_awful, + allowed_blocks={"snapshot", "materialization"}, + collect_raw_data=False, + ) + self.assertEqual(len(blocks), 2) + self.assertEqual(len([b for b in blocks if b.block_type_name == "__dbt__data"]), 0) + self.assertEqual(blocks[0].block_type_name, "snapshot") + self.assertEqual( + blocks[0].contents, + "\n ".join( + [ + """{% set x = ("{% endsnapshot %}" + (40 * '%})')) %}""", + "{# {% endsnapshot %} #}", + "{% embedded %}", + " some block data right here", + "{% endembedded %}", + ] + ), + ) + self.assertEqual(blocks[1].block_type_name, "materialization") + self.assertEqual(blocks[1].contents, "\nhi\n") + + def test_quoted_endblock_within_block(self): + body = '{% myblock something -%} {% set x = ("{% endmyblock %}") %} {% endmyblock %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "myblock") + self.assertEqual(blocks[0].contents, '{% set x = ("{% endmyblock %}") %} ') + + def test_docs_block(self): + body = '{% docs __my_doc__ %} asdf {# nope {% enddocs %}} #} {% enddocs %} {% docs __my_other_doc__ %} asdf "{% enddocs %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) + self.assertEqual(len(blocks), 2) + self.assertEqual(blocks[0].block_type_name, "docs") + self.assertEqual(blocks[0].contents, " asdf {# nope {% enddocs %}} #} ") + self.assertEqual(blocks[0].block_name, "__my_doc__") + self.assertEqual(blocks[1].block_type_name, "docs") + self.assertEqual(blocks[1].contents, ' asdf "') + self.assertEqual(blocks[1].block_name, "__my_other_doc__") + + def test_docs_block_expr(self): + body = '{% docs more_doc %} asdf {{ "{% enddocs %}" ~ "}}" }}{% enddocs %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "docs") + self.assertEqual(blocks[0].contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') + self.assertEqual(blocks[0].block_name, "more_doc") + + def test_unclosed_model_quotes(self): + # test case for https://github.com/dbt-labs/dbt-core/issues/1533 + body = '{% model my_model -%} select * from "something"."something_else{% endmodel %}' + blocks = extract_toplevel_blocks(body, allowed_blocks={"model"}, collect_raw_data=False) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].block_type_name, "model") + self.assertEqual(blocks[0].contents, 'select * from "something"."something_else') + self.assertEqual(blocks[0].block_name, "my_model") + + def test_if(self): + # if you conditionally define your macros/models, don't + body = "{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}" + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body) + + def test_if_innocuous(self): + body = "{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}" + blocks = extract_toplevel_blocks(body) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, body) + + def test_for(self): + # no for-loops over macros. + body = "{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}" + with self.assertRaises(CompilationError): + extract_toplevel_blocks(body) + + def test_for_innocuous(self): + # no for-loops over macros. + body = "{% for x in range(10) %}{% something my_something %} adsf {% endsomething %}{% endfor %}" + blocks = extract_toplevel_blocks(body) + self.assertEqual(len(blocks), 1) + self.assertEqual(blocks[0].full_block, body) + + def test_endif(self): + body = "{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}" + with self.assertRaises(CompilationError) as err: + extract_toplevel_blocks(body) + self.assertIn( + "Got an unexpected control flow end tag, got endif but never saw a preceeding if (@ 1:53)", + str(err.exception), + ) + + def test_if_endfor(self): + body = "{% if x %}...{% endfor %}{% endif %}" + with self.assertRaises(CompilationError) as err: + extract_toplevel_blocks(body) + self.assertIn( + "Got an unexpected control flow end tag, got endfor but expected endif next (@ 1:13)", + str(err.exception), + ) + + def test_if_endfor_newlines(self): + body = "{% if x %}\n ...\n {% endfor %}\n{% endif %}" + with self.assertRaises(CompilationError) as err: + extract_toplevel_blocks(body) + self.assertIn( + "Got an unexpected control flow end tag, got endfor but expected endif next (@ 3:4)", + str(err.exception), + ) + + +bar_block = """{% mytype bar %} +{# a comment + that inside it has + {% mytype baz %} +{% endmyothertype %} +{% endmytype %} +{% endmytype %} + {# +{% endmytype %}#} + +some other stuff + +{%- endmytype%}""" + +x_block = """ +{% myothertype x %} +before +{##} +and after +{% endmyothertype %} +""" + +complex_snapshot_file = ( + """ +{#some stuff {% mytype foo %} #} +{% mytype foo %} some stuff {% endmytype %} + +""" + + bar_block + + x_block +) + + +if_you_do_this_you_are_awful = """ +{#} here is a comment with a block inside {% block x %} asdf {% endblock %} {#} +{% do + set('foo="bar"') +%} +{% set x = ("100" + "hello'" + '%}') %} +{% snapshot something -%} + {% set x = ("{% endsnapshot %}" + (40 * '%})')) %} + {# {% endsnapshot %} #} + {% embedded %} + some block data right here + {% endembedded %} +{%- endsnapshot %} + +{% raw %} + {% set x = SYNTAX ERROR} +{% endraw %} + + +{% materialization whatever, adapter='thing' %} +hi +{% endmaterialization %} +""" diff --git a/tests/unit/test_model_config.py b/tests/unit/test_model_config.py new file mode 100644 index 00000000..0cc1e711 --- /dev/null +++ b/tests/unit/test_model_config.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass, field +from dbt_common.dataclass_schema import dbtClassMixin +from typing import List, Dict +from dbt_common.contracts.config.metadata import ShowBehavior +from dbt_common.contracts.config.base import MergeBehavior, CompareBehavior + + +@dataclass +class ThingWithMergeBehavior(dbtClassMixin): + default_behavior: int + appended: List[str] = field(metadata={"merge": MergeBehavior.Append}) + updated: Dict[str, int] = field(metadata={"merge": MergeBehavior.Update}) + clobbered: str = field(metadata={"merge": MergeBehavior.Clobber}) + keysappended: Dict[str, int] = field(metadata={"merge": MergeBehavior.DictKeyAppend}) + + +def test_merge_behavior_meta(): + existing = {"foo": "bar"} + initial_existing = existing.copy() + assert set(MergeBehavior) == { + MergeBehavior.Append, + MergeBehavior.Update, + MergeBehavior.Clobber, + MergeBehavior.DictKeyAppend, + } + for behavior in MergeBehavior: + assert behavior.meta() == {"merge": behavior} + assert behavior.meta(existing) == {"merge": behavior, "foo": "bar"} + assert existing == initial_existing + + +def test_merge_behavior_from_field(): + fields = [f[0] for f in ThingWithMergeBehavior._get_fields()] + fields = {name: f for f, name in ThingWithMergeBehavior._get_fields()} + assert set(fields) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} + assert MergeBehavior.from_field(fields["default_behavior"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields["appended"]) == MergeBehavior.Append + assert MergeBehavior.from_field(fields["updated"]) == MergeBehavior.Update + assert MergeBehavior.from_field(fields["clobbered"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields["keysappended"]) == MergeBehavior.DictKeyAppend + + +@dataclass +class ThingWithShowBehavior(dbtClassMixin): + default_behavior: int + hidden: str = field(metadata={"show_hide": ShowBehavior.Hide}) + shown: float = field(metadata={"show_hide": ShowBehavior.Show}) + + +def test_show_behavior_meta(): + existing = {"foo": "bar"} + initial_existing = existing.copy() + assert set(ShowBehavior) == {ShowBehavior.Hide, ShowBehavior.Show} + for behavior in ShowBehavior: + assert behavior.meta() == {"show_hide": behavior} + assert behavior.meta(existing) == {"show_hide": behavior, "foo": "bar"} + assert existing == initial_existing + + +def test_show_behavior_from_field(): + fields = [f[0] for f in ThingWithShowBehavior._get_fields()] + fields = {name: f for f, name in ThingWithShowBehavior._get_fields()} + assert set(fields) == {"default_behavior", "hidden", "shown"} + assert ShowBehavior.from_field(fields["default_behavior"]) == ShowBehavior.Show + assert ShowBehavior.from_field(fields["hidden"]) == ShowBehavior.Hide + assert ShowBehavior.from_field(fields["shown"]) == ShowBehavior.Show + + +@dataclass +class ThingWithCompareBehavior(dbtClassMixin): + default_behavior: int + included: float = field(metadata={"compare": CompareBehavior.Include}) + excluded: str = field(metadata={"compare": CompareBehavior.Exclude}) + + +def test_compare_behavior_meta(): + existing = {"foo": "bar"} + initial_existing = existing.copy() + assert set(CompareBehavior) == {CompareBehavior.Include, CompareBehavior.Exclude} + for behavior in CompareBehavior: + assert behavior.meta() == {"compare": behavior} + assert behavior.meta(existing) == {"compare": behavior, "foo": "bar"} + assert existing == initial_existing + + +def test_compare_behavior_from_field(): + fields = [f[0] for f in ThingWithCompareBehavior._get_fields()] + fields = {name: f for f, name in ThingWithCompareBehavior._get_fields()} + assert set(fields) == {"default_behavior", "included", "excluded"} + assert CompareBehavior.from_field(fields["default_behavior"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields["included"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields["excluded"]) == CompareBehavior.Exclude diff --git a/tests/unit/test_system_client.py b/tests/unit/test_system_client.py new file mode 100644 index 00000000..198802d6 --- /dev/null +++ b/tests/unit/test_system_client.py @@ -0,0 +1,263 @@ +import os +import shutil +import stat +import unittest +import tarfile +import pathspec +from pathlib import Path +from tempfile import mkdtemp, NamedTemporaryFile + +from dbt_common.exceptions import ExecutableError, WorkingDirectoryError +import dbt_common.clients.system + + +class SystemClient(unittest.TestCase): + def setUp(self): + super().setUp() + self.tmp_dir = mkdtemp() + self.profiles_path = "{}/profiles.yml".format(self.tmp_dir) + + def set_up_profile(self): + with open(self.profiles_path, "w") as f: + f.write("ORIGINAL_TEXT") + + def get_profile_text(self): + with open(self.profiles_path, "r") as f: + return f.read() + + def tearDown(self): + try: + shutil.rmtree(self.tmp_dir) + except Exception as e: # noqa: [F841] + pass + + def test__make_file_when_exists(self): + self.set_up_profile() + written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") + + self.assertFalse(written) + self.assertEqual(self.get_profile_text(), "ORIGINAL_TEXT") + + def test__make_file_when_not_exists(self): + written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") + + self.assertTrue(written) + self.assertEqual(self.get_profile_text(), "NEW_TEXT") + + def test__make_file_with_overwrite(self): + self.set_up_profile() + written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT", overwrite=True) + + self.assertTrue(written) + self.assertEqual(self.get_profile_text(), "NEW_TEXT") + + def test__make_dir_from_str(self): + test_dir_str = self.tmp_dir + "/test_make_from_str/sub_dir" + dbt_common.clients.system.make_directory(test_dir_str) + self.assertTrue(Path(test_dir_str).is_dir()) + + def test__make_dir_from_pathobj(self): + test_dir_pathobj = Path(self.tmp_dir + "/test_make_from_pathobj/sub_dir") + dbt_common.clients.system.make_directory(test_dir_pathobj) + self.assertTrue(test_dir_pathobj.is_dir()) + + +class TestRunCmd(unittest.TestCase): + """Test `run_cmd`. + + Don't mock out subprocess, in order to expose any OS-level differences. + """ + + not_a_file = "zzzbbfasdfasdfsdaq" + + def setUp(self): + self.tempdir = mkdtemp() + self.run_dir = os.path.join(self.tempdir, "run_dir") + self.does_not_exist = os.path.join(self.tempdir, "does_not_exist") + self.empty_file = os.path.join(self.tempdir, "empty_file") + if os.name == "nt": + self.exists_cmd = ["cmd", "/C", "echo", "hello"] + else: + self.exists_cmd = ["echo", "hello"] + + os.mkdir(self.run_dir) + with open(self.empty_file, "w") as fp: # noqa: [F841] + pass # "touch" + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test__executable_does_not_exist(self): + with self.assertRaises(ExecutableError) as exc: + dbt_common.clients.system.run_cmd(self.run_dir, [self.does_not_exist]) + + msg = str(exc.exception).lower() + + self.assertIn("path", msg) + self.assertIn("could not find", msg) + self.assertIn(self.does_not_exist.lower(), msg) + + def test__not_exe(self): + with self.assertRaises(ExecutableError) as exc: + dbt_common.clients.system.run_cmd(self.run_dir, [self.empty_file]) + + msg = str(exc.exception).lower() + if os.name == "nt": + # on windows, this means it's not an executable at all! + self.assertIn("not executable", msg) + else: + # on linux, this means you don't have executable permissions on it + self.assertIn("permissions", msg) + self.assertIn(self.empty_file.lower(), msg) + + def test__cwd_does_not_exist(self): + with self.assertRaises(WorkingDirectoryError) as exc: + dbt_common.clients.system.run_cmd(self.does_not_exist, self.exists_cmd) + msg = str(exc.exception).lower() + self.assertIn("does not exist", msg) + self.assertIn(self.does_not_exist.lower(), msg) + + def test__cwd_not_directory(self): + with self.assertRaises(WorkingDirectoryError) as exc: + dbt_common.clients.system.run_cmd(self.empty_file, self.exists_cmd) + + msg = str(exc.exception).lower() + self.assertIn("not a directory", msg) + self.assertIn(self.empty_file.lower(), msg) + + def test__cwd_no_permissions(self): + # it would be nice to add a windows test. Possible path to that is via + # `psexec` (to get SYSTEM privs), use `icacls` to set permissions on + # the directory for the test user. I'm pretty sure windows users can't + # create files that they themselves cannot access. + if os.name == "nt": + return + + # read-only -> cannot cd to it + os.chmod(self.run_dir, stat.S_IRUSR) + + with self.assertRaises(WorkingDirectoryError) as exc: + dbt_common.clients.system.run_cmd(self.run_dir, self.exists_cmd) + + msg = str(exc.exception).lower() + self.assertIn("permissions", msg) + self.assertIn(self.run_dir.lower(), msg) + + def test__ok(self): + out, err = dbt_common.clients.system.run_cmd(self.run_dir, self.exists_cmd) + self.assertEqual(out.strip(), b"hello") + self.assertEqual(err.strip(), b"") + + +class TestFindMatching(unittest.TestCase): + def setUp(self): + self.base_dir = mkdtemp() + self.tempdir = mkdtemp(dir=self.base_dir) + + def test_find_matching_lowercase_file_pattern(self): + with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir) as named_file: + file_path = os.path.dirname(named_file.name) + relative_path = os.path.basename(file_path) + out = dbt_common.clients.system.find_matching( + self.base_dir, + [relative_path], + "*.sql", + ) + expected_output = [ + { + "searched_path": relative_path, + "absolute_path": named_file.name, + "relative_path": os.path.basename(named_file.name), + "modification_time": out[0]["modification_time"], + } + ] + self.assertEqual(out, expected_output) + + def test_find_matching_uppercase_file_pattern(self): + with NamedTemporaryFile(prefix="sql-files", suffix=".SQL", dir=self.tempdir) as named_file: + file_path = os.path.dirname(named_file.name) + relative_path = os.path.basename(file_path) + out = dbt_common.clients.system.find_matching(self.base_dir, [relative_path], "*.sql") + expected_output = [ + { + "searched_path": relative_path, + "absolute_path": named_file.name, + "relative_path": os.path.basename(named_file.name), + "modification_time": out[0]["modification_time"], + } + ] + self.assertEqual(out, expected_output) + + def test_find_matching_file_pattern_not_found(self): + with NamedTemporaryFile(prefix="sql-files", suffix=".SQLT", dir=self.tempdir): + out = dbt_common.clients.system.find_matching(self.tempdir, [""], "*.sql") + self.assertEqual(out, []) + + def test_ignore_spec(self): + with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir): + out = dbt_common.clients.system.find_matching( + self.tempdir, + [""], + "*.sql", + pathspec.PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, "sql-files*".splitlines()), + ) + self.assertEqual(out, []) + + def tearDown(self): + try: + shutil.rmtree(self.base_dir) + except Exception as e: # noqa: [F841] + pass + + +class TestUntarPackage(unittest.TestCase): + def setUp(self): + self.base_dir = mkdtemp() + self.tempdir = mkdtemp(dir=self.base_dir) + self.tempdest = mkdtemp(dir=self.base_dir) + + def tearDown(self): + try: + shutil.rmtree(self.base_dir) + except Exception as e: # noqa: [F841] + pass + + def test_untar_package_success(self): + # set up a valid tarball to test against + with NamedTemporaryFile( + prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False + ) as named_tar_file: + tar_file_full_path = named_tar_file.name + with NamedTemporaryFile(prefix="a", suffix=".txt", dir=self.tempdir) as file_a: + file_a.write(b"some text in the text file") + relative_file_a = os.path.basename(file_a.name) + with tarfile.open(fileobj=named_tar_file, mode="w:gz") as tar: + tar.addfile(tarfile.TarInfo(relative_file_a), open(file_a.name)) + + # now we test can test that we can untar the file successfully + assert tarfile.is_tarfile(tar.name) + dbt_common.clients.system.untar_package(tar_file_full_path, self.tempdest) + path = Path(os.path.join(self.tempdest, relative_file_a)) + assert path.is_file() + + def test_untar_package_failure(self): + # create a text file then rename it as a tar (so it's invalid) + with NamedTemporaryFile(prefix="a", suffix=".txt", dir=self.tempdir, delete=False) as file_a: + file_a.write(b"some text in the text file") + txt_file_name = file_a.name + file_path = os.path.dirname(txt_file_name) + tar_file_path = os.path.join(file_path, "mypackage.2.tar.gz") + os.rename(txt_file_name, tar_file_path) + + # now that we're set up, test that untarring the file fails + with self.assertRaises(tarfile.ReadError) as exc: # noqa: [F841] + dbt_common.clients.system.untar_package(tar_file_path, self.tempdest) + + def test_untar_package_empty(self): + # create a tarball with nothing in it + with NamedTemporaryFile(prefix="my-empty-package.2", suffix=".tar.gz", dir=self.tempdir) as named_file: + + # make sure we throw an error for the empty file + with self.assertRaises(tarfile.ReadError) as exc: + dbt_common.clients.system.untar_package(named_file.name, self.tempdest) + self.assertEqual("empty file", str(exc.exception)) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 00000000..0b417052 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,139 @@ +import unittest + +import dbt_common.exceptions +import dbt_common.utils + + +class TestDeepMerge(unittest.TestCase): + def test__simple_cases(self): + cases = [ + {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, + { + "args": [{}, {"b": 1}, {"a": 1}], + "expected": {"a": 1, "b": 1}, + "description": "three merges", + }, + ] + + for case in cases: + actual = dbt_common.utils.deep_merge(*case["args"]) + self.assertEqual( + case["expected"], + actual, + "failed on {} (actual {}, expected {})".format(case["description"], actual, case["expected"]), + ) + + +class TestMerge(unittest.TestCase): + def test__simple_cases(self): + cases = [ + {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, + { + "args": [{}, {"b": 1}, {"a": 1}], + "expected": {"a": 1, "b": 1}, + "description": "three merges", + }, + ] + + for case in cases: + actual = dbt_common.utils.deep_merge(*case["args"]) + self.assertEqual( + case["expected"], + actual, + "failed on {} (actual {}, expected {})".format(case["description"], actual, case["expected"]), + ) + + +class TestDeepMap(unittest.TestCase): + def setUp(self): + self.input_value = { + "foo": { + "bar": "hello", + "baz": [1, 90.5, "990", "89.9"], + }, + "nested": [ + { + "test": "90", + "other_test": None, + }, + { + "test": 400, + "other_test": 4.7e9, + }, + ], + } + + @staticmethod + def intify_all(value, _): + try: + return int(value) + except (TypeError, ValueError): + return -1 + + def test__simple_cases(self): + expected = { + "foo": { + "bar": -1, + "baz": [1, 90, 990, -1], + }, + "nested": [ + { + "test": 90, + "other_test": -1, + }, + { + "test": 400, + "other_test": 4700000000, + }, + ], + } + actual = dbt_common.utils.deep_map_render(self.intify_all, self.input_value) + self.assertEqual(actual, expected) + + actual = dbt_common.utils.deep_map_render(self.intify_all, expected) + self.assertEqual(actual, expected) + + @staticmethod + def special_keypath(value, keypath): + + if tuple(keypath) == ("foo", "baz", 1): + return "hello" + else: + return value + + def test__keypath(self): + expected = { + "foo": { + "bar": "hello", + # the only change from input is the second entry here + "baz": [1, "hello", "990", "89.9"], + }, + "nested": [ + { + "test": "90", + "other_test": None, + }, + { + "test": 400, + "other_test": 4.7e9, + }, + ], + } + actual = dbt_common.utils.deep_map_render(self.special_keypath, self.input_value) + self.assertEqual(actual, expected) + + actual = dbt_common.utils.deep_map_render(self.special_keypath, expected) + self.assertEqual(actual, expected) + + def test__noop(self): + actual = dbt_common.utils.deep_map_render(lambda x, _: x, self.input_value) + self.assertEqual(actual, self.input_value) + + def test_trivial(self): + cases = [[], {}, 1, "abc", None, True] + for case in cases: + result = dbt_common.utils.deep_map_render(lambda x, _: x, case) + self.assertEqual(result, case) + + with self.assertRaises(dbt_common.exceptions.DbtConfigError): + dbt_common.utils.deep_map_render(lambda x, _: x, {"foo": object()}) diff --git a/third-party-stubs/colorama/__init__.pyi b/third-party-stubs/colorama/__init__.pyi new file mode 100644 index 00000000..4502880e --- /dev/null +++ b/third-party-stubs/colorama/__init__.pyi @@ -0,0 +1,16 @@ +from typing import Optional, Any + +class Fore: + RED: str = ... + GREEN: str = ... + YELLOW: str = ... + +class Style: + RESET_ALL: str = ... + +def init( + autoreset: bool = ..., + convert: Optional[Any] = ..., + strip: Optional[Any] = ..., + wrap: bool = ..., +) -> None: ... diff --git a/third-party-stubs/isodate/__init__.pyi b/third-party-stubs/isodate/__init__.pyi new file mode 100644 index 00000000..96b67c34 --- /dev/null +++ b/third-party-stubs/isodate/__init__.pyi @@ -0,0 +1,4 @@ +import datetime + +def parse_datetime(datetimestring: str): + datetime.datetime