diff --git a/.changes/unreleased/Under the Hood-20240506-145511.yaml b/.changes/unreleased/Under the Hood-20240506-145511.yaml new file mode 100644 index 00000000000..f5bad25d797 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240506-145511.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Enable use of context in serialization +time: 2024-05-06T14:55:11.1812-04:00 +custom: + Author: gshank + Issue: "10093" diff --git a/core/dbt/artifacts/resources/v1/components.py b/core/dbt/artifacts/resources/v1/components.py index 119e5fee1bb..6e6605c18ab 100644 --- a/core/dbt/artifacts/resources/v1/components.py +++ b/core/dbt/artifacts/resources/v1/components.py @@ -195,6 +195,12 @@ class ParsedResource(ParsedResourceMandatory): relation_name: Optional[str] = None raw_code: str = "" + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) + if context and context.get("artifact") and "config_call_dict" in dct: + del dct["config_call_dict"] + return dct + @dataclass class CompiledResource(ParsedResource): @@ -214,8 +220,8 @@ class CompiledResource(ParsedResource): _pre_injected_sql: Optional[str] = None contract: Contract = field(default_factory=Contract) - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) if "_pre_injected_sql" in dct: del dct["_pre_injected_sql"] # Remove compiled attributes diff --git a/core/dbt/artifacts/resources/v1/model.py b/core/dbt/artifacts/resources/v1/model.py index f575f360aa3..ed88bb34f8b 100644 --- a/core/dbt/artifacts/resources/v1/model.py +++ b/core/dbt/artifacts/resources/v1/model.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List, Literal, Optional +from typing import Dict, List, Literal, Optional from dbt.artifacts.resources.types import AccessType, NodeType from dbt.artifacts.resources.v1.components import ( @@ -31,3 +31,9 @@ class Model(CompiledResource): latest_version: Optional[NodeVersion] = None deprecation_date: Optional[datetime] = None defer_relation: Optional[DeferRelation] = None + + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) + if context and context.get("artifact") and "defer_relation" in dct: + del dct["defer_relation"] + return dct diff --git a/core/dbt/artifacts/resources/v1/seed.py b/core/dbt/artifacts/resources/v1/seed.py index 09d9233710f..5328488b3c5 100644 --- a/core/dbt/artifacts/resources/v1/seed.py +++ b/core/dbt/artifacts/resources/v1/seed.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import Dict, Literal, Optional from dbt.artifacts.resources.types import NodeType from dbt.artifacts.resources.v1.components import ( @@ -33,3 +33,9 @@ class Seed(ParsedResource): # No SQLDefaults! root_path: Optional[str] = None depends_on: MacroDependsOn = field(default_factory=MacroDependsOn) defer_relation: Optional[DeferRelation] = None + + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) + if context and context.get("artifact") and "defer_relation" in dct: + del dct["defer_relation"] + return dct diff --git a/core/dbt/artifacts/resources/v1/snapshot.py b/core/dbt/artifacts/resources/v1/snapshot.py index c20911ad3af..6164d953184 100644 --- a/core/dbt/artifacts/resources/v1/snapshot.py +++ b/core/dbt/artifacts/resources/v1/snapshot.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union from dbt.artifacts.resources.types import NodeType from dbt.artifacts.resources.v1.components import CompiledResource, DeferRelation @@ -65,3 +65,9 @@ class Snapshot(CompiledResource): resource_type: Literal[NodeType.Snapshot] config: SnapshotConfig defer_relation: Optional[DeferRelation] = None + + def __post_serialize__(self, dct, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) + if context and context.get("artifact") and "defer_relation" in dct: + del dct["defer_relation"] + return dct diff --git a/core/dbt/artifacts/schemas/base.py b/core/dbt/artifacts/schemas/base.py index 2ee9f09eb5e..c1a8f0b65de 100644 --- a/core/dbt/artifacts/schemas/base.py +++ b/core/dbt/artifacts/schemas/base.py @@ -33,7 +33,7 @@ def __str__(self) -> str: class Writable: def write(self, path: str): - write_json(path, self.to_dict(omit_none=False)) # type: ignore + write_json(path, self.to_dict(omit_none=False, context={"artifact": True})) # type: ignore class Readable: @@ -59,8 +59,8 @@ class BaseArtifactMetadata(dbtClassMixin): invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id) env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_vars) - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) if dct["generated_at"] and dct["generated_at"].endswith("+00:00"): dct["generated_at"] = dct["generated_at"].replace("+00:00", "") + "Z" return dct diff --git a/core/dbt/artifacts/schemas/catalog/v1/catalog.py b/core/dbt/artifacts/schemas/catalog/v1/catalog.py index d1c692e7573..d6d02608bca 100644 --- a/core/dbt/artifacts/schemas/catalog/v1/catalog.py +++ b/core/dbt/artifacts/schemas/catalog/v1/catalog.py @@ -81,8 +81,8 @@ class CatalogResults(dbtClassMixin): errors: Optional[List[str]] = None _compile_results: Optional[Any] = None - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) if "_compile_results" in dct: del dct["_compile_results"] return dct diff --git a/core/dbt/artifacts/schemas/manifest/v12/manifest.py b/core/dbt/artifacts/schemas/manifest/v12/manifest.py index 4a4314ceab9..cc13fca43f5 100644 --- a/core/dbt/artifacts/schemas/manifest/v12/manifest.py +++ b/core/dbt/artifacts/schemas/manifest/v12/manifest.py @@ -180,11 +180,3 @@ def upgrade_schema_version(cls, data): if manifest_schema_version < cls.dbt_schema_version.version: data = upgrade_manifest_json(data, manifest_schema_version) return cls.from_dict(data) - - def __post_serialize__(self, dct): - for unique_id, node in dct["nodes"].items(): - if "config_call_dict" in node: - del node["config_call_dict"] - if "defer_relation" in node: - del node["defer_relation"] - return dct diff --git a/core/dbt/contracts/files.py b/core/dbt/contracts/files.py index 512dc1533d0..2c78e97f977 100644 --- a/core/dbt/contracts/files.py +++ b/core/dbt/contracts/files.py @@ -139,8 +139,8 @@ def _deserialize(cls, dct: Dict[str, int]): sf = SourceFile.from_dict(dct) return sf - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) # remove empty lists to save space dct_keys = list(dct.keys()) for key in dct_keys: @@ -226,8 +226,8 @@ def macro_patches(self): def source_patches(self): return self.sop - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) # Remove partial parsing specific data for key in ("pp_test_index", "pp_dict"): if key in dct: diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 08000eb5ad9..9ca11166388 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -20,7 +20,6 @@ Union, ) -from mashumaro.mixins.msgpack import DataClassMessagePackMixin from typing_extensions import Protocol import dbt_common.exceptions @@ -805,7 +804,7 @@ class ManifestStateCheck(dbtClassMixin): @dataclass -class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin): +class Manifest(MacroMethods, dbtClassMixin): """The manifest for the full graph, after parsing and during compilation.""" # These attributes are both positional and by keyword. If an attribute @@ -872,7 +871,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin): metadata={"serialize": lambda x: None, "deserialize": lambda x: None}, ) - def __pre_serialize__(self): + def __pre_serialize__(self, context: Optional[Dict] = None): # serialization won't work with anything except an empty source_patches because # tuple keys are not supported, so ensure it's empty self.source_patches = {} diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 8edb307f242..4cc72327332 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -258,8 +258,8 @@ def write_node(self, project_root: str, compiled_path, compiled_code: str): def _serialize(self): return self.to_dict() - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) if "_event_status" in dct: del dct["_event_status"] return dct diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 0cddc3139b0..f2fb390c69e 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -278,8 +278,8 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasColumnAndTestProps): external: Optional[ExternalTable] = None tags: List[str] = field(default_factory=list) - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) if "freshness" not in dct and self.freshness is None: dct["freshness"] = None return dct @@ -316,8 +316,8 @@ def validate(cls, data): def yaml_key(self) -> "str": return "sources" - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): + dct = super().__post_serialize__(dct, context) if "freshness" not in dct and self.freshness is None: dct["freshness"] = None return dct diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index adb1072107d..0ab35110a66 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -211,7 +211,7 @@ class ManifestLoaderInfo(dbtClassMixin, Writable): projects: List[ProjectLoaderInfo] = field(default_factory=list) _project_index: Dict[str, ProjectLoaderInfo] = field(default_factory=dict) - def __post_serialize__(self, dct): + def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): del dct["_project_index"] return dct diff --git a/tests/unit/artifacts/test_base_resource.py b/tests/unit/artifacts/test_base_resource.py index 9067bf7a205..6809d524cd1 100644 --- a/tests/unit/artifacts/test_base_resource.py +++ b/tests/unit/artifacts/test_base_resource.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + import pytest from dbt.artifacts.resources.base import BaseResource diff --git a/third-party-stubs/mashumaro/jsonschema/builder.pyi b/third-party-stubs/mashumaro/jsonschema/builder.pyi index 98bbc860298..8f973240a85 100644 --- a/third-party-stubs/mashumaro/jsonschema/builder.pyi +++ b/third-party-stubs/mashumaro/jsonschema/builder.pyi @@ -16,7 +16,7 @@ def build_json_schema( class JSONSchemaDefinitions(DataClassJSONMixin): definitions: Dict[str, JSONSchema] - def __post_serialize__(self, d: Dict[Any, Any]) -> List[Dict[str, Any]]: ... # type: ignore + def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> List[Dict[str, Any]]: ... # type: ignore def __init__(self, definitions) -> None: ... class JSONSchemaBuilder: diff --git a/third-party-stubs/mashumaro/jsonschema/models.pyi b/third-party-stubs/mashumaro/jsonschema/models.pyi index b67db67b20b..6022d3d129f 100644 --- a/third-party-stubs/mashumaro/jsonschema/models.pyi +++ b/third-party-stubs/mashumaro/jsonschema/models.pyi @@ -106,8 +106,8 @@ class JSONSchema(DataClassJSONMixin): serialize_by_alias: bool aliases: Incomplete serialization_strategy: Incomplete - def __pre_serialize__(self) -> JSONSchema: ... - def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]: ... + def __pre_serialize__(self, context: Optional[Dict]) -> JSONSchema: ... + def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> Dict[Any, Any]: ... def __init__( self, schema, diff --git a/third-party-stubs/mashumaro/mixins/dict.pyi b/third-party-stubs/mashumaro/mixins/dict.pyi index 877283960a9..c6ec9accad1 100644 --- a/third-party-stubs/mashumaro/mixins/dict.pyi +++ b/third-party-stubs/mashumaro/mixins/dict.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, Mapping, Type, TypeVar +from typing import Any, Dict, Mapping, Type, TypeVar, Optional T = TypeVar("T", bound="DataClassDictMixin") @@ -11,5 +11,5 @@ class DataClassDictMixin: def __pre_deserialize__(cls: Type[T], d: Dict[Any, Any]) -> Dict[Any, Any]: ... @classmethod def __post_deserialize__(cls: Type[T], obj: T) -> T: ... - def __pre_serialize__(self: T) -> T: ... - def __post_serialize__(self, d: Dict[Any, Any]) -> Dict[Any, Any]: ... + def __pre_serialize__(self: T, context: Optional[Dict]) -> T: ... + def __post_serialize__(self, d: Dict[Any, Any], context: Optional[Dict]) -> Dict[Any, Any]: ...