Skip to content

Commit

Permalink
poc: add Decimal[bits, places] syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Apr 9, 2024
1 parent 6f34b45 commit 52d1be4
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
from vyper.semantics.types.bytestrings import BytesT, StringT
from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT
from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT, DecimalT
from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT
from vyper.utils import checksum_encode, int_to_fourbytes

Expand Down Expand Up @@ -303,7 +303,7 @@ def types_from_Constant(self, node):
# special handling for bytestrings since their
# class objects are in the type map, not the type itself
# (worth rethinking this design at some point.)
if t in (BytesT, StringT):
if t in (BytesT, StringT, DecimalT):
t = t.from_literal(node)

# any more validation which needs to occur
Expand Down
4 changes: 3 additions & 1 deletion vyper/semantics/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def _get_primitive_types():
res = [BoolT(), DecimalT()]
res = [BoolT()]

res.extend(IntegerT.all())
res.extend(BytesM_T.all())
Expand All @@ -21,6 +21,8 @@ def _get_primitive_types():
# note: since bytestrings are parametrizable, the *class* objects
# are in the namespace instead of concrete type objects.
res.extend([BytesT, StringT])
# ditto for Decimals
res.append(DecimalT)

ret = {t._id: t for t in res}
ret.update(_get_sequence_types())
Expand Down
42 changes: 39 additions & 3 deletions vyper/semantics/types/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CompilerPanic,
InvalidLiteral,
InvalidOperation,
InvalidType, UnimplementedException,
OverflowException,
VyperException,
)
Expand Down Expand Up @@ -313,9 +314,7 @@ def SINT(bits):
class DecimalT(NumericT):
typeclass = "decimal"

_bits = 168 # TODO generalize
_decimal_places = 10 # TODO generalize
_id = "decimal"
_id = "Decimal"
_is_signed = True
_invalid_ops = (
vy_ast.Pow,
Expand All @@ -331,6 +330,43 @@ class DecimalT(NumericT):

ast_type = Decimal

def __init__(self, bits=168, decimal_places=10):
self._bits = bits
self._decimal_places = decimal_places

if bits != 168 or decimal_places != 10:
raise UnimplementedException("Not implemented: {repr(self)}", hint="only Decimal[168, 10] is currently available")

def __repr__(self):
return f"Decimal[{self._bits}, {self._decimal_places}]"

@classmethod
def from_annotation(cls, node):
def _fail():
raise InvalidType("not a valid Decimal", hint="expected: Decimal[<bits>, <places}")

if not isinstance(node, vy_ast.Subscript):
_fail()
if not isinstance(node.slice, vy_ast.Tuple):
_fail()
if len(node.slice.elements) != 2:
_fail()
bits = node.slice.elements[0].get_folded_value().value
places = node.slice.elements[1].get_folded_value().value
if not isinstance(bits, int) or not isinstance(places, int):
_fail()

return cls(bits, places)

def validate_literal(self, node) -> None:
if not isinstance(node, vy_ast.Decimal):
# TODO: check bits, places
raise TypeMismatch("Not a decimal")

@classmethod
def from_literal(cls, node):
return DecimalT(168, 10)

def validate_numeric_op(self, node) -> None:
try:
super().validate_numeric_op(node)
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType:

if isinstance(node, vy_ast.Subscript):
# ex. HashMap, DynArray, Bytes, static arrays
if node.value.get("id") in ("HashMap", "Bytes", "String", "DynArray"):
if node.value.get("id") in ("HashMap", "Bytes", "String", "DynArray", "Decimal"):
assert isinstance(node.value, vy_ast.Name) # mypy hint
type_ctor = namespace[node.value.id]
else:
Expand Down

0 comments on commit 52d1be4

Please sign in to comment.