Skip to content

Commit

Permalink
fix: .to_public() now doesn't underline for pyright (#49)
Browse files Browse the repository at this point in the history
* fix: `.to_public()` now doesn't underline for pyright

plus some other typing improvements

* fix: import `Self` from typing_extensions to keep it 3.10

* fix: address pylint errors

* feat: add typing_extensions to dependencies

* fix: address pylint errors, take 2

* fix: tests now should accept that all scalar types can call `to_public()`

* fix: remove all public scalar types not allowed to call `to_public()`

* feat: add `typing_extensions` to tool.uv

* feat: bump version
  • Loading branch information
cyberglot authored Nov 11, 2024
1 parent ec3406e commit a815088
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 44 deletions.
89 changes: 61 additions & 28 deletions nada_dsl/nada_types/scalar_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"""The Nada Scalar type definitions."""

from dataclasses import dataclass
from typing import Union
from typing import Any, Union, TypeVar
from typing_extensions import Self
from nada_dsl.operations import *
from nada_dsl.program_io import Literal
from nada_dsl import SourceRef
Expand All @@ -11,8 +12,34 @@
# Constant dictionary that stores all the Nada types and is use to
# convert from the (mode, base_type) representation to the concrete Nada type
# (Integer, SecretBoolean,...)
SCALAR_TYPES = {}

# pylint: disable=invalid-name
_AnyScalarType = TypeVar("_AnyScalarType",
'Integer',
'UnsignedInteger',
'Boolean',
'PublicInteger',
'PublicUnsignedInteger',
'PublicBoolean',
'SecretInteger',
'SecretUnsignedInteger',
'SecretBoolean')
# pylint: enable=invalid-name

AnyScalarType = Union['Integer',
'UnsignedInteger',
'Boolean',
'PublicInteger',
'PublicUnsignedInteger',
'PublicBoolean',
'SecretInteger',
'SecretUnsignedInteger',
'SecretBoolean']

# pylint: disable=global-variable-not-assigned
SCALAR_TYPES: dict[tuple[Mode, BaseType], type[AnyScalarType]] = {}

AnyBoolean = Union['Boolean', 'PublicBoolean', 'SecretBoolean']

