Skip to content

Commit

Permalink
interactive: app - ability to choose and apply passes onto IR (4/7) (#…
Browse files Browse the repository at this point in the history
…1772)

Have added the ability to apply passes to the IR via the selection list.
the output IR text area contains the changed IR. Tests have been added
to test that pass application works.

![image](https://github.com/xdslproject/xdsl/assets/144673861/fa70482e-4b73-4e26-a09a-9837213e14d2)

---------

Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
dshaaban01 and superlopuh authored Nov 13, 2023
1 parent 523f2a6 commit 50d3837
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 17 deletions.
108 changes: 99 additions & 9 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@

import pytest

from xdsl.backend.riscv.lowering import convert_func_to_riscv_func
from xdsl.builder import ImplicitBuilder
from xdsl.dialects import arith, func
from xdsl.dialects.builtin import IndexType, IntegerAttr, ModuleOp
from xdsl.dialects import arith, func, riscv, riscv_func
from xdsl.dialects.builtin import (
IndexType,
IntegerAttr,
ModuleOp,
UnrealizedConversionCastOp,
)
from xdsl.interactive.app import InputApp
from xdsl.ir import Block, Region
from xdsl.utils.exceptions import ParseError


@pytest.mark.asyncio()
async def test_input_and_buttons():
"""Test pressing keys has the desired result."""
async def test_inputs():
"""Test different inputs produce desired result."""
async with InputApp().run_test() as pilot:
app = cast(InputApp, pilot.app)

Expand All @@ -27,6 +33,7 @@ async def test_input_and_buttons():
app.output_text_area.text
== "(Span[5:6](text=''), 'Operation builtin.unregistered does not have a custom format.')"
)

