Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
broskoTT committed Nov 29, 2024
1 parent 53c32c0 commit a2a5e40
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 16 deletions.
11 changes: 6 additions & 5 deletions tests/tt_metal/test_utils/env_vars.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "common/utils.hpp"

#include "umd/device/device_api_metal.h"
#include "umd/device/tt_cluster_descriptor.h"

#include <string>

Expand Down Expand Up @@ -43,11 +44,11 @@ inline std::string get_umd_arch_name() {
return get_env_arch_name();
}

std::vector<chip_id_t> physical_mmio_device_ids = tt::umd::Cluster::detect_available_device_ids();
tt::ARCH arch = detect_arch(physical_mmio_device_ids.at(0));
for (int dev_index = 1; dev_index < physical_mmio_device_ids.size(); dev_index++) {
chip_id_t device_id = physical_mmio_device_ids.at(dev_index);
tt::ARCH detected_arch = detect_arch(device_id);
auto cluster_desc = tt_ClusterDescriptor::create();
const std::unordered_set<chip_id_t> &device_ids = cluster_desc->get_all_chips();
tt::ARCH arch = cluster_desc->get_arch(*device_ids.begin());
for (auto device_id : device_ids) {
tt::ARCH detected_arch = cluster_desc->get_arch(device_id);
TT_FATAL(
arch == detected_arch,
"Expected all devices to be {} but device {} is {}",
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/common/metal_soc_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ void metal_SocDescriptor::update_pcie_cores(const BoardType& board_type) {
return;
}
switch (board_type) {
case DEFAULT: { // Workaround for BHs running FW that does not return board type in the cluster yaml
case UNKNOWN: { // Workaround for BHs running FW that does not return board type in the cluster yaml
this->pcie_cores = {CoreCoord(11, 0)};
} break;
case P150A: {
Expand Down
17 changes: 9 additions & 8 deletions tt_metal/llrt/get_platform_architecture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "tt_metal/common/tt_backend_api_types.hpp"
#include "tt_metal/common/assert.hpp"
#include "umd/device/tt_cluster_descriptor.h"
#include "umd/device/cluster.h"

namespace tt::tt_metal {
Expand Down Expand Up @@ -47,8 +48,8 @@ namespace tt::tt_metal {
* @endcode
*
* @see tt::get_arch_from_string
* @see tt::umd::Cluster::detect_available_device_ids
* @see detect_arch
* @see tt_ClusterDescriptor::detect_arch
* @see tt_ClusterDescriptor::get_arch
*/
inline tt::ARCH get_platform_architecture() {
auto arch = tt::ARCH::Invalid;
Expand All @@ -57,12 +58,12 @@ inline tt::ARCH get_platform_architecture() {
TT_FATAL(arch_env, "ARCH_NAME env var needed for VCS");
arch = tt::get_arch_from_string(arch_env);
} else {
std::vector<chip_id_t> physical_mmio_device_ids = tt::umd::Cluster::detect_available_device_ids();
if (!physical_mmio_device_ids.empty()) {
arch = detect_arch(physical_mmio_device_ids.at(0));
for (int i = 1; i < physical_mmio_device_ids.size(); ++i) {
chip_id_t device_id = physical_mmio_device_ids.at(i);
tt::ARCH detected_arch = detect_arch(device_id);
auto cluster_desc = tt_ClusterDescriptor::create();
if (cluster_desc->get_number_of_chips() > 0) {
const std::unordered_set<chip_id_t> &device_ids = cluster_desc->get_all_chips();
arch = cluster_desc->get_arch(*device_ids.begin());
for (auto device_id : device_ids) {
tt::ARCH detected_arch = cluster_desc->get_arch(device_id);
TT_FATAL(
arch == detected_arch,
"Expected all devices to be {} but device {} is {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ operation::ProgramWithCallbacks Prod_op::create_program(
Tensor prod_all(const Tensor& input, const MemoryConfig& output_mem_config) {
Tensor result = ttnn::tiled_prod(
operation::run(Prod_op{.output_mem_config = output_mem_config}, {input}).at(0), output_mem_config);
auto arch_env = detect_arch();
auto arch_env = tt_ClusterDescriptor::detect_arch((chip_id_t)0);
if (arch_env == tt::ARCH::WORMHOLE_B0) {
return ttnn::numpy::prod_result_computation_WH_B0<bfloat16>(
result, result.get_dtype(), result.get_layout(), result.device(), output_mem_config);
Expand Down

0 comments on commit a2a5e40

Please sign in to comment.