From 05ec5a17b0bcfc6f4af083e264f468ed407ddb5b Mon Sep 17 00:00:00 2001 From: cyberthirst Date: Tue, 19 Mar 2024 22:42:01 +0100 Subject: [PATCH] fix[lang]: fix importing of flag types (#3871) fix imports for flag types. this is a missed case from 8ccacb3f47f. this commit adds flags into the module data and threads them through codegen. --------- Co-authored-by: Charles Cooper --- .../codegen/modules/test_flag_imports.py | 41 +++++++++++++++++++ vyper/ast/grammar.lark | 2 +- vyper/codegen/expr.py | 9 ++-- vyper/semantics/analysis/module.py | 1 + vyper/semantics/types/module.py | 8 ++++ 5 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 tests/functional/codegen/modules/test_flag_imports.py diff --git a/tests/functional/codegen/modules/test_flag_imports.py b/tests/functional/codegen/modules/test_flag_imports.py new file mode 100644 index 0000000000..fd954dab02 --- /dev/null +++ b/tests/functional/codegen/modules/test_flag_imports.py @@ -0,0 +1,41 @@ +def test_import_flag_types(make_input_bundle, get_contract): + lib1 = """ +import lib2 + +flag Roles: + ADMIN + USER + +enum Roles2: + ADMIN + USER + +role: Roles +role2: Roles2 +role3: lib2.Roles3 + """ + lib2 = """ +flag Roles3: + ADMIN + USER + NOBODY + """ + contract = """ +import lib1 + +initializes: lib1 + +@external +def bar(r: lib1.Roles, r2: lib1.Roles2, r3: lib1.lib2.Roles3) -> bool: + lib1.role = r + lib1.role2 = r2 + lib1.role3 = r3 + assert lib1.role == lib1.Roles.ADMIN + assert lib1.role2 == lib1.Roles2.USER + assert lib1.role3 == lib1.lib2.Roles3.NOBODY + return True + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + c = get_contract(contract, input_bundle=input_bundle) + assert c.bar(1, 2, 4) is True diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index cd7b64e6c3..1c318a76a5 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -98,7 +98,7 @@ tuple_def: "(" ( NAME | array_def | dyn_array_def | tuple_def ) ( "," ( NAME | a // NOTE: Map takes a basic type and maps to another type (can be non-basic, including maps) _MAP: "HashMap" map_def: _MAP "[" ( NAME | array_def ) "," type "]" -imported_type: NAME "." NAME +imported_type: NAME ("." NAME)+ type: ( NAME | imported_type | array_def | tuple_def | map_def | dyn_array_def ) // Structs can be composed of 1+ basic types or other custom_types diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index d7afe6c7f6..8ce4288c89 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -206,12 +206,9 @@ def parse_Name(self): def parse_Attribute(self): typ = self.expr._metadata["type"] - # MyFlag.foo - if ( - isinstance(typ, FlagT) - and isinstance(self.expr.value, vy_ast.Name) - and typ.name == self.expr.value.id - ): + # check if we have a flag constant, e.g. + # [lib1].MyFlag.FOO + if isinstance(typ, FlagT) and is_type_t(self.expr.value._metadata["type"], FlagT): # 0, 1, 2, .. 255 flag_id = typ._flag_members[self.expr.attr] value = 2**flag_id # 0 => 0001, 1 => 0010, 2 => 0100, etc. diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 90493d643b..f4b7db129f 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -657,6 +657,7 @@ def _validate_self_namespace(): def visit_FlagDef(self, node): obj = FlagT.from_FlagDef(node) + node._metadata["flag_type"] = obj self.namespace[node.name] = obj def visit_EventDef(self, node): diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index a242bfa1fe..0d2b343e0d 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -300,6 +300,10 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): # add the type of the event so it can be used in call position self.add_member(e.name, TYPE_T(e._metadata["event_type"])) # type: ignore + for f in self.flag_defs: + self.add_member(f.name, TYPE_T(f._metadata["flag_type"])) + self._helper.add_member(f.name, TYPE_T(f._metadata["flag_type"])) + for s in self.struct_defs: # add the type of the struct so it can be used in call position self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore @@ -347,6 +351,10 @@ def function_defs(self): def event_defs(self): return self._module.get_children(vy_ast.EventDef) + @cached_property + def flag_defs(self): + return self._module.get_children(vy_ast.FlagDef) + @property def struct_defs(self): return self._module.get_children(vy_ast.StructDef)