From 5a6efee2531f452418d536fe02794997bd48e1b3 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 4 Jan 2024 10:16:16 -0500 Subject: [PATCH 01/26] copy dbt-core dbt/common contents to dbt/common --- dbt/common/__init__.py | 0 dbt/common/clients/__init__.py | 0 dbt/common/clients/_jinja_blocks.py | 360 +++++++++++ dbt/common/clients/agate_helper.py | 251 ++++++++ dbt/common/clients/jinja.py | 505 ++++++++++++++++ dbt/common/clients/system.py | 571 ++++++++++++++++++ dbt/common/constants.py | 1 + dbt/common/contracts/__init__.py | 0 dbt/common/contracts/config/__init__.py | 0 dbt/common/contracts/config/base.py | 259 ++++++++ .../contracts/config/materialization.py | 11 + dbt/common/contracts/config/metadata.py | 69 +++ dbt/common/contracts/config/properties.py | 63 ++ dbt/common/contracts/connection.py | 0 dbt/common/contracts/constraints.py | 43 ++ dbt/common/contracts/util.py | 7 + dbt/common/dataclass_schema.py | 165 +++++ dbt/common/events/README.md | 41 ++ dbt/common/events/__init__.py | 9 + dbt/common/events/base_types.py | 185 ++++++ dbt/common/events/contextvars.py | 114 ++++ dbt/common/events/event_handler.py | 40 ++ dbt/common/events/event_manager.py | 66 ++ dbt/common/events/event_manager_client.py | 29 + dbt/common/events/format.py | 56 ++ dbt/common/events/functions.py | 162 +++++ dbt/common/events/helpers.py | 14 + dbt/common/events/interfaces.py | 7 + dbt/common/events/logger.py | 180 ++++++ dbt/common/events/types.proto | 121 ++++ dbt/common/events/types.py | 124 ++++ dbt/common/events/types_pb2.py | 69 +++ dbt/common/exceptions/__init__.py | 4 + dbt/common/exceptions/base.py | 275 +++++++++ dbt/common/exceptions/cache.py | 68 +++ dbt/common/exceptions/contracts.py | 17 + dbt/common/exceptions/events.py | 9 + dbt/common/exceptions/macros.py | 110 ++++ dbt/common/helper_types.py | 126 ++++ dbt/common/invocation.py | 12 + dbt/common/semver.py | 473 +++++++++++++++ dbt/common/ui.py | 68 +++ dbt/common/utils/__init__.py | 26 + dbt/common/utils/casting.py | 25 + dbt/common/utils/connection.py | 33 + dbt/common/utils/dict.py | 128 ++++ dbt/common/utils/encoding.py | 56 ++ dbt/common/utils/executor.py | 67 ++ dbt/common/utils/formatting.py | 8 + dbt/common/utils/jinja.py | 33 + 50 files changed, 5060 insertions(+) create mode 100644 dbt/common/__init__.py create mode 100644 dbt/common/clients/__init__.py create mode 100644 dbt/common/clients/_jinja_blocks.py create mode 100644 dbt/common/clients/agate_helper.py create mode 100644 dbt/common/clients/jinja.py create mode 100644 dbt/common/clients/system.py create mode 100644 dbt/common/constants.py create mode 100644 dbt/common/contracts/__init__.py create mode 100644 dbt/common/contracts/config/__init__.py create mode 100644 dbt/common/contracts/config/base.py create mode 100644 dbt/common/contracts/config/materialization.py create mode 100644 dbt/common/contracts/config/metadata.py create mode 100644 dbt/common/contracts/config/properties.py create mode 100644 dbt/common/contracts/connection.py create mode 100644 dbt/common/contracts/constraints.py create mode 100644 dbt/common/contracts/util.py create mode 100644 dbt/common/dataclass_schema.py create mode 100644 dbt/common/events/README.md create mode 100644 dbt/common/events/__init__.py create mode 100644 dbt/common/events/base_types.py create mode 100644 dbt/common/events/contextvars.py create mode 100644 dbt/common/events/event_handler.py create mode 100644 dbt/common/events/event_manager.py create mode 100644 dbt/common/events/event_manager_client.py create mode 100644 dbt/common/events/format.py create mode 100644 dbt/common/events/functions.py create mode 100644 dbt/common/events/helpers.py create mode 100644 dbt/common/events/interfaces.py create mode 100644 dbt/common/events/logger.py create mode 100644 dbt/common/events/types.proto create mode 100644 dbt/common/events/types.py create mode 100644 dbt/common/events/types_pb2.py create mode 100644 dbt/common/exceptions/__init__.py create mode 100644 dbt/common/exceptions/base.py create mode 100644 dbt/common/exceptions/cache.py create mode 100644 dbt/common/exceptions/contracts.py create mode 100644 dbt/common/exceptions/events.py create mode 100644 dbt/common/exceptions/macros.py create mode 100644 dbt/common/helper_types.py create mode 100644 dbt/common/invocation.py create mode 100644 dbt/common/semver.py create mode 100644 dbt/common/ui.py create mode 100644 dbt/common/utils/__init__.py create mode 100644 dbt/common/utils/casting.py create mode 100644 dbt/common/utils/connection.py create mode 100644 dbt/common/utils/dict.py create mode 100644 dbt/common/utils/encoding.py create mode 100644 dbt/common/utils/executor.py create mode 100644 dbt/common/utils/formatting.py create mode 100644 dbt/common/utils/jinja.py diff --git a/dbt/common/__init__.py b/dbt/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt/common/clients/__init__.py b/dbt/common/clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt/common/clients/_jinja_blocks.py b/dbt/common/clients/_jinja_blocks.py new file mode 100644 index 00000000..1ada0a62 --- /dev/null +++ b/dbt/common/clients/_jinja_blocks.py @@ -0,0 +1,360 @@ +import re +from collections import namedtuple + +from dbt.exceptions import ( + BlockDefinitionNotAtTopError, + DbtInternalError, + MissingCloseTagError, + MissingControlFlowStartTagError, + NestedTagsError, + UnexpectedControlFlowEndTagError, + UnexpectedMacroEOFError, +) + + +def regex(pat): + return re.compile(pat, re.DOTALL | re.MULTILINE) + + +class BlockData: + """raw plaintext data from the top level of the file.""" + + def __init__(self, contents): + self.block_type_name = "__dbt__data" + self.contents = contents + self.full_block = contents + + +class BlockTag: + def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw): + self.block_type_name = block_type_name + self.block_name = block_name + self.contents = contents + self.full_block = full_block + + def __str__(self): + return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name) + + def __repr__(self): + return str(self) + + @property + def end_block_type_name(self): + return "end{}".format(self.block_type_name) + + def end_pat(self): + # we don't want to use string formatting here because jinja uses most + # of the string formatting operators in its syntax... + pattern = "".join( + ( + r"(?P((?:\s*\{\%\-|\{\%)\s*", + self.end_block_type_name, + r"\s*(?:\-\%\}\s*|\%\})))", + ) + ) + return regex(pattern) + + +Tag = namedtuple("Tag", "block_type_name block_name start end") + + +_NAME_PATTERN = r"[A-Za-z_][A-Za-z_0-9]*" + +COMMENT_START_PATTERN = regex(r"(?:(?P(\s*\{\#)))") +COMMENT_END_PATTERN = regex(r"(.*?)(\s*\#\})") +RAW_START_PATTERN = regex(r"(?:\s*\{\%\-|\{\%)\s*(?P(raw))\s*(?:\-\%\}\s*|\%\})") +EXPR_START_PATTERN = regex(r"(?P(\{\{\s*))") +EXPR_END_PATTERN = regex(r"(?P(\s*\}\}))") + +BLOCK_START_PATTERN = regex( + "".join( + ( + r"(?:\s*\{\%\-|\{\%)\s*", + r"(?P({}))".format(_NAME_PATTERN), + # some blocks have a 'block name'. + r"(?:\s+(?P({})))?".format(_NAME_PATTERN), + ) + ) +) + + +RAW_BLOCK_PATTERN = regex( + "".join( + ( + r"(?:\s*\{\%\-|\{\%)\s*raw\s*(?:\-\%\}\s*|\%\})", + r"(?:.*?)", + r"(?:\s*\{\%\-|\{\%)\s*endraw\s*(?:\-\%\}\s*|\%\})", + ) + ) +) + +TAG_CLOSE_PATTERN = regex(r"(?:(?P(\-\%\}\s*|\%\})))") + +# stolen from jinja's lexer. Note that we've consumed all prefix whitespace by +# the time we want to use this. +STRING_PATTERN = regex(r"(?P('([^'\\]*(?:\\.[^'\\]*)*)'|" r'"([^"\\]*(?:\\.[^"\\]*)*)"))') + +QUOTE_START_PATTERN = regex(r"""(?P(['"]))""") + + +class TagIterator: + def __init__(self, data): + self.data = data + self.blocks = [] + self._parenthesis_stack = [] + self.pos = 0 + + def linepos(self, end=None) -> str: + """Given an absolute position in the input data, return a pair of + line number + relative position to the start of the line. + """ + end_val: int = self.pos if end is None else end + data = self.data[:end_val] + # if not found, rfind returns -1, and -1+1=0, which is perfect! + last_line_start = data.rfind("\n") + 1 + # it's easy to forget this, but line numbers are 1-indexed + line_number = data.count("\n") + 1 + return f"{line_number}:{end_val - last_line_start}" + + def advance(self, new_position): + self.pos = new_position + + def rewind(self, amount=1): + self.pos -= amount + + def _search(self, pattern): + return pattern.search(self.data, self.pos) + + def _match(self, pattern): + return pattern.match(self.data, self.pos) + + def _first_match(self, *patterns, **kwargs): + matches = [] + for pattern in patterns: + # default to 'search', but sometimes we want to 'match'. + if kwargs.get("method", "search") == "search": + match = self._search(pattern) + else: + match = self._match(pattern) + if match: + matches.append(match) + if not matches: + return None + # if there are multiple matches, pick the least greedy match + # TODO: do I need to account for m.start(), or is this ok? + return min(matches, key=lambda m: m.end()) + + def _expect_match(self, expected_name, *patterns, **kwargs): + match = self._first_match(*patterns, **kwargs) + if match is None: + raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :]) + return match + + def handle_expr(self, match): + """Handle an expression. At this point we're at a string like: + {{ 1 + 2 }} + ^ right here + + And the match contains "{{ " + + We expect to find a `}}`, but we might find one in a string before + that. Imagine the case of `{{ 2 * "}}" }}`... + + You're not allowed to have blocks or comments inside an expr so it is + pretty straightforward, I hope: only strings can get in the way. + """ + self.advance(match.end()) + while True: + match = self._expect_match("}}", EXPR_END_PATTERN, QUOTE_START_PATTERN) + if match.groupdict().get("expr_end") is not None: + break + else: + # it's a quote. we haven't advanced for this match yet, so + # just slurp up the whole string, no need to rewind. + match = self._expect_match("string", STRING_PATTERN) + self.advance(match.end()) + + self.advance(match.end()) + + def handle_comment(self, match): + self.advance(match.end()) + match = self._expect_match("#}", COMMENT_END_PATTERN) + self.advance(match.end()) + + def _expect_block_close(self): + """Search for the tag close marker. + To the right of the type name, there are a few possiblities: + - a name (handled by the regex's 'block_name') + - any number of: `=`, `(`, `)`, strings, etc (arguments) + - nothing + + followed eventually by a %} + + So the only characters we actually have to worry about in this context + are quote and `%}` - nothing else can hide the %} and be valid jinja. + """ + while True: + end_match = self._expect_match( + 'tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN + ) + self.advance(end_match.end()) + if end_match.groupdict().get("tag_close") is not None: + return + # must be a string. Rewind to its start and advance past it. + self.rewind() + string_match = self._expect_match("string", STRING_PATTERN) + self.advance(string_match.end()) + + def handle_raw(self): + # raw blocks are super special, they are a single complete regex + match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN) + self.advance(match.end()) + return match.end() + + def handle_tag(self, match): + """The tag could be one of a few things: + + {% mytag %} + {% mytag x = y %} + {% mytag x = "y" %} + {% mytag x.y() %} + {% mytag foo("a", "b", c="d") %} + + But the key here is that it's always going to be `{% mytag`! + """ + groups = match.groupdict() + # always a value + block_type_name = groups["block_type_name"] + # might be None + block_name = groups.get("block_name") + start_pos = self.pos + if block_type_name == "raw": + match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN) + self.advance(match.end()) + else: + self.advance(match.end()) + self._expect_block_close() + return Tag( + block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos + ) + + def find_tags(self): + while True: + match = self._first_match( + BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN + ) + if match is None: + break + + self.advance(match.start()) + # start = self.pos + + groups = match.groupdict() + comment_start = groups.get("comment_start") + expr_start = groups.get("expr_start") + block_type_name = groups.get("block_type_name") + + if comment_start is not None: + self.handle_comment(match) + elif expr_start is not None: + self.handle_expr(match) + elif block_type_name is not None: + yield self.handle_tag(match) + else: + raise DbtInternalError( + "Invalid regex match in next_block, expected block start, " + "expr start, or comment start" + ) + + def __iter__(self): + return self.find_tags() + + +_CONTROL_FLOW_TAGS = { + "if": "endif", + "for": "endfor", +} + +_CONTROL_FLOW_END_TAGS = {v: k for k, v in _CONTROL_FLOW_TAGS.items()} + + +class BlockIterator: + def __init__(self, data): + self.tag_parser = TagIterator(data) + self.current = None + self.stack = [] + self.last_position = 0 + + @property + def current_end(self): + if self.current is None: + return 0 + else: + return self.current.end + + @property + def data(self): + return self.tag_parser.data + + def is_current_end(self, tag): + return ( + tag.block_type_name.startswith("end") + and self.current is not None + and tag.block_type_name[3:] == self.current.block_type_name + ) + + def find_blocks(self, allowed_blocks=None, collect_raw_data=True): + """Find all top-level blocks in the data.""" + if allowed_blocks is None: + allowed_blocks = {"snapshot", "macro", "materialization", "docs"} + + for tag in self.tag_parser.find_tags(): + if tag.block_type_name in _CONTROL_FLOW_TAGS: + self.stack.append(tag.block_type_name) + elif tag.block_type_name in _CONTROL_FLOW_END_TAGS: + found = None + if self.stack: + found = self.stack.pop() + else: + expected = _CONTROL_FLOW_END_TAGS[tag.block_type_name] + raise UnexpectedControlFlowEndTagError(tag, expected, self.tag_parser) + expected = _CONTROL_FLOW_TAGS[found] + if expected != tag.block_type_name: + raise MissingControlFlowStartTagError(tag, expected, self.tag_parser) + + if tag.block_type_name in allowed_blocks: + if self.stack: + raise BlockDefinitionNotAtTopError(self.tag_parser, tag.start) + if self.current is not None: + raise NestedTagsError(outer=self.current, inner=tag) + if collect_raw_data: + raw_data = self.data[self.last_position : tag.start] + self.last_position = tag.start + if raw_data: + yield BlockData(raw_data) + self.current = tag + + elif self.is_current_end(tag): + self.last_position = tag.end + assert self.current is not None + yield BlockTag( + block_type_name=self.current.block_type_name, + block_name=self.current.block_name, + contents=self.data[self.current.end : tag.start], + full_block=self.data[self.current.start : tag.end], + ) + self.current = None + + if self.current: + linecount = self.data[: self.current.end].count("\n") + 1 + raise MissingCloseTagError(self.current.block_type_name, linecount) + + if collect_raw_data: + raw_data = self.data[self.last_position :] + if raw_data: + yield BlockData(raw_data) + + def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True): + return list( + self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data) + ) diff --git a/dbt/common/clients/agate_helper.py b/dbt/common/clients/agate_helper.py new file mode 100644 index 00000000..d7ac0916 --- /dev/null +++ b/dbt/common/clients/agate_helper.py @@ -0,0 +1,251 @@ +from codecs import BOM_UTF8 + +import agate +import datetime +import isodate +import json +from typing import Iterable, List, Dict, Union, Optional, Any + +from dbt.common.exceptions import DbtRuntimeError +from dbt.common.utils import ForgivingJSONEncoder + +BOM = BOM_UTF8.decode("utf-8") # '\ufeff' + + +class Integer(agate.data_types.DataType): + def cast(self, d): + # by default agate will cast none as a Number + # but we need to cast it as an Integer to preserve + # the type when merging and unioning tables + if type(d) == int or d is None: + return d + else: + raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d) + + def jsonify(self, d): + return d + + +class Number(agate.data_types.Number): + # undo the change in https://github.com/wireservice/agate/pull/733 + # i.e. do not cast True and False to numeric 1 and 0 + def cast(self, d): + if type(d) == bool: + raise agate.exceptions.CastError("Do not cast True to 1 or False to 0.") + else: + return super().cast(d) + + +class ISODateTime(agate.data_types.DateTime): + def cast(self, d): + # this is agate.data_types.DateTime.cast with the "clever" bits removed + # so we only handle ISO8601 stuff + if isinstance(d, datetime.datetime) or d is None: + return d + elif isinstance(d, datetime.date): + return datetime.datetime.combine(d, datetime.time(0, 0, 0)) + elif isinstance(d, str): + d = d.strip() + if d.lower() in self.null_values: + return None + try: + return isodate.parse_datetime(d) + except: # noqa + pass + + raise agate.exceptions.CastError('Can not parse value "%s" as datetime.' % d) + + +def build_type_tester( + text_columns: Iterable[str], string_null_values: Optional[Iterable[str]] = ("null", "") +) -> agate.TypeTester: + + types = [ + Integer(null_values=("null", "")), + Number(null_values=("null", "")), + agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"), + agate.data_types.DateTime(null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"), + ISODateTime(null_values=("null", "")), + agate.data_types.Boolean( + true_values=("true",), false_values=("false",), null_values=("null", "") + ), + agate.data_types.Text(null_values=string_null_values), + ] + force = {k: agate.data_types.Text(null_values=string_null_values) for k in text_columns} + return agate.TypeTester(force=force, types=types) + + +DEFAULT_TYPE_TESTER = build_type_tester(()) + + +def table_from_rows( + rows: List[Any], + column_names: Iterable[str], + text_only_columns: Optional[Iterable[str]] = None, +) -> agate.Table: + if text_only_columns is None: + column_types = DEFAULT_TYPE_TESTER + else: + # If text_only_columns are present, prevent coercing empty string or + # literal 'null' strings to a None representation. + column_types = build_type_tester(text_only_columns, string_null_values=()) + + return agate.Table(rows, column_names, column_types=column_types) + + +def table_from_data(data, column_names: Iterable[str]) -> agate.Table: + "Convert a list of dictionaries into an Agate table" + + # The agate table is generated from a list of dicts, so the column order + # from `data` is not preserved. We can use `select` to reorder the columns + # + # If there is no data, create an empty table with the specified columns + + if len(data) == 0: + return agate.Table([], column_names=column_names) + else: + table = agate.Table.from_object(data, column_types=DEFAULT_TYPE_TESTER) + return table.select(column_names) + + +def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table: + """ + Convert a list of dictionaries into an Agate table. This method does not + coerce string values into more specific types (eg. '005' will not be + coerced to '5'). Additionally, this method does not coerce values to + None (eg. '' or 'null' will retain their string literal representations). + """ + + rows = [] + text_only_columns = set() + for _row in data: + row = [] + for col_name in column_names: + value = _row[col_name] + if isinstance(value, (dict, list, tuple)): + # Represent container types as json strings + value = json.dumps(value, cls=ForgivingJSONEncoder) + text_only_columns.add(col_name) + elif isinstance(value, str): + text_only_columns.add(col_name) + row.append(value) + + rows.append(row) + + return table_from_rows( + rows=rows, column_names=column_names, text_only_columns=text_only_columns + ) + + +def empty_table(): + "Returns an empty Agate table. To be used in place of None" + + return agate.Table(rows=[]) + + +def as_matrix(table): + "Return an agate table as a matrix of data sans columns" + + return [r.values() for r in table.rows.values()] + + +def from_csv(abspath, text_columns, delimiter=","): + type_tester = build_type_tester(text_columns=text_columns) + with open(abspath, encoding="utf-8") as fp: + if fp.read(1) != BOM: + fp.seek(0) + return agate.Table.from_csv(fp, column_types=type_tester, delimiter=delimiter) + + +class _NullMarker: + pass + + +NullableAgateType = Union[agate.data_types.DataType, _NullMarker] + + +class ColumnTypeBuilder(Dict[str, NullableAgateType]): + def __init__(self) -> None: + super().__init__() + + def __setitem__(self, key, value): + if key not in self: + super().__setitem__(key, value) + return + + existing_type = self[key] + if isinstance(existing_type, _NullMarker): + # overwrite + super().__setitem__(key, value) + elif isinstance(value, _NullMarker): + # use the existing value + return + # when one table column is Number while another is Integer, force the column to Number on merge + elif isinstance(value, Integer) and isinstance(existing_type, agate.data_types.Number): + # use the existing value + return + elif isinstance(existing_type, Integer) and isinstance(value, agate.data_types.Number): + # overwrite + super().__setitem__(key, value) + elif not isinstance(value, type(existing_type)): + # actual type mismatch! + raise DbtRuntimeError( + f"Tables contain columns with the same names ({key}), " + f"but different types ({value} vs {existing_type})" + ) + + def finalize(self) -> Dict[str, agate.data_types.DataType]: + result: Dict[str, agate.data_types.DataType] = {} + for key, value in self.items(): + if isinstance(value, _NullMarker): + # agate would make it a Number but we'll make it Integer so that if this column + # gets merged with another Integer column, it won't get forced to a Number + result[key] = Integer() + else: + result[key] = value + return result + + +def _merged_column_types(tables: List[agate.Table]) -> Dict[str, agate.data_types.DataType]: + # this is a lot like agate.Table.merge, but with handling for all-null + # rows being "any type". + new_columns: ColumnTypeBuilder = ColumnTypeBuilder() + for table in tables: + for i in range(len(table.columns)): + column_name: str = table.column_names[i] + column_type: NullableAgateType = table.column_types[i] + # avoid over-sensitive type inference + if all(x is None for x in table.columns[column_name]): + column_type = _NullMarker() + new_columns[column_name] = column_type + + return new_columns.finalize() + + +def merge_tables(tables: List[agate.Table]) -> agate.Table: + """This is similar to agate.Table.merge, but it handles rows of all 'null' + values more gracefully during merges. + """ + new_columns = _merged_column_types(tables) + column_names = tuple(new_columns.keys()) + column_types = tuple(new_columns.values()) + + rows: List[agate.Row] = [] + for table in tables: + if table.column_names == column_names and table.column_types == column_types: + rows.extend(table.rows) + else: + for row in table.rows: + data = [row.get(name, None) for name in column_names] + rows.append(agate.Row(data, column_names)) + # _is_fork to tell agate that we already made things into `Row`s. + return agate.Table(rows, column_names, column_types, _is_fork=True) + + +def get_column_value_uncased(column_name: str, row: agate.Row) -> Any: + """Get the value of a column in this row, ignoring the casing of the column name.""" + for key, value in row.items(): + if key.casefold() == column_name.casefold(): + return value + + raise KeyError diff --git a/dbt/common/clients/jinja.py b/dbt/common/clients/jinja.py new file mode 100644 index 00000000..e01ad570 --- /dev/null +++ b/dbt/common/clients/jinja.py @@ -0,0 +1,505 @@ +import codecs +import linecache +import os +import tempfile +from ast import literal_eval +from contextlib import contextmanager +from itertools import chain, islice +from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, Callable +from typing_extensions import Protocol + +import jinja2 +import jinja2.ext +import jinja2.nativetypes # type: ignore +import jinja2.nodes +import jinja2.parser +import jinja2.sandbox + +from dbt.common.utils import ( + get_dbt_macro_name, + get_docs_macro_name, + get_materialization_macro_name, + get_test_macro_name, +) +from dbt.common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag + +from dbt.common.exceptions import ( + CompilationError, + DbtInternalError, + CaughtMacroErrorWithNodeError, + MaterializationArgError, + JinjaRenderingError, + UndefinedCompilationError, +) +from dbt.common.exceptions.macros import MacroReturn, UndefinedMacroError, CaughtMacroError + + +SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param") + +# Global which can be set by dependents of dbt-common (e.g. core via flag parsing) +MACRO_DEBUGGING = False + + +def _linecache_inject(source, write): + if write: + # this is the only reliable way to accomplish this. Obviously, it's + # really darn noisy and will fill your temporary directory + tmp_file = tempfile.NamedTemporaryFile( + prefix="dbt-macro-compiled-", + suffix=".py", + delete=False, + mode="w+", + encoding="utf-8", + ) + tmp_file.write(source) + filename = tmp_file.name + else: + # `codecs.encode` actually takes a `bytes` as the first argument if + # the second argument is 'hex' - mypy does not know this. + rnd = codecs.encode(os.urandom(12), "hex") # type: ignore + filename = rnd.decode("ascii") + + # put ourselves in the cache + cache_entry = (len(source), None, [line + "\n" for line in source.splitlines()], filename) + # linecache does in fact have an attribute `cache`, thanks + linecache.cache[filename] = cache_entry # type: ignore + return filename + + +class MacroFuzzParser(jinja2.parser.Parser): + def parse_macro(self): + node = jinja2.nodes.Macro(lineno=next(self.stream).lineno) + + # modified to fuzz macros defined in the same file. this way + # dbt can understand the stack of macros being called. + # - @cmcarthur + node.name = get_dbt_macro_name(self.parse_assign_target(name_only=True).name) + + self.parse_signature(node) + node.body = self.parse_statements(("name:endmacro",), drop_needle=True) + return node + + +class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment): + def _parse(self, source, name, filename): + return MacroFuzzParser(self, source, name, filename).parse() + + def _compile(self, source, filename): + """Override jinja's compilation to stash the rendered source inside + the python linecache for debugging when the appropriate environment + variable is set. + + If the value is 'write', also write the files to disk. + WARNING: This can write a ton of data if you aren't careful. + """ + if filename == "