class ScalarType(NadaType):
"""The Nada Scalar type.
Expand All @@ -27,26 +54,30 @@ class ScalarType(NadaType):

base_type: BaseType
mode: Mode
value: Any

def __init__(self, inner: OperationType, base_type: BaseType, mode: Mode):
super().__init__(inner=inner)
self.base_type = base_type
self.mode = mode

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean: # type: ignore
return equals_operation(
"Equals", "==", self, other, lambda lhs, rhs: lhs == rhs
)

def __ne__(self, other):
def __ne__(self, other) -> AnyBoolean: # type: ignore
return equals_operation(
"NotEquals", "!=", self, other, lambda lhs, rhs: lhs != rhs
)

def to_public(self) -> Self:
"""Convert this scalar type into a public variable."""
return self

def equals_operation(
operation, operator, left: ScalarType, right: ScalarType, f
) -> ScalarType:
) -> AnyBoolean:
"""This function is an abstraction for the equality operations"""
base_type = left.base_type
if base_type != right.base_type:
Expand All @@ -66,11 +97,10 @@ def equals_operation(
)
return SecretBoolean(inner=operation)


def register_scalar_type(mode: Mode, base_type: BaseType):
"""Decorator used to register a new scalar type in the `SCALAR_TYPES` dictionary."""

def decorator(scalar_type: ScalarType):
def decorator(scalar_type: type[_AnyScalarType]) -> type[_AnyScalarType]:
SCALAR_TYPES[(mode, base_type)] = scalar_type
scalar_type.mode = mode
scalar_type.base_type = base_type
Expand All @@ -79,9 +109,10 @@ def decorator(scalar_type: ScalarType):
return decorator


def new_scalar_type(mode: Mode, base_type: BaseType) -> ScalarType:
def new_scalar_type(mode: Mode, base_type: BaseType) -> type[AnyScalarType]:
"""Returns the corresponding MIR Nada Type"""
return SCALAR_TYPES.get((mode, base_type))
global SCALAR_TYPES
return SCALAR_TYPES[(mode, base_type)]


class NumericType(ScalarType):
Expand Down Expand Up @@ -143,22 +174,22 @@ def __rshift__(self, other):
"RightShift", ">>", self, other, lambda lhs, rhs: lhs >> rhs
)

def __lt__(self, other):
def __lt__(self, other) -> AnyBoolean:
return binary_relational_operation(
"LessThan", "<", self, other, lambda lhs, rhs: lhs < rhs
)

def __gt__(self, other):
def __gt__(self, other) -> AnyBoolean:
return binary_relational_operation(
"GreaterThan", ">", self, other, lambda lhs, rhs: lhs > rhs
)

def __le__(self, other):
def __le__(self, other) -> AnyBoolean:
return binary_relational_operation(
"LessOrEqualThan", "<=", self, other, lambda lhs, rhs: lhs <= rhs
)

def __ge__(self, other):
def __ge__(self, other) -> AnyBoolean:
return binary_relational_operation(
"GreaterOrEqualThan", ">=", self, other, lambda lhs, rhs: lhs >= rhs
)
Expand Down Expand Up @@ -218,23 +249,23 @@ def shift_operation(

def binary_relational_operation(
operation, operator, left: ScalarType, right: ScalarType, f
) -> ScalarType:
) -> AnyBoolean:
"""This function is an abstraction for the binary relational operations"""
base_type = left.base_type
if base_type != right.base_type or not base_type.is_numeric():
raise TypeError(f"Invalid operation: {left} {operator} {right}")
mode = Mode(max([left.mode.value, right.mode.value]))
match mode:
case Mode.CONSTANT:
return new_scalar_type(mode, BaseType.BOOLEAN)(f(left.value, right.value))
return new_scalar_type(mode, BaseType.BOOLEAN)(f(left.value, right.value)) # type: ignore
case Mode.PUBLIC | Mode.SECRET:
inner = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, BaseType.BOOLEAN)(inner)
return new_scalar_type(mode, BaseType.BOOLEAN)(inner) # type: ignore


def public_equals_operation(left: ScalarType, right: ScalarType) -> ScalarType:
def public_equals_operation(left: ScalarType, right: ScalarType) -> "PublicBoolean":
"""This function is an abstraction for the public_equals operation for all types."""
base_type = left.base_type
if base_type != right.base_type:
Expand All @@ -245,7 +276,7 @@ def public_equals_operation(left: ScalarType, right: ScalarType) -> ScalarType:
return PublicBoolean(
inner=PublicOutputEquality(
left=left, right=right, source_ref=SourceRef.back_frame()
)
) # type: ignore
)


Expand All @@ -270,7 +301,7 @@ def __xor__(self, other):
"BooleanXor", "^", self, other, lambda lhs, rhs: lhs ^ rhs
)

def if_else(self, arg_0: ScalarType, arg_1: ScalarType) -> ScalarType:
def if_else(self, arg_0: _AnyScalarType, arg_1: _AnyScalarType) -> _AnyScalarType:
"""This function implements the function 'if_else' for every class that extends 'BooleanType'."""
base_type = arg_0.base_type
if (
Expand Down Expand Up @@ -325,7 +356,7 @@ def __init__(self, value):
)
self.value = value

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)


Expand All @@ -336,6 +367,8 @@ class UnsignedInteger(NumericType):
Represents a constant (literal) unsigned integer."""

value: int

def __init__(self, value):
value = int(value)
super().__init__(
Expand All @@ -345,7 +378,7 @@ def __init__(self, value):
)
self.value = value

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)


Expand All @@ -369,7 +402,7 @@ def __init__(self, value):
def __bool__(self) -> bool:
return self.value

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)

