Skip to content

Commit

Permalink
#12371: Migrate moreh_getitem operation from tt_eager to ttnn (#…
Browse files Browse the repository at this point in the history
…12372)

#12371: un-deprecate moreh_getitem
thd1007 authored Sep 12, 2024
1 parent 3fe27e9 commit 157cb8a
Showing 23 changed files with 660 additions and 611 deletions.
Original file line number Diff line number Diff line change
@@ -71,7 +71,7 @@ def test_getitem_RAW_MJOR_one_index(shape_index_dim, dtype, index_size, device):
elif index_dim == 4:
tt_cpu = x[:, :, :, :, idx]

tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, [dev_idx], [index_dim])
tt_npu = ttnn.operations.moreh.getitem(dev_x, [dev_idx], [index_dim])

assert list(tt_npu.get_legacy_shape()) == list(tt_cpu.shape)
tt_dev = tt_npu.cpu().to_torch()
@@ -132,7 +132,7 @@ def test_getitem_RAW_MAJOR_two_indices(shape_index_dims, dtype, index_size, devi
tt_cpu = x[:, indices[0], indices[1]]
if index_dims == (2, 3):
tt_cpu = x[:, :, indices[0], indices[1]]
tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims)
tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims)

assert list(tt_npu.get_legacy_shape()) == list(tt_cpu.shape)
tt_dev = tt_npu.cpu().to_torch()
@@ -191,7 +191,7 @@ def test_getitem_RAW_MAJOR_three_indices(shape_index_dims, dtype, index_size, de
tt_cpu = x[indices[0], indices[1], indices[2]]
if index_dims == (1, 2, 3):
tt_cpu = x[:, indices[0], indices[1], indices[2]]
tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims)
tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims)

assert list(tt_npu.get_legacy_shape()) == list(tt_cpu.shape)
tt_dev = tt_npu.cpu().to_torch()
@@ -300,7 +300,7 @@ def test_getitem_tilized_one_index(shape_index_dim, dtype, index_size, row_major
elif index_dim == 4:
tt_cpu = x[:, :, :, :, idx]

tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, [dev_idx], [index_dim])
tt_npu = ttnn.operations.moreh.getitem(dev_x, [dev_idx], [index_dim])
tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT)

cpu_5d_shape = to_output_5d_shape(shape, [index_dim], index_size)
@@ -392,7 +392,7 @@ def test_getitem_tilized_two_indices(shape_index_dims, dtype, index_size, row_ma
if index_dims == (3, 4):
tt_cpu = x[:, :, :, indices[0], indices[1]]

tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims)
tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims)
tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT)

output_5d_shape = to_output_5d_shape(shape, index_dims, index_size)
@@ -478,7 +478,7 @@ def test_getitem_tilized_three_indices(shape_index_dims, dtype, index_size, row_
if index_dims == (2, 3, 4):
tt_cpu = x[:, :, indices[0], indices[1], indices[2]]

tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims)
tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims)
tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT)

output_5d_shape = to_output_5d_shape(shape, index_dims, index_size)
@@ -559,7 +559,7 @@ def test_getitem_tilized_four_indices(shape_index_dims, dtype, index_size, row_m
if index_dims == (1, 2, 3, 4):
tt_cpu = x[:, indices[0], indices[1], indices[2], indices[3]]

tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims)
tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims)
tt_npu = tt_npu.cpu().to(ttnn.Layout.ROW_MAJOR)

output_5d_shape = to_output_5d_shape(shape, index_dims, index_size)
@@ -634,7 +634,7 @@ def test_getitem_tilized_five_indices(shape_index_dims, dtype, index_size, row_m

tt_cpu = x[indices[0], indices[1], indices[2], indices[3], indices[4]]

tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims)
tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims)
tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT)

output_5d_shape = to_output_5d_shape(shape, index_dims, index_size)
6 changes: 6 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -342,6 +342,12 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp

${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp
)

# Split src and python bindings
3 changes: 0 additions & 3 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -77,9 +77,6 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/moreh_cumsum/moreh_cumsum_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sgd/moreh_sgd_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sgd/moreh_sgd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_getitem/moreh_getitem_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp

CACHE INTERNAL "tt_dnn sources to reuse in ttnn build"
)

This file was deleted.

This file was deleted.

Loading

0 comments on commit 157cb8a

Please sign in to comment.