Skip to content

Commit

Permalink
Support nested datamodels in the datamodel frontend generator
Browse files Browse the repository at this point in the history
Nested datamodels are stored within the datamodel frontend as instances
of the frontend of the nested datamodel model.
  • Loading branch information
johningve committed Dec 5, 2024
1 parent 56718cc commit 7e2dac4
Showing 1 changed file with 111 additions and 20 deletions.
131 changes: 111 additions & 20 deletions continuous-integration/code-generation/datamodel_frontend_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ def _get_underscore_name(base_type: str, node_path: Path) -> str:
return f"{underscored_base_type}_{'_'.join([_to_snake_case(part) for part in node_path.parts])}"


@dataclass
class NestedDatamodelInfo:
name: str
snake: str
to_internal: str
from_internal: str
internal_type: str


@dataclass
class MemberInfo:
is_optional: bool
is_nested_datamodel: bool
underlying_type: str


Expand All @@ -63,6 +73,7 @@ class NodeData:
underlying_zivid_class: Any
member_info: Optional[MemberInfo] = None
container_info: Optional[ContainerInfo] = None
nested_datamodel_info: Optional[NestedDatamodelInfo] = None


def _inner_classes_list(cls: Any) -> List:
Expand Down Expand Up @@ -97,6 +108,8 @@ def _imports(extra_imports: Sequence[str]) -> str:


def _create_init_special_member_function(node_data: NodeData, base_type: str) -> str:
# pylint: disable=too-many-branches

full_dot_path = _get_dot_path(base_type, node_data.path)
signature_vars = ""
member_variable_set = ""
Expand Down Expand Up @@ -135,7 +148,18 @@ def _create_init_special_member_function(node_data: NodeData, base_type: str) ->
else:
is_none_check = ""
none_message = ""
signature_vars += f"{member.snake_case}={full_dot_path}.{member.name}().value,"

if member.member_info.is_nested_datamodel:
if member.member_info.is_optional:
signature_vars += f"{member.snake_case}=None,"
else:
signature_vars += (
f"{member.snake_case}={member.member_info.underlying_type}(),"
)
else:
signature_vars += (
f"{member.snake_case}={full_dot_path}.{member.name}().value,"
)

if member.is_enum:
member_variable_set += dedent(
Expand All @@ -149,6 +173,18 @@ def _create_init_special_member_function(node_data: NodeData, base_type: str) ->
"""
)

elif member.member_info.is_nested_datamodel:
underlying_type = member.member_info.underlying_type
member_variable_set += dedent(
f"""
if isinstance({member.snake_case}, {member.member_info.underlying_type} ) {is_none_check}:
self._{member.snake_case} = {member.snake_case}
else:
raise TypeError(
'Unsupported type, expected: {underlying_type}{none_message}, got {{value_type}}'.format(value_type=type({member.snake_case}))
)
"""
)
else:
underlying_type = member.member_info.underlying_type
member_variable_set += dedent(
Expand Down Expand Up @@ -230,6 +266,8 @@ def __str__(self):


def _create_properties(node_data: NodeData, base_type: str) -> str:
# pylint: disable=too-many-branches

get_properties = "\n"
set_properties = "\n"

Expand Down Expand Up @@ -302,6 +340,29 @@ def {member.snake_case}(self, value):
"""
)

elif member.member_info.is_nested_datamodel:
underlying_type_str = member.member_info.underlying_type
expected_types_str = underlying_type_str + can_be_none_error_message_part

get_properties += dedent(
f"""
@property
def {member.snake_case}(self):
return self._{member.snake_case}
"""
)

set_properties += dedent(
f"""
@{member.snake_case}.setter
def {member.snake_case}(self, value):
if isinstance(value, {member.member_info.underlying_type}) {is_none_check}:
self._{member.snake_case} = value
else:
raise TypeError('Unsupported type, expected: {expected_types_str}, got {{value_type}}'.format(value_type=type(value)))
"""
)

