Skip to content

Commit

Permalink
Matmul1DProgramConfig workaround removal (#1248)
Browse files Browse the repository at this point in the history
Workaround introduced in #894
is not needed anymore. The issue was fixed in metal
tenstorrent/tt-metal#13819.

Closes #891

FYI @odjuricicTT
  • Loading branch information
azecevicTT authored Nov 14, 2024
1 parent a5cc742 commit 6dd5eba
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 114 deletions.
12 changes: 3 additions & 9 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ struct Env {
#endif
get(bool ignoreTileShape = true, bool emptyOpForceRowMajor = true,
bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true,
bool setMatmul1DProgramConfig = true, bool swapBinaryOperands = true)
bool swapBinaryOperands = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true, true, true, true);
return Env(true, true, true, true, true);
}
#endif
// TODO(bug #272), determine correct layout by tile shape in the future
Expand All @@ -42,9 +42,6 @@ struct Env {
// instead of adding a method in runtime
bool maxpool2dPreshard;

// TODO(bug #891): ttnn::matmul doesn't chose correct program config.
bool setMatmul1DProgramConfig;

// TODO(bug #1124): We're currently swapping the operands for binary ops
// in runtime if the lhs operand is smaller (and requires broadcast onto the
// rhs operand). We should add this check in the compiler.
Expand All @@ -53,12 +50,11 @@ struct Env {
private:
constexpr Env(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool setMatmul1DProgramConfig, bool swapBinaryOperands)
bool swapBinaryOperands)
: ignoreTileShape(ignoreTileShape),
emptyOpForceRowMajor(emptyOpForceRowMajor),
fullOpForceRowMajor(fullOpForceRowMajor),
maxpool2dPreshard(maxpool2dPreshard),
setMatmul1DProgramConfig(setMatmul1DProgramConfig),
swapBinaryOperands(swapBinaryOperands) {}
};

Expand All @@ -72,8 +68,6 @@ inline std::ostream &operator<<(std::ostream &os, const Env &env) {
<< "fullOpForceRowMajor: " << env.fullOpForceRowMajor << ",\n";
os << "\t"
<< "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n";
os << "\t"
<< "setMatmul1DProgramConfig: " << env.setMatmul1DProgramConfig << "\n";
os << "\t"
<< "swapBinaryOperands: " << env.swapBinaryOperands << "\n";
os << "}";
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool setMatmul1DProgramConfig, bool swapBinaryOperands) {
bool swapBinaryOperands) {
static const Env config(ignoreTileShape, emptyOpForceRowMajor,
fullOpForceRowMajor, maxpool2dPreshard,
setMatmul1DProgramConfig, swapBinaryOperands);
swapBinaryOperands);
return config;
}
#endif
Expand Down
79 changes: 1 addition & 78 deletions runtime/lib/ttnn/operations/matmul/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,81 +5,11 @@
#include "matmul.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/detail/workarounds.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include <optional>

// ANCHOR: adding_an_op_matmul_runtime_operations
namespace tt::runtime::ttnn::operations::matmul {

// This is a workaround for the lack of program config selection in ttnn.matmul.
// The logic here is temporary and totaly incompleate.
static ::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig
createProgramConfig(const ::tt::target::ttnn::MatmulOp *op,
ProgramContext &context,
::tt::tt_metal::MemoryConfig outputMemoryConfig) {

uint32_t numCores = outputMemoryConfig.shard_spec->grid.num_cores();
bool fuseBatch = true; // required for sharded inputs

ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id());
const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id());

// note: use ttnn::Shape::value returns a legacy tt::tt_metal::Shape object
// which does take padding into account.
uint32_t volume = 1;
for (size_t i = 0; i < lhs.shape().rank(); i++) {
volume *= lhs.shape().value[i];
}

uint32_t M =
fuseBatch ? volume / lhs.shape().value[-1] : lhs.shape().value[-2];
// uint32_t K = lhs.shape().value[-1];
uint32_t N = rhs.shape().value[-1];
bool mcastIn0 = N >= M;

uint32_t perCoreM, perCoreN;

if (mcastIn0) {
perCoreM = M / tt::constants::TILE_HEIGHT;
perCoreN = tt::div_up(tt::div_up(N, numCores), tt::constants::TILE_WIDTH);
} else {
perCoreM = tt::div_up(tt::div_up(M, numCores), tt::constants::TILE_HEIGHT);
perCoreN = N / tt::constants::TILE_WIDTH;
}

