Skip to content

Commit

Permalink
Enable serialization context (#10094)
Browse files Browse the repository at this point in the history
* Update __post_serialize__ signatures

* Temporarily linke dbt-common and dbt-adapters branches

* Changie

* Move fields not in artifacts to resource __post_serialize__ methods

* remove defer_relation in snapshots

* Remove references to branch changes
  • Loading branch information
gshank authored May 7, 2024
1 parent e349e01 commit 760e4ce
Show file tree
Hide file tree
Showing 16 changed files with 59 additions and 38 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240506-145511.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Enable use of context in serialization
time: 2024-05-06T14:55:11.1812-04:00
custom:
Author: gshank
Issue: "10093"
10 changes: 8 additions & 2 deletions core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ class ParsedResource(ParsedResourceMandatory):
relation_name: Optional[str] = None
raw_code: str = ""

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "config_call_dict" in dct:
del dct["config_call_dict"]
return dct


@dataclass
class CompiledResource(ParsedResource):
Expand All @@ -214,8 +220,8 @@ class CompiledResource(ParsedResource):
_pre_injected_sql: Optional[str] = None
contract: Contract = field(default_factory=Contract)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_pre_injected_sql" in dct:
del dct["_pre_injected_sql"]
# Remove compiled attributes
Expand Down
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Literal, Optional
from typing import Dict, List, Literal, Optional

from dbt.artifacts.resources.types import AccessType, NodeType
from dbt.artifacts.resources.v1.components import (
Expand Down Expand Up @@ -31,3 +31,9 @@ class Model(CompiledResource):
latest_version: Optional[NodeVersion] = None
deprecation_date: Optional[datetime] = None
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
del dct["defer_relation"]
return dct
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/seed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Dict, Literal, Optional

from dbt.artifacts.resources.types import NodeType
from dbt.artifacts.resources.v1.components import (
Expand Down Expand Up @@ -33,3 +33,9 @@ class Seed(ParsedResource): # No SQLDefaults!
root_path: Optional[str] = None
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
del dct["defer_relation"]
return dct
8 changes: 7 additions & 1 deletion core/dbt/artifacts/resources/v1/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union

from dbt.artifacts.resources.types import NodeType
from dbt.artifacts.resources.v1.components import CompiledResource, DeferRelation
Expand Down Expand Up @@ -65,3 +65,9 @@ class Snapshot(CompiledResource):
resource_type: Literal[NodeType.Snapshot]
config: SnapshotConfig
defer_relation: Optional[DeferRelation] = None

def __post_serialize__(self, dct, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if context and context.get("artifact") and "defer_relation" in dct:
del dct["defer_relation"]
return dct
6 changes: 3 additions & 3 deletions core/dbt/artifacts/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __str__(self) -> str:

class Writable:
def write(self, path: str):
write_json(path, self.to_dict(omit_none=False)) # type: ignore
write_json(path, self.to_dict(omit_none=False, context={"artifact": True})) # type: ignore


class Readable:
Expand All @@ -59,8 +59,8 @@ class BaseArtifactMetadata(dbtClassMixin):
invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id)
env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_vars)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if dct["generated_at"] and dct["generated_at"].endswith("+00:00"):
dct["generated_at"] = dct["generated_at"].replace("+00:00", "") + "Z"
return dct
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/artifacts/schemas/catalog/v1/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class CatalogResults(dbtClassMixin):
errors: Optional[List[str]] = None
_compile_results: Optional[Any] = None

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_compile_results" in dct:
del dct["_compile_results"]
return dct
Expand Down
8 changes: 0 additions & 8 deletions core/dbt/artifacts/schemas/manifest/v12/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,3 @@ def upgrade_schema_version(cls, data):
if manifest_schema_version < cls.dbt_schema_version.version:
data = upgrade_manifest_json(data, manifest_schema_version)
return cls.from_dict(data)

def __post_serialize__(self, dct):
for unique_id, node in dct["nodes"].items():
if "config_call_dict" in node:
del node["config_call_dict"]
if "defer_relation" in node:
del node["defer_relation"]
return dct
8 changes: 4 additions & 4 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def _deserialize(cls, dct: Dict[str, int]):
sf = SourceFile.from_dict(dct)
return sf

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
# remove empty lists to save space
dct_keys = list(dct.keys())
for key in dct_keys:
Expand Down Expand Up @@ -226,8 +226,8 @@ def macro_patches(self):
def source_patches(self):
return self.sop

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
# Remove partial parsing specific data
for key in ("pp_test_index", "pp_dict"):
if key in dct:
Expand Down
5 changes: 2 additions & 3 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Union,
)

from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from typing_extensions import Protocol

import dbt_common.exceptions
Expand Down Expand Up @@ -805,7 +804,7 @@ class ManifestStateCheck(dbtClassMixin):


@dataclass
class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
class Manifest(MacroMethods, dbtClassMixin):
"""The manifest for the full graph, after parsing and during compilation."""

# These attributes are both positional and by keyword. If an attribute
Expand Down Expand Up @@ -872,7 +871,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)

