Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable serialization context #10094

Merged
merged 6 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240506-145511.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Enable use of context in serialization
time: 2024-05-06T14:55:11.1812-04:00
custom:
Author: gshank
Issue: "10093"
10 changes: 8 additions & 2 deletions core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@
relation_name: Optional[str] = None
raw_code: str = ""

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "config_call_dict" in dct:
del dct["config_call_dict"]

Check warning on line 201 in core/dbt/artifacts/resources/v1/components.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/artifacts/resources/v1/components.py#L201

Added line #L201 was not covered by tests
return dct


@dataclass
class CompiledResource(ParsedResource):
Expand All @@ -214,8 +220,8 @@
_pre_injected_sql: Optional[str] = None
contract: Contract = field(default_factory=Contract)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_pre_injected_sql" in dct:
del dct["_pre_injected_sql"]
# Remove compiled attributes
Expand Down
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Literal, Optional
from typing import Dict, List, Literal, Optional

from dbt.artifacts.resources.types import AccessType, NodeType
from dbt.artifacts.resources.v1.components import (
Expand Down Expand Up @@ -31,3 +31,9 @@
latest_version: Optional[NodeVersion] = None
deprecation_date: Optional[datetime] = None
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this strictly a refactor or would we be starting to include the defer_relation in the programmatic response from dbtRunner invocations with this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just maintains current functionality, which is that when generating artifacts we remove "defer_relation" and "call_config_dict" from all nodes. We could easily change it now if we want, but in order to minimize disruption I went with preserving the way it currently works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of those are included in the msgpack versions (for internal use/partial parsing)

del dct["defer_relation"]

Check warning on line 38 in core/dbt/artifacts/resources/v1/model.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/artifacts/resources/v1/model.py#L38

Added line #L38 was not covered by tests
return dct
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/seed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Dict, Literal, Optional

from dbt.artifacts.resources.types import NodeType
from dbt.artifacts.resources.v1.components import (
Expand Down Expand Up @@ -33,3 +33,9 @@
root_path: Optional[str] = None
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
del dct["defer_relation"]

Check warning on line 40 in core/dbt/artifacts/resources/v1/seed.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/artifacts/resources/v1/seed.py#L40

Added line #L40 was not covered by tests
return dct
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union

from dbt.artifacts.resources.types import NodeType
from dbt.artifacts.resources.v1.components import CompiledResource, DeferRelation
Expand Down Expand Up @@ -65,3 +65,9 @@
resource_type: Literal[NodeType.Snapshot]
config: SnapshotConfig
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
del dct["defer_relation"]

Check warning on line 72 in core/dbt/artifacts/resources/v1/snapshot.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/artifacts/resources/v1/snapshot.py#L72

Added line #L72 was not covered by tests
return dct
6 changes: 3 additions & 3 deletions core/dbt/artifacts/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class Writable:
def write(self, path: str):
write_json(path, self.to_dict(omit_none=False)) # type: ignore
write_json(path, self.to_dict(omit_none=False, context={"artifact": True})) # type: ignore

Check warning on line 36 in core/dbt/artifacts/schemas/base.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/artifacts/schemas/base.py#L36

Added line #L36 was not covered by tests


class Readable:
Expand All @@ -59,8 +59,8 @@
invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id)
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_vars)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if dct["generated_at"] and dct["generated_at"].endswith("+00:00"):
dct["generated_at"] = dct["generated_at"].replace("+00:00", "") + "Z"
return dct
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/artifacts/schemas/catalog/v1/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class CatalogResults(dbtClassMixin):
errors: Optional[List[str]] = None
_compile_results: Optional[Any] = None

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_compile_results" in dct:
del dct["_compile_results"]
return dct
Expand Down
8 changes: 0 additions & 8 deletions core/dbt/artifacts/schemas/manifest/v12/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,3 @@ def upgrade_schema_version(cls, data):
if manifest_schema_version < cls.dbt_schema_version.version:
data = upgrade_manifest_json(data, manifest_schema_version)
return cls.from_dict(data)

