diff --git a/continuous-integration/code-generation/datamodel_frontend_generator.py b/continuous-integration/code-generation/datamodel_frontend_generator.py index d43fb21c..4a95052d 100644 --- a/continuous-integration/code-generation/datamodel_frontend_generator.py +++ b/continuous-integration/code-generation/datamodel_frontend_generator.py @@ -34,9 +34,19 @@ def _get_underscore_name(base_type: str, node_path: Path) -> str: return f"{underscored_base_type}_{'_'.join([_to_snake_case(part) for part in node_path.parts])}" +@dataclass +class NestedDatamodelInfo: + name: str + snake: str + to_internal: str + from_internal: str + internal_type: str + + @dataclass class MemberInfo: is_optional: bool + is_nested_datamodel: bool underlying_type: str @@ -63,6 +73,7 @@ class NodeData: underlying_zivid_class: Any member_info: Optional[MemberInfo] = None container_info: Optional[ContainerInfo] = None + nested_datamodel_info: Optional[NestedDatamodelInfo] = None def _inner_classes_list(cls: Any) -> List: @@ -135,7 +146,18 @@ def _create_init_special_member_function(node_data: NodeData, base_type: str) -> else: is_none_check = "" none_message = "" - signature_vars += f"{member.snake_case}={full_dot_path}.{member.name}().value," + + if member.member_info.is_nested_datamodel: + if member.member_info.is_optional: + signature_vars += f"{member.snake_case}=None," + else: + signature_vars += ( + f"{member.snake_case}={member.member_info.underlying_type}()," + ) + else: + signature_vars += ( + f"{member.snake_case}={full_dot_path}.{member.name}().value," + ) if member.is_enum: member_variable_set += dedent( @@ -149,6 +171,18 @@ def _create_init_special_member_function(node_data: NodeData, base_type: str) -> """ ) + elif member.member_info.is_nested_datamodel: + underlying_type = member.member_info.underlying_type + member_variable_set += dedent( + f""" + if isinstance({member.snake_case}, {member.member_info.underlying_type} ) {is_none_check}: + self._{member.snake_case} = {member.snake_case} + else: + raise TypeError( + 'Unsupported type, expected: {underlying_type}{none_message}, got {{value_type}}'.format(value_type=type({member.snake_case})) + ) + """ + ) else: underlying_type = member.member_info.underlying_type member_variable_set += dedent( @@ -302,6 +336,29 @@ def {member.snake_case}(self, value): """ ) + elif member.member_info.is_nested_datamodel: + underlying_type_str = member.member_info.underlying_type + expected_types_str = underlying_type_str + can_be_none_error_message_part + + get_properties += dedent( + f""" + @property + def {member.snake_case}(self): + return self._{member.snake_case} + """ + ) + + set_properties += dedent( + f""" + @{member.snake_case}.setter + def {member.snake_case}(self, value): + if isinstance(value, {member.member_info.underlying_type}) {is_none_check}: + self._{member.snake_case} = value + else: + raise TypeError('Unsupported type, expected: {expected_types_str}, got {{value_type}}'.format(value_type=type(value))) + """ + ) + else: underlying_type_str = member.member_info.underlying_type[1:-2].split(",") expected_types_str = ( @@ -520,9 +577,28 @@ def _parse_internal_datamodel(current_class: Any) -> NodeData: to_be_removed.append(child) elif child.is_leaf: + if child.underlying_zivid_class.value_type.startswith("_zivid"): + datamodel_name = child.underlying_zivid_class.value_type.split(".")[-1] + datamodel_name_snake = _to_snake_case(datamodel_name) + child.is_nested_datamodel = True + datamodel_snake_case = _to_snake_case(datamodel_name) + underlying_type = f"zivid.{datamodel_snake_case}.{datamodel_name}" + child.nested_datamodel_info = NestedDatamodelInfo( + name=datamodel_name, + snake=datamodel_name_snake, + to_internal=f"zivid.{datamodel_name_snake}._to_internal_{datamodel_name_snake}", + from_internal=f"zivid.{datamodel_name_snake}._to_{datamodel_name_snake}", + internal_type=child.underlying_zivid_class.value_type, + ) + + else: + child.is_nested_datamodel = False + underlying_type = child.underlying_zivid_class.value_type + child.member_info = MemberInfo( is_optional=child.underlying_zivid_class.is_optional, - underlying_type=child.underlying_zivid_class.value_type, + is_nested_datamodel=child.is_nested_datamodel, + underlying_type=underlying_type, ) member_variables.append(child) to_be_removed.append(child) @@ -581,23 +657,22 @@ def _create_to_frontend_converter(node_data: NodeData, base_type: str) -> str: member_convert_logic = "" for member in node_data.member_variables: - member_convert_logic += ( - "{member} = {temp_internal_name}.{member}.value,".format( - member=member.snake_case, - temp_internal_name=temp_internal_name, + if member.member_info.is_nested_datamodel: + if member.member_info.is_optional: + if_is_none_check = f"if {temp_internal_name}.{member.snake_case}.value is not None else None" + else: + if_is_none_check = "" + + member_convert_logic += f"{member.snake_case} = {member.nested_datamodel_info.from_internal}({temp_internal_name}.{member.snake_case}.value) {if_is_none_check}," + else: + member_convert_logic += ( + f"{member.snake_case} = {temp_internal_name}.{member.snake_case}.value," ) - ) child_convert_logic = "" for child in node_data.children: if not child.is_uninstantiated_node: - child_convert_logic += ( - "{child_name}=_to_{child}({temp_internal_name}.{child_name}),".format( - child_name=child.snake_case, - child=f"{underscored_path}_{child.snake_case}", - temp_internal_name=temp_internal_name, - ) - ) + child_convert_logic += f"{child.snake_case}=_to_{f'{underscored_path}_{child.snake_case}'}({temp_internal_name}.{child.snake_case})," base_function = dedent( f""" @@ -641,12 +716,19 @@ def _create_to_internal_converter(node_data: NodeData, base_type: str) -> str: if node_data.member_variables: for member in node_data.member_variables: - constructor_arg = ( - f"{node_data.snake_case}._{member.snake_case}.value" - if member.is_enum - else f"{node_data.snake_case}.{member.snake_case}" - ) - convert_member_logic += f"\n{temp_internal_name}.{member.snake_case} = {full_dot_path}.{member.name}({constructor_arg})" + if member.member_info.is_nested_datamodel: + if member.member_info.is_optional: + if_is_none_check = f"if {node_data.snake_case}.{member.snake_case} is not None else None" + else: + if_is_none_check = "" + convert_member_logic += f"\n{temp_internal_name}.{member.snake_case} = {full_dot_path}.{member.name}({member.nested_datamodel_info.to_internal}({node_data.snake_case}.{member.snake_case}) {if_is_none_check})" + else: + constructor_arg = ( + f"{node_data.snake_case}._{member.snake_case}.value" + if member.is_enum + else f"{node_data.snake_case}.{member.snake_case}" + ) + convert_member_logic += f"\n{temp_internal_name}.{member.snake_case} = {full_dot_path}.{member.name}({constructor_arg})" convert_children_logic = "" if node_data.children: