diff --git a/nada_dsl/nada_types/__init__.py b/nada_dsl/nada_types/__init__.py index f65a47f..33c26f7 100644 --- a/nada_dsl/nada_types/__init__.py +++ b/nada_dsl/nada_types/__init__.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, TypeAlias, Union, Type +from typing import Any, Dict, TypeAlias, Union, Type from nada_dsl.source_ref import SourceRef @@ -171,3 +171,13 @@ def is_scalar(cls) -> bool: def is_literal(cls) -> bool: """Returns True if the type is a literal.""" return False + + def instantiate(self, child: Any) -> "NadaValue": + pass + + +@dataclass +class NadaValue: + @classmethod + def to_type(cls) -> NadaType: + pass diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index dbba76a..32fcbb4 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -29,7 +29,7 @@ ) from nada_dsl.nada_types.function import NadaFunction, nada_fn from nada_dsl.nada_types.generics import U, T, R -from . import AllTypes, AllTypesType, NadaTypeRepr, OperationType +from . import AllTypes, AllTypesType, NadaTypeRepr, NadaValue, OperationType def is_primitive_integer(nada_type_str: str): @@ -47,73 +47,10 @@ def is_primitive_integer(nada_type_str: str): ) +@dataclass class Collection(NadaType): """Superclass of collection types""" - left_type: AllTypesType - right_type: AllTypesType - contained_type: AllTypesType - - def to_mir(self): - """Convert operation wrapper to a dictionary representing its type.""" - if isinstance(self, (Array, ArrayType)): - size = {"size": self.size} if self.size else {} - contained_type = self.retrieve_inner_type() - return {"Array": {"inner_type": contained_type, **size}} - if isinstance(self, (Tuple, TupleType)): - return { - "Tuple": { - "left_type": ( - self.left_type.to_mir() - if isinstance(self.left_type, (NadaType, ArrayType, TupleType)) - else self.left_type.class_to_mir() - ), - "right_type": ( - self.right_type.to_mir() - if isinstance( - self.right_type, - (NadaType, ArrayType, TupleType), - ) - else self.right_type.class_to_mir() - ), - } - } - if isinstance(self, NTuple): - return { - "NTuple": { - "types": [ - ( - ty.to_mir() - if isinstance(ty, (NadaType, ArrayType, TupleType)) - else ty.class_to_mir() - ) - for ty in [ - type(value) - for value in self.values # pylint: disable=E1101 - ] - ] - } - } - if isinstance(self, Object): - return { - "Object": { - "types": { - name: ( - ty.to_mir() - if isinstance(ty, (NadaType, ArrayType, TupleType)) - else ty.class_to_mir() - ) - for name, ty in [ - (name, type(value)) - for name, value in self.values.items() # pylint: disable=E1101 - ] - } - } - } - raise InvalidTypeError( - f"{self.__class__.__name__} is not a valid Nada Collection" - ) - def retrieve_inner_type(self): """Retrieves the child type of this collection""" if isinstance(self.contained_type, TypeVar): @@ -123,6 +60,7 @@ def retrieve_inner_type(self): return self.contained_type.to_mir() +@dataclass class Map(Generic[T, R]): """The Map operation""" @@ -203,9 +141,11 @@ def to_mir(self): } +@dataclass class Tuple(Generic[T, U], Collection): """The Tuple type""" + # TODO: T and U have to inherit from NadaType? left_type: T right_type: U @@ -215,14 +155,33 @@ def __init__(self, child, left_type: T, right_type: U): self.child = child super().__init__(self.child) + def to_mir(self): + return { + "Tuple": { + "left_type": ( + self.left_type.to_mir() + if isinstance(self.left_type, (NadaType, ArrayType, TupleType)) + else self.left_type.class_to_mir() + ), + "right_type": ( + self.right_type.to_mir() + if isinstance( + self.right_type, + (NadaType, ArrayType, TupleType), + ) + else self.right_type.class_to_mir() + ), + } + } + @classmethod - def new(cls, left_type: T, right_type: U) -> "Tuple[T, U]": + def new(cls, left_value: NadaValue, right_value: NadaValue) -> "Tuple[T, U]": """Constructs a new Tuple.""" return Tuple( - left_type=left_type, - right_type=right_type, + left_type=left_value.to_type(), + right_type=right_value.to_type(), child=TupleNew( - child=(left_type, right_type), + child=(left_value.to_type(), right_value.to_type()), source_ref=SourceRef.back_frame(), ), ) @@ -233,55 +192,70 @@ def generic_type(cls, left_type: U, right_type: T) -> TupleType: return TupleType(left_type=left_type, right_type=right_type) -def _generate_accessor(value: Any, accessor: Any) -> NadaType: - ty = type(value) - +def _generate_accessor(ty: Any, accessor: Any) -> NadaType: if ty.is_scalar(): if ty.is_literal(): - return value + return ty # value.instantiate(child=accessor) ? return ty(child=accessor) if ty == Array: return Array( child=accessor, - contained_type=value.contained_type, - size=value.size, + contained_type=ty.contained_type, + size=ty.size, ) if ty == NTuple: return NTuple( child=accessor, - values=value.values, + types=ty.values, ) if ty == Object: return Object( child=accessor, - values=value.values, + values=ty.values, ) raise TypeError(f"Unsupported type for accessor: {ty}") +@dataclass +class NTupleType: + """Marker type for NTuples.""" + + types: List[NadaType] + + def to_mir(self): + """Convert a tuple object into a Nada type.""" + return { + "NTuple": { + "types": [ty.to_mir() for ty in self.types], + } + } + + +@dataclass class NTuple(Collection): """The NTuple type""" - values: List[NadaType] + types: List[NadaValue] - def __init__(self, child, values: List[NadaType]): - self.values = values + def __init__(self, child, types: List[NadaType]): + self.types = types self.child = child super().__init__(self.child) @classmethod - def new(cls, values: List[NadaType]) -> "NTuple": + def new(cls, values: List[NadaValue]) -> "NTuple": """Constructs a new NTuple.""" + types = [value.to_type() for value in values] return NTuple( - values=values, + types=types, child=NTupleNew( - child=values, + child=types, source_ref=SourceRef.back_frame(), ), ) def __getitem__(self, index: int) -> NadaType: - if index >= len(self.values): + if index >= len(self.types): raise IndexError(f"Invalid index {index} for NTuple.") accessor = NTupleAccessor( @@ -290,7 +264,21 @@ def __getitem__(self, index: int) -> NadaType: source_ref=SourceRef.back_frame(), ) - return _generate_accessor(self.values[index], accessor) + return _generate_accessor(self.types[index], accessor) + + def to_mir(self): + return { + "NTuple": { + "types": [ + ( + ty.to_mir() + if isinstance(ty, (NadaType, ArrayType, TupleType)) + else ty.class_to_mir() + ) + for ty in self.types + ] + } + } @dataclass @@ -323,29 +311,42 @@ def store_in_ast(self, ty: object): ) +@dataclass +class ObjectType: + """Marker type for Objects.""" + + types: Dict[str, NadaType] + + def to_mir(self): + """Convert an object into a Nada type.""" + return {"Object": {name: ty.to_mir() for name, ty in self.types.items()}} + + +@dataclass class Object(Collection): """The Object type""" - values: Dict[str, NadaType] + types: Dict[str, NadaType] - def __init__(self, child, values: Dict[str, NadaType]): - self.values = values + def __init__(self, child, types: Dict[str, NadaType]): + self.types = types self.child = child super().__init__(self.child) @classmethod def new(cls, values: Dict[str, NadaType]) -> "Object": """Constructs a new Object.""" + types = {key: value.to_type() for key, value in values.items()} return Object( - values=values, + types=types, child=ObjectNew( - child=values, + child=types, source_ref=SourceRef.back_frame(), ), ) def __getattr__(self, attr: str) -> NadaType: - if attr not in self.values: + if attr not in self.types: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{attr}'" ) @@ -356,7 +357,21 @@ def __getattr__(self, attr: str) -> NadaType: source_ref=SourceRef.back_frame(), ) - return _generate_accessor(self.values[attr], accessor) + return _generate_accessor(self.types[attr], accessor) + + def to_mir(self): + return { + "Object": { + "types": { + name: ( + ty.to_mir() + if isinstance(ty, (NadaType, ArrayType, TupleType)) + else ty.class_to_mir() + ) + for name, ty in self.types.items() + } + } + } @dataclass @@ -476,6 +491,7 @@ def to_mir(self): } +@dataclass class Array(Generic[T], Collection): """Nada Array type. @@ -576,6 +592,11 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T: "Inner product is only implemented for arrays of integer types" ) + def to_mir(self): + size = {"size": self.size} if self.size else {} + contained_type = self.retrieve_inner_type() + return {"Array": {"inner_type": contained_type, **size}} + @classmethod def new(cls, *args) -> "Array[T]": """Constructs a new Array.""" @@ -601,6 +622,7 @@ def init_as_template_type(cls, contained_type) -> "Array[T]": return Array(child=None, contained_type=contained_type, size=None) +@dataclass class TupleNew(Generic[T, U]): """MIR Tuple new operation. @@ -626,6 +648,7 @@ def store_in_ast(self, ty: object): ) +@dataclass class NTupleNew: """MIR NTuple new operation. @@ -651,6 +674,7 @@ def store_in_ast(self, ty: object): ) +@dataclass class ObjectNew: """MIR Object new operation. @@ -692,6 +716,7 @@ def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: ) +@dataclass class ArrayNew(Generic[T]): """MIR Array new operation"""