Skip to content

Commit

Permalink
Enable serialization context (#10094) (#10104)
Browse files Browse the repository at this point in the history
* Update __post_serialize__ signatures

* Temporarily linke dbt-common and dbt-adapters branches

* Changie

* Move fields not in artifacts to resource __post_serialize__ methods

* remove defer_relation in snapshots

* Remove references to branch changes

Co-authored-by: Gerda Shank <[email protected]>
  • Loading branch information
QMalcolm and gshank authored May 8, 2024
1 parent 062a778 commit 4034327
Show file tree
Hide file tree
Showing 16 changed files with 59 additions and 38 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240506-145511.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Enable use of context in serialization
time: 2024-05-06T14:55:11.1812-04:00
custom:
Author: gshank
Issue: "10093"
10 changes: 8 additions & 2 deletions core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ class ParsedResource(ParsedResourceMandatory):
relation_name: Optional[str] = None
raw_code: str = ""

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "config_call_dict" in dct:
del dct["config_call_dict"]
return dct


@dataclass
class CompiledResource(ParsedResource):
Expand All @@ -215,8 +221,8 @@ class CompiledResource(ParsedResource):
_pre_injected_sql: Optional[str] = None
contract: Contract = field(default_factory=Contract)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_pre_injected_sql" in dct:
del dct["_pre_injected_sql"]
# Remove compiled attributes
Expand Down
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Literal, Optional, List
from typing import Dict, Literal, Optional, List
from datetime import datetime
from dbt_common.contracts.config.base import MergeBehavior
from dbt_common.contracts.constraints import ModelLevelConstraint
Expand All @@ -26,3 +26,9 @@ class Model(CompiledResource):
latest_version: Optional[NodeVersion] = None
deprecation_date: Optional[datetime] = None
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
del dct["defer_relation"]
return dct
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/seed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Literal
from typing import Dict, Optional, Literal
from dbt_common.dataclass_schema import ValidationError
from dbt.artifacts.resources.types import NodeType
from dbt.artifacts.resources.v1.components import MacroDependsOn, DeferRelation, ParsedResource
Expand Down Expand Up @@ -28,3 +28,9 @@ class Seed(ParsedResource): # No SQLDefaults!
root_path: Optional[str] = None
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
del dct["defer_relation"]
return dct
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List, Optional, Literal
from typing import Dict, Union, List, Optional, Literal
from dataclasses import dataclass
from dbt_common.dataclass_schema import ValidationError
from dbt.artifacts.resources.types import NodeType
Expand Down Expand Up @@ -64,3 +64,9 @@ class Snapshot(CompiledResource):
resource_type: Literal[NodeType.Snapshot]
config: SnapshotConfig
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
del dct["defer_relation"]
return dct
6 changes: 3 additions & 3 deletions core/dbt/artifacts/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __str__(self) -> str:

class Writable:
def write(self, path: str):
write_json(path, self.to_dict(omit_none=False)) # type: ignore
write_json(path, self.to_dict(omit_none=False, context={"artifact": True})) # type: ignore


class Readable:
Expand All @@ -60,8 +60,8 @@ class BaseArtifactMetadata(dbtClassMixin):
invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id)
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_vars)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if dct["generated_at"] and dct["generated_at"].endswith("+00:00"):
dct["generated_at"] = dct["generated_at"].replace("+00:00", "") + "Z"
return dct
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/artifacts/schemas/catalog/v1/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class CatalogResults(dbtClassMixin):
errors: Optional[List[str]] = None
_compile_results: Optional[Any] = None

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_compile_results" in dct:
del dct["_compile_results"]
return dct
Expand Down
8 changes: 0 additions & 8 deletions core/dbt/artifacts/schemas/manifest/v12/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,3 @@ def upgrade_schema_version(cls, data):
if manifest_schema_version < cls.dbt_schema_version.version:
data = upgrade_manifest_json(data, manifest_schema_version)
return cls.from_dict(data)