def __post_serialize__(self, dct):
for unique_id, node in dct["nodes"].items():
if "config_call_dict" in node:
del node["config_call_dict"]
if "defer_relation" in node:
del node["defer_relation"]
return dct
8 changes: 4 additions & 4 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@
sf = SourceFile.from_dict(dct)
return sf

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
# remove empty lists to save space
dct_keys = list(dct.keys())
for key in dct_keys:
Expand Down Expand Up @@ -226,8 +226,8 @@
def source_patches(self):
return self.sop

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)

Check warning on line 230 in core/dbt/contracts/files.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/files.py#L230

Added line #L230 was not covered by tests
# Remove partial parsing specific data
for key in ("pp_test_index", "pp_dict"):
if key in dct:
Expand Down
5 changes: 2 additions & 3 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Union,
)

from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from typing_extensions import Protocol

import dbt_common.exceptions
Expand Down Expand Up @@ -805,7 +804,7 @@ class ManifestStateCheck(dbtClassMixin):


@dataclass
class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
class Manifest(MacroMethods, dbtClassMixin):
"""The manifest for the full graph, after parsing and during compilation."""

# These attributes are both positional and by keyword. If an attribute
Expand Down Expand Up @@ -872,7 +871,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)

def __pre_serialize__(self):
def __pre_serialize__(self, context: Optional[Dict] = None):
# serialization won't work with anything except an empty source_patches because
# tuple keys are not supported, so ensure it's empty
self.source_patches = {}
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def write_node(self, project_root: str, compiled_path, compiled_code: str):
def _serialize(self):
return self.to_dict()

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_event_status" in dct:
del dct["_event_status"]
return dct
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasColumnAndTestProps):
external: Optional[ExternalTable] = None
tags: List[str] = field(default_factory=list)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "freshness" not in dct and self.freshness is None:
dct["freshness"] = None
return dct
Expand Down Expand Up @@ -316,8 +316,8 @@ def validate(cls, data):
def yaml_key(self) -> "str":
return "sources"

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "freshness" not in dct and self.freshness is None:
dct["freshness"] = None
return dct
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class ManifestLoaderInfo(dbtClassMixin, Writable):
projects: List[ProjectLoaderInfo] = field(default_factory=list)
_project_index: Dict[str, ProjectLoaderInfo] = field(default_factory=dict)

def __post_serialize__(self, dct):
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
del dct["_project_index"]
return dct

Expand Down
1 change: 1 addition & 0 deletions tests/unit/artifacts/test_base_resource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass

import pytest

from dbt.artifacts.resources.base import BaseResource
Expand Down
2 changes: 1 addition & 1 deletion third-party-stubs/mashumaro/jsonschema/builder.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def build_json_schema(

class JSONSchemaDefinitions(DataClassJSONMixin):
definitions: Dict[str, JSONSchema]
def __post_serialize__(self, d: Dict[Any, Any]) -> List[Dict[str, Any]]: ... # type: ignore
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> List[Dict[str, Any]]: ... # type: ignore
def __init__(self, definitions) -> None: ...

class JSONSchemaBuilder:
Expand Down
4 changes: 2 additions & 2 deletions third-party-stubs/mashumaro/jsonschema/models.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class JSONSchema(DataClassJSONMixin):
serialize_by_alias: bool
aliases: Incomplete
serialization_strategy: Incomplete
def __pre_serialize__(self) -> JSONSchema: ...
def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]: ...
def __pre_serialize__(self, context: Optional[Dict]) -> JSONSchema: ...
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> Dict[Any, Any]: ...
def __init__(
self,
schema,
Expand Down
6 changes: 3 additions & 3 deletions third-party-stubs/mashumaro/mixins/dict.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Mapping, Type, TypeVar
from typing import Any, Dict, Mapping, Type, TypeVar, Optional

T = TypeVar("T", bound="DataClassDictMixin")

Expand All @@ -11,5 +11,5 @@ class DataClassDictMixin:
def __pre_deserialize__(cls: Type[T], d: Dict[Any, Any]) -> Dict[Any, Any]: ...
@classmethod
def __post_deserialize__(cls: Type[T], obj: T) -> T: ...
def __pre_serialize__(self: T) -> T: ...
def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]: ...
def __pre_serialize__(self: T, context: Optional[Dict]) -> T: ...
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> Dict[Any, Any]: ...
Loading