diff --git a/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py b/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py index 3038b0069cadf..7677d32bd72b8 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py @@ -1,7 +1,18 @@ import functools -import json import os -from typing import AbstractSet, Any, Callable, Mapping, Optional, Sequence, Type, TypeVar, Union +import typing +from typing import ( + Annotated, + Any, + Callable, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, + get_origin, +) import dagster._check as check from dagster._core.definitions.declarative_automation.automation_condition import ( @@ -9,15 +20,14 @@ ) from dagster._record import record from jinja2.nativetypes import NativeTemplate -from pydantic import BaseModel, Field -from pydantic.fields import FieldInfo +from pydantic import BaseModel, ConfigDict, TypeAdapter T = TypeVar("T") REF_BASE = "#/$defs/" REF_TEMPLATE = f"{REF_BASE}{{model}}" -CONTEXT_KEY = "required_rendering_scope" +JSON_SCHEMA_EXTRA_KEY = "requires_rendering_scope" def automation_condition_scope() -> Mapping[str, Any]: @@ -27,31 +37,64 @@ def automation_condition_scope() -> Mapping[str, Any]: } -def RenderingScope(field: Optional[FieldInfo] = None, *, required_scope: AbstractSet[str]) -> Any: - """Defines a Pydantic Field that requires a specific scope to be available before rendering. +def requires_additional_scope(subschema: Mapping[str, Any]) -> bool: + raw = check.opt_inst(subschema.get(JSON_SCHEMA_EXTRA_KEY), bool) + return raw or False + + +def _env(key: str) -> Optional[str]: + return os.environ.get(key) + + +ShouldRenderFn = Callable[[Sequence[Union[str, int]]], bool] + + +@record(checked=False) +class RenderingMetadata: + """Stores metadata about how a field should be rendered. Examples: ```python - class Schema(BaseModel): - a: str = RenderingScope(required_scope={"foo", "bar"}) - b: Optional[int] = RenderingScope(Field(default=None), required_scope={"baz"}) + class MyModel(BaseModel): + some_field: Annotated[str, RenderingMetadata(output_type=MyOtherModel)] + opt_field: Annotated[Optional[str], RenderingMetadata(output_type=(None, MyOtherModel))] ``` """ - return FieldInfo.merge_field_infos( - field or Field(), Field(json_schema_extra={CONTEXT_KEY: json.dumps(list(required_scope))}) - ) + output_type: Type -def get_required_rendering_context(subschema: Mapping[str, Any]) -> Optional[AbstractSet[str]]: - raw = check.opt_inst(subschema.get(CONTEXT_KEY), str) - return set(json.loads(raw)) if raw else None +def _get_expected_type(annotation: Type) -> Optional[Type]: + origin = get_origin(annotation) + if origin is Annotated: + _, f_metadata, *_ = typing.get_args(annotation) + if isinstance(f_metadata, RenderingMetadata): + return f_metadata.output_type + else: + return annotation + return None -def _env(key: str) -> Optional[str]: - return os.environ.get(key) +class RenderedModel(BaseModel): + """Base class for models that get rendered lazily.""" -ShouldRenderFn = Callable[[Sequence[Union[str, int]]], bool] + model_config = ConfigDict(json_schema_extra={JSON_SCHEMA_EXTRA_KEY: True}) + + def render_properties(self, value_resolver: "TemplatedValueResolver") -> Mapping[str, Any]: + """Returns a dictionary of rendered properties for this class.""" + rendered_properties = value_resolver.render_obj(self.model_dump(exclude_unset=True)) + + # validate that the rendered properties match the output type + for k, v in rendered_properties.items(): + annotation = self.__annotations__[k] + expected_type = _get_expected_type(annotation) + if expected_type is not None: + # hook into pydantic's type validation to handle complicated stuff like Optional[Mapping[str, int]] + TypeAdapter( + expected_type, config={"arbitrary_types_allowed": True} + ).validate_python(v) + + return rendered_properties @record @@ -110,12 +153,12 @@ def render_params(self, val: T, target_type: Type) -> T: should_render = lambda _: True else: should_render = functools.partial( - has_rendering_scope, json_schema=json_schema, subschema=json_schema + can_render_with_default_scope, json_schema=json_schema, subschema=json_schema ) return self._render_obj(val, [], should_render=should_render) -def has_rendering_scope( +def can_render_with_default_scope( valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any] ) -> bool: """Given a valpath and the json schema of a given target type, determines if there is a rendering scope @@ -126,14 +169,17 @@ def has_rendering_scope( if "$ref" in subschema: subschema = json_schema["$defs"].get(subschema["$ref"][len(REF_BASE) :]) - if get_required_rendering_context(subschema) is not None: + if requires_additional_scope(subschema): return False elif len(valpath) == 0: return True # Optional[ComplexType] (e.g.) will contain multiple schemas in the "anyOf" field if "anyOf" in subschema: - return all(has_rendering_scope(valpath, json_schema, inner) for inner in subschema["anyOf"]) + return all( + can_render_with_default_scope(valpath, json_schema, inner) + for inner in subschema["anyOf"] + ) el = valpath[0] if isinstance(el, str): @@ -152,4 +198,4 @@ def has_rendering_scope( return subschema.get("additionalProperties", True) _, *rest = valpath - return has_rendering_scope(rest, json_schema, inner) + return can_render_with_default_scope(rest, json_schema, inner) diff --git a/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py b/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py index 558ffef7aa8ab..86d43a49a0cdc 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py @@ -12,7 +12,11 @@ from dagster._record import replace from pydantic import BaseModel, Field -from dagster_components.core.component_rendering import RenderingScope, TemplatedValueResolver +from dagster_components.core.component_rendering import ( + RenderedModel, + RenderingMetadata, + TemplatedValueResolver, +) class OpSpecBaseModel(BaseModel): @@ -20,26 +24,23 @@ class OpSpecBaseModel(BaseModel): tags: Optional[Dict[str, str]] = None -class AssetAttributesModel(BaseModel): +class AssetAttributesModel(RenderedModel): key: Optional[str] = None deps: Sequence[str] = [] description: Optional[str] = None - metadata: Union[str, Mapping[str, Any]] = {} + metadata: Annotated[ + Union[str, Mapping[str, Any]], RenderingMetadata(output_type=Mapping[str, Any]) + ] = {} group_name: Optional[str] = None skippable: bool = False code_version: Optional[str] = None owners: Sequence[str] = [] - tags: Union[str, Mapping[str, str]] = {} - automation_condition: Optional[Union[str, AutomationCondition]] = RenderingScope( - Field(None), required_scope={"automation_condition"} - ) - - class Config: - # required for AutomationCondition - arbitrary_types_allowed = True - - def get_resolved_attributes(self, value_resolver: TemplatedValueResolver) -> Mapping[str, Any]: - return value_resolver.render_obj(self.model_dump(exclude_unset=True)) + tags: Annotated[ + Union[str, Mapping[str, str]], RenderingMetadata(output_type=Mapping[str, str]) + ] = {} + automation_condition: Annotated[ + Optional[str], RenderingMetadata(output_type=Optional[AutomationCondition]) + ] = None class AssetSpecProcessor(ABC, BaseModel): @@ -62,7 +63,7 @@ def apply_to_spec( # add the original spec to the context and resolve values return self._apply_to_spec( - spec, self.attributes.get_resolved_attributes(value_resolver.with_context(asset=spec)) + spec, self.attributes.render_properties(value_resolver.with_context(asset=spec)) ) def apply(self, defs: Definitions, value_resolver: TemplatedValueResolver) -> Definitions: @@ -102,8 +103,5 @@ def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> Asse AssetAttributes = Sequence[ - Annotated[ - Union[MergeAttributes, ReplaceAttributes], - RenderingScope(Field(union_mode="left_to_right"), required_scope={"asset"}), - ] + Annotated[Union[MergeAttributes, ReplaceAttributes], Field(union_mode="left_to_right")] ] diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py index 2928be971b39b..3644de3f71741 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py @@ -17,12 +17,12 @@ TemplatedValueResolver, component_type, ) -from dagster_components.core.component_rendering import RenderingScope +from dagster_components.core.component_rendering import RenderedModel from dagster_components.core.dsl_schema import AssetAttributes, AssetSpecProcessor, OpSpecBaseModel from dagster_components.generate import generate_component_yaml -class DbtNodeTranslatorParams(BaseModel): +class DbtNodeTranslatorParams(RenderedModel): key: Optional[str] = None group: Optional[str] = None @@ -30,9 +30,7 @@ class DbtNodeTranslatorParams(BaseModel): class DbtProjectParams(BaseModel): dbt: DbtCliResource op: Optional[OpSpecBaseModel] = None - translator: Optional[DbtNodeTranslatorParams] = RenderingScope( - Field(default=None), required_scope={"node"} - ) + translator: Optional[DbtNodeTranslatorParams] = None asset_attributes: Optional[AssetAttributes] = None diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/pipes_subprocess_script_collection.py b/python_modules/libraries/dagster-components/dagster_components/lib/pipes_subprocess_script_collection.py index e69340633c179..eeeb67e1cfe68 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/pipes_subprocess_script_collection.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/pipes_subprocess_script_collection.py @@ -52,7 +52,7 @@ def load(cls, context: ComponentLoadContext) -> "PipesSubprocessScriptCollection if not script_path.exists(): raise FileNotFoundError(f"Script {script_path} does not exist") path_specs[script_path] = [ - AssetSpec(**asset.get_resolved_attributes(context.templated_value_resolver)) + AssetSpec(**asset.render_properties(context.templated_value_resolver)) for asset in script.assets ] diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py b/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py index b60e4f9941ec5..74a2bb29cb1e4 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py @@ -1,64 +1,74 @@ -from typing import Optional, Sequence +from typing import Annotated, Optional, Sequence import pytest from dagster_components.core.component_rendering import ( - RenderingScope, + RenderedModel, + RenderingMetadata, TemplatedValueResolver, - has_rendering_scope, + can_render_with_default_scope, ) -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, TypeAdapter, ValidationError -class Inner(BaseModel): +class InnerRendered(RenderedModel): a: Optional[str] = None - deferred: Optional[str] = RenderingScope(required_scope={"foo", "bar", "baz"}) -class Outer(BaseModel): +class Container(BaseModel): a: str - deferred: str = RenderingScope(required_scope={"a"}) - inner: Sequence[Inner] - inner_deferred: Sequence[Inner] = RenderingScope(required_scope={"b"}) + inner: InnerRendered - inner_optional: Optional[Sequence[Inner]] = None - inner_deferred_optional: Optional[Sequence[Inner]] = RenderingScope( - Field(default=None), required_scope={"b"} - ) + +class Outer(BaseModel): + a: str + inner: InnerRendered + container: Container + container_optional: Optional[Container] = None + inner_seq: Sequence[InnerRendered] + inner_optional: Optional[InnerRendered] = None + inner_optional_seq: Optional[Sequence[InnerRendered]] = None @pytest.mark.parametrize( "path,expected", [ (["a"], True), - (["deferred"], False), - (["inner", 0, "a"], True), - (["inner", 0, "deferred"], False), - (["inner_deferred", 0, "a"], False), - (["inner_deferred", 0, "deferred"], False), + (["inner"], False), + (["inner", "a"], False), + (["container", "a"], True), + (["container", "inner"], False), + (["container", "inner", "a"], False), + (["container_optional", "a"], True), + (["container_optional", "inner"], False), + (["container_optional", "inner", "a"], False), + (["inner_seq"], True), + (["inner_seq", 0], False), + (["inner_seq", 0, "a"], False), (["inner_optional"], True), - (["inner_optional", 0, "a"], True), - (["inner_optional", 0, "deferred"], False), - (["inner_deferred_optional", 0], False), - (["inner_deferred_optional", 0, "a"], False), + (["inner_optional", "a"], False), + (["inner_optional_seq"], True), + (["inner_optional_seq", 0], False), + (["inner_optional_seq", 0, "a"], False), ], ) -def test_should_render(path, expected: bool) -> None: +def test_can_render(path, expected: bool) -> None: assert ( - has_rendering_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected + can_render_with_default_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) + == expected ) def test_render() -> None: data = { "a": "{{ foo_val }}", - "deferred": "{{ deferred }}", - "inner": [ - {"a": "{{ bar_val }}", "deferred": "{{ deferred }}"}, - {"a": "zzz", "deferred": "zzz"}, - ], - "inner_deferred": [ - {"a": "{{ deferred }}", "deferred": "zzz"}, + "inner": {"a": "{{ deferred }}"}, + "inner_seq": [ + {"a": "{{ deferred }}"}, ], + "container": { + "a": "{{ bar_val }}", + "inner": {"a": "{{ deferred }}"}, + }, } renderer = TemplatedValueResolver(context={"foo_val": "foo", "bar_val": "bar"}) @@ -66,14 +76,55 @@ def test_render() -> None: assert rendered_data == { "a": "foo", - "deferred": "{{ deferred }}", - "inner": [ - {"a": "bar", "deferred": "{{ deferred }}"}, - {"a": "zzz", "deferred": "zzz"}, - ], - "inner_deferred": [ - {"a": "{{ deferred }}", "deferred": "zzz"}, + "inner": {"a": "{{ deferred }}"}, + "inner_seq": [ + {"a": "{{ deferred }}"}, ], + "container": { + "a": "bar", + "inner": {"a": "{{ deferred }}"}, + }, } TypeAdapter(Outer).validate_python(rendered_data) + + +class RM(RenderedModel): + the_renderable_int: Annotated[str, RenderingMetadata(output_type=int)] + the_unrenderable_int: int + + the_str: str + the_opt_int: Annotated[Optional[str], RenderingMetadata(output_type=Optional[int])] = None + + +def test_valid_rendering() -> None: + rm = RM( + the_renderable_int="{{ some_int }}", + the_unrenderable_int=1, + the_str="{{ some_str }}", + the_opt_int="{{ some_int }}", + ) + resolver = TemplatedValueResolver(context={"some_int": 1, "some_str": "aaa"}) + resolved_properties = rm.render_properties(resolver) + + assert resolved_properties == { + "the_renderable_int": 1, + "the_unrenderable_int": 1, + "the_str": "aaa", + "the_opt_int": 1, + } + + +def test_invalid_rendering() -> None: + rm = RM( + the_renderable_int="{{ some_int }}", + the_unrenderable_int=1, + the_str="{{ some_str }}", + the_opt_int="{{ some_str }}", + ) + + resolver = TemplatedValueResolver(context={"some_int": 1, "some_str": "aaa"}) + + with pytest.raises(ValidationError): + # string is not a valid output type for the_opt_int + rm.render_properties(resolver)