Skip to content

Commit

Permalink
interactive: adding a function that takes a module pass and returns a…
Browse files Browse the repository at this point in the history
… pipelinepassspec + tests (#1841)
  • Loading branch information
dshaaban01 authored Dec 10, 2023
1 parent 9306f8f commit 997d671
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
72 changes: 72 additions & 0 deletions tests/test_pass_to_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dataclasses import dataclass, field

import pytest

from xdsl.dialects import builtin
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.utils.parse_pipeline import PipelinePassSpec


@dataclass
class CustomPass(ModulePass):
name = "custom-pass"

number: int

int_list: list[int]

str_thing: str | None

list_str: list[str] = field(default_factory=list)

optional_bool: bool = False

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
pass


@dataclass
class EmptyPass(ModulePass):
name = "empty"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
pass


@dataclass
class SimplePass(ModulePass):
name = "simple"

a: list[float]
b: int

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
pass


@pytest.mark.parametrize(
"test_pass, test_spec",
(
(
CustomPass(3, [1, 2], None, ["clown", "season"]),
PipelinePassSpec(
"custom-pass",
{
"number": [3],
"int_list": [1, 2],
"str_thing": [],
"list_str": ["clown", "season"],
"optional_bool": [False],
},
),
),
(EmptyPass(), PipelinePassSpec("empty", {})),
(
SimplePass([3.14, 2.13], 2),
PipelinePassSpec("simple", {"a": [3.14, 2.13], "b": [2]}),
),
),
)
def test_pass_to_spec_equality(test_pass: ModulePass, test_spec: PipelinePassSpec):
assert test_pass.pipeline_pass_spec() == test_spec
27 changes: 27 additions & 0 deletions xdsl/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,33 @@ def from_pass_spec(cls: type[ModulePassT], spec: PipelinePassSpec) -> ModulePass
# instantiate the dataclass using kwargs
return cls(**arg_dict)

def pipeline_pass_spec(self) -> PipelinePassSpec:
"""
This function takes a ModulePass and returns a PipelinePassSpec.
"""
# get all dataclass fields
fields: tuple[Field[Any], ...] = dataclasses.fields(self)
arg_dict: dict[str, PassArgListType] = {}

# iterate over all fields of the dataclass
for op_field in fields:
# ignore the name field and everything that's not used by __init__
if op_field.name == "name" or not op_field.init:
continue

if _is_optional(op_field):
arg_dict[op_field.name] = _get_default(op_field)

val = getattr(self, op_field.name)
if val is None:
arg_dict.update({op_field.name: []})
elif isinstance(val, PassArgElementType):
arg_dict.update({op_field.name: [getattr(self, op_field.name)]})
else:
arg_dict.update({op_field.name: getattr(self, op_field.name)})

return PipelinePassSpec(self.name, arg_dict)


# Git Issue: https://github.com/xdslproject/xdsl/issues/1845
def get_pass_argument_names_and_types(arg: type[ModulePassT]) -> str:
Expand Down

0 comments on commit 997d671

Please sign in to comment.