Skip to content

Commit

Permalink
#7571: Update index module for dim 1, 0 and test files
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed May 17, 2024
1 parent 43b9e07 commit fd6379b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
(torch.Size([5, 4, 3, 20])),
),
)
@pytest.mark.parametrize("dim", (3, 2, -1, -2))
@pytest.mark.parametrize("dim", (3, 2, 1, 0, -1, -2, -3, -4))
@pytest.mark.parametrize("all", (True, False))
class TestArgmax:
def test_argmax(self, input_shapes, dim, all, device):
torch.manual_seed(10)

# input_data = torch.tensor([5,3,2,1,10,9,6,7,4,8]).reshape(input_shapes).bfloat16()
input_data = torch.randn(input_shapes).bfloat16()
input_tensor = (
tt_lib.tensor.Tensor(input_data, tt_lib.tensor.DataType.BFLOAT16)
Expand All @@ -34,16 +32,14 @@ def test_argmax(self, input_shapes, dim, all, device):
.to(device)
)
tt_output_tensor_on_device = tt_lib.tensor.argmax(input_tensor, dim=dim, all=all)

tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
# tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR)
if all:
golden_tensor = torch.argmax(input_data)
tt_out_tensor = tt_out_tensor[0, 0, 0, 0]
else:
golden_tensor = torch.argmax(input_data, dim=dim)
if dim == 1 or dim == -3 or dim == 0 or dim == -4:
tt_out_tensor = tt_out_tensor[0]
tt_out_tensor = tt_out_tensor[0, :, 0 : input_shapes[2], 0 : input_shapes[3]]
else:
if input_shapes[1] != 1 or input_shapes[0] != 1:
if dim == 2 or dim == -2:
Expand Down
12 changes: 5 additions & 7 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1570,7 +1570,6 @@ Tensor create_mask(const Tensor& input_a, const MemoryConfig& output_mem_config)
{
auto& padded_shape = input_a.get_legacy_shape();
auto& unpadded_shape = padded_shape.without_padding();

if (padded_shape == unpadded_shape)
return input_a;
float t_inf = -std::numeric_limits<float>::infinity();
Expand All @@ -1579,17 +1578,16 @@ Tensor create_mask(const Tensor& input_a, const MemoryConfig& output_mem_config)
return masked_input;
}
// Argmax returns the index of maximum element in the tensor
Tensor _argmax(const Tensor& input, int64_t _dim, bool all, const MemoryConfig& output_mem_config) {
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input}))};
Tensor _argmax(const Tensor& input_t, int64_t _dim, bool all, const MemoryConfig& output_mem_config) {
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_t}))};
operation::launch_with_autoformat(
[_dim, all, output_mem_config] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
const auto& input_a = input_tensors.at(0);
auto& input_shape = input_a.get_legacy_shape();
const auto& input = input_tensors.at(0);
auto& input_shape = input.get_legacy_shape();
TT_FATAL(input_shape.rank() == 4, "supported for rank-4 tensors at this time");

Tensor input_a = create_mask(input, output_mem_config);


uint32_t dim = input_shape.get_normalized_index(_dim);
int size = input_a.volume();

Expand Down Expand Up @@ -1687,7 +1685,7 @@ Tensor _argmax(const Tensor& input, int64_t _dim, bool all, const MemoryConfig&
max_indices.deallocate();
result = global_min(result, output_mem_config);
return {result};
}, {input}, output_tensors);
}, {input_t}, output_tensors);
return output_tensors.at(0);
}

Expand Down
21 changes: 11 additions & 10 deletions tt_eager/tt_numpy/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,6 @@ static Tensor index_all(
}// dim H
index = index + ((shape[penultimate] - up_shape[penultimate]) * TILE_WIDTH);
} // dim C
//index = index + ((shape[rank - 3] - up_shape[rank - 3]) * TILE_WIDTH * TILE_HEIGHT);
} // dim N
auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, Layout::ROW_MAJOR).to(layout);
if (device != nullptr) {
Expand All @@ -390,7 +389,6 @@ static Tensor mask_padded_input(
auto owned_buffer = tt_metal::owned_buffer::create<T>(tt_metal::compute_volume(padded_shape));

auto index = 0;
//auto value = 0;
auto rank = padded_shape.rank();
auto penultimate = rank - 2;
auto ultimate = rank - 1;
Expand Down Expand Up @@ -576,24 +574,27 @@ static Tensor index_batch(
Device* device = nullptr,
const MemoryConfig& output_mem_config = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) {
auto owned_buffer = tt_metal::owned_buffer::create<T>(tt_metal::compute_volume(shape));

auto owned_buffer = tt_metal::owned_buffer::create<T>(tt_metal::compute_volume(shape));
std::fill(owned_buffer.begin(), owned_buffer.end(), -std::numeric_limits<float>::infinity());
auto& up_shape = shape.without_padding();
auto index = 0;
auto value = 0;
auto rank = shape.rank();
auto rank = up_shape.rank();
auto penultimate = rank - 2;
auto ultimate = rank - 1;
for (uint32_t b = 0; b < shape[rank - 4]; b++) {
for (uint32_t c = 0; c < shape[rank - 3]; c++) {
for (uint32_t y = 0; y < shape[penultimate]; y++) {
for (uint32_t x = 0; x < shape[ultimate]; x++) {
for (uint32_t b = 0; b < up_shape[rank - 4]; b++) {
for (uint32_t c = 0; c < up_shape[rank - 3]; c++) {
for (uint32_t y = 0; y < up_shape[penultimate]; y++) {
for (uint32_t x = 0; x < up_shape[ultimate]; x++) {
owned_buffer[index++] = T(static_cast<float>(value));
} // dim W
index = index + (shape[ultimate] - up_shape[ultimate]);
} // dim H
index = index + ((shape[penultimate] - up_shape[penultimate]) * TILE_WIDTH);
} // dim C
value = value + 1;
} // dim N

} // dim N
auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, Layout::ROW_MAJOR).to(layout);
if (device != nullptr) {
output = output.to(device, output_mem_config);
Expand Down

0 comments on commit fd6379b

Please sign in to comment.