def __invert__(self: "Boolean") -> "Boolean":
Expand All @@ -386,7 +419,7 @@ class PublicInteger(NumericType):
def __init__(self, inner: NadaType):
super().__init__(inner, BaseType.INTEGER, Mode.PUBLIC)

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)

def public_equals(
Expand All @@ -406,7 +439,7 @@ class PublicUnsignedInteger(NumericType):
def __init__(self, inner: NadaType):
super().__init__(inner, BaseType.UNSIGNED_INTEGER, Mode.PUBLIC)

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)

def public_equals(
Expand All @@ -427,7 +460,7 @@ class PublicBoolean(BooleanType):
def __init__(self, inner: NadaType):
super().__init__(inner, BaseType.BOOLEAN, Mode.PUBLIC)

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)

def __invert__(self: "PublicBoolean") -> "PublicBoolean":
Expand All @@ -449,7 +482,7 @@ class SecretInteger(NumericType):
def __init__(self, inner: NadaType):
super().__init__(inner, BaseType.INTEGER, Mode.SECRET)

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)

def public_equals(
Expand Down Expand Up @@ -494,7 +527,7 @@ class SecretUnsignedInteger(NumericType):
def __init__(self, inner: NadaType):
super().__init__(inner, BaseType.UNSIGNED_INTEGER, Mode.SECRET)

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)

def public_equals(
Expand Down Expand Up @@ -541,7 +574,7 @@ class SecretBoolean(BooleanType):
def __init__(self, inner: NadaType):
super().__init__(inner, BaseType.BOOLEAN, Mode.SECRET)

def __eq__(self, other):
def __eq__(self, other) -> AnyBoolean:
return ScalarType.__eq__(self, other)

def __invert__(self: "SecretBoolean") -> "SecretBoolean":
Expand Down
15 changes: 1 addition & 14 deletions nada_dsl/scalar_type_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from nada_dsl.nada_types.scalar_types import Integer, PublicInteger, SecretInteger, Boolean, PublicBoolean, \
SecretBoolean, UnsignedInteger, PublicUnsignedInteger, SecretUnsignedInteger, ScalarType, BooleanType


def combine_lists(list1, list2):
"""This returns all combinations for the items of two lists"""
result = []
Expand Down Expand Up @@ -266,7 +265,7 @@ def test_random(operand):


# Allowed types that can invoke the to_public() function.
to_public_operands = secret_integers + secret_unsigned_integers + secret_booleans
to_public_operands = integers + unsigned_integers + booleans


@pytest.mark.parametrize("operand", to_public_operands)
Expand Down Expand Up @@ -443,18 +442,6 @@ def test_not_allowed_random(operand):
operand.random()
assert invalid_operation.type == AttributeError


# List of types that cannot invoke the function to_public()
to_public_operands = public_booleans + public_integers + public_unsigned_integers


@pytest.mark.parametrize("operand", to_public_operands)
def test_not_to_public(operand):
with pytest.raises(Exception) as invalid_operation:
operand.to_public()
assert invalid_operation.type == AttributeError


# List of operands that the function if_else does not accept
not_allowed_if_else_operands = (
# Boolean branches
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"

[project]
name = "nada_dsl"
version = "0.7.0"
version = "0.7.1"
description = "Nillion Nada DSL to create Nillion MPC programs."
requires-python = ">=3.10"
readme = "README.pyproject.md"
dependencies = ["asttokens~=2.4", "richreports~=0.2", "parsial~=0.1", "sortedcontainers~=2.4"]
dependencies = ["asttokens~=2.4", "richreports~=0.2", "parsial~=0.1", "sortedcontainers~=2.4", "typing_extensions~=4.12.2"]
classifiers = ["License :: OSI Approved :: Apache Software License"]
license = { file = "LICENSE" }

Expand All @@ -34,6 +34,7 @@ dev-dependencies = [
"nada-mir-proto[dev]",
"tomli",
"requests",
"typing_extensions~=4.12.2",
]

[tool.uv.sources]
Expand Down

0 comments on commit a815088

Please sign in to comment.