assert isinstance(app.current_module, ParseError)
assert (
str(app.current_module)
Expand Down Expand Up @@ -70,6 +77,12 @@ async def test_input_and_buttons():
assert isinstance(app.current_module, ModuleOp)
assert app.current_module.is_structurally_equivalent(expected_module)


@pytest.mark.asyncio()
async def test_buttons():
"""Test pressing keys has the desired result."""
async with InputApp().run_test() as pilot:
app = cast(InputApp, pilot.app)
# Test clicking the "clear input" button
app.input_text_area.clear()
app.input_text_area.insert(
Expand All @@ -81,17 +94,94 @@ async def test_input_and_buttons():
}
"""
)
# press clear input button
await pilot.click("#clear_input_button")

# assert that the input text area has been cleared
await pilot.pause()
assert (
app.input_text_area.text
== """
assert app.input_text_area.text == ""


@pytest.mark.asyncio()
async def test_passes():
"""Test pass application has the desired result."""
async with InputApp().run_test() as pilot:
app = cast(InputApp, pilot.app)
# Testing a pass
app.input_text_area.insert(
"""
func.func @hello(%n : index) -> index {
%two = arith.constant 2 : index
%res = arith.muli %n, %two : index
func.return %res : index
}
"""
)
await pilot.click("#clear_input_button")

# Await on test update to make sure we only update due to pass change later
await pilot.pause()
assert app.input_text_area.text == ""
assert (
app.output_text_area.text
== """builtin.module {
func.func @hello(%n : index) -> index {
%two = arith.constant 2 : index
%res = arith.muli %n, %two : index
func.return %res : index
}
}
"""
)

# Select a pass
app.passes_selection_list.select(
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass
)

# assert that the Output Text Area has changed accordingly
await pilot.pause()
assert (
app.output_text_area.text
== """builtin.module {
riscv.assembly_section ".text" {
riscv.directive ".globl" "hello"
riscv.directive ".p2align" "2"
riscv_func.func @hello(%n : !riscv.reg<a0>) -> !riscv.reg<a0> {
%0 = riscv.mv %n : (!riscv.reg<a0>) -> !riscv.reg<>
%n_1 = builtin.unrealized_conversion_cast %0 : !riscv.reg<> to index
%two = arith.constant 2 : index
%res = arith.muli %n_1, %two : index
%1 = builtin.unrealized_conversion_cast %res : index to !riscv.reg<>
%2 = riscv.mv %1 : (!riscv.reg<>) -> !riscv.reg<a0>
riscv_func.return %2 : !riscv.reg<a0>
}
}
}
"""
)

index = IndexType()
expected_module = ModuleOp(Region([Block()]))
with ImplicitBuilder(expected_module.body):
section = riscv.AssemblySectionOp(".text")
with ImplicitBuilder(section.data):
riscv.DirectiveOp(".globl", "hello")
riscv.DirectiveOp(".p2align", "2")
function = riscv_func.FuncOp(
"hello",
Region([Block(arg_types=[riscv.Registers.A0])]),
((riscv.Registers.A0,), (riscv.Registers.A0,)),
)
with ImplicitBuilder(function.body) as (n,):
zero = riscv.MVOp(n, rd=riscv.IntRegisterType(""))
n_one = UnrealizedConversionCastOp.get([zero.rd], [index])
two = arith.Constant(IntegerAttr(2, index)).result
res = arith.Muli(n_one, two)
one = UnrealizedConversionCastOp.get(
[res.result], [riscv.IntRegisterType("")]
)
two_two = riscv.MVOp(one, rd=riscv.Registers.A0)
riscv_func.ReturnOp(two_two)

assert isinstance(app.current_module, ModuleOp)
# Assert that the current module has been changed accordingly
assert app.current_module.is_structurally_equivalent(expected_module)
46 changes: 39 additions & 7 deletions xdsl/interactive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Execute `xdsl-gui` in your terminal to run it.
Run `terminal -m xdsl.interactive.app:InputApp --def` to run in development mode. Please
Run `terminal -m xdsl.interactive.app:InputApp --dev` to run in development mode. Please
be sure to install `textual-dev` to run this command.
"""

Expand All @@ -14,17 +14,21 @@
from textual.app import App, ComposeResult
from textual.containers import Horizontal, Vertical
from textual.reactive import reactive
from textual.widgets import Button, Footer, TextArea
from textual.widgets import Button, Footer, SelectionList, TextArea
from textual.widgets.text_area import TextAreaTheme

from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import MLContext
from xdsl.parser import Parser
from xdsl.passes import ModulePass, PipelinePass
from xdsl.printer import Printer
from xdsl.tools.command_line_tool import get_all_dialects
from xdsl.tools.command_line_tool import get_all_dialects, get_all_passes

from ._pasteboard import pyclip_copy

ALL_PASSES = tuple(get_all_passes())
"""Contains the list of xDSL passes."""


class OutputTextArea(TextArea):
"""Used to prevent users from being able to change/alter the Output TextArea"""
Expand Down Expand Up @@ -61,8 +65,16 @@ class InputApp(App[None]):
(i.e. is the Output TextArea)
"""

input_text_area = TextArea(id="input")
output_text_area = OutputTextArea(id="output")
input_text_area: TextArea
output_text_area: OutputTextArea

passes_selection_list: SelectionList[type[ModulePass]]

def __init__(self):
self.input_text_area = TextArea(id="input")
self.output_text_area = OutputTextArea(id="output")
self.passes_selection_list = SelectionList(id="passes_selection_list")
super().__init__()

def compose(self) -> ComposeResult:
"""
Expand All @@ -72,6 +84,8 @@ def compose(self) -> ComposeResult:
and sort the list in alphabetical order.
"""

yield self.passes_selection_list

with Horizontal(id="input_output"):
with Vertical(id="input_container"):
yield self.input_text_area
Expand All @@ -83,19 +97,25 @@ def compose(self) -> ComposeResult:
yield Button("Copy Output", id="copy_output_button")
yield Footer()

@on(SelectionList.SelectedChanged)
@on(TextArea.Changed, "#input")
def update_current_module(self) -> None:
"""
Function called when the Input TextArea is cahnged. This function parses the Input
IR and updates the current_module reactive variable.
Function called when the Input TextArea is changed or a pass is selected/
unselected. This function parses the Input IR, applies selected passes and
updates the Output TextArea.
"""
input_text = self.input_text_area.text
selected_passes = self.passes_selection_list.selected

try:
ctx = MLContext(True)
for dialect in get_all_dialects():
ctx.load_dialect(dialect)
parser = Parser(ctx, input_text)
module = parser.parse_module()
pipeline = PipelinePass([p() for p in selected_passes])
pipeline.apply(ctx, module)
self.current_module = module
except Exception as e:
self.current_module = e
Expand Down Expand Up @@ -130,6 +150,18 @@ def on_mount(self) -> None:

self.query_one("#input_container").border_title = "Input xDSL IR"
self.query_one("#output_container").border_title = "Output xDSL IR"
self.query_one(
"#passes_selection_list"
).border_title = "Choose a pass or multiple passes to be applied."

# aids in the construction of the seleciton list containing all the passes
selections = sorted((value.name, value) for value in ALL_PASSES)

# type error due to Textual Bug requires pyright ignore
# Link to issue: https://github.com/xdslproject/xdsl/issues/1777
self.passes_selection_list.add_options( # pyright: ignore[reportUnknownMemberType]
selections
)

def action_toggle_dark(self) -> None:
"""An action to toggle dark mode."""
Expand Down
18 changes: 17 additions & 1 deletion xdsl/interactive/app.tcss
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# SelectionList
#passes_selection_list {
margin: 1 1 1 1;
border: heavy $warning-lighten-1;
border-title-color: $primary;
border-title-align: center;
width: 100%;
height: 100%;
align: center middle;
}

# Horizontal(Button)
#clear_input{
Expand All @@ -22,10 +32,16 @@
text-opacity: 60%;
}

# Vertical(OutputTextArea)
# Vertical(TextArea, Horizontal(Button))
#output_container {
margin: 1;
border: heavy $warning-lighten-1;
border-title-color: $primary;
border-title-align: center;
}

Screen {
layout: grid;
grid-size: 1;
grid-rows: 45% 55%;
}

0 comments on commit 50d3837

Please sign in to comment.