Skip to content

Commit

Permalink
WIP: Add infra to construct a while loop using the loop_info analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
VedantParanjape committed Aug 26, 2023
1 parent 17634dd commit 94be2dd
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 5 deletions.
1 change: 1 addition & 0 deletions include/blocks/basic_blocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class basic_block {
block::expr::Ptr branch_expr;
block::stmt::Ptr parent;
unsigned int ast_index;
unsigned int ast_depth;
unsigned int id;
std::string name;
};
Expand Down
4 changes: 4 additions & 0 deletions include/blocks/loops.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class loop {
stmt::Ptr entry_stmt;
} loop_bounds;

unsigned int loop_id;
basic_block::cfg_block blocks;
std::unordered_set<int> blocks_id_map;
std::shared_ptr<loop> parent_loop;
Expand All @@ -37,13 +38,16 @@ class loop_info {
analyze();
}
std::shared_ptr<loop> allocate_loop(std::shared_ptr<basic_block> header);
block::stmt_block::Ptr convert_to_ast(block::stmt_block::Ptr ast);
std::map<unsigned int, std::vector<int>> postorder_loops_map;
std::vector<std::shared_ptr<loop>> loops;
std::vector<std::shared_ptr<loop>> top_level_loops;

private:
basic_block::cfg_block parent_ast;
dominator_analysis dta;
std::map<int, std::shared_ptr<loop>> bb_loop_map;
void postorder_dfs_helper(std::vector<int> &postorder_loops_map, std::vector<bool> &visited_loops, int id);
// discover loops during traversal of the abstract syntax tree
void analyze();
};
Expand Down
8 changes: 8 additions & 0 deletions src/blocks/basic_blocks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) {
auto bb = std::make_shared<basic_block>(std::to_string(basic_block_count));
bb->parent = st;
bb->ast_index = ast_index_counter++;
bb->ast_depth = 0;
work_list.push_back(bb);
basic_block_count++;
}
Expand All @@ -40,6 +41,7 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) {
stmt_block_list.push_back(std::make_shared<basic_block>(std::to_string(basic_block_count++)));
stmt_block_list.back()->parent = st;
stmt_block_list.back()->ast_index = ast_index_counter++;
stmt_block_list.back()->ast_depth = bb->ast_depth + 1;
}

