Skip to content

Commit

Permalink
Remove workarounds for empty and full op (#1485)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT authored Dec 4, 2024
1 parent d837ac7 commit 9405473
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 130 deletions.
34 changes: 4 additions & 30 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,14 @@ struct Env {
#else
constexpr static Env
#endif
get(bool ignoreTileShape = true, bool emptyOpForceRowMajor = true,
bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true,
bool swapBinaryOperands = true)
get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true, true, true);
return Env(true, true);
}
#endif
// TODO(bug #272), determine correct layout by tile shape in the future
// currently tile shape is not set correctly, so as a workaround, hardcode
// layout
bool ignoreTileShape;

// TODO(bug #582): ttnn::empty doesn't work properly with tile layout,
// using ROW_MAJOR until we fix it
bool emptyOpForceRowMajor;

// TODO(bug #582): ttnn::full doesn't work properly with tile layout,
// using ROW_MAJOR until we fix it
bool fullOpForceRowMajor;

// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d
// instead of adding a method in runtime
bool maxpool2dPreshard;
Expand All @@ -48,24 +33,13 @@ struct Env {
bool swapBinaryOperands;

private:
constexpr Env(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool swapBinaryOperands)
: ignoreTileShape(ignoreTileShape),
emptyOpForceRowMajor(emptyOpForceRowMajor),
fullOpForceRowMajor(fullOpForceRowMajor),
maxpool2dPreshard(maxpool2dPreshard),
constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands)
: maxpool2dPreshard(maxpool2dPreshard),
swapBinaryOperands(swapBinaryOperands) {}
};

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
os << "workaround::Env{\n";
os << "\t"
<< "ignoreTileShape: " << env.ignoreTileShape << ",\n";
os << "\t"
<< "emptyOpForceRowMajor: " << env.emptyOpForceRowMajor << ",\n";
os << "\t"
<< "fullOpForceRowMajor: " << env.fullOpForceRowMajor << ",\n";
os << "\t"
<< "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n";
os << "\t"
Expand Down
8 changes: 2 additions & 6 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@

namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool swapBinaryOperands) {
static const Env config(ignoreTileShape, emptyOpForceRowMajor,
fullOpForceRowMajor, maxpool2dPreshard,
swapBinaryOperands);
const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands) {
static const Env config(maxpool2dPreshard, swapBinaryOperands);
return config;
}
#endif
Expand Down
5 changes: 0 additions & 5 deletions runtime/lib/ttnn/operations/creation/empty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ struct EmptyTensorConfig {
dtype(::tt::runtime::ttnn::operations::utils::getDataType(op->out())),
numShards(op->num_shards()), strategy(op->strategy()) {
layout = ::tt::runtime::ttnn::utils::toTTNNLayout(op->layout());
// TODO(bug #582): ttnn::empty doesn't work properly with tile layout,
// using ROW_MAJOR until we fix it
if (workaround::Env::get().emptyOpForceRowMajor) {
layout = ::ttnn::Layout::ROW_MAJOR;
}
if (op->device()) {
LOG_ASSERT(op->memcfg(),
"Memory config must be provided when device is provided");
Expand Down
13 changes: 0 additions & 13 deletions runtime/lib/ttnn/operations/creation/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,6 @@ struct FullTensorConfig {

layout = ::tt::runtime::ttnn::utils::inferLayoutFromTileShape(op->out());

// TODO(bug #272), determine correct layout by tile shape in the future
// currently tile shape is not set correctly, so as a workaround, hardcode
// layout
if (workaround::Env::get().ignoreTileShape) {
layout = ::ttnn::Layout::TILE;
}

// TODO(bug #582): ttnn::empty doesn't work properly with tile layout,
// using ROW_MAJOR until we fix it
if (workaround::Env::get().fullOpForceRowMajor) {
layout = ::ttnn::Layout::ROW_MAJOR;
}

if (!utils::inSystemMemory(op->out())) {
memoryConfig = ::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
}
Expand Down
2 changes: 1 addition & 1 deletion runtime/test/python/ttnn/test_runtime_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,5 @@ def test_runtime_stitching_eltwise_binary_op_chain(helper: Helper, request):
golden = (
(inputs_torch[0] + inputs_torch[1]).mul(inputs_torch[1]).sub(inputs_torch[1])
)
assert_pcc(golden, torch_result_tensor, threshold=0.999), program_index
assert_pcc(golden, torch_result_tensor, threshold=0.99)
helper.teardown()
51 changes: 0 additions & 51 deletions runtime/tools/python/test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,57 +311,6 @@ def test_enable_async_ttnn_cmd_run():
sub_process_command(command)


def test_disable_ignore_tile_shape_run():
API.initialize_apis()
custom_args = {}
custom_args[
"--result-file"
] = f"ttrt-results/{inspect.currentframe().f_code.co_name}.json"
custom_args["binary"] = BINARY_FILE_PATH
custom_args["--disable-ignore-tile-shape"] = True
run_instance = API.Run(args=custom_args)
run_instance()


def test_disable_ignore_tile_shape_cmd_run():
command = f"ttrt run {BINARY_FILE_PATH} --disable-ignore-tile-shape --log-file ttrt-results/{inspect.currentframe().f_code.co_name}.log --result-file ttrt-results/{inspect.currentframe().f_code.co_name}.json"
sub_process_command(command)


def test_disable_empty_op_row_major_run():
API.initialize_apis()
custom_args = {}
custom_args[
"--result-file"
] = f"ttrt-results/{inspect.currentframe().f_code.co_name}.json"
custom_args["binary"] = BINARY_FILE_PATH
custom_args["--disable-empty-op-row-major"] = True
run_instance = API.Run(args=custom_args)
run_instance()


def test_disable_empty_op_row_major_cmd_run():
command = f"ttrt run {BINARY_FILE_PATH} --disable-empty-op-row-major --log-file ttrt-results/{inspect.currentframe().f_code.co_name}.log --result-file ttrt-results/{inspect.currentframe().f_code.co_name}.json"
sub_process_command(command)


def test_disable_full_op_row_major_run():
API.initialize_apis()
custom_args = {}
custom_args[
"--result-file"
] = f"ttrt-results/{inspect.currentframe().f_code.co_name}.json"
custom_args["binary"] = BINARY_FILE_PATH
custom_args["--disable-full-op-row-major"] = True
run_instance = API.Run(args=custom_args)
run_instance()


def test_disable_full_op_row_major_cmd_run():
command = f"ttrt run {BINARY_FILE_PATH} --disable-full-op-row-major --log-file ttrt-results/{inspect.currentframe().f_code.co_name}.log --result-file ttrt-results/{inspect.currentframe().f_code.co_name}.json"
sub_process_command(command)


def test_disable_maxpool2d_preshard_run():
API.initialize_apis()
custom_args = {}
Expand Down
24 changes: 0 additions & 24 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,6 @@ def initialize_api():
choices=[True, False],
help="enable async mode device execution for TTNN runtime",
)
Run.register_arg(
name="--disable-ignore-tile-shape",
type=bool,
default=False,
choices=[True, False],
help="disable ignore tile shape workaround",
)
Run.register_arg(
name="--disable-empty-op-row-major",
type=bool,
default=False,
choices=[True, False],
help="disable empty op force row major workaround",
)
Run.register_arg(
name="--disable-full-op-row-major",
type=bool,
default=False,
choices=[True, False],
help="disable full op force row major workaround",
)
Run.register_arg(
name="--disable-maxpool2d-preshard",
type=bool,
Expand Down Expand Up @@ -370,9 +349,6 @@ def _execute(binaries):
)
self.logging.debug(f"setting tt runtime debug env={debug_env}")
workaround_env = ttrt.runtime.WorkaroundEnv.get(
not self["--disable-ignore-tile-shape"],
not self["--disable-empty-op-row-major"],
not self["--disable-full-op-row-major"],
not self["--disable-maxpool2d-preshard"],
not self["--disable-swap-binary-operands"],
)
Expand Down

0 comments on commit 9405473

Please sign in to comment.