else:
underlying_type_str = member.member_info.underlying_type[1:-2].split(",")
expected_types_str = (
Expand Down Expand Up @@ -497,6 +558,8 @@ class {class_name}:


def _parse_internal_datamodel(current_class: Any) -> NodeData:
# pylint: disable=too-many-branches, too-many-locals

child_classes = []
if hasattr(current_class, "valid_values") and hasattr(current_class, "enum"):
is_leaf = True
Expand All @@ -520,9 +583,27 @@ def _parse_internal_datamodel(current_class: Any) -> NodeData:
to_be_removed.append(child)

elif child.is_leaf:
if child.underlying_zivid_class.value_type.startswith("_zivid"):
datamodel_name = child.underlying_zivid_class.value_type.split(".")[-1]
datamodel_name_snake = _to_snake_case(datamodel_name)
child.is_nested_datamodel = True
underlying_type = f"zivid.{datamodel_name_snake}.{datamodel_name}"
child.nested_datamodel_info = NestedDatamodelInfo(
name=datamodel_name,
snake=datamodel_name_snake,
to_internal=f"zivid.{datamodel_name_snake}._to_internal_{datamodel_name_snake}",
from_internal=f"zivid.{datamodel_name_snake}._to_{datamodel_name_snake}",
internal_type=child.underlying_zivid_class.value_type,
)

else:
child.is_nested_datamodel = False
underlying_type = child.underlying_zivid_class.value_type

child.member_info = MemberInfo(
is_optional=child.underlying_zivid_class.is_optional,
underlying_type=child.underlying_zivid_class.value_type,
is_nested_datamodel=child.is_nested_datamodel,
underlying_type=underlying_type,
)
member_variables.append(child)
to_be_removed.append(child)
Expand Down Expand Up @@ -563,6 +644,8 @@ def _parse_internal_datamodel(current_class: Any) -> NodeData:


def _create_to_frontend_converter(node_data: NodeData, base_type: str) -> str:
# pylint: disable=too-many-locals

base_typename = base_type.split(".")[-1]
temp_internal_name = f"internal_{node_data.snake_case}"
nested_converters = [
Expand All @@ -581,23 +664,22 @@ def _create_to_frontend_converter(node_data: NodeData, base_type: str) -> str:

member_convert_logic = ""
for member in node_data.member_variables:
member_convert_logic += (
"{member} = {temp_internal_name}.{member}.value,".format(
member=member.snake_case,
temp_internal_name=temp_internal_name,
if member.member_info.is_nested_datamodel:
if member.member_info.is_optional:
if_is_none_check = f"if {temp_internal_name}.{member.snake_case}.value is not None else None"
else:
if_is_none_check = ""

member_convert_logic += f"{member.snake_case} = {member.nested_datamodel_info.from_internal}({temp_internal_name}.{member.snake_case}.value) {if_is_none_check},"
else:
member_convert_logic += (
f"{member.snake_case} = {temp_internal_name}.{member.snake_case}.value,"
)
)

child_convert_logic = ""
for child in node_data.children:
if not child.is_uninstantiated_node:
child_convert_logic += (
"{child_name}=_to_{child}({temp_internal_name}.{child_name}),".format(
child_name=child.snake_case,
child=f"{underscored_path}_{child.snake_case}",
temp_internal_name=temp_internal_name,
)
)
child_convert_logic += f"{child.snake_case}=_to_{f'{underscored_path}_{child.snake_case}'}({temp_internal_name}.{child.snake_case}),"

base_function = dedent(
f"""
Expand All @@ -615,6 +697,8 @@ def _to_{underscored_path}(internal_{node_data.snake_case}):


def _create_to_internal_converter(node_data: NodeData, base_type: str) -> str:
# pylint: disable=too-many-locals

temp_internal_name = f"internal_{node_data.snake_case}"
nested_converters = [
_create_to_internal_converter(element, base_type=base_type)
Expand All @@ -641,12 +725,19 @@ def _create_to_internal_converter(node_data: NodeData, base_type: str) -> str:

if node_data.member_variables:
for member in node_data.member_variables:
constructor_arg = (
f"{node_data.snake_case}._{member.snake_case}.value"
if member.is_enum
else f"{node_data.snake_case}.{member.snake_case}"
)
convert_member_logic += f"\n{temp_internal_name}.{member.snake_case} = {full_dot_path}.{member.name}({constructor_arg})"
if member.member_info.is_nested_datamodel:
if member.member_info.is_optional:
if_is_none_check = f"if {node_data.snake_case}.{member.snake_case} is not None else None"
else:
if_is_none_check = ""
convert_member_logic += f"\n{temp_internal_name}.{member.snake_case} = {full_dot_path}.{member.name}({member.nested_datamodel_info.to_internal}({node_data.snake_case}.{member.snake_case}) {if_is_none_check})"
else:
constructor_arg = (
f"{node_data.snake_case}._{member.snake_case}.value"
if member.is_enum
else f"{node_data.snake_case}.{member.snake_case}"
)
convert_member_logic += f"\n{temp_internal_name}.{member.snake_case} = {full_dot_path}.{member.name}({constructor_arg})"

convert_children_logic = ""
if node_data.children:
Expand Down

0 comments on commit 7e2dac4

Please sign in to comment.