// set the basic block successors
Expand Down Expand Up @@ -77,6 +79,8 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) {
auto exit_bb = std::make_shared<basic_block>("exit" + std::to_string(basic_block_count));
// assign it a empty stmt_block as parent
exit_bb->parent = std::make_shared<stmt_block>();
// set the ast depth of the basic block
exit_bb->ast_depth = bb->ast_depth;
// check if this is the last block, if yes the successor will be empty
if (bb->successor.size()) {
// set the successor to the block that if_stmt successor pointer to earlier
Expand All @@ -94,6 +98,8 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) {
auto then_bb = std::make_shared<basic_block>(std::to_string(++basic_block_count));
// set the parent of this block as the then stmts
then_bb->parent = if_stmt_->then_stmt;
// set the ast depth of the basic block
then_bb->ast_depth = bb->ast_depth;
// set the successor of this block to be the exit block
then_bb->successor.push_back(exit_bb);
// set the successor of the original if_stmt block to be this then block
Expand All @@ -106,6 +112,8 @@ basic_block::cfg_block generate_basic_blocks(block::stmt_block::Ptr ast) {
auto else_bb = std::make_shared<basic_block>(std::to_string(++basic_block_count));
// set the parent of this block as the else stmts
else_bb->parent = if_stmt_->else_stmt;
// set the ast depth of the basic block
else_bb->ast_depth = bb->ast_depth;
// set the successor of this block to be the exit block
else_bb->successor.push_back(exit_bb);
// set the successor of the orignal if_stmt block to be this else block
Expand Down
2 changes: 1 addition & 1 deletion src/blocks/dominance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dominator_analysis::dominator_analysis(basic_block::cfg_block &cfg) : cfg_(cfg)

void dominator_analysis::postorder_idom_helper(std::vector<bool> &visited, int id) {
for (int idom_id: idom_map[id]) {
std::cerr << idom_id << "\n";
// std::cerr << idom_id << "\n";
if (idom_id != -1 && !visited[idom_id]) {
visited[idom_id] = true;
postorder_idom_helper(visited, idom_id);
Expand Down
211 changes: 210 additions & 1 deletion src/blocks/loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ std::shared_ptr<loop> loop_info::allocate_loop(std::shared_ptr<basic_block> head
return loops.back();
}

void loop_info::postorder_dfs_helper(std::vector<int> &postorder_loops_map, std::vector<bool> &visited_loops, int id) {
for (auto subloop: loops[id]->subloops) {
if (!visited_loops[subloop->loop_id]) {
visited_loops[subloop->loop_id] = true;
postorder_dfs_helper(postorder_loops_map, visited_loops, subloop->loop_id);
postorder_loops_map.push_back(subloop->loop_id);
}
}
}

void loop_info::analyze() {
std::vector<int> idom = dta.get_idom();

Expand Down Expand Up @@ -126,4 +136,203 @@ void loop_info::analyze() {
}
}
}
}

// Assign id to the loops
for (unsigned int i = 0; i < loops.size(); i++) {
loops[i]->loop_id = i;
}

// build a loop tree
std::vector<bool> visited_loops(loops.size());
visited_loops.assign(visited_loops.size(), false);
for (auto loop: top_level_loops) {
std::vector<int> postorder_loop_tree;
visited_loops[loop->loop_id] = true;

postorder_dfs_helper(postorder_loop_tree, visited_loops, loop->loop_id);
postorder_loop_tree.push_back(loop->loop_id);
postorder_loops_map[loop->loop_id] = postorder_loop_tree;
}
}

static stmt::Ptr get_loop_block(std::shared_ptr<basic_block> loop_header, block::stmt_block::Ptr ast) {
block::stmt::Ptr current_ast = to<block::stmt>(ast);
std::vector<stmt::Ptr> current_block = to<block::stmt_block>(current_ast)->stmts;
// unsigned int ast_index = loop_header->ast_index;
std::deque<stmt::Ptr> worklist;
std::map<stmt::Ptr, stmt::Ptr> ast_parent_map;

for (auto stmt: current_block) {
ast_parent_map[stmt] = current_ast;
}
worklist.insert(worklist.end(), current_block.begin(), current_block.end());

while (worklist.size()) {
stmt::Ptr worklist_top = worklist.front();
worklist.pop_front();

if (isa<block::stmt_block>(worklist_top)) {
stmt_block::Ptr wl_stmt_block = to<stmt_block>(worklist_top);
for (auto stmt: wl_stmt_block->stmts) {
ast_parent_map[stmt] = worklist_top;
}
worklist.insert(worklist.end(), wl_stmt_block->stmts.begin(), wl_stmt_block->stmts.end());
}
else if (isa<block::if_stmt>(worklist_top)) {
if_stmt::Ptr wl_if_stmt = to<if_stmt>(worklist_top);

if (to<stmt_block>(wl_if_stmt->then_stmt)->stmts.size() != 0) {
stmt_block::Ptr wl_if_then_stmt = to<stmt_block>(wl_if_stmt->then_stmt);
for (auto stmt: wl_if_then_stmt->stmts) {
ast_parent_map[stmt] = worklist_top;
}
worklist.insert(worklist.end(), wl_if_then_stmt->stmts.begin(), wl_if_then_stmt->stmts.end());
}
if (to<stmt_block>(wl_if_stmt->else_stmt)->stmts.size() != 0) {
stmt_block::Ptr wl_if_else_stmt = to<stmt_block>(wl_if_stmt->else_stmt);
for (auto stmt: wl_if_else_stmt->stmts) {
ast_parent_map[stmt] = worklist_top;
}
worklist.insert(worklist.end(), wl_if_else_stmt->stmts.begin(), wl_if_else_stmt->stmts.end());
}
}
else if (isa<block::label_stmt>(worklist_top)) {
label_stmt::Ptr wl_label_stmt = to<label_stmt>(worklist_top);
if (worklist_top == loop_header->parent)
return ast_parent_map[worklist_top];
}
else if (isa<block::goto_stmt>(worklist_top)) {
goto_stmt::Ptr wl_goto_stmt = to<goto_stmt>(worklist_top);
if (worklist_top == loop_header->parent)
return ast_parent_map[worklist_top];
}
}

return nullptr;
}

static void replace_loop_latches(std::shared_ptr<loop> loop, block::stmt_block::Ptr ast) {
for (auto latch : loop->loop_latch_blocks) {
stmt::Ptr loop_latch_ast = get_loop_block(latch, ast);
if (isa<stmt_block>(loop_latch_ast)) {
std::vector<stmt::Ptr> &temp_loop_ast = to<stmt_block>(loop_latch_ast)->stmts;
std::replace(temp_loop_ast.begin(), temp_loop_ast.end(), temp_loop_ast[latch->ast_index], to<stmt>(std::make_shared<continue_stmt>()));
}
else if (isa<if_stmt>(loop_latch_ast)) {
stmt_block::Ptr if_then_block = to<block::stmt_block>(to<block::if_stmt>(loop_latch_ast)->then_stmt);
stmt_block::Ptr if_else_block = to<block::stmt_block>(to<block::if_stmt>(loop_latch_ast)->else_stmt);

if (if_then_block->stmts.size() && if_then_block->stmts[latch->ast_index] == latch->parent) {
std::replace(if_then_block->stmts.begin(), if_then_block->stmts.end(), if_then_block->stmts[latch->ast_index], to<stmt>(std::make_shared<continue_stmt>()));
}
else if (if_else_block->stmts.size() && if_else_block->stmts[latch->ast_index] == latch->parent) {
std::replace(if_else_block->stmts.begin(), if_else_block->stmts.end(), if_else_block->stmts[latch->ast_index], to<stmt>(std::make_shared<continue_stmt>()));
}
}
}
}

block::stmt_block::Ptr loop_info::convert_to_ast(block::stmt_block::Ptr ast) {
for (auto loop_map: postorder_loops_map) {
// std::cerr << "== top level loop tree ==\n";
for (auto postorder: loop_map.second) {
// std::cerr << postorder <<"\n";
block::stmt::Ptr loop_header_ast = get_loop_block(loops[postorder]->header_block, ast);

while_stmt::Ptr while_block = std::make_shared<while_stmt>();
while_block->cond = std::make_shared<int_const>();
to<int_const>(while_block->cond)->value = 1;
while_block->body = std::make_shared<stmt_block>();

if (isa<block::stmt_block>(loop_header_ast)) {
unsigned int ast_index = loops[postorder]->header_block->ast_index;
if (to<block::stmt_block>(loop_header_ast)->stmts[ast_index] == loops[postorder]->header_block->parent) {
stmt_block::Ptr then_block = to<block::stmt_block>(to<block::if_stmt>(to<block::stmt_block>(loop_header_ast)->stmts[ast_index + 1])->then_stmt);
stmt_block::Ptr else_block = to<block::stmt_block>(to<block::if_stmt>(to<block::stmt_block>(loop_header_ast)->stmts[ast_index + 1])->else_stmt);

// if (isa<goto_stmt>(then_block->stmts.back())) {
// then_block->stmts.pop_back();
// then_block->stmts.push_back(std::make_shared<continue_stmt>());
// }
replace_loop_latches(loops[postorder], ast);

else_block->stmts.push_back(std::make_shared<break_stmt>());
to<stmt_block>(while_block->body)->stmts.push_back(to<block::stmt_block>(loop_header_ast)->stmts[ast_index + 1]);
// while_block->cond = to<block::if_stmt>(to<block::stmt_block>(loop_header_ast)->stmts[ast_index + 1])->cond;
// while_block->dump(std::cerr, 0);
// std::cerr << "found loop header in stmt block\n";

// if block to be replaced with while block
std::vector<stmt::Ptr> &temp_ast = to<block::stmt_block>(loop_header_ast)->stmts;
std::replace(temp_ast.begin(), temp_ast.end(), temp_ast[ast_index + 1], to<stmt>(while_block));
temp_ast.erase(temp_ast.begin() + ast_index);
}
else {
// std::cerr << "not found loop header in stmt block\n";
}
}
else if (isa<block::if_stmt>(loop_header_ast)) {
unsigned int ast_index = loops[postorder]->header_block->ast_index;
stmt_block::Ptr if_then_block = to<block::stmt_block>(to<block::if_stmt>(loop_header_ast)->then_stmt);
stmt_block::Ptr if_else_block = to<block::stmt_block>(to<block::if_stmt>(loop_header_ast)->else_stmt);

if (if_then_block->stmts.size() != 0) {
if (if_then_block->stmts[ast_index] == loops[postorder]->header_block->parent) {
stmt_block::Ptr then_block = to<block::stmt_block>(to<block::if_stmt>(if_then_block->stmts[ast_index + 1])->then_stmt);
stmt_block::Ptr else_block = to<block::stmt_block>(to<block::if_stmt>(if_then_block->stmts[ast_index + 1])->else_stmt);

replace_loop_latches(loops[postorder], ast);

else_block->stmts.push_back(std::make_shared<break_stmt>());
to<stmt_block>(while_block->body)->stmts.push_back(if_then_block->stmts[ast_index + 1]);
// while_block->cond = to<block::if_stmt>(loop_header_ast)->cond;

// while_block->dump(std::cerr, 0);
// std::cerr << "found loop header in if-then stmt\n";

// if block to be replaced with while block
std::vector<stmt::Ptr> &temp_ast = if_then_block->stmts;
std::replace(temp_ast.begin(), temp_ast.end(), temp_ast[ast_index + 1], to<stmt>(while_block));
temp_ast.erase(temp_ast.begin() + ast_index);
}
else {
// loop_header_ast->dump(std::cerr, 0);
// std::cerr << "not found loop header in if-then stmt\n";
}
}
else if (if_else_block->stmts.size() != 0) {
if (if_else_block->stmts[ast_index] == loops[postorder]->header_block->parent) {
stmt_block::Ptr then_block = to<block::stmt_block>(to<block::if_stmt>(if_else_block->stmts[ast_index + 1])->then_stmt);
stmt_block::Ptr else_block = to<block::stmt_block>(to<block::if_stmt>(if_else_block->stmts[ast_index + 1])->else_stmt);

replace_loop_latches(loops[postorder], ast);

else_block->stmts.push_back(std::make_shared<break_stmt>());
to<stmt_block>(while_block->body)->stmts.push_back(if_else_block->stmts[ast_index + 1]);
// while_block->cond = to<block::if_stmt>(loop_header_ast)->cond;

// while_block->dump(std::cerr, 0);
// std::cerr << "found loop header in if-else stmt\n";

// if block to be replaced with while block
std::vector<stmt::Ptr> &temp_ast = if_else_block->stmts;
std::replace(temp_ast.begin(), temp_ast.end(), temp_ast[ast_index + 1], to<stmt>(while_block));
temp_ast.erase(temp_ast.begin() + ast_index);
}
else {
// loop_header_ast->dump(std::cerr, 0);
// std::cerr << "not found loop header in if-else stmt\n";
}
}
}
else {
// std::cerr << "loop header not found\n";
}
// insert into AST - std::replace
// set the ast to loop depth + 1
// loops[loop_tree.first]->header_block->ast_index
}
}

return ast;
}
23 changes: 20 additions & 3 deletions src/builder/builder_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ block::stmt::Ptr builder_context::extract_ast_from_function_impl(void) {
for (auto pred: bb->predecessor) {
std::cerr << pred->name << ", ";
}
std::cerr << bb->ast_depth;
std::cerr << "\n";
if (bb->branch_expr) {
std::cerr << " ";
Expand Down Expand Up @@ -386,14 +387,30 @@ block::stmt::Ptr builder_context::extract_ast_from_function_impl(void) {
for (auto subl: loop->subloops) std::cerr << "(loop header: " << subl->header_block->id << ") ";
std::cerr << "\n";
}

std::cerr << "++++++ top level loops ++++++ \n";
for (auto top_level_loop: LI.top_level_loops) std::cerr << "(loop header: " << top_level_loop->header_block->id << ") ";
std::cerr << "\n";

std::cerr << "++++++ preorder loops tree ++++++ \n";
for (auto loop_tree: LI.postorder_loops_map) {
std::cerr << "loop tree root: (loop header: " << LI.loops[loop_tree.first]->header_block->id << ")\n";
std::cerr << "postorder: ";
for (auto node: loop_tree.second) std::cerr << node << " ";
std::cerr << "\n";
}
std::cerr << "++++++ loop info ++++++ \n";

std::cerr << "++++++ convert to ast ++++++ \n";
LI.convert_to_ast(block::to<block::stmt_block>(ast));
std::cerr << "++++++ convert to ast ++++++ \n";

if (feature_unstructured)
return ast;

block::loop_finder finder;
finder.ast = ast;
ast->accept(&finder);
// block::loop_finder finder;
// finder.ast = ast;
// ast->accept(&finder);

block::for_loop_finder for_finder;
for_finder.ast = ast;
Expand Down

0 comments on commit 94be2dd

Please sign in to comment.