Skip to content

Commit

Permalink
#4686: fix get device for tensors not on device
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Feb 26, 2024
1 parent 0564b8e commit df729c9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
12 changes: 8 additions & 4 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ inline Tensor matmul(
std::optional<const DataType> output_dtype = std::nullopt,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt
) {
auto kernel_config_val = init_device_compute_kernel_config(input_tensor_a.device()->arch(), compute_kernel_config);
auto arch = input_tensor_a.storage_type() == StorageType::DEVICE ? input_tensor_a.device()->arch() : AutoFormat::GetDefaultDevice()->arch();
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config);
return operation::run(Matmul{program_config, mem_config, output_dtype.value_or(input_tensor_a.dtype()), kernel_config_val}, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0);
}

Expand All @@ -314,7 +315,8 @@ inline Tensor matmul(
std::optional<const DataType> output_dtype = std::nullopt,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt
) {
auto kernel_config_val = init_device_compute_kernel_config(input_tensor_a.device()->arch(), compute_kernel_config);
auto arch = input_tensor_a.storage_type() == StorageType::DEVICE ? input_tensor_a.device()->arch() : AutoFormat::GetDefaultDevice()->arch();
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config);
return operation::run(Matmul{program_config, mem_config, output_dtype.value_or(input_tensor_a.dtype()), kernel_config_val}, {input_tensor_a, input_tensor_b}, {bias}).at(0);
}

Expand Down Expand Up @@ -365,7 +367,8 @@ inline Tensor matmul (const Tensor &input_tensor_a, const Tensor &input_tensor_b
TT_FATAL(input_tensor_a.shape()[3] == input_tensor_b.shape()[2] && "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in bmm_op"); // A.K == B.K
TT_FATAL(input_tensor_b.shape()[0]*input_tensor_b.shape()[1] == 1 && "matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN");

auto kernel_config_val = init_device_compute_kernel_config(input_tensor_a.device()->arch(), compute_kernel_config, MathFidelity::HiFi4);
auto arch = input_tensor_a.storage_type() == StorageType::DEVICE ? input_tensor_a.device()->arch() : AutoFormat::GetDefaultDevice()->arch();
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4);

// TODO: Uplift interleaved path to call tt::operation::primary::Matmul and deprecate old tt::tt_metal::Matmul
if (input_tensor_a.is_sharded()) {
Expand All @@ -387,7 +390,8 @@ inline Tensor bmm (const Tensor &input_tensor_a, const Tensor &input_tensor_b
TT_FATAL(input_tensor_a.shape()[1] == input_tensor_b.shape()[1] && input_tensor_a.shape()[0] == input_tensor_b.shape()[0]
&& "bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN");

auto kernel_config_val = init_device_compute_kernel_config(input_tensor_a.device()->arch(), compute_kernel_config, MathFidelity::HiFi4);
auto arch = input_tensor_a.storage_type() == StorageType::DEVICE ? input_tensor_a.device()->arch() : AutoFormat::GetDefaultDevice()->arch();
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4);

if (input_tensor_a.is_sharded()) {
auto matmul_program_config = bmm_op_utils::get_matmul_program_config(input_tensor_a, input_tensor_b, mem_config, std::nullopt, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,17 @@ struct AttnMatmul {
};

inline Tensor attn_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, const CoreCoord& compute_with_storage_grid_size, const MemoryConfig& mem_config, std::optional<const DataType> output_dtype=std::nullopt, std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt) {
auto kernel_config_val = init_device_compute_kernel_config(input_tensor_a.device()->arch(), compute_kernel_config);
auto arch = input_tensor_a.storage_type() == StorageType::DEVICE ? input_tensor_a.device()->arch() : AutoFormat::GetDefaultDevice()->arch();
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config);
return operation::run(AttnMatmul{std::nullopt, std::nullopt, compute_with_storage_grid_size, mem_config, output_dtype.value_or(input_tensor_a.dtype()), kernel_config_val}, {input_tensor_a, input_tensor_b}).at(0);
}

inline Tensor attn_matmul_from_cache(const Tensor &input_tensor_a, const Tensor &input_tensor_b, const uint32_t num_tokens, const bool transpose_hw, const CoreCoord& compute_with_storage_grid_size, const MemoryConfig& mem_config, std::optional<const DataType> output_dtype=std::nullopt, std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt) {
TT_FATAL(num_tokens > 0, "Number of tokens must be at least 1!");
TT_FATAL(num_tokens <= input_tensor_b.shape()[2], "Number of tokens must be smaller or equal to the max cache length (B.shape[2])!");
const uint32_t num_tokens_rounded_up_to_32 = ((num_tokens - 1) / 32 + 1) * 32;
auto kernel_config_val = init_device_compute_kernel_config(input_tensor_a.device()->arch(), compute_kernel_config);
auto arch = input_tensor_a.storage_type() == StorageType::DEVICE ? input_tensor_a.device()->arch() : AutoFormat::GetDefaultDevice()->arch();
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config);
return operation::run(AttnMatmul{num_tokens_rounded_up_to_32, transpose_hw, compute_with_storage_grid_size, mem_config, output_dtype.value_or(input_tensor_a.dtype()), kernel_config_val}, {input_tensor_a, input_tensor_b}).at(0);
}

Expand Down Expand Up @@ -120,7 +122,8 @@ inline Tensor group_attn_matmul(const Tensor &input_tensor_a, const Tensor &inpu
}
}

auto kernel_config_val = init_device_compute_kernel_config(input_tensor_a.device()->arch(), compute_kernel_config);
auto arch = input_tensor_a.storage_type() == StorageType::DEVICE ? input_tensor_a.device()->arch() : AutoFormat::GetDefaultDevice()->arch();
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config);

// Need to cache on out_subblock_w because it must be a compile time arg for optimal use of templated pack_untilize APIs
const uint32_t Nt = input_tensor_b.shape()[-1] / TILE_WIDTH;
Expand Down

0 comments on commit df729c9

Please sign in to comment.