Skip to content

Commit

Permalink
core: Add operands directive (#3507)
Browse files Browse the repository at this point in the history
Adds the operands directive to the assembly format, which works as a format directive and in type directives.
  • Loading branch information
alexarice authored Nov 27, 2024
1 parent a44a424 commit 9bc6dc1
Show file tree
Hide file tree
Showing 3 changed files with 353 additions and 1 deletion.
229 changes: 229 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,235 @@ class TwoOperandsOp(IRDLOperation):
check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"program",
[
"test.operands_directive %0 : i32",
"test.operands_directive %0, %1 : i32, i32",
"test.operands_directive %0, %1, %2 : i32, i32, i32",
],
)
def test_operands_directive(program: str):
"""Test the operands directive"""

@irdl_op_definition
class OperandsDirectiveOp(IRDLOperation):
name = "test.operands_directive"

op1 = operand_def()
op2 = var_operand_def()

assembly_format = "operands `:` type(operands) attr-dict"

ctx = MLContext()
ctx.load_op(OperandsDirectiveOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"program",
[
"test.operands_directive %0 : i32",
"test.operands_directive %0, %1 : i32, i32",
],
)
def test_operands_directive_with_optional(program: str):
"""Test the operands directive"""

@irdl_op_definition
class OperandsDirectiveOp(IRDLOperation):
name = "test.operands_directive"

op1 = opt_operand_def()
op2 = operand_def()

assembly_format = "operands `:` type(operands) attr-dict"

ctx = MLContext()
ctx.load_op(OperandsDirectiveOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)


def test_operands_directive_fails_with_two_var():
"""Test operands directive cannot be used with two variadic operands"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'operands' is ambiguous with multiple variadic operands",
):

@irdl_op_definition
class TwoVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_var_op"

op1 = var_operand_def()
op2 = var_operand_def()

irdl_options = [AttrSizedOperandSegments()]

assembly_format = "operands attr-dict `:` type(operands)"


def test_operands_directive_fails_with_no_operands():
"""Test operands directive cannot be used with no operands"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'operands' should not be used when there are no operands",
):

@irdl_op_definition
class NoOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_operands_op"

assembly_format = "operands attr-dict `:` type(operands)"


def test_operands_directive_fails_with_other_directive():
"""Test operands directive cannot be used with no operands"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'operands' cannot be used with other operand directives",
):

@irdl_op_definition
class TwoOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_operands_op"

op1 = operand_def()
op2 = operand_def()

assembly_format = "$op1 `,` operands attr-dict `:` type(operands)"


def test_operands_directive_fails_with_other_type_directive():
"""Test operands directive cannot be used with no operands"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'operands' cannot be used in a type directive with other operand type directives",
):

@irdl_op_definition
class TwoOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_operands_op"

op1 = operand_def()
op2 = operand_def()

assembly_format = "operands attr-dict `:` type($op1) `,` type(operands)"


@pytest.mark.parametrize(
"program, error",
[
("test.two_operands %0 : i32, i32", "Expected 2 operands but found 1"),
(
"test.two_operands %0, %1, %2 : i32, i32",
"Expected 2 operands but found 3",
),
("test.two_operands %0, %1 : i32", "Expected 2 operand types but found 1"),
(
"test.two_operands %0, %1 : i32, i32, i32",
"Expected 2 operand types but found 3",
),
],
)
def test_operands_directive_bounds(program: str, error: str):
@irdl_op_definition
class TwoOperandsOp(IRDLOperation):
name = "test.two_operands"

op1 = operand_def()
op2 = operand_def()

assembly_format = "operands attr-dict `:` type(operands)"

