From e2635fa3f0465740cae7b447f026d057a560fd9d Mon Sep 17 00:00:00 2001 From: Jmgr Date: Wed, 20 Nov 2024 21:50:38 +0000 Subject: [PATCH] chore: split NadaType and NadaValue / refactoring --- Makefile | 2 +- nada_dsl/nada_types/__init__.py | 27 +-- nada_dsl/nada_types/collections.py | 348 +++++++++++++++------------- nada_dsl/nada_types/function.py | 6 +- nada_dsl/nada_types/scalar_types.py | 117 ++++++++++ test-programs/ntuple_accessor.py | 12 +- tests/compiler_frontend_test.py | 4 +- tests/nada_type_test.py | 18 -- 8 files changed, 334 insertions(+), 200 deletions(-) delete mode 100644 tests/nada_type_test.py diff --git a/Makefile b/Makefile index bb49512..f9b7cab 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ test-dependencies: pip install .'[test]' test: test-dependencies - pytest + uv run pytest # Build protocol buffers definitions. build_proto: diff --git a/nada_dsl/nada_types/__init__.py b/nada_dsl/nada_types/__init__.py index f65a47f..b313729 100644 --- a/nada_dsl/nada_types/__init__.py +++ b/nada_dsl/nada_types/__init__.py @@ -2,8 +2,9 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, TypeAlias, Union, Type +from typing import Any, Dict, TypeAlias, Union, Type from nada_dsl.source_ref import SourceRef +from abc import abstractmethod @dataclass @@ -144,20 +145,16 @@ def __init__(self, child: OperationType): """ self.child = child if self.child is not None: - self.child.store_in_ast(self.to_mir()) - - def to_mir(self): - """Default implementation for the Conversion of a type into MIR representation.""" - return self.__class__.class_to_mir() - - @classmethod - def class_to_mir(cls) -> str: - """Converts a class into a MIR Nada type.""" - name = cls.__name__ - # Rename public variables so they are considered as the same as literals. - if name.startswith("Public"): - name = name[len("Public") :].lstrip() - return name + self.child.store_in_ast(self.metatype().to_mir()) + + # @classmethod + # def class_to_mir(cls) -> str: + # """Converts a class into a MIR Nada type.""" + # name = cls.__name__ + # # Rename public variables so they are considered as the same as literals. + # if name.startswith("Public"): + # name = name[len("Public") :].lstrip() + # return name def __bool__(self): raise NotImplementedError diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index dbba76a..a9aa911 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -3,6 +3,7 @@ import copy from dataclasses import dataclass import inspect +import traceback from typing import Any, Dict, Generic, List import typing from typing import TypeVar @@ -47,82 +48,7 @@ def is_primitive_integer(nada_type_str: str): ) -class Collection(NadaType): - """Superclass of collection types""" - - left_type: AllTypesType - right_type: AllTypesType - contained_type: AllTypesType - - def to_mir(self): - """Convert operation wrapper to a dictionary representing its type.""" - if isinstance(self, (Array, ArrayType)): - size = {"size": self.size} if self.size else {} - contained_type = self.retrieve_inner_type() - return {"Array": {"inner_type": contained_type, **size}} - if isinstance(self, (Tuple, TupleType)): - return { - "Tuple": { - "left_type": ( - self.left_type.to_mir() - if isinstance(self.left_type, (NadaType, ArrayType, TupleType)) - else self.left_type.class_to_mir() - ), - "right_type": ( - self.right_type.to_mir() - if isinstance( - self.right_type, - (NadaType, ArrayType, TupleType), - ) - else self.right_type.class_to_mir() - ), - } - } - if isinstance(self, NTuple): - return { - "NTuple": { - "types": [ - ( - ty.to_mir() - if isinstance(ty, (NadaType, ArrayType, TupleType)) - else ty.class_to_mir() - ) - for ty in [ - type(value) - for value in self.values # pylint: disable=E1101 - ] - ] - } - } - if isinstance(self, Object): - return { - "Object": { - "types": { - name: ( - ty.to_mir() - if isinstance(ty, (NadaType, ArrayType, TupleType)) - else ty.class_to_mir() - ) - for name, ty in [ - (name, type(value)) - for name, value in self.values.items() # pylint: disable=E1101 - ] - } - } - } - raise InvalidTypeError( - f"{self.__class__.__name__} is not a valid Nada Collection" - ) - - def retrieve_inner_type(self): - """Retrieves the child type of this collection""" - if isinstance(self.contained_type, TypeVar): - return "T" - if inspect.isclass(self.contained_type): - return self.contained_type.class_to_mir() - return self.contained_type.to_mir() - - +@dataclass class Map(Generic[T, R]): """The Map operation""" @@ -187,12 +113,15 @@ def store_in_ast(self, ty): @dataclass -class TupleType: +class TupleMetaType(MetaType): """Marker type for Tuples.""" left_type: NadaType right_type: NadaType + def instantiate(self, child): + return Tuple(child, self.left_type, self.right_type) + def to_mir(self): """Convert a tuple object into a Nada type.""" return { @@ -203,7 +132,8 @@ def to_mir(self): } -class Tuple(Generic[T, U], Collection): +@dataclass +class Tuple(Generic[T, U], NadaType): """The Tuple type""" left_type: T @@ -215,65 +145,112 @@ def __init__(self, child, left_type: T, right_type: U): self.child = child super().__init__(self.child) + """TODO this should be deleted and use MetaType.to_mir""" + + # def to_mir(self): + # return { + # "Tuple": { + # "left_type": ( + # self.left_type.to_mir() + # if isinstance( + # self.left_type, (NadaType, ArrayMetaType, TupleMetaType) + # ) + # else self.left_type.class_to_mir() + # ), + # "right_type": ( + # self.right_type.to_mir() + # if isinstance( + # self.right_type, + # (NadaType, ArrayMetaType, TupleMetaType), + # ) + # else self.right_type.class_to_mir() + # ), + # } + # } + @classmethod - def new(cls, left_type: T, right_type: U) -> "Tuple[T, U]": + def new(cls, left_value: NadaType, right_value: NadaType) -> "Tuple[T, U]": """Constructs a new Tuple.""" return Tuple( - left_type=left_type, - right_type=right_type, + left_type=left_value.metatype(), + right_type=right_value.metatype(), child=TupleNew( - child=(left_type, right_type), + child=(left_value, right_value), source_ref=SourceRef.back_frame(), ), ) @classmethod - def generic_type(cls, left_type: U, right_type: T) -> TupleType: + def generic_type(cls, left_type: U, right_type: T) -> TupleMetaType: """Returns the generic type for this Tuple""" - return TupleType(left_type=left_type, right_type=right_type) + return TupleMetaType(left_type=left_type, right_type=right_type) + + def metatype(self): + return TupleMetaType(self.left_type, self.right_type) + + +def _generate_accessor(ty: Any, accessor: Any) -> NadaType: + if hasattr(ty, "ty") and ty.ty.is_literal(): # TODO: fix + raise TypeError("Literals are not supported in accessors") + return ty.instantiate(accessor) + # if ty.is_scalar(): + # if ty.is_literal(): + # return ty # value.instantiate(child=accessor) ? + # return ty(child=accessor) + # if ty == Array: + # return Array( + # child=accessor, + # contained_type=ty.contained_type, + # size=ty.size, + # ) + # if ty == NTuple: + # return NTuple( + # child=accessor, + # types=ty.types, + # ) + # if ty == Object: + # return Object( + # child=accessor, + # types=ty.types, + # ) + # raise TypeError(f"Unsupported type for accessor: {ty}") -def _generate_accessor(value: Any, accessor: Any) -> NadaType: - ty = type(value) +@dataclass +class NTupleMetaType(MetaType): + """Marker type for NTuples.""" - if ty.is_scalar(): - if ty.is_literal(): - return value - return ty(child=accessor) - if ty == Array: - return Array( - child=accessor, - contained_type=value.contained_type, - size=value.size, - ) - if ty == NTuple: - return NTuple( - child=accessor, - values=value.values, - ) - if ty == Object: - return Object( - child=accessor, - values=value.values, - ) - raise TypeError(f"Unsupported type for accessor: {ty}") + types: List[NadaType] + + def instantiate(self, child): + return NTuple(child, self.types) + + def to_mir(self): + """Convert a tuple object into a Nada type.""" + return { + "NTuple": { + "types": [ty.to_mir() for ty in self.types], + } + } -class NTuple(Collection): +@dataclass +class NTuple(NadaType): """The NTuple type""" - values: List[NadaType] + types: List[Any] - def __init__(self, child, values: List[NadaType]): - self.values = values + def __init__(self, child, types: List[Any]): + self.types = types self.child = child super().__init__(self.child) @classmethod - def new(cls, values: List[NadaType]) -> "NTuple": + def new(cls, values: List[Any]) -> "NTuple": """Constructs a new NTuple.""" + types = [value.metatype() for value in values] return NTuple( - values=values, + types=types, child=NTupleNew( child=values, source_ref=SourceRef.back_frame(), @@ -281,7 +258,7 @@ def new(cls, values: List[NadaType]) -> "NTuple": ) def __getitem__(self, index: int) -> NadaType: - if index >= len(self.values): + if index >= len(self.types): raise IndexError(f"Invalid index {index} for NTuple.") accessor = NTupleAccessor( @@ -290,7 +267,26 @@ def __getitem__(self, index: int) -> NadaType: source_ref=SourceRef.back_frame(), ) - return _generate_accessor(self.values[index], accessor) + return _generate_accessor(self.types[index], accessor) + + """TODO this should be deleted and use MetaType.to_mir""" + + # def to_mir(self): + # return { + # "NTuple": { + # "types": [ + # ( + # ty.to_mir() + # if isinstance(ty, (NadaType, ArrayMetaType, TupleMetaType)) + # else ty.class_to_mir() + # ) + # for ty in self.types + # ] + # } + # } + + def metatype(self): + return NTupleMetaType(self.types) @dataclass @@ -323,29 +319,45 @@ def store_in_ast(self, ty: object): ) -class Object(Collection): +@dataclass +class ObjectMetaType(MetaType): + """Marker type for Objects.""" + + types: Dict[str, Any] + + def to_mir(self): + """Convert an object into a Nada type.""" + return {"Object": {name: ty.to_mir() for name, ty in self.types.items()}} + + def instantiate(self, child): + return Object(child, self.types) + + +@dataclass +class Object(NadaType): """The Object type""" - values: Dict[str, NadaType] + types: Dict[str, Any] - def __init__(self, child, values: Dict[str, NadaType]): - self.values = values + def __init__(self, child, types: Dict[str, Any]): + self.types = types self.child = child super().__init__(self.child) @classmethod - def new(cls, values: Dict[str, NadaType]) -> "Object": + def new(cls, values: Dict[str, Any]) -> "Object": """Constructs a new Object.""" + types = {key: value.metatype() for key, value in values.items()} return Object( - values=values, + types=types, child=ObjectNew( - child=values, + child=types, source_ref=SourceRef.back_frame(), ), ) def __getattr__(self, attr: str) -> NadaType: - if attr not in self.values: + if attr not in self.types: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{attr}'" ) @@ -356,7 +368,26 @@ def __getattr__(self, attr: str) -> NadaType: source_ref=SourceRef.back_frame(), ) - return _generate_accessor(self.values[attr], accessor) + return _generate_accessor(self.types[attr], accessor) + + """TODO delete this use Meta.to_mir""" + + # def to_mir(self): + # return { + # "Object": { + # "types": { + # name: ( + # ty.to_mir() + # if isinstance(ty, (NadaType, ArrayMetaType, TupleMetaType)) + # else ty.class_to_mir() + # ) + # for name, ty in self.types.items() + # } + # } + # } + + def metatype(self): + return ObjectMetaType(types=self.types) @dataclass @@ -389,15 +420,6 @@ def store_in_ast(self, ty: object): ) -# pylint: disable=W0511 -# TODO: remove this -def get_inner_type(inner_type): - """Utility that returns the inner type for a composite type.""" - inner_type = copy.copy(inner_type) - setattr(inner_type, "inner", None) - return inner_type - - class Zip: """The Zip operation.""" @@ -460,7 +482,7 @@ def store_in_ast(self, ty: NadaTypeRepr): @dataclass -class ArrayType: +class ArrayMetaType(MetaType): """Marker type for arrays.""" contained_type: AllTypesType @@ -468,15 +490,17 @@ class ArrayType: def to_mir(self): """Convert this generic type into a MIR Nada type.""" + size = {"size": self.size} if self.size else {} return { - "Array": { - "inner_type": self.contained_type.to_mir(), - "size": self.size, - } + "Array": {"inner_type": self.contained_type.to_mir(), **size} # TODO: why? } + def instantiate(self, child): + return Array(child, self.size, self.contained_type) + -class Array(Generic[T], Collection): +@dataclass +class Array(Generic[T], NadaType): """Nada Array type. This is the representation of arrays in Nada MIR. @@ -497,16 +521,22 @@ class Array(Generic[T], Collection): def __init__(self, child, size: int, contained_type: T = None): self.contained_type = ( - contained_type - if (child is None or contained_type is not None) - else get_inner_type(child) + contained_type if (child is None or contained_type is not None) else child ) + + # TODO: can we simplify the following 10 lines? + # If it's not a metatype, fetch it + if self.contained_type is not None and not isinstance( + self.contained_type, MetaType + ): + self.contained_type = self.contained_type.metatype() + self.size = size self.child = ( child if contained_type is not None else getattr(child, "child", None) ) if self.child is not None: - self.child.store_in_ast(self.to_mir()) + self.child.store_in_ast(self.metatype().to_mir()) def __iter__(self): raise NotAllowedException( @@ -543,10 +573,9 @@ def zip(self: "Array[T]", other: "Array[U]") -> "Array[Tuple[T, U]]": raise IncompatibleTypesError("Cannot zip arrays of different size") return Array( size=self.size, - contained_type=Tuple( + contained_type=TupleMetaType( left_type=self.contained_type, right_type=other.contained_type, - child=None, ), child=Zip(left=self, right=other, source_ref=SourceRef.back_frame()), ) @@ -558,15 +587,11 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T: "Cannot do child product of arrays of different size" ) - if is_primitive_integer(self.retrieve_inner_type()) and is_primitive_integer( - other.retrieve_inner_type() + if is_primitive_integer(self.contained_type) and is_primitive_integer( + other.contained_type ): - contained_type = ( - self.contained_type - if inspect.isclass(self.contained_type) - else self.contained_type.__class__ - ) - return contained_type( + + return self.contained_type.instantiate( child=InnerProduct( left=self, right=other, source_ref=SourceRef.back_frame() ) @@ -576,6 +601,12 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T: "Inner product is only implemented for arrays of integer types" ) + # TODO delete + + # def to_mir(self): + # size = {"size": self.size} if self.size else {} + # return {"Array": {"inner_type": self.contained_type, **size}} + @classmethod def new(cls, *args) -> "Array[T]": """Constructs a new Array.""" @@ -600,7 +631,11 @@ def init_as_template_type(cls, contained_type) -> "Array[T]": """Construct an empty template array with the given child type.""" return Array(child=None, contained_type=contained_type, size=None) + def metatype(self): + return ArrayMetaType(self.contained_type, self.size) + +@dataclass class TupleNew(Generic[T, U]): """MIR Tuple new operation. @@ -626,6 +661,7 @@ def store_in_ast(self, ty: object): ) +@dataclass class NTupleNew: """MIR NTuple new operation. @@ -651,6 +687,7 @@ def store_in_ast(self, ty: object): ) +@dataclass class ObjectNew: """MIR Object new operation. @@ -678,10 +715,10 @@ def store_in_ast(self, ty: object): def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: """The Unzip operation for Arrays.""" - right_type = ArrayType( + right_type = ArrayMetaType( contained_type=array.contained_type.right_type, size=array.size ) - left_type = ArrayType( + left_type = ArrayMetaType( contained_type=array.contained_type.left_type, size=array.size ) @@ -692,6 +729,7 @@ def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: ) +@dataclass class ArrayNew(Generic[T]): """MIR Array new operation""" diff --git a/nada_dsl/nada_types/function.py b/nada_dsl/nada_types/function.py index 4209d8d..7d8ba7f 100644 --- a/nada_dsl/nada_types/function.py +++ b/nada_dsl/nada_types/function.py @@ -35,7 +35,7 @@ def __init__(self, function_id: int, name: str, arg_type: T, source_ref: SourceR self.name = name self.type = arg_type self.source_ref = source_ref - self.store_in_ast(arg_type.to_mir()) + self.store_in_ast(arg_type.metatype().to_mir()) def store_in_ast(self, ty): """Store object in AST.""" @@ -101,7 +101,7 @@ def store_in_ast(self): name=self.function.__name__, args=[arg.id for arg in self.args], id=self.id, - ty=self.return_type.class_to_mir(), + ty=self.return_type.metatype().to_mir(), source_ref=self.source_ref, child=self.child.child.id, ) @@ -125,7 +125,7 @@ def __init__(self, nada_function, args, source_ref): self.args = args self.fn = nada_function self.source_ref = source_ref - self.store_in_ast(nada_function.return_type.class_to_mir()) + self.store_in_ast(nada_function.return_type.metatype().to_mir()) def store_in_ast(self, ty): """Store this function call in the AST.""" diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index 0c6bae8..4f8553a 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -1,6 +1,7 @@ # pylint:disable=W0401,W0614 """The Nada Scalar type definitions.""" +from abc import ABC from dataclasses import dataclass from typing import Union, TypeVar from typing_extensions import Self @@ -348,6 +349,26 @@ def binary_logical_operation( return SecretBoolean(child=operation) +@dataclass +class MetaType(ABC): + pass + + +@dataclass +class MetaTypePassthroughMixin(MetaType): + @classmethod + def instantiate(cls, child_or_value): + cls.ty(child_or_value) + + @classmethod + def to_mir(cls): + name = cls.ty.__name__ + # Rename public variables so they are considered as the same as literals. + if name.startswith("Public"): + name = name[len("Public") :].lstrip() + return name + + @register_scalar_type(Mode.CONSTANT, BaseType.INTEGER) class Integer(NumericType): """The Nada Integer type. @@ -370,6 +391,14 @@ def __eq__(self, other) -> AnyBoolean: def is_literal(cls) -> bool: return True + @classmethod + def metatype(cls): + return IntegerMetaType() + + +class IntegerMetaType(MetaTypePassthroughMixin): + ty = Integer + @dataclass @register_scalar_type(Mode.CONSTANT, BaseType.UNSIGNED_INTEGER) @@ -396,6 +425,14 @@ def __eq__(self, other) -> AnyBoolean: def is_literal(cls) -> bool: return True + @classmethod + def metatype(cls): + return UnsignedIntegerMetaType() + + +class UnsignedIntegerMetaType(MetaTypePassthroughMixin): + ty = UnsignedInteger + @register_scalar_type(Mode.CONSTANT, BaseType.BOOLEAN) class Boolean(BooleanType): @@ -427,6 +464,14 @@ def __invert__(self: "Boolean") -> "Boolean": def is_literal(cls) -> bool: return True + @classmethod + def metatype(cls): + return BooleanMetaType() + + +class BooleanMetaType(MetaTypePassthroughMixin): + ty = Boolean + @register_scalar_type(Mode.PUBLIC, BaseType.INTEGER) class PublicInteger(NumericType): @@ -447,6 +492,14 @@ def public_equals( """Implementation of public equality for Public integer types.""" return public_equals_operation(self, other) + @classmethod + def metatype(cls): + return PublicIntegerMetaType() + + +class PublicIntegerMetaType(MetaTypePassthroughMixin): + ty = PublicInteger + @register_scalar_type(Mode.PUBLIC, BaseType.UNSIGNED_INTEGER) class PublicUnsignedInteger(NumericType): @@ -467,6 +520,14 @@ def public_equals( """Implementation of public equality for Public unsigned integer types.""" return public_equals_operation(self, other) + @classmethod + def metatype(cls): + return PublicUnsignedIntegerMetaType() + + +class PublicUnsignedIntegerMetaType(MetaTypePassthroughMixin): + ty = PublicUnsignedInteger + @dataclass @register_scalar_type(Mode.PUBLIC, BaseType.BOOLEAN) @@ -492,6 +553,14 @@ def public_equals( """Implementation of public equality for Public boolean types.""" return public_equals_operation(self, other) + @classmethod + def metatype(cls): + return PublicBooleanMetaType() + + +class PublicBooleanMetaType(MetaTypePassthroughMixin): + ty = PublicBoolean + @dataclass @register_scalar_type(Mode.SECRET, BaseType.INTEGER) @@ -537,6 +606,14 @@ def to_public(self: "SecretInteger") -> "PublicInteger": operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicInteger(child=operation) + @classmethod + def metatype(cls): + return SecretIntegerMetaType() + + +class SecretIntegerMetaType(MetaTypePassthroughMixin): + ty = SecretInteger + @dataclass @register_scalar_type(Mode.SECRET, BaseType.UNSIGNED_INTEGER) @@ -584,6 +661,14 @@ def to_public( operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicUnsignedInteger(child=operation) + @classmethod + def metatype(cls): + return SecretUnsignedIntegerMetaType() + + +class SecretUnsignedIntegerMetaType(MetaTypePassthroughMixin): + ty = SecretUnsignedInteger + @dataclass @register_scalar_type(Mode.SECRET, BaseType.BOOLEAN) @@ -610,6 +695,14 @@ def random(cls) -> "SecretBoolean": """Generate a random secret boolean.""" return SecretBoolean(child=Random(source_ref=SourceRef.back_frame())) + @classmethod + def metatype(cls): + return SecretBooleanMetaType() + + +class SecretBooleanMetaType(MetaTypePassthroughMixin): + ty = SecretBoolean + @dataclass class EcdsaSignature(NadaType): @@ -618,6 +711,14 @@ class EcdsaSignature(NadaType): def __init__(self, child: OperationType): super().__init__(child=child) + @classmethod + def metatype(cls): + return EcdsaSignatureMetaType() + + +class EcdsaSignatureMetaType(MetaTypePassthroughMixin): + ty = EcdsaSignature + @dataclass class EcdsaDigestMessage(NadaType): @@ -626,6 +727,14 @@ class EcdsaDigestMessage(NadaType): def __init__(self, child: OperationType): super().__init__(child=child) + @classmethod + def metatype(cls): + return EcdsaDigestMessageMetaType() + + +class EcdsaDigestMessageMetaType(MetaTypePassthroughMixin): + ty = EcdsaDigestMessage + @dataclass class EcdsaPrivateKey(NadaType): @@ -639,3 +748,11 @@ def ecdsa_sign(self, digest: "EcdsaDigestMessage") -> "EcdsaSignature": return EcdsaSignature( child=EcdsaSign(left=self, right=digest, source_ref=SourceRef.back_frame()) ) + + @classmethod + def metatype(cls): + return EcdsaPrivateKeyMetaType() + + +class EcdsaPrivateKeyMetaType(MetaTypePassthroughMixin): + ty = EcdsaPrivateKey diff --git a/test-programs/ntuple_accessor.py b/test-programs/ntuple_accessor.py index 6edef15..9701830 100644 --- a/test-programs/ntuple_accessor.py +++ b/test-programs/ntuple_accessor.py @@ -9,16 +9,16 @@ def nada_main(): array = Array.new(my_int1, my_int1) # Store a scalar, a compound type and a literal. - tuple = NTuple.new([my_int1, array, Integer(42)]) + tup = NTuple.new([my_int1, array, my_int2]) - scalar = tuple[0] - array = tuple[1] - literal = tuple[2] + scalar = tup[0] + array = tup[1] + scalar2 = tup[2] @nada_fn def add(a: PublicInteger) -> PublicInteger: return a + my_int2 - sum = array.reduce(add, Integer(0)) + result = array.reduce(add, Integer(0)) - return [Output(scalar + literal + sum, "my_output", party1)] + return [Output(scalar + scalar2 + result, "my_output", party1)] diff --git a/tests/compiler_frontend_test.py b/tests/compiler_frontend_test.py index 6469cbc..b53d0f7 100644 --- a/tests/compiler_frontend_test.py +++ b/tests/compiler_frontend_test.py @@ -125,7 +125,7 @@ def test_duplicated_inputs_checks(): def test_array_type_conversion(input_type, type_name, size): inner_input = create_input(SecretInteger, "name", "party", **{}) collection = create_collection(input_type, inner_input, size, **{}) - converted_input = collection.to_mir() + converted_input = collection.metatype().to_mir() assert list(converted_input.keys()) == [type_name] @@ -513,7 +513,7 @@ def test_tuple_new_empty(): Tuple.new() assert ( str(e.value) - == "Tuple.new() missing 2 required positional arguments: 'left_type' and 'right_type'" + == "Tuple.new() missing 2 required positional arguments: 'left_value' and 'right_value'" ) diff --git a/tests/nada_type_test.py b/tests/nada_type_test.py deleted file mode 100644 index ad20ea5..0000000 --- a/tests/nada_type_test.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for NadaType.""" - -import pytest -from nada_dsl.nada_types import NadaType -from nada_dsl.nada_types.scalar_types import Integer, PublicBoolean, SecretInteger - - -@pytest.mark.parametrize( - ("cls", "expected"), - [ - (SecretInteger, "SecretInteger"), - (Integer, "Integer"), - (PublicBoolean, "Boolean"), - ], -) -def test_class_to_mir(cls: NadaType, expected: str): - """Tests `NadaType.class_to_mir()""" - assert cls.class_to_mir() == expected