Skip to content

Commit

Permalink
Implement the trivial dead code elmination algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
chuang0221 committed Sep 5, 2024
1 parent c0ab3e1 commit 0874fd7
Show file tree
Hide file tree
Showing 15 changed files with 186 additions and 12 deletions.
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ set(SOURCES
src/buildCFG.cpp
src/common.cpp
src/logger.cpp
src/config.cpp
src/localValueNumbering.cpp
src/deadCodeElimination.cpp
)

add_library(bril_common ${SOURCES})
Expand All @@ -27,4 +29,7 @@ add_executable(build_cfg src/buildCFGMain.cpp)
target_link_libraries(build_cfg PRIVATE bril_common nlohmann_json::nlohmann_json)

add_executable(local_value_numbering src/localValueNumberingMain.cpp)
target_link_libraries(local_value_numbering PRIVATE bril_common nlohmann_json::nlohmann_json)
target_link_libraries(local_value_numbering PRIVATE bril_common nlohmann_json::nlohmann_json)

add_executable(dead_code_elimination src/deadCodeEliminationMain.cpp)
target_link_libraries(dead_code_elimination PRIVATE bril_common nlohmann_json::nlohmann_json)
18 changes: 16 additions & 2 deletions include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,28 @@

#include <string>

class Config {
class PassConfig {
public:
virtual ~PassConfig() = default;
};

class LVNConfig : public PassConfig {
public:
bool enableCommutative;
bool enableConstantFolding;
bool enableAlgebraicIdentity;

Config(bool enableCommutative, bool enableConstantFolding, bool enableAlgebraicIdentity)
LVNConfig(bool enableCommutative, bool enableConstantFolding, bool enableAlgebraicIdentity)
: enableCommutative(enableCommutative), enableConstantFolding(enableConstantFolding), enableAlgebraicIdentity(enableAlgebraicIdentity) {}
};

class DCEConfig : public PassConfig {
public:
bool enableAggressiveDCE;

DCEConfig(bool enableAggressiveDCE) : enableAggressiveDCE(enableAggressiveDCE) {}
};

PassConfig* createPassConfig(const std::string& passName);

#endif
17 changes: 17 additions & 0 deletions include/deadCodeElimination.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef DEAD_CODE_ELIMINATION_H
#define DEAD_CODE_ELIMINATION_H

#include "config.h"
#include <nlohmann/json.hpp>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include <tuple>
#include <functional>
#include <iostream>

using json = nlohmann::json;

void deadCodeElimination(std::vector<std::vector<json>>& blocks, DCEConfig& config);

#endif
4 changes: 2 additions & 2 deletions include/localValueNumbering.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ class ValueNumbering {
int refreshNumber(const std::string& name);
long long stov(const std::string& value);
long long evaluate(const std::string& type1, const std::string& value1, const std::string& type2, const std::string& value2, const std::string& op);
void update(json& instr, const Config& config);
void update(json& instr, const LVNConfig& config);
};

void checkCommutative(std::tuple<int, std::string, int>& nameTuple);
IdentityType checkAlgebraicIdentity(std::tuple<int, std::string, int>& nameTuple);

void localValueNumbering(std::vector<std::vector<json>>& blocks, const Config& config);
void localValueNumbering(std::vector<std::vector<json>>& blocks, const LVNConfig& config);
#endif
10 changes: 10 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include "config.h"

PassConfig* createPassConfig(const std::string& passName) {
if (passName == "LVN") {
return new LVNConfig(false, false, false);
} else if (passName == "DCE") {
return new DCEConfig(false);
}
return nullptr;
}
34 changes: 34 additions & 0 deletions src/deadCodeElimination.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "deadCodeElimination.h"
#include "logger.h"

void deadCodeElimination(std::vector<std::vector<json>>& blocks, DCEConfig& config) {
std::unordered_set<std::string> usedVars;
// 1. Traverse the blocks to get all used variables
for (auto &block : blocks) {
for (auto &instr : block) {
for (auto &arg : instr["args"]) {
usedVars.insert(arg.get<std::string>());
}
}
}
// 2. Traverse the blocks to remove unused variables
bool changed = false;
for (auto &block : blocks) {
std::vector<json> newBlock;
for (auto &instr : block) {
if (instr.find("dest") == instr.end() || usedVars.find(instr["dest"].get<std::string>()) != usedVars.end()) {
LOG_DEBUG(instr.dump());
newBlock.push_back(instr);
}
}
if (block.size() != newBlock.size()) {
changed = true;
}
block = newBlock;
}
if (changed) {
deadCodeElimination(blocks, config);
}
LOG_DEBUG("\n");
return;
}
36 changes: 36 additions & 0 deletions src/deadCodeEliminationMain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "deadCodeElimination.h"
#include "common.h"
#include "buildBlocks.h"
#include "logger.h"

int main(int argc, char* argv[]) {
try {
Logger::getInstance().setLogLevel(LogLevel::INFO);
json program = parseJsonFromStdin();
LOG_DEBUG("Starting dead code elimination");
for (auto& func : program["functions"]) {
std::vector<std::vector<json>> blocks = buildBlocks(func["instrs"]);
std::vector<std::string> args(argv + 1, argv + argc);
DCEConfig* config = static_cast<DCEConfig*>(createPassConfig("DCE"));
for (auto& arg : args) {
if (arg == "-g") {
Logger::getInstance().setLogLevel(LogLevel::DEBUG);
}
if (arg == "-a") {
config->enableAggressiveDCE = true;
}
}
deadCodeElimination(blocks, *config);
delete config;
LOG_DEBUG("After dead code elimination");
//printBlocks(blocks, false);
func["instrs"] = flattenBlocks(blocks);
}
std::cout << program.dump(2) << std::endl;
LOG_DEBUG("Local value numbering finished");
} catch (const std::exception& e) {
LOG_ERROR(std::string("Error: ") + e.what());
return 1;
}
return 0;
}
4 changes: 2 additions & 2 deletions src/localValueNumbering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ IdentityType checkAlgebraicIdentity(std::tuple<int, std::string, int>& nameTuple
return IDENTITY_NONE;
}

void ValueNumbering::update(json& instr, const Config& config) {
void ValueNumbering::update(json& instr, const LVNConfig& config) {
std::string op = instr["op"];
std::string dest = instr["dest"];
if (op == "const") {
Expand Down Expand Up @@ -146,7 +146,7 @@ void checkCommutative(std::tuple<int, std::string, int>& nameTuple) {
return;
}

void localValueNumbering(std::vector<std::vector<json>>& blocks, const Config& config) {
void localValueNumbering(std::vector<std::vector<json>>& blocks, const LVNConfig& config) {
for (auto& block : blocks) {
ValueNumbering vn;
for (auto& instr : block) {
Expand Down
11 changes: 6 additions & 5 deletions src/localValueNumberingMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@ int main(int argc, char* argv[]) {
LOG_DEBUG("Before local value numbering");
//printBlocks(blocks, true);
std::vector<std::string> args(argv + 1, argv + argc);
Config config(false, false, false);
LVNConfig* config = static_cast<LVNConfig*>(createPassConfig("LVN"));
for (auto& arg : args) {
if (arg == "-g") {
Logger::getInstance().setLogLevel(LogLevel::DEBUG);
}
if (arg == "-c") {
config.enableCommutative = true;
config->enableCommutative = true;
}
else if (arg == "-f") {
config.enableConstantFolding = true;
config->enableConstantFolding = true;
}
else if (arg == "-a") {
config.enableAlgebraicIdentity = true;
config->enableAlgebraicIdentity = true;
}
}
localValueNumbering(blocks, config);
localValueNumbering(blocks, *config);
delete config;
LOG_DEBUG("After local value numbering");
//printBlocks(blocks, false);
func["instrs"] = flattenBlocks(blocks);
Expand Down
12 changes: 12 additions & 0 deletions test/dce/dce_test.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
@main {
num1: int = const 1;
num2: int = const 2;
should_be_removed1: int = const 3;
should_be_removed2: int = const 4;
should_not_be_removed: int = const 5;

num3: int = add num1 num2;
num4: int = add num3 num3;
num5: int = and num4 should_not_be_removed;
print num4;
}
9 changes: 9 additions & 0 deletions test/dce/lvn_dce.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@main {
num1: int = const 1;
num2: int = const 2;
should_be_removed2: int = add num1 num2;
should_not_be_removed1: int = id should_be_removed2;
num4: int = add num3 num3;
print num4;
print should_not_be_removed1;
}
12 changes: 12 additions & 0 deletions test/dce/lvn_dce_test.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
@main {
num1: int = const 1;
num2: int = const 2;
should_be_removed1: int = const 3;
should_be_removed2: int = add num1 num2;
should_be_removed3: int = add num1 num2;
should_not_be_removed1: int = add num1 num2;

num4: int = add num3 num3;
print num4;
print should_not_be_removed1;
}
8 changes: 8 additions & 0 deletions test/dce/only_dce.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@main {
num1: int = const 1;
num2: int = const 2;
should_not_be_removed1: int = add num1 num2;
num4: int = add num3 num3;
print num4;
print should_not_be_removed1;
}
10 changes: 10 additions & 0 deletions test/lvn/cse.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@main{
a: int = const 1;
b: int = const 1;
two: int = const 2;
t1: int = mul a two;
t2: int = mul a two;
t3: int = mul t2 b;
result: int = add t1 t3;
print result;
}
6 changes: 6 additions & 0 deletions test/lvn/cse_out.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@main {
t1: int = mul a two;
t2: int = mul t1 b;
result: int = add t1 t2;
print result;
}

0 comments on commit 0874fd7

Please sign in to comment.