ctx = MLContext()
ctx.load_op(TwoOperandsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


@pytest.mark.parametrize(
"program, error",
[
(
"test.three_operands %0 : i32, i32",
"Expected at least 2 operands but found 1",
),
(
"test.three_operands %0, %1, %2, %3 : i32, i32, i32",
"Expected at most 3 operands but found 4",
),
(
"test.three_operands %0, %1 : i32",
"Expected at least 2 operand types but found 1",
),
(
"test.three_operands %0, %1, %3 : i32, i32, i32, i32",
"Expected at most 3 operand types but found 4",
),
],
)
def test_operands_directive_bounds_with_opt(program: str, error: str):
@irdl_op_definition
class ThreeOperandsOp(IRDLOperation):
name = "test.three_operands"

op1 = operand_def()
op2 = opt_operand_def()
op3 = operand_def()

assembly_format = "operands attr-dict `:` type(operands)"

ctx = MLContext()
ctx.load_op(ThreeOperandsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


@pytest.mark.parametrize(
"program, error",
[
(
"test.three_operands %0 : i32, i32",
"Expected at least 2 operands but found 1",
),
(
"test.three_operands %0, %1 : i32",
"Expected at least 2 operand types but found 1",
),
],
)
def test_operands_directive_bound_with_var(program: str, error: str):
@irdl_op_definition
class ThreeOperandsOp(IRDLOperation):
name = "test.three_operands"

op1 = operand_def()
op2 = var_operand_def()
op3 = operand_def()

assembly_format = "operands attr-dict `:` type(operands)"

ctx = MLContext()
ctx.load_op(ThreeOperandsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


################################################################################
# Results #
################################################################################
Expand Down
96 changes: 95 additions & 1 deletion xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Literal
from typing import Literal, TypeVar

from xdsl.dialects.builtin import UnitAttr
from xdsl.ir import (
Expand Down Expand Up @@ -600,6 +600,100 @@ def set_types_empty(self, state: ParsingState) -> None:
state.operand_types[self.index] = ()


_T = TypeVar("_T")


@dataclass(frozen=True)
class OperandsDirective(VariadicOperandDirective, VariadicTypeableDirective):
"""
An operands directive, with the following format:
operands-directive ::= operands
Prints each operand of the operation, inserting a comma between each.
"""

variadic_index: tuple[bool, int] | None
"""
Represents the position of a (single) variadic variable, with the boolean
representing whether it is optional
"""

def _set_using_variadic_index(
self,
field: list[_T | None | Sequence[_T]],
field_name: str,
set_to: Sequence[_T],
) -> str | None:
if self.variadic_index is None:
if len(set_to) != len(field):
return f"Expected {len(field)} {field_name} but found {len(set_to)}"
field = [o for o in set_to] # Copy needed as list is not covariant
return

is_optional, var_position = self.variadic_index
var_length = len(set_to) - len(field) + 1
if var_length < 0:
return f"Expected at least {len(field) - 1} {field_name} but found {len(set_to)}"
if var_length > 1 and is_optional:
return f"Expected at most {len(field)} {field_name} but found {len(set_to)}"
field[:var_position] = set_to[:var_position]
field[var_position] = set_to[var_position : var_position + var_length]
field[var_position + 1 :] = set_to[var_position + var_length :]

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
operands = (
parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_unresolved_operand,
parser.parse_unresolved_operand,
)
or []
)

if s := self._set_using_variadic_index(state.operands, "operands", operands):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(operands)

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.operand_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.operand_types[0] = parser.parse_type()

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
types = (
parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
or []
)

if s := self._set_using_variadic_index(
state.operand_types, "operand types", types
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(types)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if op.operands:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_list(op.operands, printer.print_ssa_value)
state.last_was_punctuation = False
state.should_emit_space = True

def set_types_empty(self, state: ParsingState) -> None:
state.operand_types = [() for _ in state.operand_types]

def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
return op.operand_types

def set_empty(self, state: ParsingState):
state.operands = [() for _ in state.operands]

def is_present(self, op: IRDLOperation) -> bool:
return bool(op.operands)


@dataclass(frozen=True)
class ResultVariable(VariableDirective, TypeableDirective):
"""
Expand Down
29 changes: 29 additions & 0 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
FormatProgram,
KeywordDirective,
OperandOrResult,
OperandsDirective,
OperandVariable,
OptionalAttributeVariable,
OptionalGroupDirective,
Expand Down Expand Up @@ -699,6 +700,8 @@ def parse_typeable_directive(self) -> TypeableDirective:
Parse a typeable directive, with the following format:
directive ::= variable
"""
if self.parse_optional_keyword("operands"):
return self.create_operands_directive(False)
if variable := self.parse_optional_typeable_variable():
return variable
self.raise_error(f"unexpected token '{self._current_token.text}'")
Expand All @@ -718,6 +721,8 @@ def parse_format_directive(self) -> FormatDirective:
return self.create_attr_dict_directive(True)
if self.parse_optional_keyword("type"):
return self.parse_type_directive()
if self.parse_optional_keyword("operands"):
return self.create_operands_directive(True)
if self._current_token.text == "`":
return self.parse_keyword_or_punctuation()
if self.parse_optional_punctuation("("):
Expand All @@ -744,3 +749,27 @@ def create_attr_dict_directive(self, with_keyword: bool) -> AttrDictDirective:
reserved_attr_names=set(),
print_properties=print_properties,
)

def create_operands_directive(self, top_level: bool) -> OperandsDirective:
if not self.op_def.operands:
self.raise_error("'operands' should not be used when there are no operands")
if top_level and any(self.seen_operands):
self.raise_error("'operands' cannot be used with other operand directives")
if not top_level and any(self.seen_operand_types):
self.raise_error(
"'operands' cannot be used in a type directive with other operand type directives"
)
variadics = tuple(
(isinstance(o, OptionalDef), i)
for i, (_, o) in enumerate(self.op_def.operands)
if isinstance(o, VariadicDef)
)
if len(variadics) > 1:
self.raise_error("'operands' is ambiguous with multiple variadic operands")
if top_level:
self.seen_operands = [True] * len(self.seen_operands)
else:
self.seen_operand_types = [True] * len(self.seen_operand_types)
if not variadics:
return OperandsDirective(None)
return OperandsDirective(variadics[0])

0 comments on commit 9bc6dc1

Please sign in to comment.