Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable inline csv format in unit testing #8743

Merged
merged 11 commits into from
Oct 5, 2023
8 changes: 4 additions & 4 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
UnparsedSourceTableDefinition,
UnparsedColumn,
UnitTestOverrides,
InputFixture,
OutputFixture,
UnitTestInputFixture,
UnitTestOutputFixture,
)
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
Expand Down Expand Up @@ -1071,8 +1071,8 @@ class UnitTestNode(CompiledNode):
@dataclass
class UnitTestDefinition(GraphNode):
model: str
given: Sequence[InputFixture]
expect: OutputFixture
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
overrides: Optional[UnitTestOverrides] = None
depends_on: DependsOn = field(default_factory=DependsOn)
Expand Down
33 changes: 29 additions & 4 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime
import re
import csv
from io import StringIO

from dbt import deprecations
from dbt.node_types import NodeType
Expand Down Expand Up @@ -741,15 +743,38 @@ class UnitTestFormat(StrEnum):
Dict = "dict"


class UnitTestFixture:
@property
def format(self) -> UnitTestFormat:
return UnitTestFormat.Dict

@property
def rows(self) -> Union[str, List[Dict[str, Any]]]:
return []
Comment on lines +747 to +753
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary for these to be properties? Could they instead be attributes that are inherited by UnitTestInputFixture and UnitTestOutputFixture?

e.g.

rows: Union[str, List[Dict[str, Any]]] = ""
format: UnitTestFormat = UnitTestFormat.Dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do that then we run into the frustrating issue of fields without defaults can't come after fields with defaults issue and have to split them out into a special class and do a different order. I'd kind of rather not.


def get_rows(self) -> List[Dict[str, Any]]:
if self.format == UnitTestFormat.Dict:
assert isinstance(self.rows, List)
return self.rows
elif self.format == UnitTestFormat.CSV:
assert isinstance(self.rows, str)
dummy_file = StringIO(self.rows)
reader = csv.DictReader(dummy_file)
rows = []
for row in reader:
rows.append(row)
return rows


@dataclass
class InputFixture(dbtClassMixin):
class UnitTestInputFixture(dbtClassMixin, UnitTestFixture):
input: str
rows: Union[str, List[Dict[str, Any]]] = ""
format: UnitTestFormat = UnitTestFormat.Dict


@dataclass
class OutputFixture(dbtClassMixin):
class UnitTestOutputFixture(dbtClassMixin, UnitTestFixture):
rows: Union[str, List[Dict[str, Any]]] = ""
format: UnitTestFormat = UnitTestFormat.Dict

Expand All @@ -764,8 +789,8 @@ class UnitTestOverrides(dbtClassMixin):
@dataclass
class UnparsedUnitTestDefinition(dbtClassMixin):
name: str
given: Sequence[InputFixture]
expect: OutputFixture
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
overrides: Optional[UnitTestOverrides] = None
config: Dict[str, Any] = field(default_factory=dict)
Expand Down
34 changes: 4 additions & 30 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import List, Set, Dict, Any
import csv
from io import StringIO

from dbt.config import RuntimeConfig
from dbt.context.context_config import ContextConfig
Expand Down Expand Up @@ -61,25 +59,16 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
# for selection.
# Note: no depends_on, that's added later using input nodes
name = f"{test_case.model}__{test_case.name}"
if test_case.expect.format == UnitTestFormat.Dict:
if isinstance(test_case.expect.rows, List):
expected_rows = test_case.expect.rows
else:
raise ParsingError("Wrong format for expected rows")
else: # test_case.expect.format == UnitTestFormat.CSV:
# build a dictionary from the csv string
if isinstance(test_case.expect.rows, str):
expected_rows = self._build_rows_from_csv(test_case.expect.rows)
else:
raise ParsingError("Wrong format for expected rows")
unit_test_node = UnitTestNode(
name=name,
resource_type=NodeType.Unit,
package_name=package_name,
path=get_pseudo_test_path(name, test_case.original_file_path),
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
config=UnitTestNodeConfig(materialized="unit", expected_rows=expected_rows),
config=UnitTestNodeConfig(
materialized="unit", expected_rows=test_case.expect.get_rows()
),
raw_code=actual_node.raw_code,
database=actual_node.database,
schema=actual_node.schema,
Expand Down Expand Up @@ -131,15 +120,8 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
# TODO: package_name?
input_name = f"{test_case.model}__{test_case.name}__{original_input_node.name}"
input_unique_id = f"model.{package_name}.{input_name}"
rows: List[Dict[str, Any]]
if given.format == UnitTestFormat.CSV:
rows = self._build_rows_from_csv(given.rows)
else: # format == "dict"
# Should always be a dictionary.
rows = given.rows # type:ignore

input_node = ModelNode(
raw_code=self._build_raw_code(rows, original_input_node_columns),
raw_code=self._build_raw_code(given.get_rows(), original_input_node_columns),
resource_type=NodeType.Model,
package_name=package_name,
path=original_input_node.path,
Expand All @@ -157,14 +139,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
# Add unique ids of input_nodes to depends_on
unit_test_node.depends_on.nodes.append(input_node.unique_id)

def _build_rows_from_csv(self, csv_string) -> List[Dict[str, Any]]:
dummy_file = StringIO(csv_string)
reader = csv.DictReader(dummy_file)
rows = []
for row in reader:
rows.append(row)
return rows

def _build_raw_code(self, rows, column_name_to_data_types) -> str:
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
Expand Down