def __post_serialize__(self, dct):
for unique_id, node in dct["nodes"].items():
if "config_call_dict" in node:
del node["config_call_dict"]
if "defer_relation" in node:
del node["defer_relation"]
return dct
8 changes: 4 additions & 4 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def _deserialize(cls, dct: Dict[str, int]):
sf = SourceFile.from_dict(dct)
return sf

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
# remove empty lists to save space
dct_keys = list(dct.keys())
for key in dct_keys:
Expand Down Expand Up @@ -226,8 +226,8 @@ def macro_patches(self):
def source_patches(self):
return self.sop

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
# Remove partial parsing specific data
for key in ("pp_test_index", "pp_dict"):
if key in dct:
Expand Down
5 changes: 2 additions & 3 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict
from dataclasses import dataclass, field, replace
from itertools import chain
from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from multiprocessing.synchronize import Lock
from typing import (
DefaultDict,
Expand Down Expand Up @@ -803,7 +802,7 @@ class ManifestStateCheck(dbtClassMixin):


@dataclass
class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
class Manifest(MacroMethods, dbtClassMixin):
"""The manifest for the full graph, after parsing and during compilation."""

# These attributes are both positional and by keyword. If an attribute
Expand Down Expand Up @@ -870,7 +869,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)

def __pre_serialize__(self):
def __pre_serialize__(self, context: Optional[Dict] = None):
# serialization won't work with anything except an empty source_patches because
# tuple keys are not supported, so ensure it's empty
self.source_patches = {}
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def write_node(self, project_root: str, compiled_path, compiled_code: str):
def _serialize(self):
return self.to_dict()

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_event_status" in dct:
del dct["_event_status"]
return dct
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasColumnAndTestProps):
external: Optional[ExternalTable] = None
tags: List[str] = field(default_factory=list)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "freshness" not in dct and self.freshness is None:
dct["freshness"] = None
return dct
Expand Down Expand Up @@ -314,8 +314,8 @@ def validate(cls, data):
def yaml_key(self) -> "str":
return "sources"

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "freshness" not in dct and self.freshness is None:
dct["freshness"] = None
return dct
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class ManifestLoaderInfo(dbtClassMixin, Writable):
projects: List[ProjectLoaderInfo] = field(default_factory=list)
_project_index: Dict[str, ProjectLoaderInfo] = field(default_factory=dict)

def __post_serialize__(self, dct):
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
del dct["_project_index"]
return dct

Expand Down
2 changes: 1 addition & 1 deletion third-party-stubs/mashumaro/jsonschema/builder.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def build_json_schema(

class JSONSchemaDefinitions(DataClassJSONMixin):
definitions: Dict[str, JSONSchema]
def __post_serialize__(self, d: Dict[Any, Any]) -> List[Dict[str, Any]]: ... # type: ignore
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> List[Dict[str, Any]]: ... # type: ignore
def __init__(self, definitions) -> None: ...

class JSONSchemaBuilder:
Expand Down
4 changes: 2 additions & 2 deletions third-party-stubs/mashumaro/jsonschema/models.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class JSONSchema(DataClassJSONMixin):
serialize_by_alias: bool
aliases: Incomplete
serialization_strategy: Incomplete
def __pre_serialize__(self) -> JSONSchema: ...
def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]: ...
def __pre_serialize__(self, context: Optional[Dict]) -> JSONSchema: ...
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> Dict[Any, Any]: ...
def __init__(
self,
schema,
Expand Down
6 changes: 3 additions & 3 deletions third-party-stubs/mashumaro/mixins/dict.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Mapping, Type, TypeVar
from typing import Any, Dict, Mapping, Type, TypeVar, Optional

T = TypeVar("T", bound="DataClassDictMixin")

Expand All @@ -11,5 +11,5 @@ class DataClassDictMixin:
def __pre_deserialize__(cls: Type[T], d: Dict[Any, Any]) -> Dict[Any, Any]: ...
@classmethod
def __post_deserialize__(cls: Type[T], obj: T) -> T: ...
def __pre_serialize__(self: T) -> T: ...
def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]: ...
def __pre_serialize__(self: T, context: Optional[Dict]) -> T: ...
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> Dict[Any, Any]: ...

0 comments on commit 4034327

Please sign in to comment.