// uint32_t in0_block_w = (K / tt::constants::TILE_WIDTH) % 2 == 0 ? 2 : 1;
uint32_t in0BlockW = 1;

// These should work in most cases, but there is a logic how we can optimize
// this later.
uint32_t outSubblockH = 1, outSubblockW = 1;

LOG_ASSERT(outputMemoryConfig.shard_spec->grid.ranges().size() == 1);
CoreCoord computeWithStorageGridSize =
outputMemoryConfig.shard_spec->grid.ranges().begin()->grid_size();
if (lhs.is_sharded()) {
CoreCoord lhs_grid_size =
lhs.shard_spec()->grid.ranges().begin()->grid_size();
if (computeWithStorageGridSize < lhs_grid_size) {
computeWithStorageGridSize = lhs_grid_size;
}
}

return ::ttnn::operations::matmul::
MatmulMultiCoreReuseMultiCast1DProgramConfig{
.compute_with_storage_grid_size = computeWithStorageGridSize,
.in0_block_w = in0BlockW,
.out_subblock_h = outSubblockH,
.out_subblock_w = outSubblockW,
.per_core_M = perCoreM,
.per_core_N = perCoreN,
.fuse_batch = true,
.fused_activation = std::nullopt,
.mcast_in0 = mcastIn0};
};

void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id());
Expand All @@ -94,13 +24,6 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) {
::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig>
programConfig = std::nullopt;

// TODO(bug #891): ttnn::matmul doesn't chose correct program config.
if (workaround::Env::get().setMatmul1DProgramConfig &&
outputMemoryConfig.memory_layout ==
::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) {
programConfig = createProgramConfig(op, context, outputMemoryConfig);
}

const std::optional<const ::tt::tt_metal::MemoryConfig> memoryConfig =
std::make_optional(outputMemoryConfig);

Expand All @@ -109,7 +32,7 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) {

::ttnn::Tensor out = ::ttnn::matmul(
lhs, rhs, /*transposeA*/ false, /*transposeB*/ false, memoryConfig, dtype,
programConfig, /*activation*/ std::nullopt,
/*programConfig*/ std::nullopt, /*activation*/ std::nullopt,
/*computeKernelConfig*/ std::nullopt, /*coreGrid*/ std::nullopt);

tensorPool.insert_or_assign(op->out()->global_id(), out);
Expand Down
17 changes: 0 additions & 17 deletions runtime/tools/python/test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,20 +377,3 @@ def test_disable_maxpool2d_preshard_run():
def test_disable_maxpool2d_preshard_cmd_run():
command = f"ttrt run {BINARY_FILE_PATH} --disable-maxpool2d-preshard --log-file ttrt-results/{inspect.currentframe().f_code.co_name}.log --result-file ttrt-results/{inspect.currentframe().f_code.co_name}.json"
sub_process_command(command)


def test_disable_matmul_1d_program_config_run():
API.initialize_apis()
custom_args = {}
custom_args[
"--result-file"
] = f"ttrt-results/{inspect.currentframe().f_code.co_name}.json"
custom_args["binary"] = BINARY_FILE_PATH
custom_args["--disable-matmul-1d-program-config"] = True
run_instance = API.Run(args=custom_args)
run_instance()


def test_disable_matmul_1d_program_config_cmd_run():
command = f"ttrt run {BINARY_FILE_PATH} --disable-matmul-1d-program-config --log-file ttrt-results/{inspect.currentframe().f_code.co_name}.log --result-file ttrt-results/{inspect.currentframe().f_code.co_name}.json"
sub_process_command(command)
8 changes: 0 additions & 8 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,6 @@ def initialize_api():
choices=[True, False],
help="disable maxpool2d preshard workaround",
)
Run.register_arg(
name="--disable-matmul-1d-program-config",
type=bool,
default=False,
choices=[True, False],
help="disable matmul 1d program config workaround",
)
Run.register_arg(
name="--disable-swap-binary-operands",
type=bool,
Expand Down Expand Up @@ -370,7 +363,6 @@ def _execute(binaries):
not self["--disable-empty-op-row-major"],
not self["--disable-full-op-row-major"],
not self["--disable-maxpool2d-preshard"],
not self["--disable-matmul-1d-program-config"],
not self["--disable-swap-binary-operands"],
)
self.logging.debug(f"setting tt runtime workaround env={workaround_env}")
Expand Down

0 comments on commit 6dd5eba

Please sign in to comment.