From 1665f6959f30b9422c1a96b371339e623aa04253 Mon Sep 17 00:00:00 2001 From: OwenKephart Date: Tue, 31 Dec 2024 11:13:39 -0500 Subject: [PATCH] [components] Replace RequiredScope with RenderedModel (#26759) ## Summary & Motivation This changes up how we manage marking objects as having deferred fields. The result is a lot more terse, and generally easier to work around. The base class exposes a method to get all of the rendered properties, which can be overridden in scenarios where we want to customize this. For now, our only existing usage can just use this raw properties dictionary directly ## How I Tested These Changes ## Changelog NOCHANGELOG --- .../core/component_rendering.py | 94 +++++++++---- .../dagster_components/core/dsl_schema.py | 36 +++-- .../dagster_components/lib/dbt_project.py | 8 +- .../lib/pipes_subprocess_script_collection.py | 2 +- .../test_component_rendering.py | 129 ++++++++++++------ 5 files changed, 181 insertions(+), 88 deletions(-) diff --git a/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py b/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py index 3038b0069cadf..7677d32bd72b8 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py @@ -1,7 +1,18 @@ import functools -import json import os -from typing import AbstractSet, Any, Callable, Mapping, Optional, Sequence, Type, TypeVar, Union +import typing +from typing import ( + Annotated, + Any, + Callable, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, + get_origin, +) import dagster._check as check from dagster._core.definitions.declarative_automation.automation_condition import ( @@ -9,15 +20,14 @@ ) from dagster._record import record from jinja2.nativetypes import NativeTemplate -from pydantic import BaseModel, Field -from pydantic.fields import FieldInfo +from pydantic import BaseModel, ConfigDict, TypeAdapter T = TypeVar("T") REF_BASE = "#/$defs/" REF_TEMPLATE = f"{REF_BASE}{{model}}" -CONTEXT_KEY = "required_rendering_scope" +JSON_SCHEMA_EXTRA_KEY = "requires_rendering_scope" def automation_condition_scope() -> Mapping[str, Any]: @@ -27,31 +37,64 @@ def automation_condition_scope() -> Mapping[str, Any]: } -def RenderingScope(field: Optional[FieldInfo] = None, *, required_scope: AbstractSet[str]) -> Any: - """Defines a Pydantic Field that requires a specific scope to be available before rendering. +def requires_additional_scope(subschema: Mapping[str, Any]) -> bool: + raw = check.opt_inst(subschema.get(JSON_SCHEMA_EXTRA_KEY), bool) + return raw or False + + +def _env(key: str) -> Optional[str]: + return os.environ.get(key) + + +ShouldRenderFn = Callable[[Sequence[Union[str, int]]], bool] + + +@record(checked=False) +class RenderingMetadata: + """Stores metadata about how a field should be rendered. Examples: ```python - class Schema(BaseModel): - a: str = RenderingScope(required_scope={"foo", "bar"}) - b: Optional[int] = RenderingScope(Field(default=None), required_scope={"baz"}) + class MyModel(BaseModel): + some_field: Annotated[str, RenderingMetadata(output_type=MyOtherModel)] + opt_field: Annotated[Optional[str], RenderingMetadata(output_type=(None, MyOtherModel))] ``` """ - return FieldInfo.merge_field_infos( - field or Field(), Field(json_schema_extra={CONTEXT_KEY: json.dumps(list(required_scope))}) - ) + output_type: Type -def get_required_rendering_context(subschema: Mapping[str, Any]) -> Optional[AbstractSet[str]]: - raw = check.opt_inst(subschema.get(CONTEXT_KEY), str) - return set(json.loads(raw)) if raw else None +def _get_expected_type(annotation: Type) -> Optional[Type]: + origin = get_origin(annotation) + if origin is Annotated: + _, f_metadata, *_ = typing.get_args(annotation) + if isinstance(f_metadata, RenderingMetadata): + return f_metadata.output_type + else: + return annotation + return None -def _env(key: str) -> Optional[str]: - return os.environ.get(key) +class RenderedModel(BaseModel): + """Base class for models that get rendered lazily.""" -ShouldRenderFn = Callable[[Sequence[Union[str, int]]], bool] + model_config = ConfigDict(json_schema_extra={JSON_SCHEMA_EXTRA_KEY: True}) + + def render_properties(self, value_resolver: "TemplatedValueResolver") -> Mapping[str, Any]: + """Returns a dictionary of rendered properties for this class.""" + rendered_properties = value_resolver.render_obj(self.model_dump(exclude_unset=True)) + + # validate that the rendered properties match the output type + for k, v in rendered_properties.items(): + annotation = self.__annotations__[k] + expected_type = _get_expected_type(annotation) + if expected_type is not None: + # hook into pydantic's type validation to handle complicated stuff like Optional[Mapping[str, int]] + TypeAdapter( + expected_type, config={"arbitrary_types_allowed": True} + ).validate_python(v) + + return rendered_properties @record @@ -110,12 +153,12 @@ def render_params(self, val: T, target_type: Type) -> T: should_render = lambda _: True else: should_render = functools.partial( - has_rendering_scope, json_schema=json_schema, subschema=json_schema + can_render_with_default_scope, json_schema=json_schema, subschema=json_schema ) return self._render_obj(val, [], should_render=should_render) -def has_rendering_scope( +def can_render_with_default_scope( valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any] ) -> bool: """Given a valpath and the json schema of a given target type, determines if there is a rendering scope @@ -126,14 +169,17 @@ def has_rendering_scope( if "$ref" in subschema: subschema = json_schema["$defs"].get(subschema["$ref"][len(REF_BASE) :]) - if get_required_rendering_context(subschema) is not None: + if requires_additional_scope(subschema): return False elif len(valpath) == 0: return True # Optional[ComplexType] (e.g.) will contain multiple schemas in the "anyOf" field if "anyOf" in subschema: - return all(has_rendering_scope(valpath, json_schema, inner) for inner in subschema["anyOf"]) + return all( + can_render_with_default_scope(valpath, json_schema, inner) + for inner in subschema["anyOf"] + ) el = valpath[0] if isinstance(el, str): @@ -152,4 +198,4 @@ def has_rendering_scope( return subschema.get("additionalProperties", True) _, *rest = valpath - return has_rendering_scope(rest, json_schema, inner) + return can_render_with_default_scope(rest, json_schema, inner) diff --git a/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py b/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py index 558ffef7aa8ab..86d43a49a0cdc 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py @@ -12,7 +12,11 @@ from dagster._record import replace from pydantic import BaseModel, Field -from dagster_components.core.component_rendering import RenderingScope, TemplatedValueResolver +from dagster_components.core.component_rendering import ( + RenderedModel, + RenderingMetadata, + TemplatedValueResolver, +) class OpSpecBaseModel(BaseModel): @@ -20,26 +24,23 @@ class OpSpecBaseModel(BaseModel): tags: Optional[Dict[str, str]] = None -class AssetAttributesModel(BaseModel): +class AssetAttributesModel(RenderedModel): key: Optional[str] = None deps: Sequence[str] = [] description: Optional[str] = None - metadata: Union[str, Mapping[str, Any]] = {} + metadata: Annotated[ + Union[str, Mapping[str, Any]], RenderingMetadata(output_type=Mapping[str, Any]) + ] = {} group_name: Optional[str] = None skippable: bool = False code_version: Optional[str] = None owners: Sequence[str] = [] - tags: Union[str, Mapping[str, str]] = {} - automation_condition: Optional[Union[str, AutomationCondition]] = RenderingScope( - Field(None), required_scope={"automation_condition"} - ) - - class Config: - # required for AutomationCondition - arbitrary_types_allowed = True - - def get_resolved_attributes(self, value_resolver: TemplatedValueResolver) -> Mapping[str, Any]: - return value_resolver.render_obj(self.model_dump(exclude_unset=True)) + tags: Annotated[ + Union[str, Mapping[str, str]], RenderingMetadata(output_type=Mapping[str, str]) + ] = {} + automation_condition: Annotated[ + Optional[str], RenderingMetadata(output_type=Optional[AutomationCondition]) + ] = None class AssetSpecProcessor(ABC, BaseModel): @@ -62,7 +63,7 @@ def apply_to_spec( # add the original spec to the context and resolve values return self._apply_to_spec( - spec, self.attributes.get_resolved_attributes(value_resolver.with_context(asset=spec)) + spec, self.attributes.render_properties(value_resolver.with_context(asset=spec)) ) def apply(self, defs: Definitions, value_resolver: TemplatedValueResolver) -> Definitions: @@ -102,8 +103,5 @@ def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> Asse AssetAttributes = Sequence[ - Annotated[ - Union[MergeAttributes, ReplaceAttributes], - RenderingScope(Field(union_mode="left_to_right"), required_scope={"asset"}), - ] + Annotated[Union[MergeAttributes, ReplaceAttributes], Field(union_mode="left_to_right")] ] diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py index 2928be971b39b..3644de3f71741 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py @@ -17,12 +17,12 @@ TemplatedValueResolver, component_type, ) -from dagster_components.core.component_rendering import RenderingScope +from dagster_components.core.component_rendering import RenderedModel from dagster_components.core.dsl_schema import AssetAttributes, AssetSpecProcessor, OpSpecBaseModel from dagster_components.generate import generate_component_yaml -class DbtNodeTranslatorParams(BaseModel): +class DbtNodeTranslatorParams(RenderedModel): key: Optional[str] = None group: Optional[str] = None @@ -30,9 +30,7 @@ class DbtNodeTranslatorParams(BaseModel): class DbtProjectParams(BaseModel): dbt: DbtCliResource op: Optional[OpSpecBaseModel] = None - translator: Optional[DbtNodeTranslatorParams] = RenderingScope( - Field(default=None), required_scope={"node"} - ) + translator: Optional[DbtNodeTranslatorParams] = None asset_attributes: Optional[AssetAttributes] = None diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/pipes_subprocess_script_collection.py b/python_modules/libraries/dagster-components/dagster_components/lib/pipes_subprocess_script_collection.py index e69340633c179..eeeb67e1cfe68 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/pipes_subprocess_script_collection.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/pipes_subprocess_script_collection.py @@ -52,7 +52,7 @@ def load(cls, context: ComponentLoadContext) -> "PipesSubprocessScriptCollection if not script_path.exists(): raise FileNotFoundError(f"Script {script_path} does not exist") path_specs[script_path] = [ - AssetSpec(**asset.get_resolved_attributes(context.templated_value_resolver)) + AssetSpec(**asset.render_properties(context.templated_value_resolver)) for asset in script.assets ] diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py b/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py index b60e4f9941ec5..74a2bb29cb1e4 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py @@ -1,64 +1,74 @@ -from typing import Optional, Sequence +from typing import Annotated, Optional, Sequence import pytest from dagster_components.core.component_rendering import ( - RenderingScope, + RenderedModel, + RenderingMetadata, TemplatedValueResolver, - has_rendering_scope, + can_render_with_default_scope, ) -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, TypeAdapter, ValidationError -class Inner(BaseModel): +class InnerRendered(RenderedModel): a: Optional[str] = None - deferred: Optional[str] = RenderingScope(required_scope={"foo", "bar", "baz"}) -class Outer(BaseModel): +class Container(BaseModel): a: str - deferred: str = RenderingScope(required_scope={"a"}) - inner: Sequence[Inner] - inner_deferred: Sequence[Inner] = RenderingScope(required_scope={"b"}) + inner: InnerRendered - inner_optional: Optional[Sequence[Inner]] = None - inner_deferred_optional: Optional[Sequence[Inner]] = RenderingScope( - Field(default=None), required_scope={"b"} - ) + +class Outer(BaseModel): + a: str + inner: InnerRendered + container: Container + container_optional: Optional[Container] = None + inner_seq: Sequence[InnerRendered] + inner_optional: Optional[InnerRendered] = None + inner_optional_seq: Optional[Sequence[InnerRendered]] = None @pytest.mark.parametrize( "path,expected", [ (["a"], True), - (["deferred"], False), - (["inner", 0, "a"], True), - (["inner", 0, "deferred"], False), - (["inner_deferred", 0, "a"], False), - (["inner_deferred", 0, "deferred"], False), + (["inner"], False), + (["inner", "a"], False), + (["container", "a"], True), + (["container", "inner"], False), + (["container", "inner", "a"], False), + (["container_optional", "a"], True), + (["container_optional", "inner"], False), + (["container_optional", "inner", "a"], False), + (["inner_seq"], True), + (["inner_seq", 0], False), + (["inner_seq", 0, "a"], False), (["inner_optional"], True), - (["inner_optional", 0, "a"], True), - (["inner_optional", 0, "deferred"], False), - (["inner_deferred_optional", 0], False), - (["inner_deferred_optional", 0, "a"], False), + (["inner_optional", "a"], False), + (["inner_optional_seq"], True), + (["inner_optional_seq", 0], False), + (["inner_optional_seq", 0, "a"], False), ], ) -def test_should_render(path, expected: bool) -> None: +def test_can_render(path, expected: bool) -> None: assert ( - has_rendering_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected + can_render_with_default_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) + == expected ) def test_render() -> None: data = { "a": "{{ foo_val }}", - "deferred": "{{ deferred }}", - "inner": [ - {"a": "{{ bar_val }}", "deferred": "{{ deferred }}"}, - {"a": "zzz", "deferred": "zzz"}, - ], - "inner_deferred": [ - {"a": "{{ deferred }}", "deferred": "zzz"}, + "inner": {"a": "{{ deferred }}"}, + "inner_seq": [ + {"a": "{{ deferred }}"}, ], + "container": { + "a": "{{ bar_val }}", + "inner": {"a": "{{ deferred }}"}, + }, } renderer = TemplatedValueResolver(context={"foo_val": "foo", "bar_val": "bar"}) @@ -66,14 +76,55 @@ def test_render() -> None: assert rendered_data == { "a": "foo", - "deferred": "{{ deferred }}", - "inner": [ - {"a": "bar", "deferred": "{{ deferred }}"}, - {"a": "zzz", "deferred": "zzz"}, - ], - "inner_deferred": [ - {"a": "{{ deferred }}", "deferred": "zzz"}, + "inner": {"a": "{{ deferred }}"}, + "inner_seq": [ + {"a": "{{ deferred }}"}, ], + "container": { + "a": "bar", + "inner": {"a": "{{ deferred }}"}, + }, } TypeAdapter(Outer).validate_python(rendered_data) + + +class RM(RenderedModel): + the_renderable_int: Annotated[str, RenderingMetadata(output_type=int)] + the_unrenderable_int: int + + the_str: str + the_opt_int: Annotated[Optional[str], RenderingMetadata(output_type=Optional[int])] = None + + +def test_valid_rendering() -> None: + rm = RM( + the_renderable_int="{{ some_int }}", + the_unrenderable_int=1, + the_str="{{ some_str }}", + the_opt_int="{{ some_int }}", + ) + resolver = TemplatedValueResolver(context={"some_int": 1, "some_str": "aaa"}) + resolved_properties = rm.render_properties(resolver) + + assert resolved_properties == { + "the_renderable_int": 1, + "the_unrenderable_int": 1, + "the_str": "aaa", + "the_opt_int": 1, + } + + +def test_invalid_rendering() -> None: + rm = RM( + the_renderable_int="{{ some_int }}", + the_unrenderable_int=1, + the_str="{{ some_str }}", + the_opt_int="{{ some_str }}", + ) + + resolver = TemplatedValueResolver(context={"some_int": 1, "some_str": "aaa"}) + + with pytest.raises(ValidationError): + # string is not a valid output type for the_opt_int + rm.render_properties(resolver)