Skip to content

Commit

Permalink
Updating metal reference in tt-mlir to include the fix for reduce op (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT authored Sep 13, 2024
1 parent b39cc58 commit e578714
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
2 changes: 1 addition & 1 deletion runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#pragma clang diagnostic ignored "-Wlogical-op-parentheses"
#pragma clang diagnostic ignored "-Wundefined-inline"
#define FMT_HEADER_ONLY
#include "impl/device/device_mesh.hpp"
#include "impl/device/mesh_device.hpp"
#include "impl/event/event.hpp"
#include "tt_metal/host_api.hpp"
#pragma clang diagnostic pop
Expand Down
14 changes: 7 additions & 7 deletions runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#define FMT_HEADER_ONLY
#include "host_api.hpp"
#include "hostdevcommon/common_values.hpp"
#include "impl/device/device_mesh.hpp"
#include "impl/device/mesh_device.hpp"
#pragma clang diagnostic pop

namespace tt::runtime::system_desc {
Expand Down Expand Up @@ -174,8 +174,8 @@ calculateDRAMUnreservedEnd(const ::tt::tt_metal::Device *device) {
}

static std::unique_ptr<::tt::runtime::SystemDesc>
getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) {
std::vector<::tt::tt_metal::Device *> devices = deviceMesh.get_devices();
getCurrentSystemDescImpl(const ::tt::tt_metal::MeshDevice &meshDevice) {
std::vector<::tt::tt_metal::Device *> devices = meshDevice.get_devices();
std::sort(devices.begin(), devices.end(),
[](const ::tt::tt_metal::Device *a,
const ::tt::tt_metal::Device *b) { return a->id() < b->id(); });
Expand Down Expand Up @@ -267,19 +267,19 @@ std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() {
"Unexpected non-rectangular grid of devices");
std::vector<chip_id_t> deviceIds(numDevices);
std::iota(deviceIds.begin(), deviceIds.end(), 0);
::tt::tt_metal::DeviceGrid grid =
::tt::tt_metal::MeshShape grid =
std::make_pair(numDevices / numPciDevices, numPciDevices);
::tt::tt_metal::DeviceMesh deviceMesh = ::tt::tt_metal::DeviceMesh(
::tt::tt_metal::MeshDevice meshDevice = ::tt::tt_metal::MeshDevice(
grid, deviceIds, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1,
::tt::tt_metal::DispatchCoreType::WORKER);
std::exception_ptr eptr = nullptr;
std::unique_ptr<::tt::runtime::SystemDesc> desc;
try {
desc = getCurrentSystemDescImpl(deviceMesh);
desc = getCurrentSystemDescImpl(meshDevice);
} catch (...) {
eptr = std::current_exception();
}
deviceMesh.close_devices();
meshDevice.close_devices();
if (eptr) {
std::rethrow_exception(eptr);
}
Expand Down
7 changes: 3 additions & 4 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,7 @@ static void runReductionOp(
const ::ttnn::Tensor &,
const std::optional<std::variant<int, std::vector<int>>> &, const bool,
const std::optional<::tt::tt_metal::MemoryConfig> &,
const std::optional<::tt::tt_metal::DeviceComputeKernelConfig> &,
float)>
const std::optional<::ttnn::DeviceComputeKernelConfig> &, float)>
ttnnOp) {
::tt::tt_metal::MemoryConfig outputMemoryConfig =
createMemoryConfig(op->out());
Expand Down Expand Up @@ -758,8 +757,8 @@ static void run(::tt::target::ttnn::MaxPool2dOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
const ::ttnn::operations::pool::MaxPoolNewOp operation =
::ttnn::operations::pool::MaxPoolNewOp();
const ::ttnn::operations::pool::MaxPool2DOp operation =
::ttnn::operations::pool::MaxPool2DOp();

::ttnn::Device &device = getDevice(op->device(), devicePool);
::ttnn::Tensor out = operation.invoke(
Expand Down
2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ ExternalProject_Add(
-DENABLE_TRACY=${TT_RUNTIME_ENABLE_PERF_TRACE}
-DENABLE_LIBCXX=OFF
GIT_REPOSITORY https://github.com/tenstorrent/tt-metal.git
GIT_TAG 516a8917fac17649abbc9052b680894b432b4de6
GIT_TAG 6d0d0800c360dfbe1ee00543466173cb675ece36
GIT_PROGRESS ON
BUILD_BYPRODUCTS ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH}
)
Expand Down

0 comments on commit e578714

Please sign in to comment.