def __pre_serialize__(self):
def __pre_serialize__(self, context: Optional[Dict] = None):
# serialization won't work with anything except an empty source_patches because
# tuple keys are not supported, so ensure it's empty
self.source_patches = {}
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def write_node(self, project_root: str, compiled_path, compiled_code: str):
def _serialize(self):
return self.to_dict()

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "_event_status" in dct:
del dct["_event_status"]
return dct
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasColumnAndTestProps):
external: Optional[ExternalTable] = None
tags: List[str] = field(default_factory=list)

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "freshness" not in dct and self.freshness is None:
dct["freshness"] = None
return dct
Expand Down Expand Up @@ -316,8 +316,8 @@ def validate(cls, data):
def yaml_key(self) -> "str":
return "sources"

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
if "freshness" not in dct and self.freshness is None:
dct["freshness"] = None
return dct
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class ManifestLoaderInfo(dbtClassMixin, Writable):
projects: List[ProjectLoaderInfo] = field(default_factory=list)
_project_index: Dict[str, ProjectLoaderInfo] = field(default_factory=dict)

def __post_serialize__(self, dct):
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
del dct["_project_index"]
return dct

Expand Down
2 changes: 1 addition & 1 deletion third-party-stubs/mashumaro/jsonschema/builder.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def build_json_schema(

class JSONSchemaDefinitions(DataClassJSONMixin):
definitions: Dict[str, JSONSchema]
def __post_serialize__(self, d: Dict[Any, Any]) -> List[Dict[str, Any]]: ... # type: ignore
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> List[Dict[str, Any]]: ... # type: ignore
def __init__(self, definitions) -> None: ...

class JSONSchemaBuilder:
Expand Down
4 changes: 2 additions & 2 deletions third-party-stubs/mashumaro/jsonschema/models.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class JSONSchema(DataClassJSONMixin):
serialize_by_alias: bool
aliases: Incomplete
serialization_strategy: Incomplete
def __pre_serialize__(self) -> JSONSchema: ...
def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]: ...
def __pre_serialize__(self, context: Optional[Dict]) -> JSONSchema: ...
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> Dict[Any, Any]: ...
def __init__(
self,
schema,
Expand Down
6 changes: 3 additions & 3 deletions third-party-stubs/mashumaro/mixins/dict.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Mapping, Type, TypeVar
from typing import Any, Dict, Mapping, Type, TypeVar, Optional

T = TypeVar("T", bound="DataClassDictMixin")

Expand All @@ -11,5 +11,5 @@ class DataClassDictMixin:
def __pre_deserialize__(cls: Type[T], d: Dict[Any, Any]) -> Dict[Any, Any]: ...
@classmethod
def __post_deserialize__(cls: Type[T], obj: T) -> T: ...
def __pre_serialize__(self: T) -> T: ...
def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]: ...
def __pre_serialize__(self: T, context: Optional[Dict]) -> T: ...
def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> Dict[Any, Any]: ...

0 comments on commit 760e4ce

Please sign in to comment.