diff --git a/.github/workflows/remove-stale-branches.yaml b/.github/workflows/remove-stale-branches.yaml index de6f3121d39..ffd59b5bad1 100644 --- a/.github/workflows/remove-stale-branches.yaml +++ b/.github/workflows/remove-stale-branches.yaml @@ -2,7 +2,7 @@ name: "[internal] Remove Stale Branches" on: schedule: - - cron: "0 0 * * *" # Runs every night at midnight + - cron: "0 */6 * * *" # Runs at midnight, 6AM, noon, 6pm workflow_dispatch: # Allows manual trigger jobs: diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 18bc919ccf0..40371321e11 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -14,6 +14,9 @@ on: - ccl.line_all_gather - ccl.all_gather_n300 - ccl.all_gather_n300_focused + - creation.zeros.zeros + - creation.empty.empty + - creation.zeros_like.zeros_like - eltwise.unary.abs.abs_pytorch2 - eltwise.unary.relu.relu - eltwise.unary.relu.relu_pytorch2 @@ -52,6 +55,8 @@ on: - eltwise.unary.expm1.expm1 - eltwise.unary.tanh.tanh - eltwise.unary.tanh.tanh_pytorch2 + - eltwise.unary.atanh.atanh + - eltwise.unary.atan.atan - eltwise.unary.sign.sign - eltwise.unary.rad2deg.rad2deg - eltwise.unary.deg2rad.deg2rad @@ -98,6 +103,8 @@ on: - eltwise.unary.identity.identity - eltwise.unary.neg.neg - eltwise.unary.sinh.sinh + - eltwise.unary.asinh.asinh + - eltwise.unary.cosh.cosh - eltwise.unary.relu_min.relu_min - eltwise.unary.relu_max.relu_max - eltwise.unary.softplus.softplus diff --git a/CODEOWNERS b/CODEOWNERS index beb2eb6452f..ef6bf8072f2 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -120,6 +120,7 @@ tests/sweep_framework/ @xanderchin @jdesousa-TT @sjameelTT tests/sweep_framework/sweeps tests/sweep_framework/sweeps/eltwise/ @patrickroberts @yan-zaretskiy @eyonland tests/sweep_framework/sweeps/conv2d/ @nkpatel-tt @mywoodstock @shwetankTT @sankarmanoj-tt @pavlejosipovic +tests/sweep_framework/sweeps/data_movement/ @sjameelTT @ntarafdar @jaykru-tt @yugi957 # TTNN Distributed ttnn/cpp/ttnn/distributed/ @cfjchu @ayerofieiev-tt @dmakoviichuk-tt diff --git a/Doxyfile b/Doxyfile index 8f33e44c003..93eca4ad46d 100644 --- a/Doxyfile +++ b/Doxyfile @@ -499,7 +499,12 @@ NUM_PROC_THREADS = 1 # normally produced when WARNINGS is set to YES. # The default value is: NO. -EXTRACT_ALL = NO +EXTRACT_ALL = YES + +EXTRACT_NAMESPACES = YES + +INLINE_NAMESPACES = YES + # If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will # be included in the documentation. @@ -946,7 +951,17 @@ INPUT = tt_metal/hw/inc/dataflow_api.h \ tt_metal/include/compute_kernel_api/transpose_wh.h \ tt_metal/include/compute_kernel_api/untilize.h \ tt_metal/include/compute_kernel_api.h \ - tt_metal/impl/kernels/kernel_args.hpp + tt_metal/impl/kernels/kernel_args.hpp \ + tt_metal/include/tt_metal/metal.hpp \ + tt_metal/include/tt_metal/types.hpp \ + tt_metal/include/tt_metal/buffer.hpp \ + tt_metal/include/tt_metal/command_queue.hpp \ + tt_metal/include/tt_metal/device.hpp \ + tt_metal/include/tt_metal/event.hpp \ + tt_metal/include/tt_metal/kernel.hpp \ + tt_metal/include/tt_metal/program.hpp \ + tt_metal/include/tt_metal/trace.hpp + # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/AssignGlobalBufferToProgram.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/AssignGlobalBufferToProgram.rst index 75ec7b84aa3..5be958a386e 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/AssignGlobalBufferToProgram.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/AssignGlobalBufferToProgram.rst @@ -1,5 +1,5 @@ AssignGlobalBufferToProgram -=========================== +=============================== -.. doxygenfunction:: AssignGlobalBufferToProgram(std::shared_ptr buffer, Program& program) +.. doxygenfunction:: tt::tt_metal::v0::AssignGlobalBufferToProgram(std::shared_ptr buffer, Program& program) diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CircularBuffers.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CircularBuffers.rst index 43b73866cb6..761cf6249e0 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CircularBuffers.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CircularBuffers.rst @@ -1,12 +1,17 @@ CircularBuffers ================ -.. doxygenfunction:: CreateCircularBuffer +.. doxygenfunction:: tt::tt_metal::v0::CreateCircularBuffer -.. doxygenfunction:: GetCircularBufferConfig +.. doxygenfunction:: tt::tt_metal::v0::GetCircularBufferConfig -.. doxygenfunction:: UpdateCircularBufferTotalSize +.. doxygenfunction:: tt::tt_metal::v0::UpdateCircularBufferTotalSize -.. doxygenfunction:: UpdateCircularBufferPageSize +.. doxygenfunction:: tt::tt_metal::v0::UpdateCircularBufferPageSize -.. doxygenfunction:: UpdateDynamicCircularBufferAddress +.. doxygenfunction:: tt::tt_metal::v0::UpdateDynamicCircularBufferAddress + + +Version 1 +------------------------- +.. doxygenfunction:: tt::tt_metal::v1::CreateCircularBuffer diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CreateBuffer.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CreateBuffer.rst index 6a1848c3cd0..2d05db10694 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CreateBuffer.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CreateBuffer.rst @@ -1,5 +1,5 @@ CreateBuffer ================= -.. doxygenfunction:: CreateBuffer(const InterleavedBufferConfig & config); -.. doxygenfunction:: CreateBuffer(const ShardedBufferConfig & config); +.. doxygenfunction:: tt::tt_metal::v0::CreateBuffer(const InterleavedBufferConfig & config); +.. doxygenfunction:: tt::tt_metal::v0::CreateBuffer(const ShardedBufferConfig & config); diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CreateSemaphore.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CreateSemaphore.rst index 452d47cd991..eb740c4169f 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CreateSemaphore.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/CreateSemaphore.rst @@ -1,4 +1,4 @@ CreateSemaphore ================ -.. doxygenfunction:: CreateSemaphore(Program &program, const std::variant &core_spec, uint32_t initial_value, CoreType core_type) +.. doxygenfunction:: tt::tt_metal::v0::CreateSemaphore(Program &program, const std::variant &core_spec, uint32_t initial_value, CoreType core_type) diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/DeallocateBuffer.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/DeallocateBuffer.rst index 91683085250..53ddd2255a1 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/DeallocateBuffer.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/buffers/DeallocateBuffer.rst @@ -1,4 +1,4 @@ DeallocateBuffer ================= -.. doxygenfunction:: DeallocateBuffer +.. doxygenfunction:: tt::tt_metal::v0::DeallocateBuffer diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/BeginTraceCapture.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/BeginTraceCapture.rst index 2b09d801459..b314937ed24 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/BeginTraceCapture.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/BeginTraceCapture.rst @@ -1,4 +1,4 @@ BeginTraceCapture ================= -.. doxygenfunction:: BeginTraceCapture +.. doxygenfunction:: tt::tt_metal::v0::BeginTraceCapture diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EndTraceCapture.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EndTraceCapture.rst index c00dc574e46..de80b2dc152 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EndTraceCapture.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EndTraceCapture.rst @@ -1,4 +1,4 @@ EndTraceCapture =============== -.. doxygenfunction:: EndTraceCapture +.. doxygenfunction:: tt::tt_metal::v0::EndTraceCapture diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueProgram.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueProgram.rst index ebe9c16a676..936c7bea641 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueProgram.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueProgram.rst @@ -1,4 +1,4 @@ EnqueueProgram ============== -.. doxygenfunction:: EnqueueProgram(CommandQueue& cq, Program& program, bool blocking) +.. doxygenfunction:: tt::tt_metal::v0::EnqueueProgram(CommandQueue& cq, Program& program, bool blocking) diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueReadBuffer.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueReadBuffer.rst index 4fff0c3fe3e..6f7b9929086 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueReadBuffer.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueReadBuffer.rst @@ -1,5 +1,5 @@ EnqueueReadBuffer ================== -.. doxygenfunction:: EnqueueReadBuffer(CommandQueue& cq, std::variant, std::shared_ptr > buffer, std::vector& dst, bool blocking) -.. doxygenfunction:: EnqueueReadBuffer(CommandQueue& cq, std::variant, std::shared_ptr > buffer, void * dst, bool blocking) +.. doxygenfunction:: tt::tt_metal::v0::EnqueueReadBuffer(CommandQueue& cq, std::variant, std::shared_ptr > buffer, std::vector& dst, bool blocking) +.. doxygenfunction:: tt::tt_metal::v0::EnqueueReadBuffer(CommandQueue& cq, std::variant, std::shared_ptr > buffer, void * dst, bool blocking) diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueRecordEvent.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueRecordEvent.rst index 697aeab2d2e..005c8a05387 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueRecordEvent.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueRecordEvent.rst @@ -1,4 +1,4 @@ EnqueueRecordEvent ================== -.. doxygenfunction:: EnqueueRecordEvent +.. doxygenfunction:: tt::tt_metal::v0::EnqueueRecordEvent diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueTrace.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueTrace.rst index 88f6989ff57..a696da0d52f 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueTrace.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueTrace.rst @@ -1,4 +1,4 @@ EnqueueTrace ============ -.. doxygenfunction:: EnqueueTrace(CommandQueue &cq, uint32_t trace_id, bool blocking) +.. doxygenfunction:: tt::tt_metal::v0::EnqueueTrace(CommandQueue &cq, uint32_t trace_id, bool blocking) diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueWaitForEvent.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueWaitForEvent.rst index bbb19e16704..5d2160f41e0 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueWaitForEvent.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueWaitForEvent.rst @@ -1,4 +1,4 @@ EnqueueWaitForEvent =================== -.. doxygenfunction:: EnqueueWaitForEvent +.. doxygenfunction:: tt::tt_metal::v0::EnqueueWaitForEvent diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueWriteBuffer.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueWriteBuffer.rst index 7a7413c89c1..3b48f8b7b4a 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueWriteBuffer.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EnqueueWriteBuffer.rst @@ -1,5 +1,5 @@ EnqueueWriteBuffer ================== -.. doxygenfunction:: EnqueueWriteBuffer(CommandQueue& cq, std::variant, std::shared_ptr > buffer, std::vector& src, bool blocking) -.. doxygenfunction:: EnqueueWriteBuffer(CommandQueue& cq, std::variant, std::shared_ptr > buffer, HostDataType src, bool blocking) +.. doxygenfunction:: tt::tt_metal::v0::EnqueueWriteBuffer(CommandQueue& cq, std::variant, std::shared_ptr > buffer, std::vector& src, bool blocking) +.. doxygenfunction:: tt::tt_metal::v0::EnqueueWriteBuffer(CommandQueue& cq, std::variant, std::shared_ptr > buffer, HostDataType src, bool blocking) diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EventQuery.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EventQuery.rst index 1af02761a7e..abc9327b240 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EventQuery.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EventQuery.rst @@ -1,4 +1,4 @@ EventQuery ========== -.. doxygenfunction:: EventQuery +.. doxygenfunction:: tt::tt_metal::v0::EventQuery diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EventSynchronize.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EventSynchronize.rst index c621536b09f..0b3374fc25c 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EventSynchronize.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/EventSynchronize.rst @@ -1,4 +1,4 @@ EventSynchronize ================ -.. doxygenfunction:: EventSynchronize +.. doxygenfunction:: tt::tt_metal::v0::EventSynchronize diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/Finish.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/Finish.rst index 88bb6d667b2..521b17b0822 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/Finish.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/Finish.rst @@ -3,4 +3,4 @@ Finish ====== -.. doxygenfunction:: Finish(CommandQueue& cq) +.. doxygenfunction:: tt::tt_metal::v0::Finish(CommandQueue& cq) diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/ReleaseTrace.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/ReleaseTrace.rst index e61728730bc..cffa870e4f7 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/ReleaseTrace.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/ReleaseTrace.rst @@ -1,4 +1,4 @@ ReleaseTrace ============ -.. doxygenfunction:: ReleaseTrace +.. doxygenfunction:: tt::tt_metal::v0::ReleaseTrace diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/ReplayTrace.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/ReplayTrace.rst index 457f5d42c4c..c47d10020c7 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/ReplayTrace.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/ReplayTrace.rst @@ -1,4 +1,4 @@ ReplayTrace =========== -.. doxygenfunction:: ReplayTrace +.. doxygenfunction:: tt::tt_metal::v0::ReplayTrace diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/Synchronize.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/Synchronize.rst index c071ec028fd..d49ab68575b 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/Synchronize.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/Synchronize.rst @@ -3,4 +3,4 @@ Synchronize =========== -.. doxygenfunction:: Synchronize +.. doxygenfunction:: tt::tt_metal::v0::Synchronize diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/CloseDevice.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/CloseDevice.rst index 4aef6573de1..345374bb133 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/CloseDevice.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/CloseDevice.rst @@ -2,4 +2,4 @@ CloseDevice ============= -.. doxygenfunction:: CloseDevice +.. doxygenfunction:: tt::tt_metal::v0::CloseDevice diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/CreateDevice.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/CreateDevice.rst index e5891a6bfb3..59eca837261 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/CreateDevice.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/CreateDevice.rst @@ -1,4 +1,4 @@ CreateDevice ============= -.. doxygenfunction:: CreateDevice +.. doxygenfunction:: tt::tt_metal::v0::CreateDevice diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/QueryDevices.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/QueryDevices.rst index 4cfa2a53188..9691f9aaf6e 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/QueryDevices.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/device_management/QueryDevices.rst @@ -1,6 +1,6 @@ QueryDevices ============= -.. doxygenfunction:: GetNumAvailableDevices +.. doxygenfunction:: tt::tt_metal::v0::GetNumAvailableDevices -.. doxygenfunction:: GetNumPCIeDevices +.. doxygenfunction:: tt::tt_metal::v0::GetNumPCIeDevices diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/kernels/CreateKernel.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/kernels/CreateKernel.rst index a48b895e4db..cd8b26ebfb5 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/kernels/CreateKernel.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/kernels/CreateKernel.rst @@ -1,4 +1,4 @@ CreateKernel ==================== -.. doxygenfunction:: CreateKernel +.. doxygenfunction:: tt::tt_metal::v0::CreateKernel diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/kernels/CreateKernelFromString.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/kernels/CreateKernelFromString.rst index 35a7d2672a4..ab13479a66f 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/kernels/CreateKernelFromString.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/kernels/CreateKernelFromString.rst @@ -1,4 +1,4 @@ CreateKernelFromString ======================= -.. doxygenfunction:: CreateKernelFromString +.. doxygenfunction:: tt::tt_metal::v0::CreateKernelFromString diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/profiler/DumpDeviceProfileResults.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/profiler/DumpDeviceProfileResults.rst index 45800c406e2..d8631775c3f 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/profiler/DumpDeviceProfileResults.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/profiler/DumpDeviceProfileResults.rst @@ -3,4 +3,4 @@ DumpDeviceProfileResults ======================== -.. doxygenfunction:: DumpDeviceProfileResults(Device *device, const Program &program); +.. doxygenfunction:: tt::tt_metal::v0::DumpDeviceProfileResults(Device *device, const Program &program); diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/program/CreateProgram.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/program/CreateProgram.rst index 2419fdf9415..3052decc265 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/program/CreateProgram.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/program/CreateProgram.rst @@ -1,4 +1,4 @@ CreateProgram ======================== -.. doxygenfunction:: CreateProgram() +.. doxygenfunction:: tt::tt_metal::v0::CreateProgram() diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst index a269e88e034..7a4e6f79417 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst @@ -1,18 +1,18 @@ Runtime Arguments ================== -.. doxygenfunction:: SetRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::variant &logical_core, stl::Span runtime_args) +.. doxygenfunction:: tt::tt_metal::v0::SetRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::variant &logical_core, stl::Span runtime_args) -.. doxygenfunction:: SetRuntimeArgs(const Program &program, KernelHandle kernel, const std::vector< CoreCoord > & core_spec, const std::vector< std::vector > &runtime_args) +.. doxygenfunction:: tt::tt_metal::v0::SetRuntimeArgs(const Program &program, KernelHandle kernel, const std::vector< CoreCoord > & core_spec, const std::vector< std::vector > &runtime_args) -.. doxygenfunction:: SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const std::variant &core_spec, std::shared_ptr runtime_args) +.. doxygenfunction:: tt::tt_metal::v0::SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const std::variant &core_spec, std::shared_ptr runtime_args) -.. doxygenfunction:: SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const std::vector< CoreCoord > & core_spec, const std::vector> runtime_args) +.. doxygenfunction:: tt::tt_metal::v0::SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const std::vector< CoreCoord > & core_spec, const std::vector> runtime_args) -.. doxygenfunction:: GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) +.. doxygenfunction:: tt::tt_metal::v0::GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) -.. doxygenfunction:: GetRuntimeArgs(const Program &program, KernelHandle kernel_id) +.. doxygenfunction:: tt::tt_metal::v0::GetRuntimeArgs(const Program &program, KernelHandle kernel_id) -.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, stl::Span runtime_args) +.. doxygenfunction:: tt::tt_metal::v0::SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, stl::Span runtime_args) -.. doxygenfunction:: GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) +.. doxygenfunction:: tt::tt_metal::v0::GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) diff --git a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/cb_pages_available_at_front.rst b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/cb_pages_available_at_front.rst new file mode 100644 index 00000000000..5c9ca56ac3b --- /dev/null +++ b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/cb_pages_available_at_front.rst @@ -0,0 +1,4 @@ +cb_pages_available_at_front +=========================== + +.. doxygenfunction:: cb_pages_available_at_front(int32_t operand, int32_t num_pages) diff --git a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/cb_pages_reservable_at_back.rst b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/cb_pages_reservable_at_back.rst new file mode 100644 index 00000000000..a81a7095c76 --- /dev/null +++ b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/cb_pages_reservable_at_back.rst @@ -0,0 +1,4 @@ +cb_pages_reservable_at_back +=========================== + +.. doxygenfunction:: cb_pages_reservable_at_back(int32_t operand, int32_t num_pages) diff --git a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/circular_buffers.rst b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/circular_buffers.rst index a7b73db848f..e50285489ad 100644 --- a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/circular_buffers.rst +++ b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/circular_buffers/circular_buffers.rst @@ -4,7 +4,9 @@ Circular Buffer APIs Circular buffers are used for communication between threads of the Tensix core. They act as limited capacity double-ended queues with producers pushing tiles to the back of the queue and consumers popping tiles off the front of the queue. .. toctree:: + cb_pages_available_at_front cb_wait_front + cb_pages_reservable_at_back cb_reserve_back cb_push_back cb_pop_front diff --git a/docs/source/tt-metalium/tt_metal/examples/eltwise_binary.rst b/docs/source/tt-metalium/tt_metal/examples/eltwise_binary.rst index 1b5dbf6ed39..749834b28d7 100644 --- a/docs/source/tt-metalium/tt_metal/examples/eltwise_binary.rst +++ b/docs/source/tt-metalium/tt_metal/examples/eltwise_binary.rst @@ -37,19 +37,19 @@ We already have set the circular buffers needed for compute data communication. constexpr uint32_t num_input_tiles = 2; constexpr uint32_t input_cb_size = num_input_tiles * single_tile_size; CircularBufferConfig cb_src0_config = CircularBufferConfig(input_cb_size, {{src0_cb_index, tt::DataFormat::Float16_b}}, src0_cb_addr).set_page_size(src0_cb_index, single_tile_size); - CBHandle cb_src0 = CreateCircularBuffer(program, core, cb_src0_config); + CBHandle cb_src0 = v0::CreateCircularBuffer(program, core, cb_src0_config); constexpr uint32_t src1_cb_index = CB::c_in1; constexpr uint32_t src1_cb_addr = 300 * 1024; CircularBufferConfig cb_src1_config = CircularBufferConfig(input_cb_size, {{src1_cb_index, tt::DataFormat::Float16_b}}, src1_cb_addr).set_page_size(src1_cb_index, single_tile_size); - CBHandle cb_src1 = CreateCircularBuffer(program, core, cb_src1_config); + CBHandle cb_src1 = v0::CreateCircularBuffer(program, core, cb_src1_config); constexpr uint32_t output_cb_index = CB::c_out0; constexpr uint32_t output_cb_addr = 400 * 1024; constexpr uint32_t num_output_tiles = 2; constexpr uint32_t input_cb_size = num_input_tiles * single_tile_size; CircularBufferConfig cb_output_config = CircularBufferConfig(input_cb_size, {{output_cb_index, tt::DataFormat::Float16_b}}, output_cb_addr).set_page_size(output_cb_index, single_tile_size); - CBHandle cb_output = CreateCircularBuffer(program, core, cb_output); + CBHandle cb_output = v0::CreateCircularBuffer(program, core, cb_output); We will create two input circular buffers to accommodate our two input tensors, and an output one for the result of the eltwise binary operation. diff --git a/docs/source/tt-metalium/tt_metal/examples/eltwise_sfpu.rst b/docs/source/tt-metalium/tt_metal/examples/eltwise_sfpu.rst index 91f1af4869f..09640cbd571 100644 --- a/docs/source/tt-metalium/tt_metal/examples/eltwise_sfpu.rst +++ b/docs/source/tt-metalium/tt_metal/examples/eltwise_sfpu.rst @@ -34,12 +34,12 @@ compute, and writer engines. constexpr uint32_t src0_cb_index = CB::c_in0; constexpr uint32_t num_input_tiles = 2; CircularBufferConfig cb_src0_config = CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, tt::DataFormat::Float16_b}}).set_page_size(src0_cb_index, single_tile_size); - CBHandle cb_src0 = tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + CBHandle cb_src0 = tt_metal::v0::CreateCircularBuffer(program, core, cb_src0_config); constexpr uint32_t output_cb_index = CB::c_out0; constexpr uint32_t num_output_tiles = 2; CircularBufferConfig cb_output_config = CircularBufferConfig(num_output_tiles * single_tile_size, {{output_cb_index, tt::DataFormat::Float16_b}}).set_page_size(output_cb_index, single_tile_size); - CBHandle cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); + CBHandle cb_output = tt_metal::v0::CreateCircularBuffer(program, core, cb_output_config); We will create one input circular buffers to accommodate our input tensor, and an output one for the result of the eltwise sfpu operation. diff --git a/docs/source/tt-metalium/tt_metal/examples/matmul_multi_core_optimizations/data_mcast.rst b/docs/source/tt-metalium/tt_metal/examples/matmul_multi_core_optimizations/data_mcast.rst index 07efc583b05..710f70f4167 100644 --- a/docs/source/tt-metalium/tt_metal/examples/matmul_multi_core_optimizations/data_mcast.rst +++ b/docs/source/tt-metalium/tt_metal/examples/matmul_multi_core_optimizations/data_mcast.rst @@ -111,13 +111,13 @@ Recall in our data reuse example, we created our L1 circular buffers for all the .. code-block:: cpp - auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); + auto cb_output = tt_metal::v0::CreateCircularBuffer(program, all_cores, cb_output_config); METALIUM also allows us to pass all of our CoreRanges defined above through a ``CoreRangeSet(...)`` function call as the 2nd argument. Let's do so with the following: .. code-block:: cpp - auto cb_output = tt_metal::CreateCircularBuffer(program, CoreRangeSet({all_cores}), cb_output_config); + auto cb_output = tt_metal::v0::CreateCircularBuffer(program, CoreRangeSet({all_cores}), cb_output_config); In fact, you can instantiate circular buffers on any one of these three options: ``const std::variant``. Please refer to the CircularBuffers page for further details. diff --git a/docs/source/tt-metalium/tt_metal/examples/matmul_multi_core_optimizations/data_reuse.rst b/docs/source/tt-metalium/tt_metal/examples/matmul_multi_core_optimizations/data_reuse.rst index 3c194ebdc89..dd86afdcae3 100644 --- a/docs/source/tt-metalium/tt_metal/examples/matmul_multi_core_optimizations/data_reuse.rst +++ b/docs/source/tt-metalium/tt_metal/examples/matmul_multi_core_optimizations/data_reuse.rst @@ -49,7 +49,7 @@ In addition to our double-buffer config, we introduce a third circular buffer de CircularBufferConfig cb_output_config = CircularBufferConfig(out_CB_size, output_cb_data_format_spec) .set_page_size(output_cb_index, single_tile_size) .set_page_size(interm0_cb_index, single_tile_size); - auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); + auto cb_output = tt_metal::v0::CreateCircularBuffer(program, all_cores, cb_output_config); Stride Kernel Arguments ----------------------- diff --git a/docs/source/tt-metalium/tt_metal/examples/matmul_single_core.rst b/docs/source/tt-metalium/tt_metal/examples/matmul_single_core.rst index 105f1790c3b..3c6984e009b 100644 --- a/docs/source/tt-metalium/tt_metal/examples/matmul_single_core.rst +++ b/docs/source/tt-metalium/tt_metal/examples/matmul_single_core.rst @@ -151,18 +151,18 @@ double buffering.. uint32_t num_input_tiles = 2; tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}}) .set_page_size(src0_cb_index, single_tile_size); - auto cb_src0 = tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + auto cb_src0 = tt_metal::v0::CreateCircularBuffer(program, core, cb_src0_config); uint32_t src1_cb_index = CB::c_in1; // 1 tt_metal::CircularBufferConfig cb_src1_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src1_cb_index, cb_data_format}}) .set_page_size(src1_cb_index, single_tile_size); - auto cb_src1 = tt_metal::CreateCircularBuffer(program, core, cb_src1_config); + auto cb_src1 = tt_metal::v0::CreateCircularBuffer(program, core, cb_src1_config); uint32_t output_cb_index = CB::c_out0; // output operands start at index 16 uint32_t num_output_tiles = 2; tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{output_cb_index, cb_data_format}}) .set_page_size(output_cb_index, single_tile_size); - auto cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); + auto cb_output = tt_metal::v0::CreateCircularBuffer(program, core, cb_output_config); Compile-time kernels arguments ------------------------------ diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 30137696eec..e91e24ac397 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -420,6 +420,7 @@ Data Movement ttnn.reshape ttnn.repeat ttnn.repeat_interleave + ttnn.slice ttnn.tilize ttnn.tilize_with_val_padding ttnn.fill_rm @@ -467,7 +468,6 @@ Transformer ttnn.experimental.rotary_embedding ttnn.transformer.scaled_dot_product_attention ttnn.transformer.scaled_dot_product_attention_decode - ttnn.transformer.scaled_dot_product_attention_decode_gqa CCL === diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index 1d328d9985d..d640e13d707 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -516,7 +516,7 @@ def run_llama3_demo(user_input, batch_size, mesh_device, instruct_mode, is_ci_en "instruct_weights-long", ], ) -@pytest.mark.parametrize("device_params", [{"trace_region_size": 5560320, "num_command_queues": 2}], indirect=True) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 5700000, "num_command_queues": 2}], indirect=True) @pytest.mark.parametrize( "mesh_device", [ diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index cc0b49e9742..be84481d38f 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -69,13 +69,13 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, (32, True): "llama32_11b", }[(model_args.n_layers, model_args.is_vision())] - final_model_pcc = {"llama32_1b": 0.9991, "llama32_3b": 0.9989, "llama31_8b": 0.9976, "llama32_11b": 0.9976}[ + final_model_pcc = {"llama32_1b": 0.99912, "llama32_3b": 0.99898, "llama31_8b": 0.99888, "llama32_11b": 0.9976}[ model_name ] - final_k_cache_pcc = {"llama32_1b": 0.9998, "llama32_3b": 0.9998, "llama31_8b": 0.9995, "llama32_11b": 0.9995}[ + final_k_cache_pcc = {"llama32_1b": 0.99984, "llama32_3b": 0.99980, "llama31_8b": 0.99983, "llama32_11b": 0.9995}[ model_name ] - final_v_cache_pcc = {"llama32_1b": 0.9996, "llama32_3b": 0.9998, "llama31_8b": 0.9996, "llama32_11b": 0.9996}[ + final_v_cache_pcc = {"llama32_1b": 0.99984, "llama32_3b": 0.99982, "llama31_8b": 0.99984, "llama32_11b": 0.9996}[ model_name ] quick_iterations = {"llama32_1b": 2, "llama32_3b": 4, "llama31_8b": 6, "llama32_11b": 6}[model_name] diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index 3b204164ab5..a1ac038740c 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -279,19 +279,17 @@ def forward_decode( values, cur_pos_tensor=current_pos, page_table_tensor=page_table, - transpose_q=False, scale=self.scale, program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"], memory_config=ttnn.DRAM_MEMORY_CONFIG, ) else: - attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode_gqa( + attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode( q_heads_1BQD, keys, values, cur_pos_tensor=current_pos, - transpose_q=False, scale=self.scale, program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"], diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index b2bae0fbab7..c044a343399 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -244,7 +244,7 @@ def forward_decode( q_heads_1B4D, keys_1BPD, values_1BPD, - start_pos_ids, + cur_pos=start_pos_ids, scale=self.scale, program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"], diff --git a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py index 80d2a57d187..75ecb8d2451 100644 --- a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py @@ -360,7 +360,7 @@ def attn_mqa( query_layer, keys, values, - [start_pos for _ in range(self.max_batch_size)], + cur_pos=[start_pos for _ in range(self.max_batch_size)], scale=self.scale, program_config=program_config, compute_kernel_config=self.attention_config["COMPUTE_KERNEL_SDPA"], diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index ff94196d37e..267e4e097bf 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -635,13 +635,6 @@ def __init__( width=self.conv1_output_width, in_channels=self.conv1_input_channels, out_channels=self.conv1_output_channels, - kernel_size=[self.conv1_kernel_size[0], self.conv1_kernel_size[1]], - stride=[self.conv1_stride[0], self.conv1_stride[1]], - padding=[self.conv1_padding[0], self.conv1_padding[1]], - dilation=[1, 1], - groups=1, - weights_width=self.conv1_weight_tensor.shape[3], - input_width=self.conv1_input_width, ) def __del__(self): diff --git a/models/demos/wormhole/llama31_8b/demo/demo_trace.py b/models/demos/wormhole/llama31_8b/demo/demo_trace.py index dad009e7cfd..35978ca980c 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo_trace.py +++ b/models/demos/wormhole/llama31_8b/demo/demo_trace.py @@ -585,7 +585,7 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num "instruct_weights-3_batch", ], ) -@pytest.mark.parametrize("device_params", [{"trace_region_size": 7943168, "num_command_queues": 2}], indirect=True) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 8000000, "num_command_queues": 2}], indirect=True) def test_llama_demo( device, use_program_cache, input_prompts, instruct_weights, is_ci_env, is_single_card_n300, num_batches ): diff --git a/models/demos/wormhole/mamba/tests/test_mamba_perf.py b/models/demos/wormhole/mamba/tests/test_mamba_perf.py index b515c52e345..1f9a2c2ab13 100644 --- a/models/demos/wormhole/mamba/tests/test_mamba_perf.py +++ b/models/demos/wormhole/mamba/tests/test_mamba_perf.py @@ -143,11 +143,11 @@ def test_mamba_perf_e2e( @pytest.mark.models_device_performance_bare_metal @pytest.mark.parametrize( "batch, expected_layer_duration_ms", - ((32, 1.689),), + ((32, 1.670),), ) def test_mamba_perf_device(batch, expected_layer_duration_ms): subdir = "ttnn_mamba" - margin = 0.01 + margin = 0.015 command = f"pytest models/demos/wormhole/mamba/tests/test_mamba_model.py::test_device_perf[1]" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] diff --git a/models/demos/wormhole/mamba/tt/mamba_block.py b/models/demos/wormhole/mamba/tt/mamba_block.py index e3d333ae9f2..c895a15a17d 100644 --- a/models/demos/wormhole/mamba/tt/mamba_block.py +++ b/models/demos/wormhole/mamba/tt/mamba_block.py @@ -197,7 +197,8 @@ def forward(self, x): for i in range(0, 4): slice_start = (0, 0, x_ssm.shape[2] - (4 - i), 0) slice_end = (1, 1, (x_ssm.shape[2] - (4 - i)) + 1, self.args.d_inner) - entry = ttnn.slice(x_ssm, slice_start, slice_end) + step = (1, 1, 1, 1) + entry = ttnn.slice(x_ssm, starts=slice_start, ends=slice_end, steps=step) self.convolution_cache.set(self.configs["current_user"], i, entry) ttnn.deallocate(entry) diff --git a/models/demos/wormhole/mamba/tt/mamba_conv.py b/models/demos/wormhole/mamba/tt/mamba_conv.py index c4dd0d961ef..a2700198f83 100644 --- a/models/demos/wormhole/mamba/tt/mamba_conv.py +++ b/models/demos/wormhole/mamba/tt/mamba_conv.py @@ -75,9 +75,11 @@ def prepare_input(self, input_tensor): input_tensor_splits = [] split_size = self.config.input_channels // self.config.channels_split_factor for i in range(self.config.channels_split_factor): - slice_start = ttnn.Shape((0, 0, 0, i * split_size)) - slice_end = ttnn.Shape((1, self.config.input_length, 1, (i + 1) * split_size)) - input_tensor_splits.append(ttnn.slice(input_tensor, slice_start, slice_end)) + slice_start = (0, 0, 0, i * split_size) + slice_end = (1, self.config.input_length, 1, (i + 1) * split_size) + input_tensor_splits.append( + ttnn.slice(input_tensor, starts=slice_start, ends=slice_end, steps=(1, 1, 1, 1)) + ) ttnn.deallocate(input_tensor) return input_tensor_splits diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 45b297d0054..70c0a9a4705 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -266,7 +266,7 @@ def __init__( output_height=self.conv2.input_height, output_width=self.conv2.input_width, output_channels=self.conv1.out_channels, - device=device, + compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, ) @@ -320,7 +320,7 @@ def __init__( output_height=self.conv2.input_height, output_width=self.conv2.input_width, output_channels=self.conv1.out_channels, - device=device, + compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, ) @@ -448,7 +448,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: output_height=self.bnc2.input_height, output_width=self.bnc2.input_width, output_channels=self.bnc.out_channels, - device=device, + compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, ) diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index 198ad658af3..52a417526b1 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -20,6 +20,7 @@ # Contains following params # [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation] [1, 32, 1, 28, 28, 3, 3, 1, 1, 0, 0, 1, True, 1], + [1, 100, 100, 14, 14, 3, 3, 1, 1, 1, 1, 100, False, 1], [1, 1008, 1008, 14, 14, 3, 3, 2, 2, 1, 1, 21, False, 1], [1, 1008, 1008, 7, 7, 3, 3, 1, 1, 1, 1, 21, False, 1], [1, 1024, 1024, 10, 10, 3, 3, 1, 1, 1, 1, 1024, False, 1], @@ -454,25 +455,16 @@ def test_conv2d_localrun(device, input_spec): # [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation] # Input is 32MB maps to MM 64 cores, we neeed to avoid sharding this tensor and use dram intrelaved directly with MM [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 6 - [1, 1024, 1024, 19, 19, 1, 1, 1, 1, 0, 0, 1, True, 1], # 9 - [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, True, 1], # 11 [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1], # 14 [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 15 [1, 192, 192, 99, 99, 5, 5, 2, 2, 0, 0, 192, False, 1], # 100 [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1], # 141 [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 170 [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 171 - [1, 1024, 3, 224, 224, 16, 16, 16, 16, 0, 0, 1, True, 1], # 172 [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 173 - [1, 768, 3, 224, 224, 16, 16, 16, 16, 0, 0, 1, True, 1], # 181 [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 182 [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 183 - [1, 32, 3, 299, 299, 3, 3, 2, 2, 0, 0, 1, False, 1], # 192 - [1, 32, 3, 381, 381, 3, 3, 2, 2, 0, 0, 1, False, 1], # 197 [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 199 - [1, 192, 3, 512, 672, 16, 16, 16, 16, 0, 0, 1, True, 1], # 202 - [1, 1280, 3, 518, 518, 14, 14, 14, 14, 0, 0, 1, True, 1], # 203 - [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 204 [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 205 [1, 336, 336, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], # 241 [1, 336, 336, 48, 48, 5, 5, 1, 1, 2, 2, 336, False, 1], # 245 diff --git a/tests/sweep_framework/sweeps/creation/empty/empty.py b/tests/sweep_framework/sweeps/creation/empty/empty.py new file mode 100644 index 00000000000..cce64282e0f --- /dev/null +++ b/tests/sweep_framework/sweeps/creation/empty/empty.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "batch_sizes": [(1, 2), (3, 6)], + "height": [384, 1024], + "width": [1024, 4096], + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b, ttnn.float32], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + }, +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["layout"] == ttnn.ROW_MAJOR_LAYOUT and test_vector["input_dtype"] == ttnn.bfloat8_b: + return True, "Skipped as ROW_MAJOR_LAYOUT and ttnn.bfloat8_b not supported" + return False, None + + +def check_output(torch_output_tensor, output_tensor): + status = list(torch_output_tensor.shape) == list(output_tensor.shape) + msg = "" + msg = "pass" if status else "fail" + + return status, msg + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + batch_sizes, + height, + width, + input_dtype, + output_memory_config, + layout, + *, + device, +) -> list: + torch.manual_seed(0) + + input_shape = (*batch_sizes, height, width) + + torch_output_tensor = torch.empty(input_shape) + + start_time = start_measuring_time() + + output_tensor = ttnn.empty(input_shape, input_dtype, layout, device=device, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_output(torch_output_tensor, output_tensor), e2e_perf] diff --git a/tests/sweep_framework/sweeps/creation/zeros/zeros.py b/tests/sweep_framework/sweeps/creation/zeros/zeros.py new file mode 100644 index 00000000000..586f7e57cba --- /dev/null +++ b/tests/sweep_framework/sweeps/creation/zeros/zeros.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "batch_sizes": [(1, 2), (3, 6)], + "height": [384, 1024], + "width": [1024, 4096], + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b, ttnn.float32], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + }, +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_dtype"] == ttnn.bfloat8_b: + return True, "Skipped as ROW_MAJOR_LAYOUT and ttnn.bfloat8_b not supported" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + batch_sizes, + height, + width, + input_dtype, + output_memory_config, + layout, + *, + device, +) -> list: + torch.manual_seed(0) + + input_shape = (*batch_sizes, height, width) + + torch_output_tensor = torch.zeros(input_shape) + + start_time = start_measuring_time() + + output_tensor = ttnn.zeros(input_shape, input_dtype, layout, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/creation/zeros_like/zeros_like.py b/tests/sweep_framework/sweeps/creation/zeros_like/zeros_like.py new file mode 100644 index 00000000000..88aef5601d1 --- /dev/null +++ b/tests/sweep_framework/sweeps/creation/zeros_like/zeros_like.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [2, 6, 256, 256], [1, 1, 320, 320], 128), + "input_dtype": [ttnn.bfloat16, ttnn.float32, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_dtype"] == ttnn.bfloat8_b: + return True, "Skipped as bfloat8_b dtype not supported" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + torch.manual_seed(0) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype + )(input_shape) + + torch_output_tensor = torch.zeros_like(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.zeros_like(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/data_movement/slice/slice_pytorch2_tiled.py b/tests/sweep_framework/sweeps/data_movement/slice/slice_pytorch2_tiled.py index b79e211bdcf..e35a28c1a5f 100644 --- a/tests/sweep_framework/sweeps/data_movement/slice/slice_pytorch2_tiled.py +++ b/tests/sweep_framework/sweeps/data_movement/slice/slice_pytorch2_tiled.py @@ -172,7 +172,7 @@ {"dims": [8732, 4], "dim": 1, "start": 0, "end": -1, "step": 4}, {"dims": [8732, 4], "dim": 1, "start": 0, "end": 2}, ], - "dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "dtype": [ttnn.bfloat16], "layout": [ttnn.TILE_LAYOUT], } } diff --git a/tests/sweep_framework/sweeps/eltwise/unary/asinh/asinh.py b/tests/sweep_framework/sweeps/eltwise/unary/asinh/asinh.py new file mode 100644 index 00000000000..0a316463c78 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/asinh/asinh.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 32), + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + } +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT or test_vector["input_dtype"] == ttnn.bfloat8_b: + return True, "ROW_MAJOR_LAYOUT and ttnn.bfloat8_b are not supported" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_dtype, + input_layout, + input_memory_config, + output_memory_config, + *, + device, +) -> list: + torch.manual_seed(0) + + torch_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.asinh) + torch_output_tensor = golden_function(torch_input_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=input_dtype, + layout=input_layout, + device=device, + memory_config=input_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.asinh(input_tensor, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/atan/atan.py b/tests/sweep_framework/sweeps/eltwise/unary/atan/atan.py new file mode 100644 index 00000000000..be5eab3e04d --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/atan/atan.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 32), + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + } +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "ROW_MAJOR_LAYOUT is not supported" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_dtype, + input_layout, + input_memory_config, + output_memory_config, + *, + device, +) -> list: + torch.manual_seed(0) + + torch_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.atan) + torch_output_tensor = golden_function(torch_input_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=input_dtype, + layout=input_layout, + device=device, + memory_config=input_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.atan(input_tensor, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/atanh/atanh.py b/tests/sweep_framework/sweeps/eltwise/unary/atanh/atanh.py new file mode 100644 index 00000000000..2494927780f --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/atanh/atanh.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 32), + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + } +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT or test_vector["input_dtype"] == ttnn.bfloat8_b: + return True, "ROW_MAJOR_LAYOUT and ttnn.bfloat8_b are not supported" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_dtype, + input_layout, + input_memory_config, + output_memory_config, + *, + device, +) -> list: + torch.manual_seed(0) + + torch_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-0.9, high=0.9, dtype=torch.float32), input_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.atanh) + torch_output_tensor = golden_function(torch_input_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=input_dtype, + layout=input_layout, + device=device, + memory_config=input_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.atanh(input_tensor, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/cosh/cosh.py b/tests/sweep_framework/sweeps/eltwise/unary/cosh/cosh.py new file mode 100644 index 00000000000..2c01010ede0 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/cosh/cosh.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 32), + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + } +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT or test_vector["input_dtype"] == ttnn.bfloat8_b: + return True, "ROW_MAJOR_LAYOUT and ttnn.bfloat8_b are not supported" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_dtype, + input_layout, + input_memory_config, + output_memory_config, + *, + device, +) -> list: + torch.manual_seed(0) + + torch_input_tensor = gen_func_with_cast_tt(partial(torch_random, low=-9, high=9, dtype=torch.float32), input_dtype)( + input_shape + ) + + golden_function = ttnn.get_golden_function(ttnn.cosh) + torch_output_tensor = golden_function(torch_input_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=input_dtype, + layout=input_layout, + device=device, + memory_config=input_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.cosh(input_tensor, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/sinh/sinh.py b/tests/sweep_framework/sweeps/eltwise/unary/sinh/sinh.py index f9421005c69..fa08d9ea184 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/sinh/sinh.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/sinh/sinh.py @@ -6,7 +6,6 @@ from functools import partial import torch -import random import ttnn from tests.sweep_framework.sweep_utils.utils import gen_shapes from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt @@ -14,10 +13,6 @@ from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time from models.utility_functions import torch_random -# Override the default timeout in seconds for hang detection. -TIMEOUT = 30 - -random.seed(0) # Parameters provided to the test vector generator are defined here. # They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. @@ -28,23 +23,32 @@ "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + gen_shapes([32, 32], [256, 256], [32, 32], 32), - "input_a_dtype": [ttnn.bfloat16], - "input_a_layout": [ttnn.TILE_LAYOUT], - "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], "use_safe_nums": [True], }, "xfail": { "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 32), - "input_a_dtype": [ttnn.bfloat16], - "input_a_layout": [ttnn.TILE_LAYOUT], - "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_dtype": [ttnn.bfloat16], + "input_layout": [ttnn.TILE_LAYOUT], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], "use_safe_nums": [False], }, } +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT or test_vector["input_dtype"] == ttnn.bfloat8_b: + return True, "ROW_MAJOR_LAYOUT and ttnn.bfloat8_b are not supported" + return False, None + + # This is the run instructions for the test, defined by the developer. # The run function must take the above-defined parameters as inputs. # The runner will call this run function with each test vector, and the returned results from this function will be stored. @@ -52,37 +56,37 @@ def run( use_safe_nums, input_shape, - input_a_dtype, - input_a_layout, - input_a_memory_config, + input_dtype, + input_layout, + input_memory_config, output_memory_config, *, device, ) -> list: - data_seed = random.randint(0, 20000000) - torch.manual_seed(data_seed) + torch.manual_seed(0) if use_safe_nums is True: - torch_input_tensor_a = gen_func_with_cast_tt( - partial(torch_random, low=-9, high=9, dtype=torch.float32), input_a_dtype + torch_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-9, high=9, dtype=torch.float32), input_dtype )(input_shape) else: - torch_input_tensor_a = gen_func_with_cast_tt( - partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + torch_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype )(input_shape) - torch_output_tensor = torch.sinh(torch_input_tensor_a) + golden_function = ttnn.get_golden_function(ttnn.sinh) + torch_output_tensor = golden_function(torch_input_tensor) - input_tensor_a = ttnn.from_torch( - torch_input_tensor_a, - dtype=input_a_dtype, - layout=input_a_layout, + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=input_dtype, + layout=input_layout, device=device, - memory_config=input_a_memory_config, + memory_config=input_memory_config, ) start_time = start_measuring_time() - result = ttnn.sinh(input_tensor_a, memory_config=output_memory_config) + result = ttnn.sinh(input_tensor, memory_config=output_memory_config) output_tensor = ttnn.to_torch(result) e2e_perf = stop_measuring_time(start_time) diff --git a/tests/sweep_framework/sweeps/eltwise/unary/tanh/tanh.py b/tests/sweep_framework/sweeps/eltwise/unary/tanh/tanh.py index a7b2903668b..6ed4f25ac2c 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/tanh/tanh.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/tanh/tanh.py @@ -6,7 +6,6 @@ from functools import partial import torch -import random import ttnn from tests.sweep_framework.sweep_utils.utils import gen_shapes from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt @@ -14,10 +13,6 @@ from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time from models.utility_functions import torch_random -# Override the default timeout in seconds for hang detection. -TIMEOUT = 30 - -random.seed(0) # Parameters provided to the test vector generator are defined here. # They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. @@ -28,45 +23,55 @@ "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + gen_shapes([32, 32], [256, 256], [32, 32], 32), - "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], - "input_a_layout": [ttnn.TILE_LAYOUT], - "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], }, } +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "ROW_MAJOR_LAYOUT is not supported" + return False, None + + # This is the run instructions for the test, defined by the developer. # The run function must take the above-defined parameters as inputs. # The runner will call this run function with each test vector, and the returned results from this function will be stored. # If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. def run( input_shape, - input_a_dtype, - input_a_layout, - input_a_memory_config, + input_dtype, + input_layout, + input_memory_config, output_memory_config, *, device, ) -> list: - data_seed = random.randint(0, 20000000) - torch.manual_seed(data_seed) + torch.manual_seed(0) - torch_input_tensor_a = gen_func_with_cast_tt( - partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + torch_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype )(input_shape) - torch_output_tensor = torch.tanh(torch_input_tensor_a) - input_tensor_a = ttnn.from_torch( - torch_input_tensor_a, - dtype=input_a_dtype, - layout=input_a_layout, + golden_function = ttnn.get_golden_function(ttnn.tanh) + torch_output_tensor = golden_function(torch_input_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=input_dtype, + layout=input_layout, device=device, - memory_config=input_a_memory_config, + memory_config=input_memory_config, ) start_time = start_measuring_time() - result = ttnn.tanh(input_tensor_a, memory_config=output_memory_config) + result = ttnn.tanh(input_tensor, memory_config=output_memory_config) output_tensor = ttnn.to_torch(result) e2e_perf = stop_measuring_time(start_time) diff --git a/tests/sweep_framework/sweeps/max_pool2d/short/max_pool2d_short_sweep.py b/tests/sweep_framework/sweeps/max_pool2d/short/max_pool2d_short_sweep.py index 50be876d545..c0b90358064 100644 --- a/tests/sweep_framework/sweeps/max_pool2d/short/max_pool2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/max_pool2d/short/max_pool2d_short_sweep.py @@ -132,7 +132,7 @@ def run( output_height=out_h, output_width=out_w, output_channels=in_c, - device=device, + compute_grid_size=device.compute_with_storage_grid_size(), is_out_tiled=False, ) sharded_memory_config = ttnn._ttnn.operations.conv2d.create_sharded_memory_config_from_parallel_config( diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 3eec6eef4b1..4f6bc6ce792 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -1288,7 +1288,7 @@ def clip( **kwargs, ): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttnn.clip(t0, min=low, max=high, memory_config=output_mem_config) + t1 = ttnn.clip(t0, low, high, memory_config=output_mem_config) return tt2torch_tensor(t1) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py index 170753f68ea..0783e8c464e 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py @@ -48,18 +48,23 @@ def num_to_corerange(x): ) -def get_chunk_size(s): - if s <= 32: - return 32 - if s <= 64: - return 32 - if s <= 128: - return 32 - if s <= 256: - return 256 - if s <= 2048: - return 512 - return 512 +def get_chunk_size(max_start_pos, s): + if max_start_pos <= 32: + chunk_size = 32 + elif max_start_pos <= 64: + chunk_size = 32 + elif max_start_pos <= 128: + chunk_size = 32 + elif max_start_pos <= 1024: + chunk_size = 128 + else: + chunk_size = 512 + # find maximum power of 2 divisor of s + for i in range(1, s): + if s % (2 ** (i + 1)) != 0: + break + chunk_size = min(chunk_size, 2**i) + return chunk_size def fa_rand(*shape): @@ -217,7 +222,7 @@ def run_test_sdpa_decode_multi_pos( scale = d**-0.5 start_indices = np.linspace(0, max_start_idx, b, dtype=np.int32).tolist() if b > 1 else [max_start_idx] - k_chunk_size = get_chunk_size(max_start_idx + 1) + k_chunk_size = get_chunk_size(max_start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, # device.compute_with_storage_grid_size(), q_chunk_size=padded_num_heads, @@ -265,7 +270,7 @@ def run_test_sdpa_decode_multi_pos( tt_Q, tt_K, tt_V, - start_indices, + cur_pos=start_indices, scale=scale, program_config=program_config, compute_kernel_config=compute_kernel_config, @@ -290,7 +295,7 @@ def run_test_sdpa_decode_multi_pos( expect = torch.nn.functional.scaled_dot_product_attention( Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) + expect = expect.squeeze(2).unsqueeze(0) out_pass, out_pcc = comp_pcc(expect, tt_back, min_pcc) @@ -315,6 +320,7 @@ def run_test_sdpa_decode_single_iter( sharded_in=False, sharded_out=False, start_indices=None, + causal=True, ): compute_grid_size = device.compute_with_storage_grid_size() if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: @@ -355,7 +361,7 @@ def run_test_sdpa_decode_single_iter( max_start_idx = max(start_indices) scale = d**-0.5 - k_chunk_size = get_chunk_size(max_start_idx + 1) + k_chunk_size = get_chunk_size(max_start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, q_chunk_size=padded_num_heads, @@ -363,20 +369,29 @@ def run_test_sdpa_decode_single_iter( exp_approx_mode=False, ) - padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) + padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) if causal else s # Test various sequence lengths - logger.debug(f"Testing with sequence length: {max_start_idx}") + logger.debug(f"Testing with sequence length: {max_start_idx if causal else s}") logger.debug(f"Using chunk size: {k_chunk_size}") logger.debug(f"Using padded layer length: {padded_layer_len}") logger.debug(f"Using padded num heads: {padded_num_heads}") - attn_mask = torch.zeros((b, padded_num_heads, 1, padded_layer_len)) - for i in range(b): - start_idx = start_indices[i] - attn_mask[i, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min + if causal: + attn_mask = torch.zeros((b, nh, 1, padded_layer_len)) + for i in range(b): + start_idx = start_indices[i] + attn_mask[i, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min + else: + attn_mask = torch.bernoulli( + torch.full( + (b, nh, 1, padded_layer_len), + 0.25, + ) + ) + attn_mask = attn_mask * torch.finfo(torch.float32).min - Q = fa_rand(1, b, padded_num_heads, d) + Q = fa_rand(1, b, nh, d) tt_Q = ttnn.as_tensor( Q[:, :, :nh], @@ -385,24 +400,44 @@ def run_test_sdpa_decode_single_iter( layout=ttnn.TILE_LAYOUT, memory_config=height_sharded_memcfg if sharded_in else dram_memcfg, ) - if cur_pos_tensor: - start_indices_tt = ttnn.Tensor(torch.tensor(start_indices), ttnn.int32).to(device) - tt_back = ttnn.transformer.scaled_dot_product_attention_decode( - tt_Q, - tt_K, - tt_V, - cur_pos_tensor=start_indices_tt, - scale=scale, - program_config=program_config, - compute_kernel_config=compute_kernel_config, - memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, - ) + if causal: + if cur_pos_tensor: + start_indices_tt = ttnn.Tensor(torch.tensor(start_indices), ttnn.int32).to(device) + tt_back = ttnn.transformer.scaled_dot_product_attention_decode( + tt_Q, + tt_K, + tt_V, + cur_pos_tensor=start_indices_tt, + scale=scale, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, + ) + else: + tt_back = ttnn.transformer.scaled_dot_product_attention_decode( + tt_Q, + tt_K, + tt_V, + cur_pos=start_indices, + scale=scale, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, + ) else: + tt_mask = ttnn.as_tensor( + attn_mask.transpose(1, 2).contiguous(), + device=device, + dtype=ttnn.bfloat4_b, + layout=ttnn.TILE_LAYOUT, + memory_config=dram_memcfg, + ) tt_back = ttnn.transformer.scaled_dot_product_attention_decode( tt_Q, tt_K, tt_V, - start_indices, + is_causal=False, + attn_mask=tt_mask, scale=scale, program_config=program_config, compute_kernel_config=compute_kernel_config, @@ -425,7 +460,7 @@ def run_test_sdpa_decode_single_iter( expect = torch.nn.functional.scaled_dot_product_attention( Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) + expect = expect.squeeze(2).unsqueeze(0) non_skip_indices = torch.tensor(start_indices) != -1 out_pass, out_pcc = comp_pcc(expect[:, non_skip_indices], tt_back[:, non_skip_indices], min_pcc) @@ -483,6 +518,38 @@ def test_sdpa_decode( ) +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize( + "dtype, q_dtype", + [ + # [ttnn.bfloat16, ttnn.bfloat16], + [ttnn.bfloat8_b, ttnn.bfloat16], + ], + ids=[ + # "all_bfp16", + "kv_bfp8", + ], +) +@pytest.mark.parametrize( + "b, nh, nkv, s, d, grid_size", + ( + [32, 32, 8, 4224, 128, (8, 8)], # llama3.2 vision encoder on n150 + [8, 16, 4, 4224, 128, (8, 8)], # llama3.2 vision encoder on n300 + [32, 4, 1, 4224, 128, (8, 8)], # llama3.2 vision encoder on n300 + ), +) +def test_sdpa_decode_non_causal(device, b, nh, nkv, s, d, dtype, grid_size, q_dtype, use_program_cache): + if nkv > 1 and q_dtype != ttnn.bfloat16: + pytest.skip("nkv > 1 requires q_dtype to be bfloat16") + + ttnn.device.DisablePersistentKernelCache() + for _ in range(2): + run_test_sdpa_decode_single_iter( + device, b, nh, nkv, s, d, dtype, grid_size, q_dtype, sharded_in=False, sharded_out=False, causal=False + ) + assert device.num_program_cache_entries() == 1 + + @skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") @pytest.mark.parametrize( "dtype, q_dtype", @@ -620,16 +687,20 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc tt_page_table = ttnn.Tensor(page_table, ttnn.int32).to(device) max_start_idx = 0 + causal = True - while max_start_idx < s: + while max_start_idx < s or not causal: scale = d**-0.5 start_indices = np.linspace(max(max_start_idx - b, 0), max_start_idx, b, dtype=np.int32).tolist() # Test when page_table does not contain blocks for full sequence length - last_block = max(1, int(math.ceil((max_start_idx + 1) / block_size))) - tt_page_table = ttnn.Tensor(page_table[:, :last_block], ttnn.int32).to(device) + if causal: + last_block = max(1, int(math.ceil((max_start_idx + 1) / block_size))) + tt_page_table = ttnn.Tensor(page_table[:, :last_block], ttnn.int32).to(device) + else: + tt_page_table = ttnn.Tensor(page_table, ttnn.int32).to(device) - k_chunk_size = get_chunk_size(max_start_idx + 1) + k_chunk_size = get_chunk_size(max_start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, # device.compute_with_storage_grid_size(), q_chunk_size=padded_num_heads, @@ -637,20 +708,31 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc exp_approx_mode=False, ) - padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) + padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) if causal else s # Test various sequence lengths - logger.info(f"Testing with sequence length: {max_start_idx}") + logger.debug( + f"Testing {'causal' if causal else 'non-causal'} with sequence length: {max_start_idx if causal else s}" + ) logger.info(f"Using chunk size: {k_chunk_size}") logger.info(f"Using padded layer length: {padded_layer_len}") logger.info(f"Using padded num heads: {padded_num_heads}") - attn_mask = torch.zeros((b, padded_num_heads, 1, padded_layer_len)) - for i in range(b): - start_idx = start_indices[i] - attn_mask[i, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min + if causal: + attn_mask = torch.zeros((b, padded_num_heads, 1, padded_layer_len)) + for i in range(b): + start_idx = start_indices[i] + attn_mask[i, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min + else: + attn_mask = torch.bernoulli( + torch.full( + (b, nh, 1, padded_layer_len), + 0.25, + ) + ) + attn_mask = attn_mask * torch.finfo(torch.float32).min - Q = fa_rand(1, b, padded_num_heads, d) + Q = fa_rand(1, b, nh, d) tt_Q = ttnn.as_tensor( Q[:, :, :nh], @@ -662,17 +744,38 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc start_indices_tt = ttnn.Tensor(torch.tensor(start_indices), ttnn.int32).to(device) - tt_back = ttnn.transformer.paged_scaled_dot_product_attention_decode( - tt_Q, - tt_K, - tt_V, - cur_pos_tensor=start_indices_tt, - page_table_tensor=tt_page_table, - scale=scale, - program_config=program_config, - compute_kernel_config=compute_kernel_config, - memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, - ) + if causal: + tt_back = ttnn.transformer.paged_scaled_dot_product_attention_decode( + tt_Q, + tt_K, + tt_V, + tt_page_table, + cur_pos_tensor=start_indices_tt, + scale=scale, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, + ) + else: + tt_mask = ttnn.as_tensor( + attn_mask.transpose(1, 2).contiguous(), + device=device, + dtype=ttnn.bfloat4_b, + layout=ttnn.TILE_LAYOUT, + memory_config=dram_memcfg, + ) + tt_back = ttnn.transformer.paged_scaled_dot_product_attention_decode( + tt_Q, + tt_K, + tt_V, + tt_page_table, + is_causal=False, + attn_mask=tt_mask, + scale=scale, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, + ) tt_back = ttnn.to_torch(tt_back) @@ -692,7 +795,7 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc expect = torch.nn.functional.scaled_dot_product_attention( Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) + expect = expect.squeeze(2).unsqueeze(0) out_pass, out_pcc = comp_pcc(expect, tt_back, min_pcc) @@ -701,7 +804,13 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc assert out_pass max_start_idx += 71 if max_start_idx < 4096 else 3001 - # return + + if not causal: + # only run one iteration for non-causal + break + if max_start_idx >= s: + # run last iteration to test non-causal + causal = False @skip_for_blackhole("Unsupported on BH, see #12349") @@ -724,8 +833,8 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc @pytest.mark.parametrize( "b, nh, nkv, s, d, grid_size, cur_pos_tensor", ( - [32, 8, 1, 32768, 128, (8, 6), True], # Llama2-70B - [4, 32, 8, 4096, 128, (8, 8), True], # llama 3.1 8b + # [32, 8, 1, 32768, 128, (8, 6), True], # Llama2-70B + # [4, 32, 8, 4096, 128, (8, 8), True], # llama 3.1 8b # [4, 16, 4, 32768, 128, (8, 8), True], # [32, 32, 8, 4096, 128, (8, 8), True], # llama 3.1 8b [8, 16, 4, 4096, 128, (8, 2), True], # llama 3.1 8b N300 @@ -757,7 +866,7 @@ def test_sdpa_decode_paged_attention( sharded_out=False, ) - assert device.num_program_cache_entries() == 3 + assert device.num_program_cache_entries() == 4 @skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") @@ -985,7 +1094,7 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty scale = d**-0.5 - k_chunk_size = get_chunk_size(start_idx + 1) + k_chunk_size = get_chunk_size(start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, # device.compute_with_storage_grid_size(), q_chunk_size=padded_num_heads, @@ -1013,7 +1122,7 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty expect = torch.nn.functional.scaled_dot_product_attention( Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) + expect = expect.squeeze(2).unsqueeze(0) all_out_pass = True @@ -1030,7 +1139,7 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty tt_Q, tt_K, tt_V, - [start_idx for _ in range(b)], + cur_pos=[start_idx for _ in range(b)], scale=scale, program_config=program_config, compute_kernel_config=compute_kernel_config, diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode_gqa.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode_gqa.py deleted file mode 100644 index 63fbf648c52..00000000000 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode_gqa.py +++ /dev/null @@ -1,332 +0,0 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import os -import torch -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( - comp_allclose, - comp_pcc, - comp_and_get_pcc, -) -import ttnn -from loguru import logger -import pytest -from models.utility_functions import skip_for_grayskull, skip_for_wormhole_b0 -import math -import numpy as np - - -def is_watcher_enabled(): - return os.environ.get("TT_METAL_WATCHER") is not None - - -def nearest_n(x, n): - return ((x + n - 1) // n) * n - - -def nearest_pow_2(x): - if x < 1: - raise ValueError("x must be >= 1") - import math - - power = math.ceil(math.log2(x)) - return 1 << power - # if (2**math.log2(x) == x): - # return x - # return 2**(int(x).bit_length()) - - -def num_to_corerange(x): - assert x < 8 or x % 8 == 0 - num_x = min(x, 8) - num_y = x // num_x - assert num_x * num_y == x - return ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(num_x - 1, num_y - 1), - ) - - -def get_chunk_size(s): - if s <= 32: - return 32 - if s <= 64: - return 32 - if s <= 128: - return 32 - if s <= 256: - return 256 - if s <= 2048: - return 512 - return 512 - - -def fa_rand(*shape): - normal_1 = torch.randn(shape) - normal_2 = torch.randn(shape) * 10 - bernoulli = torch.bernoulli(torch.full(shape, 0.001)) - return normal_1 + normal_2 * bernoulli - - -def run_test_sdpa_decode_single_iter( - device, - b, - nh, - nkv, - s, - d, - dtype, - grid_size, - q_dtype=ttnn.bfloat16, - start_indices=None, - transpose_q=True, - share_cache=False, - cur_pos_tensor=False, -): - compute_grid_size = device.compute_with_storage_grid_size() - if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: - pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}") - - padded_num_heads = nearest_pow_2(nearest_n(nh, n=32)) - torch.manual_seed(1234) - - num_parallel_cores = grid_size[0] * grid_size[1] // b * nkv - if num_parallel_cores == 1: - min_pcc = 0.90 - else: - min_pcc = 0.99 - if q_dtype == ttnn.bfloat8_b: - min_pcc = 0.98 - min_pcc = 0.93 if dtype == ttnn.bfloat4_b else min_pcc - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=False, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - dram_memcfg = ttnn.DRAM_MEMORY_CONFIG - - if share_cache: - K = fa_rand(1, nkv, s, d).repeat(b, 1, 1, 1) - V = fa_rand(1, nkv, s, d).repeat(b, 1, 1, 1) - else: - K = fa_rand(b, nkv, s, d) - V = fa_rand(b, nkv, s, d) - - tt_K = ttnn.as_tensor( - K[:1] if share_cache else K, - device=device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=dram_memcfg, - ) - tt_V = ttnn.as_tensor( - V[:1] if share_cache else V, - device=device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=dram_memcfg, - ) - - start_indices = [s // 2 for _ in range(b * nkv)] if start_indices is None else start_indices - max_start_idx = max(start_indices) - scale = d**-0.5 - - k_chunk_size = get_chunk_size(max_start_idx + 1) - program_config = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=grid_size, - q_chunk_size=padded_num_heads, - k_chunk_size=k_chunk_size, - ) - - padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) - - # Test various sequence lengths - logger.debug(f"Testing with sequence length: {max_start_idx}") - logger.debug(f"Using chunk size: {k_chunk_size}") - logger.debug(f"Using padded layer length: {padded_layer_len}") - - attn_mask = torch.zeros((1, b * nkv, padded_num_heads, padded_layer_len)) - for i in range(b * nkv): - start_idx = start_indices[i] - attn_mask[:, i, :, start_idx + 1 :] = torch.finfo(torch.float32).min - - Q = fa_rand(1, nh, b, d) - - if not transpose_q: - Q = Q.permute(0, 2, 1, 3) - - tt_Q = ttnn.as_tensor( - Q, - device=device, - dtype=q_dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=dram_memcfg, - ) - - if cur_pos_tensor: - start_indices_tt = ttnn.Tensor(torch.tensor(start_indices), ttnn.int32).to(device) - tt_back = ttnn.transformer.scaled_dot_product_attention_decode_gqa( - tt_Q, - tt_K, - tt_V, - cur_pos_tensor=start_indices_tt, - transpose_q=transpose_q, - share_cache=share_cache, - scale=scale, - program_config=program_config, - compute_kernel_config=compute_kernel_config, - memory_config=dram_memcfg, - ) - else: - tt_back = ttnn.transformer.scaled_dot_product_attention_decode_gqa( - tt_Q, - tt_K, - tt_V, - start_indices, - transpose_q=transpose_q, - share_cache=share_cache, - scale=scale, - program_config=program_config, - compute_kernel_config=compute_kernel_config, - memory_config=dram_memcfg, - ) - - tt_back = ttnn.to_torch(tt_back) - - if not transpose_q: - Q_slice = Q.permute(1, 2, 0, 3) # b, nh, 1, d - else: - Q_slice = Q.permute(2, 1, 0, 3) # b, nh, 1, d - K_slice = K[:, :, :padded_layer_len, :] # b, nh, S, d - K_slice = torch.cat( - [K_slice[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1 - ) # b, nh, d, S - V_slice = V[:, :, :padded_layer_len, :] # b, nh, S, d - V_slice = torch.cat( - [V_slice[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1 - ) # b, nh, d, S - - attn_mask_slice = attn_mask[:, :b, :nh, :].permute(1, 2, 0, 3) # b, nh, 1, S - expect = torch.nn.functional.scaled_dot_product_attention( - Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False - ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) - - out_pass, out_pcc = comp_pcc(expect, tt_back, min_pcc) - - logger.debug(f"python vs pytorch: {out_pcc}") - assert out_pass - - -@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") -@pytest.mark.parametrize( - "dtype, q_dtype", - [ - [ttnn.bfloat8_b, ttnn.bfloat16], - ], - ids=[ - "kv_bfp8", - ], -) -@pytest.mark.parametrize( - "b, nh, nkv, s, d, grid_size, single_iter", - ( - [1, 32, 8, 32768, 128, (8, 8), True], # Llama3.1-8B - [2, 32, 8, 32768, 128, (8, 8), True], # Llama3.1-8B - [4, 32, 8, 32768, 128, (8, 8), True], # Llama3.1-8B - [8, 16, 4, 32768, 128, (8, 8), True], # Llama3.1-8B on N300 - ), -) -@pytest.mark.parametrize( - "transpose_q", - (True, False), -) -@pytest.mark.parametrize( - "share_cache", - (True, False), -) -@pytest.mark.parametrize( - "cur_pos_tensor", - (True, False), -) -def test_sdpa_decode( - device, - b, - nh, - nkv, - s, - d, - dtype, - grid_size, - q_dtype, - transpose_q, - single_iter, - share_cache, - cur_pos_tensor, - use_program_cache, -): - ttnn.device.DisablePersistentKernelCache() - run_test_sdpa_decode_single_iter( - device, - b, - nh, - nkv, - s, - d, - dtype, - grid_size, - q_dtype, - transpose_q=transpose_q, - share_cache=share_cache, - cur_pos_tensor=cur_pos_tensor, - ) - - -@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") -@pytest.mark.parametrize( - "dtype", - [ - ttnn.bfloat16, - ], - ids=[ - "bf16", - ], -) -@pytest.mark.parametrize( - "b, nh, nkv, s, d", - ([4, 32, 8, 8192, 128],), # Llama3.1-8B -) -@pytest.mark.parametrize( - "transpose_q", - (True, False), -) -@pytest.mark.parametrize( - "share_cache", - (True, False), -) -def test_sdpa_decode_program_cache(device, b, nh, nkv, s, d, dtype, transpose_q, share_cache, use_program_cache): - ttnn.device.DisablePersistentKernelCache() - - for i in range(2): - run_test_sdpa_decode_single_iter( - device, - b, - nh, - nkv, - s, - d, - dtype, - (8, 8), - dtype, - start_indices=None, - transpose_q=transpose_q, - share_cache=share_cache, - ) - - if transpose_q: - assert device.num_program_cache_entries() == 4 - else: - assert device.num_program_cache_entries() == 1 diff --git a/tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp b/tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp index eaf18c1aada..36909cb681d 100644 --- a/tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp +++ b/tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp @@ -203,7 +203,7 @@ int main(int argc, char **argv) { dm_class_idx, 0, get_latest_kernel_binary_path(mask, riscv0_kernel)); - ll_api::memory brisc_binary = llrt::get_risc_binary(brisc_hex_path, 0, ll_api::memory::PackSpans::PACK, ll_api::memory::Relocate::XIP); + ll_api::memory brisc_binary = llrt::get_risc_binary(brisc_hex_path, 0, llrt::PackSpans::PACK); TT_FATAL( brisc_binary == brisc_binaries.at(mask).at(0), "Expected saved BRISC binary to be the same as binary in persistent cache"); @@ -212,11 +212,7 @@ int main(int argc, char **argv) { dm_class_idx, 1, get_latest_kernel_binary_path(mask, riscv1_kernel)); - ll_api::memory::Relocate relo_type = - (device->arch() == tt::ARCH::GRAYSKULL || device->arch() == tt::ARCH::WORMHOLE_B0) ? - ll_api::memory::Relocate::NONE : ll_api::memory::Relocate::XIP; - - ll_api::memory ncrisc_binary = llrt::get_risc_binary(ncrisc_hex_path, 1, ll_api::memory::PackSpans::PACK, relo_type); + ll_api::memory ncrisc_binary = llrt::get_risc_binary(ncrisc_hex_path, 1, llrt::PackSpans::PACK); TT_FATAL( ncrisc_binary == ncrisc_binaries.at(mask).at(0), "Expected saved NCRISC binary to be the same as binary in persistent cache"); @@ -227,7 +223,7 @@ int main(int argc, char **argv) { compute_class_idx, trisc_id, get_latest_kernel_binary_path(mask, compute_kernel)); - ll_api::memory trisc_binary = llrt::get_risc_binary(trisc_hex_path, 2, ll_api::memory::PackSpans::PACK, ll_api::memory::Relocate::XIP); + ll_api::memory trisc_binary = llrt::get_risc_binary(trisc_hex_path, 2, llrt::PackSpans::PACK); TT_FATAL( trisc_binary == compute_binaries.at(mask).at(trisc_id), "Expected saved TRISC binary for {} to be the same as binary in persistent cache", trisc_id_str); diff --git a/tests/tt_metal/tt_metal/test_kernels/misc/circular_buffer/cb_non_blocking_master_test_kernel.cpp b/tests/tt_metal/tt_metal/test_kernels/misc/circular_buffer/cb_non_blocking_master_test_kernel.cpp new file mode 100644 index 00000000000..a6bb11a4ac9 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/misc/circular_buffer/cb_non_blocking_master_test_kernel.cpp @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_metal/hw/inc/dataflow_api.h" + +#include +#include +#include + + +/* + * This kernel enumerates over all the CBs and pushes a page at a time. For every page, it checks to see if the + * non-blocking call to `cb_pages_reservable_at_back` to get the result. It stores the result for the element + * corresponding to that iteration index, into the output buffer associated with that CB. The buffer can + * later be readback by host for comparison and checking. + */ +void kernel_main() { + constexpr int32_t n_cbs = get_compile_time_arg_val(0); + constexpr int32_t n_pages = get_compile_time_arg_val(1); + + size_t arg_idx = 0; + + auto master_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + auto slave_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + + std::array output_buffer_addrs; + for (size_t i = 0; i < n_cbs; i++) { + output_buffer_addrs[i] = get_arg_val(arg_idx++); + } + + auto get_idx = [n_pages](size_t i, size_t j) -> size_t { + return i * n_pages + j; + }; + + for (int32_t i = 0; i < n_cbs; i++) { + auto *const output_buffer = reinterpret_cast(output_buffer_addrs[i]); + for (int32_t j = 0; j < n_pages; j++) { + + if (j > 0) { + // Induce some memory load to the CB, indicating that fewer pages are available + // for reservation (writer) and that more are available for popping (reader) + cb_reserve_back(i, j); + cb_push_back(i, j); + } + + noc_semaphore_set(master_sem_addr, 1); + noc_semaphore_wait(slave_sem_addr, 1); + // noc_semaphore_set(slave_sem_addr, 0); + for (int32_t k = 0; k < n_pages; k++) { + bool result = cb_pages_reservable_at_back(i, k); + output_buffer[get_idx(j,k)] = static_cast(result); + } + + // Notify that reader can expect the appropriate number of pages to be available + noc_semaphore_set(master_sem_addr, 2); + noc_semaphore_wait(slave_sem_addr, 2); + noc_semaphore_set(slave_sem_addr, 0); + + // snap back to alignment + if (j > 0) { + cb_reserve_back(i, n_pages - j); + cb_push_back(i, n_pages - j); + } + + } + + } +} diff --git a/tests/tt_metal/tt_metal/test_kernels/misc/circular_buffer/cb_non_blocking_slave_test_kernel.cpp b/tests/tt_metal/tt_metal/test_kernels/misc/circular_buffer/cb_non_blocking_slave_test_kernel.cpp new file mode 100644 index 00000000000..8bf11e4f175 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/misc/circular_buffer/cb_non_blocking_slave_test_kernel.cpp @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_metal/hw/inc/dataflow_api.h" + +#include +#include +#include + + +/* + * This kernel enumerates over all the CBs and pushes a page at a time. For every page, it checks to see if the + * non-blocking call to `cb_pages_reservable_at_back` to get the result. It stores the result for the element + * corresponding to that iteration index, into the output buffer associated with that CB. The buffer can + * later be readback by host for comparison and checking. + */ +void kernel_main() { + constexpr int32_t n_cbs = get_compile_time_arg_val(0); + constexpr int32_t n_pages = get_compile_time_arg_val(1); + + size_t arg_idx = 0; + + auto master_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + auto slave_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + + std::array output_buffer_addrs; + for (size_t i = 0; i < n_cbs; i++) { + output_buffer_addrs[i] = get_arg_val(arg_idx++); + } + + auto get_idx = [n_pages](size_t i, size_t j) -> size_t { + return i * n_pages + j; + }; + + for (int32_t i = 0; i < n_cbs; i++) { + auto *const output_buffer = reinterpret_cast(output_buffer_addrs[i]); + for (int32_t j = 0; j < n_pages; j++) { + // First level signal indicates the writer has pushed new pages to the CB + noc_semaphore_wait(master_sem_addr, 1); + noc_semaphore_set(slave_sem_addr, 1); + + for (int32_t k = 0; k < n_pages; k++) { + auto result = cb_pages_available_at_front(i, k); + output_buffer[get_idx(j,k)] = static_cast(result); + } + noc_semaphore_wait(master_sem_addr, 2); + noc_semaphore_set(master_sem_addr, 0); + if (j > 0) { + cb_wait_front(i, j); + cb_pop_front(i, j); + } + // Second level signal indicates "alignment pages". We signal back that we are + // done processing this step + noc_semaphore_set(slave_sem_addr, 2); + + if (j > 0) { + // snap back to alignment + cb_wait_front(i, n_pages - j); + cb_pop_front(i, n_pages - j); + } + + } + + } + +} diff --git a/tests/tt_metal/tt_metal/unit_tests/CMakeLists.txt b/tests/tt_metal/tt_metal/unit_tests/CMakeLists.txt index c8bbfa0a9d1..bee628ad488 100644 --- a/tests/tt_metal/tt_metal/unit_tests/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/unit_tests/CMakeLists.txt @@ -13,6 +13,7 @@ set(UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/buffer/test_simple_l1_buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/circular_buffer/test_CircularBuffer_allocation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/circular_buffer/test_CircularBuffer_creation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/circular_buffer/test_CircularBuffer_non_blocking.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compute/test_golden_impls.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compute/test_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compute/test_single_core_binary_compute.cpp diff --git a/tests/tt_metal/tt_metal/unit_tests/circular_buffer/test_CircularBuffer_creation.cpp b/tests/tt_metal/tt_metal/unit_tests/circular_buffer/test_CircularBuffer_creation.cpp index 3b6173fb45d..199aa429f88 100644 --- a/tests/tt_metal/tt_metal/unit_tests/circular_buffer/test_CircularBuffer_creation.cpp +++ b/tests/tt_metal/tt_metal/unit_tests/circular_buffer/test_CircularBuffer_creation.cpp @@ -75,7 +75,6 @@ TEST_F(DeviceFixture, TestCreateCircularBufferAtValidIndices) { auto cb = CreateCircularBuffer(program, cr_set, config); for (unsigned int id = 0; id < num_devices_; id++) { - detail::CompileProgram(devices_.at(id), program); program.finalize(devices_.at(id)); EXPECT_TRUE(test_cb_config_written_to_core(program, this->devices_.at(id), cr_set, golden_cb_config)); } diff --git a/tests/tt_metal/tt_metal/unit_tests/circular_buffer/test_CircularBuffer_non_blocking.cpp b/tests/tt_metal/tt_metal/unit_tests/circular_buffer/test_CircularBuffer_non_blocking.cpp new file mode 100644 index 00000000000..c10b6c6e0db --- /dev/null +++ b/tests/tt_metal/tt_metal/unit_tests/circular_buffer/test_CircularBuffer_non_blocking.cpp @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + + +#include "circular_buffer_test_utils.hpp" +#include "device_fixture.hpp" +#include "gtest/gtest.h" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/buffers/circular_buffer.hpp" + +#include "tt_metal/common/core_coord.hpp" + +#include "gtest/gtest.h" + +#include +#include +#include + +using namespace tt::tt_metal; + + +constexpr CoreCoord worker_core = {0, 0}; +constexpr size_t cb_n_pages = 32; +constexpr size_t cb_page_size = 16; +constexpr size_t n_cbs = 32; +constexpr size_t data_buffer_size = cb_n_pages * cb_n_pages; + +std::vector> create_output_buffers(Program &program, Device *device) { + std::vector> output_buffers; + output_buffers.reserve(n_cbs); + for (size_t i = 0; i < n_cbs; i++) { + // Bootleg way to put a single buffer on a single core + auto const& buffer_config = ShardedBufferConfig { + device, + data_buffer_size, + data_buffer_size, + BufferType::L1, + TensorMemoryLayout::WIDTH_SHARDED, + ShardSpecBuffer( + CoreRangeSet(CoreRange(worker_core)), + {cb_n_pages,cb_n_pages}, + ShardOrientation::ROW_MAJOR, + false, + {cb_n_pages,cb_n_pages}, + {1,1} + ), + }; + output_buffers.push_back(CreateBuffer(buffer_config)); + } + return output_buffers; +} + +std::vector generate_rt_args(uint32_t master_semaphore, uint32_t slave_semaphore, std::vector> const& data_buffers) { + std::vector rt_args; + rt_args.reserve(2 + n_cbs); + rt_args.push_back(master_semaphore); + rt_args.push_back(slave_semaphore); + std::transform(data_buffers.begin(), data_buffers.end(), std::back_inserter(rt_args), [](auto const& buffer) { return buffer->address(); }); + return rt_args; +} + +TEST_F(DeviceFixture, TestCircularBufferNonBlockingAPIs) { + Program program; + Device *device = devices_.at(0); + + auto const master_semaphore = CreateSemaphore(program, worker_core, 0, CoreType::WORKER); + auto const slave_semaphore = CreateSemaphore(program, worker_core, 0, CoreType::WORKER); + + std::vector cbs; + cbs.reserve(n_cbs); + for (size_t i = 0; i < n_cbs; i++) { + CircularBufferConfig cb_config = CircularBufferConfig(cb_n_pages * cb_page_size, {{i, tt::DataFormat::Float16_b}}).set_page_size(i, cb_page_size); + cbs.push_back(CreateCircularBuffer(program, worker_core, cb_config)); + } + + auto const& master_data_buffers = create_output_buffers(program, device); + auto const& slave_data_buffers = create_output_buffers(program, device); + + std::vector const& kernel_ct_args{ + n_cbs, + cb_n_pages + }; + + auto const master_kernel_id = tt::tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/misc/circular_buffer/cb_non_blocking_master_test_kernel.cpp", + worker_core, + tt::tt_metal::ReaderDataMovementConfig{kernel_ct_args}); + auto const slave_kernel_id = tt::tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/misc/circular_buffer/cb_non_blocking_slave_test_kernel.cpp", + worker_core, + tt::tt_metal::WriterDataMovementConfig{kernel_ct_args}); + + auto const& master_rt_args = generate_rt_args(master_semaphore, slave_semaphore, master_data_buffers); + auto const& slave_rt_args = generate_rt_args(master_semaphore, slave_semaphore, slave_data_buffers); + + tt::tt_metal::SetRuntimeArgs( + program, + master_kernel_id, + worker_core, + master_rt_args); + tt::tt_metal::SetRuntimeArgs( + program, + slave_kernel_id, + worker_core, + slave_rt_args); + + tt::tt_metal::detail::CompileProgram(device, program); + tt::tt_metal::detail::LaunchProgram(device, program, true); + + std::vector out_buf(data_buffer_size); + for (size_t i = 0; i < n_cbs; i++) { + tt::tt_metal::detail::ReadFromBuffer(master_data_buffers[i], out_buf, false); + + uint8_t const* raw_data = reinterpret_cast(out_buf.data()); + for (size_t pages_pushed = 0; pages_pushed < cb_n_pages; pages_pushed++) { + for (size_t requested_pages_free = 0; requested_pages_free < cb_n_pages; requested_pages_free++) { + ASSERT_EQ(static_cast(raw_data[pages_pushed * cb_n_pages + requested_pages_free]), requested_pages_free <= (cb_n_pages - pages_pushed)); + } + } + } + + for (size_t i = 0; i < n_cbs; i++) { + tt::tt_metal::detail::ReadFromBuffer(slave_data_buffers[i], out_buf, true); + + uint8_t const* raw_data = reinterpret_cast(out_buf.data()); + for (size_t pages_pushed = 0; pages_pushed < cb_n_pages; pages_pushed++) { + for (size_t filled_pages_requested = 0; filled_pages_requested < cb_n_pages; filled_pages_requested++) { + ASSERT_EQ(static_cast(raw_data[pages_pushed * cb_n_pages + filled_pages_requested]), filled_pages_requested <= pages_pushed); + } + } + } +} diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index b75c86d6296..53cc54a8afc 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -977,7 +977,7 @@ def run_test_sdpa_decode_single_iter( tt_Q, tt_K, tt_V, - [start_idx for _ in range(b)], + cur_pos=[start_idx for _ in range(b)], scale=scale, program_config=program_config, compute_kernel_config=compute_kernel_config, diff --git a/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softshrink.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softshrink.py index 275be0d5d39..7dbe50b4472 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softshrink.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softshrink.py @@ -23,15 +23,11 @@ def test_bw_softshrink(input_shapes, lambd, device): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device) - in_data.retain_grad() - - pyt_y = torch.nn.functional.softshrink(in_data, lambd=lambd) tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor, lambd=lambd) - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.softshrink_bw) + golden_tensor = golden_function(grad_data, in_data, lambd) comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -48,15 +44,11 @@ def test_bw_softshrink(input_shapes, lambd, device): def test_bw_softshrink_default(input_shapes, device): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device) - in_data.retain_grad() - - pyt_y = torch.nn.functional.softshrink(in_data) tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor) - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.softshrink_bw) + golden_tensor = golden_function(grad_data, in_data) comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_activation.py b/tests/ttnn/unit_tests/operations/eltwise/test_activation.py index ab3ed7d3c83..57f8ecf4284 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_activation.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_activation.py @@ -349,11 +349,11 @@ def run_activation_test_scalarBC_key(device, h, w, scalar1, scalar2, ttnn_functi torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16) golden_function = ttnn.get_golden_function(ttnn_function) - torch_output_tensor = golden_function(torch_input_tensor_a, min=scalar1, max=scalar2) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar1, scalar2) input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) - output_tensor = ttnn_function(input_tensor_a, min=scalar1, max=scalar2) + output_tensor = ttnn_function(input_tensor_a, scalar1, scalar2) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.from_device(output_tensor) output_tensor = ttnn.to_torch(output_tensor) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_composite.py index 05c71cf113e..b38bdc82a15 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_composite.py @@ -112,12 +112,12 @@ def test_unary_composite_clamp_ttnn(input_shapes, min, max, device): in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) if min is None and max is None: with pytest.raises(RuntimeError, match="Only one of 'min' or 'max' can be None. Please provide one value"): - ttnn.clamp(input_tensor1, min=min, max=max) + ttnn.clamp(input_tensor1, min, max) assert True else: - output_tensor = ttnn.clamp(input_tensor1, min=min, max=max) + output_tensor = ttnn.clamp(input_tensor1, min, max) golden_function = ttnn.get_golden_function(ttnn.clamp) - golden_tensor = golden_function(in_data1, min=min, max=max) + golden_tensor = golden_function(in_data1, min, max) comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass @@ -149,12 +149,12 @@ def test_unary_composite_clip_ttnn(input_shapes, min, max, device): in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) if min is None and max is None: with pytest.raises(RuntimeError, match="Only one of 'min' or 'max' can be None. Please provide one value"): - ttnn.clip(input_tensor1, min=min, max=max) + ttnn.clip(input_tensor1, min, max) assert True else: - output_tensor = ttnn.clip(input_tensor1, min=min, max=max) + output_tensor = ttnn.clip(input_tensor1, min, max) golden_function = ttnn.get_golden_function(ttnn.clip) - golden_tensor = golden_function(in_data1, min=min, max=max) + golden_tensor = golden_function(in_data1, min, max) comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index b90fd026c96..9b56e82e3d1 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -107,7 +107,7 @@ def run_max_pool( output_height=out_h, output_width=out_w, output_channels=in_c, - device=device, + compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=False, ) @@ -632,7 +632,7 @@ def test_pool_core_nondivis( output_height=out_h, output_width=out_w, output_channels=in_c, - device=device, + compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, ) diff --git a/tests/ttnn/unit_tests/operations/test_slice.py b/tests/ttnn/unit_tests/operations/test_slice.py index 7882dde8697..a85b33ada9a 100644 --- a/tests/ttnn/unit_tests/operations/test_slice.py +++ b/tests/ttnn/unit_tests/operations/test_slice.py @@ -311,7 +311,8 @@ def test_stride_slice_three_dim(c, h, w, begins_c, begins_h, begins_w, stride_c, @pytest.mark.parametrize("begins", [[2, 0, 0, 2]]) @pytest.mark.parametrize("ends", [[18, 16, 16, 18]]) @pytest.mark.parametrize("strides", [[2, 2, 2, 2]]) -def test_stride_slice_four_dim(dims, begins, ends, strides, device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_stride_slice_four_dim(dims, begins, ends, strides, layout, device): torch.manual_seed(2005) torch_input = torch.rand(dims) slices = [] @@ -320,7 +321,28 @@ def test_stride_slice_four_dim(dims, begins, ends, strides, device): torch_output = torch_input[slices[0], slices[1], slices[2], slices[3]] - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) + ttnn_output = ttnn_input[slices[0], slices[1], slices[2], slices[3]] + ttnn_output = ttnn.to_torch(ttnn_output) + + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +@pytest.mark.parametrize("dims", [[1, 56, 56, 96]]) +@pytest.mark.parametrize("begins", [[0, 0, 0, 0]]) +@pytest.mark.parametrize("ends", [[1, -1, 56, 96]]) +@pytest.mark.parametrize("strides", [[1, 2, 1, 1]]) +@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT]) +def test_stride_slice_four_dim_tiled(dims, begins, ends, strides, layout, device): + torch.manual_seed(2005) + torch_input = torch.rand(dims) + slices = [] + for i in range(len(dims)): + slices.append(slice(begins[i], ends[i], strides[i])) + + torch_output = torch_input[slices[0], slices[1], slices[2], slices[3]] + + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) ttnn_output = ttnn_input[slices[0], slices[1], slices[2], slices[3]] ttnn_output = ttnn.to_torch(ttnn_output) @@ -328,9 +350,10 @@ def test_stride_slice_four_dim(dims, begins, ends, strides, device): # these tests are copy and paste from the yolo customers #8920 -def test_slice_usecase1(device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_slice_usecase1(layout, device): torch_input = torch.randn(1, 3, 640, 640) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) torch_output = torch_input[..., ::2, ::2] # torch_output shape: [1, 3, 320, 320] ttnn_output = ttnn_input[..., ::2, ::2] @@ -339,9 +362,10 @@ def test_slice_usecase1(device): assert_with_pcc(torch_output, ttnn_output, 0.99) -def test_slice_usecase2(device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_slice_usecase2(layout, device): torch_input = torch.randn(1, 3, 640, 640) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) torch_output = torch_input[..., ::2, 1::2] # torch_output shape: [1, 3, 320, 320] ttnn_output = ttnn_input[..., ::2, 1::2] @@ -350,9 +374,10 @@ def test_slice_usecase2(device): assert_with_pcc(torch_output, ttnn_output, 0.99) -def test_slice_usecase3(device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_slice_usecase3(layout, device): torch_input = torch.randn(1, 3, 640, 640) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) torch_output = torch_input[..., 1::2, ::2] # torch_output shape: [1, 3, 320, 320] ttnn_output = ttnn_input[..., 1::2, ::2] @@ -361,9 +386,10 @@ def test_slice_usecase3(device): assert_with_pcc(torch_output, ttnn_output, 0.99) -def test_slice_usecase4(device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_slice_usecase4(layout, device): torch_input = torch.randn(1, 3, 640, 640) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) torch_output = torch_input[..., 1::2, 1::2] # torch_output shape: [1, 3, 320, 320] ttnn_output = ttnn_input[..., 1::2, 1::2] @@ -428,10 +454,10 @@ def test_slice_bert(input_shape, input_start, input_ends, layout, device): torch_input = torch.randn(input_shape, dtype=torch.bfloat16) ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) else: - if input_ends[-1] - input_start[-1] == 1: + if (input_ends[-1] - input_start[-1]) % 2 != 0: pytest.skip("Cannot slice the last dimension to 1 in row major layout") torch_input = torch.randn(input_shape, dtype=torch.float32) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) if len(input_shape) == 4: torch_output = torch_input[ @@ -478,10 +504,10 @@ def test_ttnn_slice_bert(input_shape, input_start, input_ends, layout, memory_co torch_input = torch.randn(input_shape, dtype=torch.bfloat16) ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) else: - if input_ends[-1] - input_start[-1] == 1: + if (input_ends[-1] - input_start[-1]) % 2 != 0: pytest.skip("Cannot slice the last dimension to 1 in row major layout") torch_input = torch.randn(input_shape, dtype=torch.float32) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) if len(input_shape) == 4: torch_output = torch_input[ @@ -558,7 +584,7 @@ def test_ttnn_slice_optimized_shapes(input_shape, input_start, input_ends, layou if (input_ends[-1] - input_start[-1]) % 2: pytest.skip("Cannot slice the last dimension to 1 in row major layout") torch_input = torch.randn(input_shape, dtype=torch.float32) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) torch_output = torch_input[ input_start[0] : input_ends[0], @@ -573,3 +599,190 @@ def test_ttnn_slice_optimized_shapes(input_shape, input_start, input_ends, layou ttnn_output = ttnn.to_torch(ttnn_output) assert_with_pcc(torch_output, ttnn_output, 0.99) + + +@pytest.mark.parametrize( + "input_shape, input_start, input_ends", + ( + ((1, 1, 1, 1, 256), (0, 0, 0, 0, 0), (1, 1, 1, 1, 255)), + ((1, 1, 32, 32, 32), (0, 0, 0, 0, 0), (1, 1, 32, 32, 1)), + ((1, 1, 32, 32, 64), (0, 0, 0, 0, 0), (1, 1, 32, 1, 32)), + ((1, 1, 1, 64, 64), (0, 0, 0, 0, 0), (1, 1, 1, 1, 1)), + ((4, 3, 2, 1, 4), (1, 1, 1, 0, 0), (1, 1, 2, 1, 4)), + ), +) +@pytest.mark.parametrize( + "layout", + (ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), +) +@pytest.mark.parametrize( + "memory_config", + (ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG), +) +def test_ttnn_slice_5d(input_shape, input_start, input_ends, layout, memory_config, device): + if layout == ttnn.TILE_LAYOUT: + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) + else: + if (input_ends[-1] - input_start[-1]) % 2: + pytest.skip("Cannot slice the last dimension to 1 in row major layout") + torch_input = torch.randn(input_shape, dtype=torch.float32) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) + + torch_output = torch_input[ + input_start[0] : input_ends[0], + input_start[1] : input_ends[1], + input_start[2] : input_ends[2], + input_start[3] : input_ends[3], + input_start[4] : input_ends[4], + ] + + ttnn_output = ttnn.slice(ttnn_input, input_start, input_ends, (1, 1, 1, 1, 1), memory_config=memory_config) + + ttnn_output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +@pytest.mark.parametrize( + "input_shape, input_start, input_ends, input_stride", + ( + ((1, 1, 5, 1, 256), (0, 0, 0, 0, 0), (1, 1, 1, 1, 234), (1, 1, 1, 1, 1)), + ((1, 2, 32, 32, 32), (0, 0, 0, 0, 0), (1, 1, 32, 32, 1), (1, 1, 1, 1, 1)), + ((1, 1, 32, 32, 64), (0, 0, 0, 0, 0), (1, 1, 32, 1, 32), (1, 1, 2, 1, 2)), + ((2, 1, 1, 64, 64), (1, 0, 0, 0, 0), (2, 1, 1, 1, 1), (1, 1, 1, 1, 1)), + ((4, 3, 2, 1, 18), (1, 1, 1, 0, 0), (1, 1, 2, 1, -2), (1, 1, 1, 1, 2)), + ), +) +@pytest.mark.parametrize( + "layout", + (ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), +) +def test_slice_5d(input_shape, input_start, input_ends, input_stride, layout, device): + if layout == ttnn.TILE_LAYOUT: + if input_stride is not (1, 1, 1, 1, 1): + pytest.skip("Cannot untilize 5D tensor") + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) + else: + if (input_ends[-1] - input_start[-1]) % 2: + pytest.skip("Cannot slice the last dimension to 1 in row major layout") + torch_input = torch.randn(input_shape, dtype=torch.float32) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) + + torch_output = torch_input[ + input_start[0] : input_ends[0] : input_stride[0], + input_start[1] : input_ends[1] : input_stride[1], + input_start[2] : input_ends[2] : input_stride[2], + input_start[3] : input_ends[3] : input_stride[3], + input_start[4] : input_ends[4] : input_stride[4], + ] + ttnn_output = ttnn_input[ + input_start[0] : input_ends[0] : input_stride[0], + input_start[1] : input_ends[1] : input_stride[1], + input_start[2] : input_ends[2] : input_stride[2], + input_start[3] : input_ends[3] : input_stride[3], + input_start[4] : input_ends[4] : input_stride[4], + ] + + ttnn_output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +def test_slice_7d_strided(device): + torch_input = torch.randn(1, 1, 1, 1, 1, 1, 256) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT) + + torch_output = torch_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:256:2] + ttnn_output = ttnn_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:256:2] + + ttnn_output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +def test_slice_7d(device): + torch_input = torch.randn(1, 1, 1, 1, 1, 1, 256) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + torch_output = torch_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:200] + ttnn_output = ttnn_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:200] + + ttnn_output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +@pytest.mark.parametrize( + "input_shape, dim, start, end, step, layout", + ( + ([1, 28, 56, 96], 2, 0, -1, 2, ttnn.TILE_LAYOUT), # Formerly bad pcc + ([1, 56, 56, 96], 1, 0, -1, 2, ttnn.TILE_LAYOUT), # Formerly bad pcc + ([8732, 4], 1, 0, 2, 1, ttnn.ROW_MAJOR_LAYOUT), # Formerly bad pcc + ([1, 14, 28, 192], 2, 1, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test (low priority) + ([1, 23, 40, 128], 3, 0, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test + ([1, 28, 28, 256], 1, 1, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test + ( + [1, 3], + 1, + 0, + -1, + 1, + ttnn.TILE_LAYOUT, + ), # works when you turn it into a 2D tensor (compared to [3] example in the next test) + ), +) +def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, device): + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + + slice_obj = slice(start, end, step) + + # Prepare indices for slicing in the specified dimension + indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension + indices[dim] = slice_obj # Apply slicing to the target dimension + indices = tuple(indices) + + # Apply slicing to the input_tensor + torch_output_tensor = torch_input[indices] + + ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) + ttnn_output = ttnn_tensor[indices] + + ttnn_output_tensor = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999) + + +@pytest.mark.parametrize( + "input_shape, dim, start, end, step, layout", + ( + ([8732, 4], 1, 0, -1, 4, ttnn.TILE_LAYOUT), # Need tensor for this or a padding aware tiled kernel + ([1, 7], 0, 0, -1, 1, ttnn.ROW_MAJOR_LAYOUT), # page size must equal buffer size + ( + [1, 7, 71, 64], + 3, + 0, + -1, + 1, + ttnn.ROW_MAJOR_LAYOUT, + ), # An unpadding slice operations for a RowMajor layout on the output tensor requires the last dimension to be on a 32 bit boundary + ([1, 8, 2, 2], 2, -1, -1, 1, ttnn.TILE_LAYOUT), # Buffer size and page size should be larger than 0 bytes + ([3], 0, 0, -1, 1, ttnn.TILE_LAYOUT), # Difference in expected shape as it's a 1D tensor + ), +) +def test_slice_adversarial(input_shape, dim, start, end, step, layout, device): + pytest.skip("These tests are expected to fail at the moment") + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + + slice_obj = slice(start, end, step) + + # Prepare indices for slicing in the specified dimension + indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension + indices[dim] = slice_obj # Apply slicing to the target dimension + indices = tuple(indices) + + # Apply slicing to the input_tensor + torch_output_tensor = torch_input[indices] + + ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) + ttnn_output = ttnn_tensor[indices] + + ttnn_output_tensor = ttnn.to_torch(ttnn_output) + + assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999) diff --git a/tt_metal/hostdevcommon/common_runtime_address_map.h b/tt_metal/hostdevcommon/common_runtime_address_map.h index 3b1d25268bc..308f1b21b99 100644 --- a/tt_metal/hostdevcommon/common_runtime_address_map.h +++ b/tt_metal/hostdevcommon/common_runtime_address_map.h @@ -13,9 +13,11 @@ * This file contains addresses that are visible to both host and device compiled code. */ -// TODO: move this to the memory manager, make configurable through the API +// Kernel config buffer is WIP +// Size is presently based on the old sizes of the RTAs + CB config + Sems +// plus some extra space freed up in the mem map constexpr static std::uint32_t L1_KERNEL_CONFIG_BASE = MEM_MAP_END; -constexpr static std::uint32_t L1_KERNEL_CONFIG_SIZE = 69 * 1024; +constexpr static std::uint32_t L1_KERNEL_CONFIG_SIZE = 4 * 1024 + 256 + 128 + 512; constexpr static std::uint32_t NUM_CIRCULAR_BUFFERS = 32; constexpr static std::uint32_t UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG = 4; diff --git a/tt_metal/hw/firmware/src/brisc.cc b/tt_metal/hw/firmware/src/brisc.cc index 8b59ec9bc0e..a3b22ccfdf1 100644 --- a/tt_metal/hw/firmware/src/brisc.cc +++ b/tt_metal/hw/firmware/src/brisc.cc @@ -169,13 +169,13 @@ void set_deassert_addresses() { #endif } -void l1_to_ncrisc_iram_copy(uint32_t src_addr, uint16_t size, uint32_t address_offset = 0) { +void l1_to_ncrisc_iram_copy(uint16_t size, uint32_t address_offset = 0) { #ifdef NCRISC_HAS_IRAM // Always copy ncrisc even if its size is 0 (save branch)... // Copy NCRISC firmware from L1 to local IRAM using tensix DMA tdma_xmov( TDMA_MOVER0, - src_addr, + (MEM_NCRISC_INIT_IRAM_L1_BASE >> 4) + address_offset, MEM_MOVER_VIEW_IRAM_BASE_ADDR + address_offset, size, XMOV_L1_TO_L0); @@ -267,22 +267,16 @@ void init_sync_registers() { } } -inline void init_ncrisc_iram() { -#ifdef NCRISC_HAS_IRAM +inline void deassert_ncrisc_trisc() { + // Below sets ncrisc to go so we can wait until it is cleared on first iteration + mailboxes->slave_sync.all = RUN_SYNC_MSG_ALL_SLAVES_DONE; + uint16_t fw_size16 = mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.ncrisc_kernel_size16; ncrisc_kernel_start_offset16 = fw_size16; // Copies from L1 to IRAM on chips where NCRISC has IRAM - l1_to_ncrisc_iram_copy(MEM_NCRISC_INIT_IRAM_L1_BASE >> 4, fw_size16); + l1_to_ncrisc_iram_copy(fw_size16); l1_to_ncrisc_iram_copy_wait(); -#endif -} - -inline void deassert_ncrisc_trisc() { - // Below sets ncrisc to go so we can wait until it is cleared on first iteration - mailboxes->slave_sync.all = RUN_SYNC_MSG_ALL_SLAVES_DONE; - - init_ncrisc_iram(); // Bring ncrisc/triscs out of reset deassert_all_reset(); @@ -406,13 +400,8 @@ int main() { DeviceValidateProfiler(launch_msg_address->kernel_config.enables); DeviceZoneSetCounter(launch_msg_address->kernel_config.host_assigned_id); // Copies from L1 to IRAM on chips where NCRISC has IRAM - uint32_t kernel_config_base = firmware_config_init(mailboxes, ProgrammableCoreType::TENSIX, DISPATCH_CLASS_TENSIX_DM0); - int ncrisc_index = static_cast::type>(TensixProcessorTypes::DM1); - uint32_t ncrisc_kernel_src_address = - kernel_config_base + launch_msg_address->kernel_config.kernel_text_offset[ncrisc_index]; - l1_to_ncrisc_iram_copy(ncrisc_kernel_src_address >> 4, - launch_msg_address->kernel_config.ncrisc_kernel_size16, - ncrisc_kernel_start_offset16); + l1_to_ncrisc_iram_copy(launch_msg_address->kernel_config.ncrisc_kernel_size16, ncrisc_kernel_start_offset16); + // Invalidate the i$ now the kernels have loaded and before running volatile tt_reg_ptr uint32_t* cfg_regs = core.cfg_regs_base(0); cfg_regs[RISCV_IC_INVALIDATE_InvalidateAll_ADDR32] = RISCV_IC_BRISC_MASK | RISCV_IC_TRISC_ALL_MASK | RISCV_IC_NCRISC_MASK; @@ -434,6 +423,7 @@ int main() { } prev_noc_mode = noc_mode; + uint32_t kernel_config_base = firmware_config_init(mailboxes, ProgrammableCoreType::TENSIX, DISPATCH_CLASS_TENSIX_DM0); uint32_t tt_l1_ptr *cb_l1_base = (uint32_t tt_l1_ptr *)(kernel_config_base + launch_msg_address->kernel_config.cb_offset); setup_cb_read_write_interfaces(cb_l1_base, 0, num_cbs_to_early_init, true, true, false); @@ -443,13 +433,10 @@ int main() { WAYPOINT("R"); if (enables & DISPATCH_CLASS_MASK_TENSIX_ENABLE_DM0) { setup_cb_read_write_interfaces(cb_l1_base, num_cbs_to_early_init, launch_msg_address->kernel_config.max_cb_index, true, true, false); - int index = static_cast::type>(TensixProcessorTypes::DM0); - void (*kernel_address)(uint32_t) = (void (*)(uint32_t)) - (kernel_config_base + launch_msg_address->kernel_config.kernel_text_offset[index]); - (*kernel_address)((uint32_t)kernel_address); + kernel_init(); RECORD_STACK_USAGE(); } else { - // This was not initialized in the kernel + // This was not initialized in kernel_init if (noc_mode == DM_DEDICATED_NOC) { noc_local_state_init(noc_index); } diff --git a/tt_metal/hw/firmware/src/brisck.cc b/tt_metal/hw/firmware/src/brisck.cc index f9f04eec011..7b01d4ba354 100644 --- a/tt_metal/hw/firmware/src/brisck.cc +++ b/tt_metal/hw/firmware/src/brisck.cc @@ -19,9 +19,8 @@ #include extern uint32_t __kernel_init_local_l1_base[]; -extern uint32_t __fw_export_end_text[]; -void kernel_launch(uint32_t kernel_base_addr) { +void kernel_launch() { #if defined(DEBUG_NULL_KERNELS) && !defined(DISPATCH_KERNEL) #ifdef KERNEL_RUN_TIME @@ -29,7 +28,7 @@ void kernel_launch(uint32_t kernel_base_addr) { while (c_tensix_core::read_wall_clock() < end_time); #endif #else - firmware_kernel_common_init((void tt_l1_ptr *)(kernel_base_addr + (uint32_t) __kernel_init_local_l1_base - (uint32_t)__fw_export_end_text)); + firmware_kernel_common_init((void tt_l1_ptr *)(__kernel_init_local_l1_base)); if constexpr (NOC_MODE == DM_DEDICATED_NOC) { noc_local_state_init(NOC_INDEX); diff --git a/tt_metal/hw/firmware/src/erisc.cc b/tt_metal/hw/firmware/src/erisc.cc index dba2673dac7..2c1f978b994 100644 --- a/tt_metal/hw/firmware/src/erisc.cc +++ b/tt_metal/hw/firmware/src/erisc.cc @@ -72,13 +72,10 @@ void __attribute__((noinline)) Application(void) { launch_msg_t* launch_msg_address = &(mailboxes->launch[launch_msg_rd_ptr]); DeviceValidateProfiler(launch_msg_address->kernel_config.enables); DeviceZoneSetCounter(launch_msg_address->kernel_config.host_assigned_id); - // Note that a core may get "GO" w/ enable false to keep its launch_msg's in sync enum dispatch_core_processor_masks enables = (enum dispatch_core_processor_masks)launch_msg_address->kernel_config.enables; if (enables & DISPATCH_CLASS_MASK_ETH_DM0) { - WAYPOINT("R"); firmware_config_init(mailboxes, ProgrammableCoreType::ACTIVE_ETH, DISPATCH_CLASS_ETH_DM0); - kernel_init(0); - WAYPOINT("D"); + kernel_init(); } mailboxes->go_message.signal = RUN_MSG_DONE; @@ -92,6 +89,7 @@ void __attribute__((noinline)) Application(void) { // Only executed if watcher is enabled. Ensures that we don't report stale data due to invalid launch messages in the ring buffer CLEAR_PREVIOUS_LAUNCH_MESSAGE_ENTRY_FOR_WATCHER(); } + WAYPOINT("R"); } else if (go_message_signal == RUN_MSG_RESET_READ_PTR) { // Reset the launch message buffer read ptr diff --git a/tt_metal/hw/firmware/src/erisck.cc b/tt_metal/hw/firmware/src/erisck.cc index 54d9ab9b958..b80fa99ad8a 100644 --- a/tt_metal/hw/firmware/src/erisck.cc +++ b/tt_metal/hw/firmware/src/erisck.cc @@ -23,7 +23,7 @@ CBInterface cb_interface[NUM_CIRCULAR_BUFFERS]; -void kernel_launch(uint32_t) { +void kernel_launch() { DeviceZoneScopedMainChildN("ERISC-KERNEL"); rtos_context_switch_ptr = (void (*)())RtosTable[0]; diff --git a/tt_metal/hw/firmware/src/idle_erisc.cc b/tt_metal/hw/firmware/src/idle_erisc.cc index d0c779bec52..518b33f544c 100644 --- a/tt_metal/hw/firmware/src/idle_erisc.cc +++ b/tt_metal/hw/firmware/src/idle_erisc.cc @@ -25,6 +25,7 @@ #include "debug/watcher_common.h" #include "debug/waypoint.h" +#include "debug/dprint.h" #include "debug/stack_usage.h" uint8_t noc_index; @@ -134,10 +135,7 @@ int main() { // Run the ERISC kernel WAYPOINT("R"); - int index = static_cast::type>(EthProcessorTypes::DM0); - void (*kernel_address)(uint32_t) = (void (*)(uint32_t)) - (kernel_config_base + mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.kernel_text_offset[index]); - (*kernel_address)((uint32_t)kernel_address); + kernel_init(); RECORD_STACK_USAGE(); WAYPOINT("D"); mailboxes->go_message.signal = RUN_MSG_DONE; diff --git a/tt_metal/hw/firmware/src/idle_erisck.cc b/tt_metal/hw/firmware/src/idle_erisck.cc index 756c71d0448..99f000c3de6 100644 --- a/tt_metal/hw/firmware/src/idle_erisck.cc +++ b/tt_metal/hw/firmware/src/idle_erisck.cc @@ -22,12 +22,10 @@ #include extern uint32_t __kernel_init_local_l1_base[]; -extern uint32_t __fw_export_end_text[]; -void kernel_launch(uint32_t kernel_base_addr) { +void kernel_launch() { DeviceZoneScopedMainChildN("ERISC-KERNEL"); - - firmware_kernel_common_init((void tt_l1_ptr *)(kernel_base_addr + (uint32_t) __kernel_init_local_l1_base - (uint32_t)__fw_export_end_text)); + firmware_kernel_common_init((void tt_l1_ptr *)__kernel_init_local_l1_base); noc_local_state_init(NOC_INDEX); diff --git a/tt_metal/hw/firmware/src/ncrisc.cc b/tt_metal/hw/firmware/src/ncrisc.cc index d5cb2b614f9..48b735afd9a 100644 --- a/tt_metal/hw/firmware/src/ncrisc.cc +++ b/tt_metal/hw/firmware/src/ncrisc.cc @@ -91,23 +91,13 @@ int main(int argc, char *argv[]) { notify_brisc_and_wait(); DeviceZoneScopedMainN("NCRISC-FW"); - uint32_t launch_msg_rd_ptr = mailboxes->launch_msg_rd_ptr; - launch_msg_t* launch_msg = &(mailboxes->launch[launch_msg_rd_ptr]); - uint32_t kernel_config_base = firmware_config_init(mailboxes, ProgrammableCoreType::TENSIX, DISPATCH_CLASS_TENSIX_DM1); uint32_t tt_l1_ptr *cb_l1_base = (uint32_t tt_l1_ptr *)(kernel_config_base + - launch_msg->kernel_config.cb_offset); - setup_cb_read_write_interfaces(cb_l1_base, 0, launch_msg->kernel_config.max_cb_index, true, true, false); - WAYPOINT("R"); + mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.cb_offset); + setup_cb_read_write_interfaces(cb_l1_base, 0, mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.max_cb_index, true, true, false); - int index = static_cast::type>(TensixProcessorTypes::DM1); - void (*kernel_address)(uint32_t) = (void (*)(uint32_t)) - (kernel_config_base + launch_msg->kernel_config.kernel_text_offset[index]); -#ifdef ARCH_BLACKHOLE - (*kernel_address)((uint32_t)kernel_address); -#else - kernel_init((uint32_t)kernel_address); -#endif + WAYPOINT("R"); + kernel_init(); RECORD_STACK_USAGE(); WAYPOINT("D"); diff --git a/tt_metal/hw/firmware/src/ncrisck.cc b/tt_metal/hw/firmware/src/ncrisck.cc index 6f24d5b107b..f59e2ce313e 100644 --- a/tt_metal/hw/firmware/src/ncrisck.cc +++ b/tt_metal/hw/firmware/src/ncrisck.cc @@ -27,9 +27,8 @@ uint32_t noc_nonposted_atomics_acked[NUM_NOCS]; uint32_t noc_posted_writes_num_issued[NUM_NOCS]; extern uint32_t __kernel_init_local_l1_base[]; -extern uint32_t __fw_export_end_text[]; -void kernel_launch(uint32_t kernel_base_addr) { +void kernel_launch() { DeviceZoneScopedMainChildN("NCRISC-KERNEL"); #if defined(DEBUG_NULL_KERNELS) && !defined(DISPATCH_KERNEL) @@ -38,8 +37,11 @@ void kernel_launch(uint32_t kernel_base_addr) { while (c_tensix_core::read_wall_clock() < KERNEL_RUN_TIME); #endif #else - - firmware_kernel_common_init((void tt_l1_ptr *)(kernel_base_addr + (uint32_t) __kernel_init_local_l1_base - (uint32_t)__fw_export_end_text)); +#ifdef ARCH_BLACKHOLE + firmware_kernel_common_init((void tt_l1_ptr *)__kernel_init_local_l1_base); +#else + firmware_kernel_common_init((void tt_l1_ptr *)(MEM_NCRISC_INIT_IRAM_L1_BASE + (uint32_t)__kernel_init_local_l1_base - MEM_NCRISC_IRAM_BASE)); +#endif if constexpr (NOC_MODE == DM_DEDICATED_NOC) { noc_local_state_init(NOC_INDEX); diff --git a/tt_metal/hw/firmware/src/trisc.cc b/tt_metal/hw/firmware/src/trisc.cc index 505e0bce3bf..f71c698167e 100644 --- a/tt_metal/hw/firmware/src/trisc.cc +++ b/tt_metal/hw/firmware/src/trisc.cc @@ -94,27 +94,21 @@ int main(int argc, char *argv[]) { while (*trisc_run != RUN_SYNC_MSG_GO); DeviceZoneScopedMainN("TRISC-FW"); - uint32_t launch_msg_rd_ptr = mailboxes->launch_msg_rd_ptr; - launch_msg_t* launch_msg = &(mailboxes->launch[launch_msg_rd_ptr]); - - uint32_t kernel_config_base = launch_msg->kernel_config.kernel_config_base[ProgrammableCoreType::TENSIX]; + uint32_t kernel_config_base = mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.kernel_config_base[ProgrammableCoreType::TENSIX]; #if !defined(UCK_CHLKC_MATH) uint32_t tt_l1_ptr *cb_l1_base = (uint32_t tt_l1_ptr *)(kernel_config_base + - launch_msg->kernel_config.cb_offset); - setup_cb_read_write_interfaces(cb_l1_base, 0, launch_msg->kernel_config.max_cb_index, cb_init_read, cb_init_write, cb_init_write); + mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.cb_offset); + setup_cb_read_write_interfaces(cb_l1_base, 0, mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.max_cb_index, cb_init_read, cb_init_write, cb_init_write); #endif rta_l1_base = (uint32_t tt_l1_ptr *)(kernel_config_base + - launch_msg->kernel_config.rta_offset[DISPATCH_CLASS_TENSIX_COMPUTE].rta_offset); + mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.rta_offset[DISPATCH_CLASS_TENSIX_COMPUTE].rta_offset); crta_l1_base = (uint32_t tt_l1_ptr *)(kernel_config_base + - launch_msg->kernel_config.rta_offset[DISPATCH_CLASS_TENSIX_COMPUTE].crta_offset); + mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.rta_offset[DISPATCH_CLASS_TENSIX_COMPUTE].crta_offset); WAYPOINT("R"); - int index = static_cast::type>(TensixProcessorTypes::MATH0) + thread_id; - void (*kernel_address)(uint32_t) = (void (*)(uint32_t)) - (kernel_config_base + launch_msg->kernel_config.kernel_text_offset[index]); - (*kernel_address)((uint32_t)kernel_address); + kernel_init(); RECORD_STACK_USAGE(); WAYPOINT("D"); diff --git a/tt_metal/hw/firmware/src/trisck.cc b/tt_metal/hw/firmware/src/trisck.cc index 862c2964808..f6c1cb57a38 100644 --- a/tt_metal/hw/firmware/src/trisck.cc +++ b/tt_metal/hw/firmware/src/trisck.cc @@ -34,9 +34,8 @@ volatile tt_reg_ptr uint * mailbox_base[4] = { } extern uint32_t __kernel_init_local_l1_base[]; -extern uint32_t __fw_export_end_text[]; -void kernel_launch(uint32_t kernel_base_addr) +void kernel_launch() { DeviceZoneScopedMainChildN("TRISC-KERNEL"); #if defined(DEBUG_NULL_KERNELS) && !defined(DISPATCH_KERNEL) @@ -44,7 +43,7 @@ void kernel_launch(uint32_t kernel_base_addr) ckernel::wait(KERNEL_RUN_TIME); #endif #else - firmware_kernel_common_init((void tt_l1_ptr *)(kernel_base_addr + (uint32_t) __kernel_init_local_l1_base - (uint32_t)__fw_export_end_text)); + firmware_kernel_common_init((void tt_l1_ptr *)(__kernel_init_local_l1_base)); #if defined(UCK_CHLKC_UNPACK) // Make sure DBG_FEATURE_DISABLE register is cleared before every kernel is executed diff --git a/tt_metal/hw/inc/blackhole/dev_mem_map.h b/tt_metal/hw/inc/blackhole/dev_mem_map.h index 8afc35b6000..274f787f58a 100644 --- a/tt_metal/hw/inc/blackhole/dev_mem_map.h +++ b/tt_metal/hw/inc/blackhole/dev_mem_map.h @@ -43,19 +43,11 @@ ///////////// // Firmware/kernel code holes -#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024) -// TODO: perhaps put NCRISC FW in the scratch area and free 1.5K after init (GS/WH) -#define MEM_NCRISC_FIRMWARE_SIZE 1536 -#define MEM_TRISC0_FIRMWARE_SIZE 1536 -#define MEM_TRISC1_FIRMWARE_SIZE 1536 -#define MEM_TRISC2_FIRMWARE_SIZE 1536 - -#define MEM_BRISC_KERNEL_SIZE (24 * 1024) -#define MEM_NCRISC_KERNEL_SIZE (24 * 1024) -#define MEM_TRISC0_KERNEL_SIZE (24 * 1024) -#define MEM_TRISC1_KERNEL_SIZE (24 * 1024) -#define MEM_TRISC2_KERNEL_SIZE (24 * 1024) - +#define MEM_BRISC_FIRMWARE_SIZE (10 * 1024 + MEM_BRISC_LOCAL_SIZE) +#define MEM_NCRISC_FIRMWARE_SIZE (16 * 1024 + MEM_NCRISC_LOCAL_SIZE) +#define MEM_TRISC0_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) +#define MEM_TRISC1_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) +#define MEM_TRISC2_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) #define MEM_ZEROS_SIZE 512 #define MEM_BOOT_CODE_BASE 0 diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index 56ba958d3cf..1223a7799e5 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -432,6 +432,41 @@ inline void wait_for_sync_register_value(uint32_t addr, int32_t val) { WAYPOINT("SD"); } +/** + * A non-blocking call that checks if the specified number of pages are available for reservation at the back of the circular buffer. + * This call is used by the producer to see if the consumer has freed up the desired space (in pages). + * + * CB total size must be an even multiple of the argument passed to this call. + * + * Return value: true if the specified number of pages are available + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------|---------------------------------------|----------|---------------------------------------------------------------------------------------------------|----------| + * | cb_id | The index of the circular buffer (CB) | uint32_t | 0 to 31 | True | + * | num_tiles | The number of free tiles to wait for | uint32_t | It must be less or equal than the size of the CB (the total number of tiles that fit into the CB) | True | + */ +FORCE_INLINE +bool cb_pages_reservable_at_back(int32_t operand, int32_t num_pages) { + uint32_t pages_acked_ptr = (uint32_t)get_cb_tiles_acked_ptr(operand); + + // while the producer (write-side interface) is waiting for space to free up "tiles_pushed" is not changing + // "tiles_pushed" is updated by the producer only when the tiles are pushed + uint32_t pages_received = get_cb_tiles_received_ptr(operand)[0]; + + invalidate_l1_cache(); + // uint16_t's here because Tensix updates the val at tiles_acked_ptr as uint16 in llk_pop_tiles + // TODO: I think we could have TRISC update tiles_acked_ptr, and we wouldn't need uint16 here + uint16_t pages_acked = (uint16_t)reg_read(pages_acked_ptr); +#ifdef ARCH_GRAYSKULL + // The following test slows down by 5% when removing the barrier + // TODO(pgk) investigate GS arbiter WAR in compiler, is this fixing an issue there? + // models/experimental/stable_diffusion/tests/test_perf_unbatched_stable_diffusion.py::test_perf_bare_metal + volatile uint32_t local_mem_barrier = pages_acked; +#endif + uint16_t free_space_pages_wrap = cb_interface[operand].fifo_num_pages - (pages_received - pages_acked); + return num_pages <= static_cast(free_space_pages_wrap); +} + /** * A blocking call that waits for the specified number of tiles to be free in the specified circular buffer. This call * is used by the producer to wait for the consumer to consume (ie. free up) the specified number of tiles. @@ -473,6 +508,38 @@ void cb_reserve_back(int32_t operand, int32_t num_pages) { WAYPOINT("CRBD"); } +/** + * A non-blocking call that tells the caller if the specified number of pages are available in the specified circular buffer (CB). + * This call is used by the consumer of the CB to see if the prodcuers has fill the CB with at least the specified number + * of tiles. Important note: in case multiple calls of cb_wait_front(n) are issued without a paired cb_pop_front() call, + * n is expected to be incremented by the user to be equal to a cumulative total of tiles. Example: 4 calls of + * cb_wait_front(8) followed by a cb_pop_front(32) would produce incorrect behavior. Instead 4 calls of cb_wait_front() + * waiting on 8, 16, 24, 32 tiles should be issued. + * + * Important note: number of tiles used in all cb_* calls must evenly divide the cb size and must be the same number in + * all cb_wait_front calls in the same kernel. Example 1: cb_wait_front(32), cb_wait_front(40), cb_pop_front(32+8) tiles + * on a CB of size 64 would produce incorrect behavior. Example 2: cb_wait_front(3) on a cb of size 32 would also + * produce incorrect behavior. These limitations are due to performance optimizations in the CB implementation. + * + * Important note: CB total size must be an even multiple of the argument passed to this call. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------|---------------------------------------|----------|---------------------------------------------------------------------------------------------------|----------| + * | cb_id | The index of the circular buffer (CB) | uint32_t | 0 to 31 | True | + * | num_tiles | The number of tiles to check for | uint32_t | It must be less or equal than the size of the CB (the total number of tiles that fit into the CB) | | + * */ +FORCE_INLINE +bool cb_pages_available_at_front(int32_t operand, int32_t num_pages) { + uint32_t pages_acked = get_cb_tiles_acked_ptr(operand)[0]; + uint32_t pages_received_ptr = (uint32_t) get_cb_tiles_received_ptr(operand); + + invalidate_l1_cache(); + uint16_t pages_received = ((uint16_t)reg_read(pages_received_ptr)) - pages_acked; + return num_pages <= pages_received; +} + /** * A blocking call that waits for the specified number of tiles to be available in the specified circular buffer (CB). * This call is used by the consumer of the CB to wait for the producer to fill the CB with at least the specified number diff --git a/tt_metal/hw/inc/dev_msgs.h b/tt_metal/hw/inc/dev_msgs.h index 60a0030110b..0b027259c6a 100644 --- a/tt_metal/hw/inc/dev_msgs.h +++ b/tt_metal/hw/inc/dev_msgs.h @@ -90,6 +90,8 @@ struct kernel_config_msg_t { volatile uint16_t watcher_kernel_ids[DISPATCH_CLASS_MAX]; volatile uint16_t ncrisc_kernel_size16; // size in 16 byte units + volatile uint16_t host_assigned_id; + // Ring buffer of kernel configuration data volatile uint32_t kernel_config_base[static_cast(ProgrammableCoreType::COUNT)]; volatile uint16_t sem_offset[static_cast(ProgrammableCoreType::COUNT)]; @@ -97,8 +99,6 @@ struct kernel_config_msg_t { rta_offset_t rta_offset[DISPATCH_CLASS_MAX]; volatile uint32_t kernel_text_offset[MaxProcessorsPerCoreType]; - volatile uint16_t host_assigned_id; - volatile uint8_t mode; // dispatch mode host/dev volatile uint8_t brisc_noc_id; volatile uint8_t brisc_noc_mode; diff --git a/tt_metal/hw/inc/firmware_common.h b/tt_metal/hw/inc/firmware_common.h index fd048640f3c..64d55851522 100644 --- a/tt_metal/hw/inc/firmware_common.h +++ b/tt_metal/hw/inc/firmware_common.h @@ -21,8 +21,8 @@ extern uint32_t __ldm_data_end[]; extern void (* __init_array_start[])(); extern void (* __init_array_end[])(); -extern void kernel_init(uint32_t kernel_init); -extern void kernel_launch(uint32_t kernel_base_addr); +extern void kernel_init(); +extern void kernel_launch(); inline void l1_to_local_mem_copy(uint32_t *local_mem_addr, uint32_t tt_l1_ptr *l1_addr, int32_t len) { // Cover L1 load latency of 6 cycles for the bulk of the copy diff --git a/tt_metal/hw/inc/grayskull/dev_mem_map.h b/tt_metal/hw/inc/grayskull/dev_mem_map.h index 793bd1ba789..b7e92831929 100644 --- a/tt_metal/hw/inc/grayskull/dev_mem_map.h +++ b/tt_metal/hw/inc/grayskull/dev_mem_map.h @@ -46,18 +46,11 @@ ///////////// // Firmware/kernel code holes -#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024) -// TODO: perhaps put NCRISC FW in the scratch area and free 1.5K after init (GS/WH) -#define MEM_NCRISC_FIRMWARE_SIZE 1536 -#define MEM_TRISC0_FIRMWARE_SIZE 1536 -#define MEM_TRISC1_FIRMWARE_SIZE 1536 -#define MEM_TRISC2_FIRMWARE_SIZE 1536 - -#define MEM_BRISC_KERNEL_SIZE (24 * 1024) -#define MEM_NCRISC_KERNEL_SIZE MEM_NCRISC_IRAM_SIZE -#define MEM_TRISC0_KERNEL_SIZE (24 * 1024) -#define MEM_TRISC1_KERNEL_SIZE (24 * 1024) -#define MEM_TRISC2_KERNEL_SIZE (24 * 1024) +#define MEM_BRISC_FIRMWARE_SIZE (10 * 1024 + MEM_BRISC_LOCAL_SIZE) +#define MEM_NCRISC_FIRMWARE_SIZE (16 * 1024) +#define MEM_TRISC0_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) +#define MEM_TRISC1_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) +#define MEM_TRISC2_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) #define MEM_ZEROS_SIZE 512 diff --git a/tt_metal/hw/inc/wormhole/dev_mem_map.h b/tt_metal/hw/inc/wormhole/dev_mem_map.h index 5f78ec1b810..1f6e55da51e 100644 --- a/tt_metal/hw/inc/wormhole/dev_mem_map.h +++ b/tt_metal/hw/inc/wormhole/dev_mem_map.h @@ -47,18 +47,11 @@ ///////////// // Firmware/kernel code holes -#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024) -// TODO: perhaps put NCRISC FW in the scratch area and free 1.5K after init (GS/WH) -#define MEM_NCRISC_FIRMWARE_SIZE 1536 -#define MEM_TRISC0_FIRMWARE_SIZE 1536 -#define MEM_TRISC1_FIRMWARE_SIZE 1536 -#define MEM_TRISC2_FIRMWARE_SIZE 1536 - -#define MEM_BRISC_KERNEL_SIZE (24 * 1024) -#define MEM_NCRISC_KERNEL_SIZE MEM_NCRISC_IRAM_SIZE -#define MEM_TRISC0_KERNEL_SIZE (24 * 1024) -#define MEM_TRISC1_KERNEL_SIZE (24 * 1024) -#define MEM_TRISC2_KERNEL_SIZE (24 * 1024) +#define MEM_BRISC_FIRMWARE_SIZE (10 * 1024 + MEM_BRISC_LOCAL_SIZE) +#define MEM_NCRISC_FIRMWARE_SIZE (16 * 1024) +#define MEM_TRISC0_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) +#define MEM_TRISC1_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) +#define MEM_TRISC2_FIRMWARE_SIZE (16 * 1024 + MEM_TRISC_LOCAL_SIZE) #define MEM_ZEROS_SIZE 512 @@ -129,7 +122,6 @@ #define MEM_IERISC_MAILBOX_END (MEM_IERISC_MAILBOX_BASE + MEM_IERISC_MAILBOX_SIZE) #define MEM_IERISC_FIRMWARE_BASE MEM_IERISC_MAILBOX_END #define MEM_IERISC_MAP_END (MEM_IERISC_FIRMWARE_BASE + MEM_IERISC_FIRMWARE_SIZE) -#define MEM_IERISC_KERNEL_SIZE (24 * 1024) #define MEM_IERISC_INIT_LOCAL_L1_BASE_SCRATCH MEM_IERISC_MAP_END #define MEM_IERISC_STACK_SIZE 1024 #define MEM_IERISC_STACK_BASE (MEM_LOCAL_BASE + MEM_IERISC_LOCAL_SIZE - MEM_IERISC_STACK_SIZE) diff --git a/tt_metal/hw/toolchain/erisc-b0-app.ld b/tt_metal/hw/toolchain/erisc-b0-app.ld index 4a82d3f2f17..05f7949596e 100644 --- a/tt_metal/hw/toolchain/erisc-b0-app.ld +++ b/tt_metal/hw/toolchain/erisc-b0-app.ld @@ -19,4 +19,4 @@ __firmware_global_pointer = ORIGIN(ERISC_DATA) + 0x7f0; INCLUDE "erisc-b0-app-sections.ld" INCLUDE "tensix-address.ld" -_Z11kernel_initm = ORIGIN(REGION_APP_KERNEL_CODE); +_Z11kernel_initv = ORIGIN(REGION_APP_KERNEL_CODE); diff --git a/tt_metal/hw/toolchain/main.ld b/tt_metal/hw/toolchain/main.ld index cf62ec943c3..4bdc1d8148e 100644 --- a/tt_metal/hw/toolchain/main.ld +++ b/tt_metal/hw/toolchain/main.ld @@ -18,8 +18,6 @@ REGION_ALIAS("REGION_DATA", TARGET_LOCAL_DATA_MEM(LD_TARGET)) REGION_ALIAS("REGION_STACK", TARGET_STACK_MEM(LD_TARGET)) #define FIRMWARE_STACK_SIZE TARGET_STACK_SIZE(LD_TARGET) -#if defined(TARGET_NCRISC) -#define KERNEL_ENTRY_SYMBOL _Z11kernel_initm -#endif +#define KERNEL_ENTRY_SYMBOL _Z11kernel_initv #include "sections.ld" diff --git a/tt_metal/hw/toolchain/memory.ld b/tt_metal/hw/toolchain/memory.ld index d0dcf3bc58b..29c4ced9588 100644 --- a/tt_metal/hw/toolchain/memory.ld +++ b/tt_metal/hw/toolchain/memory.ld @@ -2,51 +2,27 @@ MEMORY { BRISC_LOCAL_DATA_MEM : ORIGIN = MEM_LOCAL_BASE, LENGTH = MEM_BRISC_LOCAL_SIZE - MEM_BRISC_STACK_SIZE BRISC_STACK_MEM : ORIGIN = MEM_BRISC_STACK_BASE, LENGTH = MEM_BRISC_STACK_SIZE -#if defined(TYPE_FIRMWARE) BRISC_FIRMWARE_CODE : ORIGIN = MEM_BRISC_FIRMWARE_BASE, LENGTH = MEM_BRISC_FIRMWARE_SIZE -#else - BRISC_FIRMWARE_CODE : ORIGIN = MEM_BRISC_FIRMWARE_BASE, LENGTH = MEM_BRISC_KERNEL_SIZE -#endif TRISC0_LOCAL_DATA_MEM : ORIGIN = MEM_LOCAL_BASE, LENGTH = MEM_TRISC_LOCAL_SIZE - MEM_TRISC0_STACK_SIZE TRISC0_STACK_MEM : ORIGIN = MEM_TRISC0_STACK_BASE, LENGTH = MEM_TRISC0_STACK_SIZE -#if defined(TYPE_FIRMWARE) TRISC0_FIRMWARE_CODE : ORIGIN = MEM_TRISC0_FIRMWARE_BASE, LENGTH = MEM_TRISC0_FIRMWARE_SIZE -#else - TRISC0_FIRMWARE_CODE : ORIGIN = MEM_TRISC0_FIRMWARE_BASE, LENGTH = MEM_TRISC0_KERNEL_SIZE -#endif TRISC1_LOCAL_DATA_MEM : ORIGIN = MEM_LOCAL_BASE, LENGTH = MEM_TRISC_LOCAL_SIZE - MEM_TRISC1_STACK_SIZE TRISC1_STACK_MEM : ORIGIN = MEM_TRISC1_STACK_BASE, LENGTH = MEM_TRISC1_STACK_SIZE -#if defined(TYPE_FIRMWARE) TRISC1_FIRMWARE_CODE : ORIGIN = MEM_TRISC1_FIRMWARE_BASE, LENGTH = MEM_TRISC1_FIRMWARE_SIZE -#else - TRISC1_FIRMWARE_CODE : ORIGIN = MEM_TRISC1_FIRMWARE_BASE, LENGTH = MEM_TRISC1_KERNEL_SIZE -#endif TRISC2_LOCAL_DATA_MEM : ORIGIN = MEM_LOCAL_BASE, LENGTH = MEM_TRISC_LOCAL_SIZE - MEM_TRISC2_STACK_SIZE TRISC2_STACK_MEM : ORIGIN = MEM_TRISC2_STACK_BASE, LENGTH = MEM_TRISC2_STACK_SIZE -#if defined(TYPE_FIRMWARE) TRISC2_FIRMWARE_CODE : ORIGIN = MEM_TRISC2_FIRMWARE_BASE, LENGTH = MEM_TRISC2_FIRMWARE_SIZE -#else - TRISC2_FIRMWARE_CODE : ORIGIN = MEM_TRISC1_FIRMWARE_BASE, LENGTH = MEM_TRISC2_KERNEL_SIZE -#endif NCRISC_LOCAL_DATA_MEM : ORIGIN = MEM_LOCAL_BASE, LENGTH = MEM_NCRISC_LOCAL_SIZE - MEM_NCRISC_STACK_SIZE NCRISC_STACK_MEM : ORIGIN = MEM_NCRISC_STACK_BASE, LENGTH = MEM_NCRISC_STACK_SIZE -#if defined(TYPE_FIRMWARE) NCRISC_FIRMWARE_CODE : ORIGIN = MEM_NCRISC_FIRMWARE_BASE, LENGTH = MEM_NCRISC_FIRMWARE_SIZE -#else - NCRISC_FIRMWARE_CODE : ORIGIN = MEM_NCRISC_FIRMWARE_BASE, LENGTH = MEM_NCRISC_KERNEL_SIZE -#endif #ifdef COMPILE_FOR_IERISC IERISC_LOCAL_DATA_MEM : ORIGIN = MEM_LOCAL_BASE, LENGTH = MEM_IERISC_LOCAL_SIZE - MEM_IERISC_STACK_SIZE IERISC_STACK_MEM : ORIGIN = MEM_IERISC_STACK_BASE, LENGTH = MEM_IERISC_STACK_SIZE -#if defined(TYPE_FIRMWARE) IERISC_FIRMWARE_CODE : ORIGIN = MEM_IERISC_FIRMWARE_BASE, LENGTH = MEM_IERISC_FIRMWARE_SIZE -#else - IERISC_FIRMWARE_CODE : ORIGIN = MEM_IERISC_FIRMWARE_BASE, LENGTH = MEM_IERISC_KERNEL_SIZE -#endif #endif } diff --git a/tt_metal/hw/toolchain/sections.ld b/tt_metal/hw/toolchain/sections.ld index a1271d8c55b..d084c6efb0f 100644 --- a/tt_metal/hw/toolchain/sections.ld +++ b/tt_metal/hw/toolchain/sections.ld @@ -67,10 +67,8 @@ SECTIONS #if defined(TYPE_FIRMWARE) __fw_export_end_text = .; -#if defined(TARGET_NCRISC) PROVIDE (KERNEL_ENTRY_SYMBOL = __fw_export_end_text); #endif -#endif #if defined(TYPE_KERNEL) __kernel_init_local_l1_base = .; diff --git a/tt_metal/hw/toolchain/tmu-crt0k.S b/tt_metal/hw/toolchain/tmu-crt0k.S index 177d79cdb84..f5d4ec04215 100644 --- a/tt_metal/hw/toolchain/tmu-crt0k.S +++ b/tt_metal/hw/toolchain/tmu-crt0k.S @@ -3,5 +3,5 @@ .type _start, @function _start: - tail _Z13kernel_launchm + tail _Z13kernel_launchv .size _start, .-_start diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 7ba127c7803..0479807018f 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -2755,14 +2755,12 @@ void Device::configure_command_queue_programs() { } } - command_queue_program.finalize(this); detail::ConfigureDeviceWithProgram(this, command_queue_program, true); tt::Cluster::instance().l1_barrier(this->id()); if (device_id != mmio_device_id) { if (tt::Cluster::instance().get_device_tunnel_depth(device_id) == 1) { //first or only remote device on the tunnel, launch fd2 kernels on mmio device for all remote devices. Program& mmio_command_queue_program = *this->command_queue_programs[1]; - mmio_command_queue_program.finalize(mmio_device); detail::ConfigureDeviceWithProgram(mmio_device, mmio_command_queue_program, true); tt::Cluster::instance().l1_barrier(mmio_device_id); } @@ -2810,6 +2808,7 @@ void Device::init_command_queue_device() { } this->configure_command_queue_programs(); Program& command_queue_program = *this->command_queue_programs[0]; + command_queue_program.finalize(this); // TODO: should get a const ref std::vector>logical_cores = command_queue_program.logical_cores(); @@ -2829,6 +2828,7 @@ void Device::init_command_queue_device() { chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(this->id()); Device *mmio_device = tt::DevicePool::instance().get_active_device(mmio_device_id); Program& mmio_command_queue_program = *this->command_queue_programs[1]; + mmio_command_queue_program.finalize(mmio_device); std::vector>logical_cores = mmio_command_queue_program.logical_cores(); for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { const auto& logical_dispatch_cores = logical_cores[index]; diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 656c740ccee..32c7ac99e73 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -934,9 +934,7 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro } else { uint32_t base_address = this->program.kernels_buffer->address(); uint32_t page_offset = kg_transfer_info.page_offsets[kernel_idx]; - - // TODO: pack all these writes into 1 linear write - uint32_t kernel_config_buffer_offset = kg_transfer_info.dst_base_addrs[kernel_idx]; + uint32_t dst_addr = kg_transfer_info.dst_base_addrs[kernel_idx]; uint32_t aligned_length = align(kg_transfer_info.lengths[kernel_idx], hal.get_alignment(HalMemType::DRAM)); uint32_t padding = aligned_length - kg_transfer_info.lengths[kernel_idx]; while (aligned_length != 0) { @@ -957,13 +955,13 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro } kernel_bins_dispatch_subcmds.back().emplace_back(CQDispatchWritePackedLargeSubCmd{ .noc_xy_addr = noc_encoding, - .addr = kernel_config_buffer_offset, + .addr = dst_addr, .length = (uint16_t)write_length, .num_mcast_dests = (uint8_t)num_mcast_dests, .flags = CQ_DISPATCH_CMD_PACKED_WRITE_LARGE_FLAG_NONE}); RecordDispatchData( program, DISPATCH_DATA_BINARY, write_length, kg_transfer_info.riscvs[kernel_idx]); - kernel_config_buffer_offset += write_length; + dst_addr += write_length; kernel_bins_prefetch_subcmds.back().emplace_back(CQPrefetchRelayPagedPackedSubCmd{ .start_page = (uint16_t)page_offset, @@ -1178,11 +1176,7 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); for (uint32_t i = 0; i < kernel_bins_dispatch_subcmds.size(); ++i) { device_command_sequence.add_dispatch_write_packed_large( - dram_alignment, - kernel_bins_dispatch_subcmds[i].size(), - kernel_bins_dispatch_subcmds[i], - 0, - DISPATCH_WRITE_OFFSET_TENSIX_L1_CONFIG_BASE); + dram_alignment, kernel_bins_dispatch_subcmds[i].size(), kernel_bins_dispatch_subcmds[i]); device_command_sequence.add_prefetch_relay_paged_packed( kernel_bins_write_packed_large_data_aligned_sizeB[i], kernel_bins_prefetch_subcmds[i], @@ -1463,6 +1457,11 @@ void EnqueueProgramCommand::write_program_command_sequence(const ProgramCommandS void EnqueueProgramCommand::process() { + bool is_finalized = program.is_finalized(); + if (not is_finalized) { + program.finalize(device); + } + const std::pair&> reservation = this->manager.get_config_buffer_mgr().reserve(program.program_config_sizes_); bool stall_first = reservation.first.need_sync; @@ -1487,7 +1486,7 @@ void EnqueueProgramCommand::process() { // Currently this is mapped by device, but will be mapped by multiple values in the future uint64_t command_hash = this->device->id(); auto cached_cmd_iter = this->program.cached_program_command_sequences_.find(command_hash); - bool is_cached = program.is_cached() && cached_cmd_iter != this->program.cached_program_command_sequences_.end(); + bool is_cached = is_finalized && cached_cmd_iter != this->program.cached_program_command_sequences_.end(); // Calculate all commands size and determine how many fetch q entries to use // Preamble, some waits and stalls @@ -1508,7 +1507,6 @@ void EnqueueProgramCommand::process() { this->write_program_command_sequence(program_command_sequence, stall_first); this->assemble_stall_commands(program_command_sequence, false); this->program.cached_program_command_sequences_.insert({command_hash, std::move(program_command_sequence)}); - program.set_cached(); } else { static constexpr uint32_t wait_count_offset = (sizeof(CQPrefetchCmd) + offsetof(CQDispatchCmd, wait.count)); static constexpr uint32_t tensix_l1_write_offset_offset = @@ -2231,14 +2229,12 @@ void HWCommandQueue::enqueue_write_buffer(Buffer& buffer, const void* src, bool void HWCommandQueue::enqueue_program(Program& program, bool blocking) { ZoneScopedN("HWCommandQueue_enqueue_program"); if (not program.is_finalized()) { - program.finalize(device); TT_FATAL(!this->manager.get_bypass_mode(), "Tracing should only be used when programs have been cached"); if (program.kernels_buffer != nullptr) { this->enqueue_write_buffer( *program.kernels_buffer, program.program_transfer_info.binary_data.data(), false); } } - #ifdef DEBUG if (tt::llrt::OptionsG.get_validate_kernel_binaries()) { TT_FATAL(!this->manager.get_bypass_mode(), "Tracing cannot be used while validating program binaries"); diff --git a/tt_metal/impl/kernels/kernel.cpp b/tt_metal/impl/kernels/kernel.cpp index 0d521ce3194..374f0942d52 100644 --- a/tt_metal/impl/kernels/kernel.cpp +++ b/tt_metal/impl/kernels/kernel.cpp @@ -379,11 +379,7 @@ void DataMovementKernel::read_binaries(Device *device) { uint32_t dm_class_idx = magic_enum::enum_integer(HalProcessorClassType::DM); int riscv_id = static_cast::type>(this->config_.processor); const JitBuildState &build_state = device->build_kernel_state(tensix_core_type, dm_class_idx, riscv_id); - // TODO: from HAL - ll_api::memory::Relocate relo_type = - (riscv_id == 1 && (device->arch() == tt::ARCH::GRAYSKULL || device->arch() == tt::ARCH::WORMHOLE_B0)) ? - ll_api::memory::Relocate::NONE : ll_api::memory::Relocate::XIP; - ll_api::memory binary_mem = llrt::get_risc_binary(build_state.get_target_out_path(this->kernel_full_name_), riscv_id, ll_api::memory::PackSpans::PACK, relo_type); + ll_api::memory binary_mem = llrt::get_risc_binary(build_state.get_target_out_path(this->kernel_full_name_), riscv_id, llrt::PackSpans::PACK); binaries.push_back(binary_mem); uint32_t binary_size = binary_mem.get_packed_size(); log_debug(LogLoader, "RISC {} kernel binary size: {} in bytes", riscv_id, binary_size); @@ -400,11 +396,8 @@ void EthernetKernel::read_binaries(Device *device) { uint32_t dm_class_idx = magic_enum::enum_integer(HalProcessorClassType::DM); const JitBuildState &build_state = device->build_kernel_state(erisc_core_type, dm_class_idx, 0); int erisc_id = this->config_.eth_mode == Eth::IDLE ? 1 : 0; - // TODO: fix when active eth supports relo - ll_api::memory::Relocate relo_type = (this->config_.eth_mode == Eth::IDLE) ? - ll_api::memory::Relocate::XIP : ll_api::memory::Relocate::NONE; - ll_api::memory binary_mem = llrt::get_risc_binary(build_state.get_target_out_path(this->kernel_full_name_), erisc_id + 5, ll_api::memory::PackSpans::PACK, relo_type); - binaries.push_back(binary_mem); + ll_api::memory binary_mem = llrt::get_risc_binary(build_state.get_target_out_path(this->kernel_full_name_), erisc_id + 5, llrt::PackSpans::PACK); + binaries.push_back(binary_mem); uint32_t binary_size = binary_mem.get_packed_size(); log_debug(LogLoader, "ERISC {} kernel binary size: {} in bytes", erisc_id, binary_size); this->set_binaries(device->build_key(), std::move(binaries)); @@ -417,7 +410,7 @@ void ComputeKernel::read_binaries(Device *device) { uint32_t compute_class_idx = magic_enum::enum_integer(HalProcessorClassType::COMPUTE); for (int trisc_id = 0; trisc_id <= 2; trisc_id++) { const JitBuildState &build_state = device->build_kernel_state(tensix_core_type, compute_class_idx, trisc_id); - ll_api::memory binary_mem = llrt::get_risc_binary(build_state.get_target_out_path(this->kernel_full_name_), trisc_id + 2, ll_api::memory::PackSpans::PACK, ll_api::memory::Relocate::XIP); + ll_api::memory binary_mem = llrt::get_risc_binary(build_state.get_target_out_path(this->kernel_full_name_), trisc_id + 2, llrt::PackSpans::PACK); binaries.push_back(binary_mem); uint32_t binary_size = binary_mem.get_packed_size(); log_debug(LogLoader, "RISC {} kernel binary size: {} in bytes", trisc_id + 2, binary_size); @@ -438,35 +431,41 @@ RISCV EthernetKernel::processor() const { return RISCV::ERISC; } RISCV ComputeKernel::processor() const { return RISCV::COMPUTE; } -bool DataMovementKernel::configure(Device *device, const CoreCoord &logical_core, uint32_t base_address, const uint32_t offsets[]) const { +bool DataMovementKernel::configure(Device *device, const CoreCoord &logical_core) const { + bool pass = true; if (not is_on_logical_core(logical_core)) { TT_THROW("Cannot configure kernel because it is not on core {}", logical_core.str()); } auto device_id = device->id(); auto worker_core = device->worker_core_from_logical_core(logical_core); ll_api::memory binary_mem = this->binaries(device->build_key()).at(0); - int riscv_id = static_cast::type>(this->config_.processor); - llrt::write_binary_to_address(binary_mem, device_id, worker_core, base_address + offsets[riscv_id]); - return true; + int riscv_id; + switch (this->config_.processor) { + case (DataMovementProcessor::RISCV_0): { + riscv_id = 0; + } break; + case (DataMovementProcessor::RISCV_1): { + riscv_id = 1; + } break; + default: TT_THROW("Unsupported data movement processor!"); + } + + pass &= tt::llrt::test_load_write_read_risc_binary(binary_mem, device_id, worker_core, riscv_id); + return pass; } -bool EthernetKernel::configure(Device *device, const CoreCoord &logical_core, uint32_t base_address, const uint32_t offsets[]) const { +bool EthernetKernel::configure(Device *device, const CoreCoord &logical_core) const { + bool pass = true; auto device_id = device->id(); auto ethernet_core = device->ethernet_core_from_logical_core(logical_core); ll_api::memory binary_mem = this->binaries(device->build_key()).at(0); - - if (this->config_.eth_mode == Eth::IDLE) { - llrt::write_binary_to_address(binary_mem, device_id, ethernet_core, base_address + offsets[0]); - } else { - int riscv_id = 5; - tt::llrt::test_load_write_read_risc_binary(binary_mem, device_id, ethernet_core, riscv_id); - } - - return true; + int riscv_id = this->config_.eth_mode == Eth::IDLE ? 6 : 5; + pass &= tt::llrt::test_load_write_read_risc_binary(binary_mem, device_id, ethernet_core, riscv_id); + return pass; } -bool ComputeKernel::configure(Device *device, const CoreCoord &logical_core, uint32_t base_address, const uint32_t offsets[]) const { +bool ComputeKernel::configure(Device *device, const CoreCoord &logical_core) const { bool pass = true; if (not is_on_logical_core(logical_core)) { TT_THROW("Cannot configure kernel because it is not on core {}", logical_core.str()); @@ -476,7 +475,7 @@ bool ComputeKernel::configure(Device *device, const CoreCoord &logical_core, uin std::vector binaries = this->binaries(device->build_key()); for (int trisc_id = 0; trisc_id <= 2; trisc_id++) { - llrt::write_binary_to_address(binaries.at(trisc_id), device_id, worker_core, base_address + offsets[2 + trisc_id]); + pass &= tt::llrt::test_load_write_read_trisc_binary(binaries.at(trisc_id), device_id, worker_core, trisc_id); } return pass; diff --git a/tt_metal/impl/kernels/kernel.hpp b/tt_metal/impl/kernels/kernel.hpp index 1c8488ab815..60d654574fb 100644 --- a/tt_metal/impl/kernels/kernel.hpp +++ b/tt_metal/impl/kernels/kernel.hpp @@ -97,7 +97,7 @@ class Kernel : public JitBuildSettings { virtual RISCV processor() const = 0; uint32_t dispatch_class() { return this->dispatch_class_; } - virtual bool configure(Device *device, const CoreCoord &logical_core, uint32_t base_address, const uint32_t offsets[]) const = 0; + virtual bool configure(Device *device, const CoreCoord &logical_core) const = 0; virtual Config config() const = 0; @@ -170,7 +170,7 @@ class DataMovementKernel : public Kernel { void generate_binaries(Device *device, JitBuildOptions& build_options) const override; void read_binaries(Device *device) override; - bool configure(Device *device, const CoreCoord &logical_core, uint32_t base_address, const uint32_t offsets[]) const override; + bool configure(Device *device, const CoreCoord &logical_core) const override; Config config() const override { return this->config_; } @@ -199,7 +199,7 @@ class EthernetKernel : public Kernel { void generate_binaries(Device *device, JitBuildOptions &build_options) const override; void read_binaries(Device *device) override; - bool configure(Device *device, const CoreCoord &logical_core, uint32_t base_address, const uint32_t offsets[]) const override; + bool configure(Device *device, const CoreCoord &logical_core) const override; Config config() const override { return this->config_; } @@ -228,7 +228,7 @@ class ComputeKernel : public Kernel { void generate_binaries(Device *device, JitBuildOptions& build_options) const override; void read_binaries(Device *device) override; - bool configure(Device *device, const CoreCoord &logical_core, uint32_t base_address, const uint32_t offsets[]) const override; + bool configure(Device *device, const CoreCoord &logical_core) const override; Config config() const override { return this->config_; } diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index fba32d32c69..4cece79d0c4 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -111,9 +111,7 @@ Program::Program() : runtime_id(0), worker_crs_(), local_circular_buffer_allocation_needed_(false), - finalized_(false), - cached_(false) { - + finalized_(false) { uint32_t programmable_core_count = hal.get_programmable_core_type_count(); for (uint32_t i = 0; i < programmable_core_count; i++) { kernels_.push_back({}); @@ -208,7 +206,6 @@ KernelGroup::KernelGroup( for (uint32_t index = 0; index < MaxProcessorsPerCoreType; index ++) { this->kernel_bin_sizes[index] = 0; - this->kernel_text_offsets[index] = 0; this->launch_msg.kernel_config.kernel_text_offset[index] = 0; } this->launch_msg.kernel_config.ncrisc_kernel_size16 = 0; @@ -651,6 +648,23 @@ void Program::set_cb_tile_dims(Device *device, const std::vector &crs } void Program::populate_dispatch_data(Device *device) { + static const uint32_t processor_to_firmware_base[] = { + MEM_BRISC_FIRMWARE_BASE, + MEM_NCRISC_FIRMWARE_BASE, + MEM_TRISC0_FIRMWARE_BASE, + MEM_TRISC1_FIRMWARE_BASE, + MEM_TRISC2_FIRMWARE_BASE, + eth_l1_mem::address_map::FIRMWARE_BASE + }; + static const uint32_t processor_to_firmware_size[] = { + MEM_BRISC_FIRMWARE_SIZE, + MEM_NCRISC_INIT_IRAM_L1_SIZE, + MEM_TRISC0_FIRMWARE_SIZE, + MEM_TRISC1_FIRMWARE_SIZE, + MEM_TRISC2_FIRMWARE_SIZE, + eth_l1_mem::address_map::FIRMWARE_SIZE + }; + auto extract_dst_noc_unicast_info = [&device](const auto &ranges, const CoreType core_type) -> std::vector> { // This API extracts all the pairs of noc multicast encodings given a set of core ranges @@ -736,11 +750,16 @@ void Program::populate_dispatch_data(Device *device) { TT_ASSERT(kernel_bin.num_spans() == 1); - // TODO: spans are packed into 1 now, just grab it and go + uint32_t max_kernel_bin_size = processor_to_firmware_size[sub_kernels[sub_kernel_index]]; + kernel_bin.process_spans([&](vector::const_iterator mem_ptr, uint64_t dst, uint32_t len) { - // Set dst for eth kernels until they move to ring buffer - dst_base_addrs[transfer_info_index] = dst; + max_kernel_bin_size -= dst - processor_to_firmware_base[sub_kernels[sub_kernel_index]]; + + uint64_t relo_addr = + tt::llrt::relocate_dev_addr(dst); + + dst_base_addrs[transfer_info_index] = (uint32_t)relo_addr; page_offsets[transfer_info_index] = binaries_data.size() * sizeof(uint32_t) / HostMemDeviceCommand::PROGRAM_PAGE_SIZE; lengths[transfer_info_index] = len * sizeof(uint32_t); @@ -751,6 +770,12 @@ void Program::populate_dispatch_data(Device *device) { align(binaries_data.size(), HostMemDeviceCommand::PROGRAM_PAGE_SIZE / sizeof(uint32_t)), 0); transfer_info_index++; }); + + uint32_t bin_size = kernel_bin.size() * sizeof(uint32_t); + // TODO: remove this check when the ring buffer is in place (checked there) + TT_FATAL(bin_size <= max_kernel_bin_size, + "Kernel binary size, {}, overflowed kernel binary storage size, {}", + bin_size, max_kernel_bin_size); } kernel_bins_transfer_info kb_transfer_info = { @@ -778,16 +803,9 @@ void Program::populate_dispatch_data(Device *device) { kernel_group.core_ranges.ranges(), core_type); vector kernel_ids; - for (int dispatch_class = 0; dispatch_class < kernel_group.kernel_ids.size(); dispatch_class++) { - auto &optional_id = kernel_group.kernel_ids[dispatch_class]; + for (auto &optional_id : kernel_group.kernel_ids) { if (optional_id) { kernel_ids.push_back(optional_id.value()); - int proc_sub_class = 0; - for (uint32_t& dst_addr : kernel_transfer_info.at(optional_id.value()).dst_base_addrs) { - // TODO: ditch this w/ linear writes based on program config kernel_text_offset and size - dst_addr = kernel_group.kernel_text_offsets[dispatch_class + proc_sub_class]; - proc_sub_class++; - } } } @@ -997,33 +1015,29 @@ uint32_t Program::finalize_kernel_bins(Device *device, uint32_t programmable_cor auto& optional_id = kg.kernel_ids[class_id]; if (optional_id) { const auto kernel = this->get_kernel(optional_id.value()); - std::vector const &binaries = kernel->binaries(device->build_key()); // TODO: this is really ugly, save me future-HAL! if (programmable_core_type_index == hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX)) { uint32_t binary_packed_size = kernel->get_binary_packed_size(device, 0); if (class_id == DISPATCH_CLASS_TENSIX_DM0) { kg.kernel_bin_sizes[0] = binary_packed_size; - kg.kernel_text_offsets[0] = offset; kg.launch_msg.kernel_config.kernel_text_offset[0] = offset; offset += binary_packed_size; offset = align(offset, l1_alignment); } else if (class_id == DISPATCH_CLASS_TENSIX_DM1) { kg.kernel_bin_sizes[1] = binary_packed_size; - kg.kernel_text_offsets[1] = offset; kg.launch_msg.kernel_config.kernel_text_offset[1] = offset; offset += binary_packed_size; - offset = align(offset, l1_alignment); uint32_t binary_text_size = kernel->get_binary_text_size(device, 0); TT_ASSERT(binary_text_size >> 4 <= std::numeric_limits::max()); kg.launch_msg.kernel_config.ncrisc_kernel_size16 = (binary_text_size + 15) >> 4; + offset = align(offset, l1_alignment); } else { constexpr uint32_t max_math_processors_count = 3; for (uint32_t proc_type_index = 0; proc_type_index < max_math_processors_count; proc_type_index++) { uint32_t binary_packed_size = kernel->get_binary_packed_size(device, proc_type_index); kg.kernel_bin_sizes[2 + proc_type_index] = binary_packed_size; - kg.kernel_text_offsets[2 + proc_type_index] = offset; kg.launch_msg.kernel_config.kernel_text_offset[2 + proc_type_index] = offset; offset += binary_packed_size; offset = align(offset, l1_alignment); @@ -1032,18 +1046,9 @@ uint32_t Program::finalize_kernel_bins(Device *device, uint32_t programmable_cor } else { uint32_t binary_packed_size = kernel->get_binary_packed_size(device, 0); kg.kernel_bin_sizes[0] = binary_packed_size; - - // No kernel config buffer on active eth yet - if (hal.get_programmable_core_type(kg.programmable_core_type_index) == - HalProgrammableCoreType::IDLE_ETH) { - kg.kernel_text_offsets[0] = offset; - kg.launch_msg.kernel_config.kernel_text_offset[0] = offset; - offset += binary_packed_size; - offset = align(offset, l1_alignment); - } else { - kg.kernel_text_offsets[0] = binaries[0].get_text_addr(); - kg.launch_msg.kernel_config.kernel_text_offset[0] = binaries[0].get_text_addr(); - } + kg.launch_msg.kernel_config.kernel_text_offset[0] = offset; + offset += binary_packed_size; + offset = align(offset, l1_alignment); } } } @@ -1062,9 +1067,6 @@ uint32_t& Program::get_program_config_size(uint32_t programmable_core_type_index } void Program::finalize(Device *device) { - - this->construct_core_range_set_for_worker_cores(); - // Store the number of tensix "go signals" for use by CQ // CQ iterates over these to update runtime addresses, needs to know when eth begins (after tensix) // TODO: should store all the counts @@ -1079,7 +1081,6 @@ void Program::finalize(Device *device) { } for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { - HalProgrammableCoreType programmable_core_type = static_cast(index); uint32_t offset = 0; offset = finalize_rt_args(index, offset); @@ -1091,25 +1092,16 @@ void Program::finalize(Device *device) { offset = finalize_cbs(index, offset); TT_ASSERT(offset == align(offset, hal.get_alignment(HalMemType::L1))); - offset = finalize_kernel_bins(device, index, offset); + // TODO: update the offset when kernel bins are moved into the kernel config buffer + (void)finalize_kernel_bins(device, index, offset); TT_ASSERT(offset == align(offset, hal.get_alignment(HalMemType::L1))); this->get_program_config_size(index) = offset; - - auto max_size = hal.get_dev_size(programmable_core_type, HalL1MemAddrType::KERNEL_CONFIG); - TT_FATAL(offset < max_size, - "Program size ({}) too large for kernel config buffer ({}) on {}", - offset, max_size, magic_enum::enum_name(programmable_core_type)); } // The sem offsets cross programmable_core_types so must be set after the loop above this->set_launch_msg_sem_offsets(); - // TODO: This check is wrong - it populates dispatch data for dispatch kernels - if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr) { - this->populate_dispatch_data(device); // TODO: maybe rename - } - finalized_ = true; } @@ -1210,6 +1202,11 @@ void Program::compile(Device *device, bool fd_bootloader_mode) { sync_build_step(events); + this->construct_core_range_set_for_worker_cores(); + if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr) { + this->populate_dispatch_data(device); // TODO: maybe rename + } + if (detail::CompilationReporter::enabled()) { detail::CompilationReporter::inst().flush_program_entry(*this, enable_persistent_kernel_cache); } diff --git a/tt_metal/impl/program/program.hpp b/tt_metal/impl/program/program.hpp index 8e1dbb587ad..05aa822d787 100644 --- a/tt_metal/impl/program/program.hpp +++ b/tt_metal/impl/program/program.hpp @@ -50,7 +50,6 @@ struct KernelGroup { kernel_id_array_t kernel_ids; uint32_t rta_sizes[DISPATCH_CLASS_MAX]; uint32_t total_rta_size; - uint32_t kernel_text_offsets[MaxProcessorsPerCoreType]; uint32_t kernel_bin_sizes[MaxProcessorsPerCoreType]; launch_msg_t launch_msg; go_msg_t go_msg; @@ -150,8 +149,6 @@ class Program { void allocate_circular_buffers(const Device *device); bool is_finalized() const { return this->finalized_; } - bool is_cached() const { return this->cached_; } - void set_cached() { this->cached_ = true; } void finalize(Device *device); std::shared_ptr get_kernel(KernelHandle kernel_id) const; @@ -174,7 +171,6 @@ class Program { ProgramTransferInfo program_transfer_info; bool finalized_; - bool cached_; struct CircularBufferAllocator { CircularBufferAllocator(const CoreRange &core_range_) : core_range(core_range_) {} diff --git a/tt_metal/include/tt_metal/buffer.hpp b/tt_metal/include/tt_metal/buffer.hpp new file mode 100644 index 00000000000..efb44bb7f8c --- /dev/null +++ b/tt_metal/include/tt_metal/buffer.hpp @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "types.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +//================================================== +// BUFFER HANDLING +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + + +/** + * @brief Allocates an interleaved DRAM or L1 buffer on the device. + * + * @param config Configuration for the buffer. + * @return Buffer handle to the allocated buffer. + */ +Buffer CreateBuffer(const InterleavedBufferConfig &config); + +/** + * @brief Allocates a buffer on the device. + * + * @param buffer The buffer to allocate. + * @param bottom_up If true, allocates buffer from the bottom up. + */ +void AllocateBuffer(Buffer buffer, bool bottom_up); + +/** + * @brief Deallocates a buffer from the device. + * + * @param buffer The buffer to deallocate. + */ +void DeallocateBuffer(Buffer buffer); + +/** + * @brief Copies data from a host buffer into the specified device buffer. + * + * @param buffer Buffer to write data into. + * @param host_buffer Host buffer containing data to copy. + */ +void WriteToBuffer(Buffer buffer, const std::vector &host_buffer); + +/** + * @brief Copies data from a device buffer into a host buffer. + * + * @param buffer Buffer to read data from. + * @param host_buffer Host buffer to copy data into. + * @param shard_order If true, reads data in shard order. + */ +void ReadFromBuffer(Buffer buffer, std::vector &host_buffer, bool shard_order = false); + +/** + * @brief Copies data from a specific shard of a device buffer into a host buffer. + * + * @param buffer Buffer to read data from. + * @param host_buffer Host buffer to copy data into. + * @param core_id ID of the core shard to read. + */ +void ReadFromShard(Buffer buffer, std::vector &host_buffer, const uint32_t &core_id); + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/command_queue.hpp b/tt_metal/include/tt_metal/command_queue.hpp new file mode 100644 index 00000000000..e0d9f24149a --- /dev/null +++ b/tt_metal/include/tt_metal/command_queue.hpp @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "types.hpp" + +#include "tt_metal/impl/buffers/buffer.hpp" + +//================================================== +// COMMAND QUEUE OPERATIONS +//================================================== + + +namespace tt::tt_metal{ +namespace v1 { + +/** + * @brief Reads a buffer from the device. + * + * @param cq The command queue used to dispatch the command. + * @param buffer The device buffer to read from. + * @param dst Pointer to the destination memory where data will be stored. + * @param blocking Indicates whether the operation is blocking. + */ +void EnqueueReadBuffer( + CommandQueue cq, + Buffer buffer, + std::byte *dst, + bool blocking); + +/** + * @brief Writes data to a buffer on the device. + * + * @param cq The command queue used to dispatch the command. + * @param buffer The device buffer to write to. + * @param src Source data vector to write to the device. + * @param blocking Indicates whether the operation is blocking. + */ +void EnqueueWriteBuffer( + CommandQueue cq, + Buffer buffer, + const std::byte *src, + bool blocking); + + +/** + * @brief Writes a program to the device and launches it. + * + * @param cq The command queue used to dispatch the command. + * @param program The program to execute on the device. + * @param blocking Indicates whether the operation is blocking. + */ +void EnqueueProgram(CommandQueue cq, Program program, bool blocking); + +/** + * @brief Blocks until all previously dispatched commands on the device have completed. + * + * @param cq The command queue to wait on. + */ +void Finish(CommandQueue cq); + + +/** + * @brief Sets the command queue mode to lazy or immediate. + * + * @param lazy If true, sets the command queue to lazy mode. + */ +void SetLazyCommandQueueMode(bool lazy); + + +/** + * @brief Retrieves the device associated with the command queue. + * + * @param cq The command queue to query. + * @return Device handle associated with the command queue. + */ +Device GetDevice(class CommandQueue cq); + +/** + * @brief Retrieves the ID of the command queue. + * + * @param cq The command queue to query. + * @return ID of the command queue. + */ +uint32_t GetId(class CommandQueue cq); + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/device.hpp b/tt_metal/include/tt_metal/device.hpp new file mode 100644 index 00000000000..90b9f77bada --- /dev/null +++ b/tt_metal/include/tt_metal/device.hpp @@ -0,0 +1,305 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "types.hpp" +#include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +#include "tt_metal/impl/dispatch/work_executor.hpp" + +//================================================== +// DEVICE MANAGEMENT +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + +/** + * @brief Returns the number of Tenstorrent devices that can be targeted. + * + * @return Size_t representing the number of available devices. + */ +size_t GetNumAvailableDevices(); + +/** + * @brief Returns the number of Tenstorrent devices connected via PCIe. + * + * @return Size_t representing the number of PCIe devices. + */ +size_t GetNumPCIeDevices(); + +/** + * @brief Retrieves the PCIe device ID for a given device ID. + * + * @param device_id ID of the device to query. + * @return Chip ID of the PCIe device. + */ +chip_id_t GetPCIeDeviceID(chip_id_t device_id); + +/** + * @brief Instantiates a Device object. + * + * @param device_id ID of the device to target (0 to GetNumAvailableDevices() - 1). + * @param num_hw_cqs Number of hardware command queues (default: 1, valid range: 1 to 2). + * @param l1_small_size L1 small space to reserve (default: DEFAULT_L1_SMALL_SIZE). + * @param trace_region_size Trace region size to reserve (default: DEFAULT_TRACE_REGION_SIZE). + * @param dispatch_core_type Dispatch core type to use (default: DispatchCoreType::WORKER). + * @return Device handle to the created device. + */ +Device CreateDevice( + chip_id_t device_id, + uint8_t num_hw_cqs = 1, + size_t l1_small_size = DEFAULT_L1_SMALL_SIZE, + size_t trace_region_size = DEFAULT_TRACE_REGION_SIZE, + DispatchCoreType dispatch_core_type = DispatchCoreType::WORKER); + +/** + * @brief Resets and closes the device. + * + * @param device Handle to the device to close. + * @return True if the device was successfully closed; otherwise, false. + */ +bool CloseDevice(Device device); + + +/** + * @brief Deallocates all buffers on the device. + */ +void DeallocateBuffers(Device device); + + +/** + * @brief Dumps device-side profiler data to a CSV log. + * + * @param device The device holding the program being profiled. + * @param worker_cores CoreRangeSet of worker cores being profiled. + * @param last_dump If true, indicates the last dump before process termination. + */ +void DumpDeviceProfileResults(Device device, const CoreRangeSet &worker_cores, bool last_dump = false); + + +/** + * @brief Retrieves the architecture of the device. + * + * @param device The device to query. + * @return ARCH representing the device architecture. + */ +ARCH GetArch(Device device); + +/** + * @brief Retrieves the ID of the device. + * + * @param device The device to query. + * @return Chip ID of the device. + */ +chip_id_t GetId(Device device); + +/** + * @brief Retrieves the number of DRAM channels on the device. + * + * @param device The device to query. + * @return Number of DRAM channels. + */ +int GetNumDramChannels(Device device); + +/** + * @brief Retrieves the available L1 size per worker core on the device. + * + * @param device The device to query. + * @return L1 size per core in bytes. + */ +uint32_t GetL1SizePerCore(Device device); + +/** + * @brief Computes the storage grid size for the device. + * + * @param device The device to query. + * @return CoreCoord representing the storage grid size. + */ +CoreCoord GetComputeWithStorageGridSize(Device device); + +/** + * @brief Retrieves the DRAM grid size for the device. + * + * @param device The device to query. + * @return CoreCoord representing the DRAM grid size. + */ +CoreCoord GetDramGridSize(Device device); + +/** + * @brief Converts a logical core coordinate to a physical core coordinate. + * + * @param device The device to query. + * @param logical_core The logical core coordinate. + * @param core_type The type of the core. + * @return CoreCoord representing the physical core coordinate. + */ +CoreCoord PhysicalCoreFromLogical(Device device, const CoreCoord &logical_core, const CoreType &core_type); + +/** + * @brief Retrieves the worker core coordinate corresponding to a logical core. + * + * @param device The device to query. + * @param logical_core The logical core coordinate. + * @return CoreCoord representing the worker core coordinate. + */ +CoreCoord WorkerCoreFromLogical(Device device, const CoreCoord &logical_core); + +/** + * @brief Retrieves the Ethernet core coordinate corresponding to a logical core. + * + * @param device The device to query. + * @param logical_core The logical core coordinate. + * @return CoreCoord representing the Ethernet core coordinate. + */ +CoreCoord EthernetCoreFromLogical(Device device, const CoreCoord &logical_core); + +/** + * @brief Enables the program cache on the device. + * + * @param device The device to modify. + */ +void EnableProgramCache(Device device); + +/** + * @brief Disables and clears the program cache on the device. + * + * @param device The device to modify. + */ +void DisableAndClearProgramCache(Device device); + +/** + * @brief Pushes a work function onto the device's work queue. + * + * @param device The device to which the work will be pushed. + * @param work The work function to execute. + * @param blocking Indicates whether the operation should be blocking (default: false). + */ +void PushWork(Device device, std::function &&work, bool blocking = false); + +/** + * @brief Pushes a shared work function onto the device's work queue. + * + * @param device The device to which the work will be pushed. + * @param work Shared pointer to the work function to execute. + * @param blocking Indicates whether the operation should be blocking (default: false). + */ +void PushWork(Device device, std::function work, bool blocking = false); + +/** + * @brief Synchronizes operations on the given device. + * + * @param device The device to synchronize. + */ +void Synchronize(Device device); + +/** + * @brief Retrieves a list of Ethernet socket coordinates connected to a specific chip ID. + * + * @param device The device to query. + * @param connected_chip_id The connected chip ID. + * @return Vector of CoreCoord representing Ethernet socket coordinates. + */ +std::vector GetEthernetSockets(Device device, chip_id_t connected_chip_id); + +/** + * @brief Returns the number of banks for a specific buffer type on the device. + * + * @param device The device to query. + * @param buffer_type The type of buffer. + * @return Number of banks. + */ +uint32_t GetNumBanks(Device device, const BufferType &buffer_type); + +/** + * @brief Computes the offset of a specific bank for a buffer type on the device. + * + * @param device The device to query. + * @param buffer_type The type of buffer. + * @param bank_id The ID of the bank. + * @return Offset of the bank. + */ +int32_t GetBankOffset(Device device, BufferType buffer_type, uint32_t bank_id); + +/** + * @brief Retrieves bank IDs associated with a logical core for a given buffer type. + * + * @param device The device to query. + * @param buffer_type The type of buffer. + * @param logical_core The logical core coordinate. + * @return Reference to a vector of bank IDs. + */ +const std::vector &BankIdsFromLogicalCore(Device device, BufferType buffer_type, const CoreCoord &logical_core); + + +/** + * @brief Retrieves the machine epsilon for the SFPU on the device. + * + * @param device The device to query. + * @return SFPU machine epsilon. + */ +float GetSfpuEps(Device device); + +/** + * @brief Retrieves the representation of NaN for the SFPU on the device. + * + * @param device The device to query. + * @return SFPU NaN value. + */ +float GetSfpuNan(Device device); + +/** + * @brief Retrieves the representation of infinity for the SFPU on the device. + * + * @param device The device to query. + * @return SFPU infinity value. + */ +float GetSfpuInf(Device device); + +/** + * @brief Retrieves a command queue from the device for a given queue ID. + * + * @param device The device to query. + * @param cq_id The command queue ID. + * @return CommandQueue handle. + */ +CommandQueue GetCommandQueue(Device device, size_t cq_id); + +/** + * @brief Retrieves the default command queue for the given device. + * + * @param device The device to query. + * @return CommandQueue handle. + */ +CommandQueue GetDefaultCommandQueue(Device device); + +/** + * @brief Retrieves the current worker mode of the device. + * + * @param device The device to query. + * @return WorkExecutorMode representing the current worker mode. + */ +WorkExecutorMode GetWorkerMode(Device device); + +/** + * @brief Retrieves the number of entries in the program cache on the device. + * + * @param device The device to query. + * @return Number of program cache entries. + */ +std::size_t GetNumProgramCacheEntries(Device device); + +/** + * @brief Checks if the current execution is in the main thread for the device. + * + * @param device The device to query. + * @return True if in the main thread; otherwise, false. + */ +bool InMainThread(Device device); + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/event.hpp b/tt_metal/include/tt_metal/event.hpp new file mode 100644 index 00000000000..f73ac2a9ca3 --- /dev/null +++ b/tt_metal/include/tt_metal/event.hpp @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "types.hpp" + +//================================================== +// EVENT MANAGEMENT +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + +/** + * @brief Enqueues a command to record an event on the device. + * + * @param cq The command queue used to dispatch the command. + * @param event Shared pointer to the Event object to record. + */ +void EnqueueRecordEvent(CommandQueue cq, const std::shared_ptr &event); + +/** + * @brief Enqueues a command to wait for an event to complete on the device. + * + * @param cq The command queue that will wait for the event. + * @param event Shared pointer to the Event object to wait on. + */ +void EnqueueWaitForEvent(CommandQueue cq, const std::shared_ptr &event); + +/** + * @brief Blocks the host until the specified event has completed on the device. + * + * @param event Shared pointer to the Event object to synchronize. + */ +void EventSynchronize(const std::shared_ptr &event); + +/** + * @brief Queries the completion status of an event on the device. + * + * @param event Shared pointer to the Event object to query. + * @return True if the event is completed; otherwise, false. + */ +bool EventQuery(const std::shared_ptr &event); + + +/** + * @brief Synchronizes the device with the host by waiting for all operations to complete. + * + * @param device The device to synchronize. + * @param cq_id Optional command queue ID to synchronize. If not provided, all queues are synchronized. + */ +void Synchronize(Device device, const std::optional cq_id = std::nullopt); + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/internal/buffer.hpp b/tt_metal/include/tt_metal/internal/buffer.hpp new file mode 100644 index 00000000000..ddc81f0fbf4 --- /dev/null +++ b/tt_metal/include/tt_metal/internal/buffer.hpp @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "types.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +//================================================== +// BUFFER HANDLING +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/internal/command_queue.hpp b/tt_metal/include/tt_metal/internal/command_queue.hpp new file mode 100644 index 00000000000..ab1aa83f89c --- /dev/null +++ b/tt_metal/include/tt_metal/internal/command_queue.hpp @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "types.hpp" + +#include "tt_metal/impl/buffers/buffer.hpp" + +//================================================== +// COMMAND QUEUE OPERATIONS +//================================================== + + +namespace tt::tt_metal{ +namespace v1 { + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/internal/device.hpp b/tt_metal/include/tt_metal/internal/device.hpp new file mode 100644 index 00000000000..996b4d3da25 --- /dev/null +++ b/tt_metal/include/tt_metal/internal/device.hpp @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "types.hpp" +#include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +#include "tt_metal/impl/dispatch/work_executor.hpp" + +//================================================== +// DEVICE MANAGEMENT +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/internal/event.hpp b/tt_metal/include/tt_metal/internal/event.hpp new file mode 100644 index 00000000000..4a5cc90f37f --- /dev/null +++ b/tt_metal/include/tt_metal/internal/event.hpp @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "types.hpp" + +//================================================== +// INTERNAL EVENT MANAGEMENT +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/internal/kernel.hpp b/tt_metal/include/tt_metal/internal/kernel.hpp new file mode 100644 index 00000000000..f2d407306d8 --- /dev/null +++ b/tt_metal/include/tt_metal/internal/kernel.hpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "types.hpp" +#include "tt_metal/impl/kernels/kernel_types.hpp" + +//================================================== +// INTERNAL KERNEL EXECUTION +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/internal/metal_internal.hpp b/tt_metal/include/tt_metal/internal/metal_internal.hpp new file mode 100644 index 00000000000..6d0e57d39df --- /dev/null +++ b/tt_metal/include/tt_metal/internal/metal_internal.hpp @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +/** +These are internal API calls that are used for testing or other internal use only +for the purpose of supporting the the official Metal API located under tt_metal/api. + +Note that the directory structure here mirrors that of the Metal API. + +*/ +#include "types.hpp" +#include "buffer.hpp" +#include "command_queue.hpp" +#include "device.hpp" +#include "event.hpp" +#include "kernel.hpp" +#include "program.hpp" +#include "trace.hpp" diff --git a/tt_metal/include/tt_metal/internal/program.hpp b/tt_metal/include/tt_metal/internal/program.hpp new file mode 100644 index 00000000000..a6926f4bfab --- /dev/null +++ b/tt_metal/include/tt_metal/internal/program.hpp @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "types.hpp" + +#include "tt_metal/impl/kernels/kernel_types.hpp" +#include "tt_metal/impl/buffers/circular_buffer_types.hpp" + +//================================================== +// INTERNAL PROGRAM MANAGEMENT +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + +//================================================== +// INTERNAL PROGRAM FUNCTIONS +//================================================== + +/** + * @brief Launches a program on the device. + * + * @param device The device on which to launch the program. + * @param program The program to execute. + * @param wait_until_cores_done If true, waits until cores have completed execution. + * @param force_slow_dispatch If true, forces slow dispatch mode. + */ +void LaunchProgram(Device device, Program program, bool wait_until_cores_done = true, bool force_slow_dispatch = false); + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/internal/trace.hpp b/tt_metal/include/tt_metal/internal/trace.hpp new file mode 100644 index 00000000000..b9c0143f777 --- /dev/null +++ b/tt_metal/include/tt_metal/internal/trace.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "types.hpp" + +namespace tt::tt_metal{ +namespace v1 { + +//================================================== +// INTERNAL TRACE OPERATIONS +//================================================== + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/internal/types.hpp b/tt_metal/include/tt_metal/internal/types.hpp new file mode 100644 index 00000000000..d9e3c46cc53 --- /dev/null +++ b/tt_metal/include/tt_metal/internal/types.hpp @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_metal/api/types.hpp" + +namespace tt::tt_metal{ + +namespace v1 { + +} + + + +} diff --git a/tt_metal/include/tt_metal/kernel.hpp b/tt_metal/include/tt_metal/kernel.hpp new file mode 100644 index 00000000000..332e1cd6c13 --- /dev/null +++ b/tt_metal/include/tt_metal/kernel.hpp @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "types.hpp" +#include "tt_metal/impl/kernels/kernel_types.hpp" + +//================================================== +// KERNEL EXECUTION +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + +/** + * @brief Sets runtime arguments for a kernel. + * + * @param program The program containing the kernel. + * @param kernel KernelHandle representing the kernel ID. + * @param core_spec Specifies the cores where the runtime arguments will be set. + * @param runtime_args The runtime arguments to be set. + */ +void SetRuntimeArgs( + const Program program, + KernelHandle kernel, + const CoreRangeSet &core_spec, + const RuntimeArgs &runtime_args); + +/** + * @brief Sets multiple runtime arguments of a kernel at once. + * + * @param program The program containing the kernel. + * @param kernel KernelHandle representing the kernel ID. + * @param core_spec Vector of core coordinates where the runtime arguments will be set. + * @param runtime_args The runtime arguments to be set. + */ +void SetRuntimeArgs( + const Program program, + KernelHandle kernel, + const std::vector &core_spec, + const RuntimeArgs &runtime_args); + +/** + * @brief Sets common runtime arguments for a kernel, shared by all cores. + * + * @param program The program containing the kernel. + * @param kernel_id KernelHandle representing the kernel ID. + * @param runtime_args The runtime arguments to be set. + */ +void SetCommonRuntimeArgs(const Program program, KernelHandle kernel_id, const RuntimeArgs &runtime_args); + +/** + * @brief Gets the runtime arguments for a kernel. + * + * @param program The program containing the kernel. + * @param kernel_id KernelHandle representing the kernel ID. + * @param logical_core The logical core coordinate. + * @return Reference to RuntimeArgsData. + */ +RuntimeArgsData &GetRuntimeArgs(const Program program, KernelHandle kernel_id, const CoreCoord &logical_core); + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/metal.hpp b/tt_metal/include/tt_metal/metal.hpp new file mode 100644 index 00000000000..a800e24b4c8 --- /dev/null +++ b/tt_metal/include/tt_metal/metal.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "types.hpp" +#include "buffer.hpp" +#include "command_queue.hpp" +#include "device.hpp" +#include "event.hpp" +#include "kernel.hpp" +#include "program.hpp" +#include "trace.hpp" diff --git a/tt_metal/include/tt_metal/program.hpp b/tt_metal/include/tt_metal/program.hpp new file mode 100644 index 00000000000..5c268498cb7 --- /dev/null +++ b/tt_metal/include/tt_metal/program.hpp @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "types.hpp" + +#include "tt_metal/impl/kernels/kernel_types.hpp" +#include "tt_metal/impl/buffers/circular_buffer_types.hpp" + +//================================================== +// PROGRAM MANAGEMENT +//================================================== + +namespace tt::tt_metal{ +namespace v1 { + +/** + * @brief Creates a Program object, which bundles kernels, circular buffers, and semaphores for execution on the device. + * + * @return Program handle to the created program. + */ +Program CreateProgram(); + + +/** + * @brief Creates a data movement or compute kernel and adds it to the program. + * + * @param program The program to which this kernel will be added. + * @param file_name Path to the kernel source file. + * @param core_spec Specifies the cores on which the kernel will be placed. + * @param config DataMovementConfig for the kernel. + * @return KernelHandle representing the kernel ID. + */ +KernelHandle CreateKernel( + Program program, + std::string_view file_name, + const CoreRangeSet &core_spec, + const DataMovementConfig &config); + +/** + * @brief Creates a data movement or compute kernel and adds it to the program. + * + * @param program The program to which this kernel will be added. + * @param file_name Path to the kernel source file. + * @param core_spec Specifies the cores on which the kernel will be placed. + * @param config ComputeConfig for the kernel. + * @return KernelHandle representing the kernel ID. + */ +KernelHandle CreateKernel( + Program program, + std::string_view file_name, + const CoreRangeSet &core_spec, + const ComputeConfig &config); + +/** + * @brief Creates a data movement or compute kernel and adds it to the program. + * + * @param program The program to which this kernel will be added. + * @param file_name Path to the kernel source file. + * @param core_spec Specifies the cores on which the kernel will be placed. + * @param config EthernetConfig for the kernel. + * @return KernelHandle representing the kernel ID. + */ +KernelHandle CreateKernel( + Program program, + std::string_view file_name, + const CoreRangeSet &core_spec, + const EthernetConfig &config); + + +/** + * @brief Initializes a semaphore on specified cores. + * + * @param program The program to which the semaphore will be added. + * @param core_spec Range of cores using the semaphore. + * @param initial_value Initial value of the semaphore. + * @param core_type Core type on which to create the semaphore (default: CoreType::WORKER). + * @return Semaphore address as a uint32_t. + */ +uint32_t CreateSemaphore( + Program program, + const CoreRangeSet &core_spec, + uint32_t initial_value, + CoreType core_type = CoreType::WORKER); + + +/** + * @brief Creates a Circular Buffer in L1 memory of specified cores and adds it to the program. + * + * @param program The program to which the buffer will be added. + * @param core_spec Specifies the cores where the circular buffer will be configured. + * @param config Configuration for the circular buffer. + * @return CBHandle representing the Circular Buffer ID. + */ +CBHandle CreateCircularBuffer( + Program program, + const CoreRangeSet &core_spec, + const CircularBufferConfig &config); + +/** + * @brief Gets the configuration of a circular buffer. + * + * @param program The program containing the circular buffer. + * @param cb_handle Handle of the circular buffer. + * @return Reference to the CircularBufferConfig. + */ +const CircularBufferConfig &GetCircularBufferConfig(Program program, CBHandle cb_handle); + +/** + * @brief Retrieves the circular buffers associated with the program. + * + * @param program The program to query. + * @return Reference to a vector of shared pointers to CircularBuffer objects. + */ +const std::vector &GetCircularBuffers(Program program); + +/** + * @brief Retrieves the circular buffers associated with the program on a specific core range. + * + * @param program The program to query. + * @param cr The core range to consider. + * @return Vector of shared pointers to CircularBuffer objects on the core range. + */ +std::vector GetCircularBuffersOnCoreRange(Program program, const CoreRange &cr); + + +//================================================== +// PROGRAM FUNCTIONS +//================================================== + +/** + * @brief Updates the total size of a circular buffer. + * + * @param program The program containing the circular buffer. + * @param cb_handle Handle of the circular buffer. + * @param total_size New total size of the circular buffer in bytes. + */ +void UpdateCircularBufferTotalSize(Program program, CBHandle cb_handle, uint32_t total_size); + +/** + * @brief Updates the address of a dynamic circular buffer. + * + * @param program The program containing the circular buffer. + * @param cb_handle Handle of the circular buffer. + * @param buffer Dynamically allocated L1 buffer that shares address space with the circular buffer. + */ +void UpdateDynamicCircularBufferAddress(Program program, CBHandle cb_handle, const Buffer buffer); + + +/** + * @brief Captures dependencies for multi-device execution in the program. + * + * @param program The program to modify. + */ +void CaptureMultiDeviceDependencies(Program program); + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/trace.hpp b/tt_metal/include/tt_metal/trace.hpp new file mode 100644 index 00000000000..5131d8a978c --- /dev/null +++ b/tt_metal/include/tt_metal/trace.hpp @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "types.hpp" + +namespace tt::tt_metal{ +namespace v1 { + +//================================================== +// TRACE OPERATIONS +//================================================== + +/** + * @brief Begins capture on a trace. Captured commands will have their execution delayed until the trace is replayed. + * + * @param device The device being traced. + * @param cq The command queue associated with the trace. + * @return Trace ID. + */ +uint32_t BeginTraceCapture(Device device, CommandQueue cq); + +/** + * @brief Ends capture on a trace. The trace can be replayed on the same device command queue. + * + * @param device The device being traced. + * @param cq The command queue associated with the trace. + * @param tid The trace ID returned by BeginTraceCapture. + */ +void EndTraceCapture(Device device, CommandQueue cq, uint32_t tid); + +/** + * @brief Replays a captured trace on the device. + * + * @param device The device holding the trace. + * @param cq The command queue associated with the trace. + * @param trace The trace ID to replay. + * @param blocking Indicates whether the operation is blocking. + */ +void ReplayTrace(Device device, CommandQueue cq, Trace trace, bool blocking); + +/** + * @brief Releases a previously captured trace, deallocating associated resources. + * + * @param device The device holding the trace. + * @param trace The trace to release. + */ +void ReleaseTrace(Device device, Trace trace); + +/** + * @brief Enqueues a trace for execution on the device. + * + * @param cq The command queue used to dispatch the trace. + * @param trace The Trace to enqueue. + * @param blocking Indicates whether the operation is blocking. + */ +void EnqueueTrace(CommandQueue cq, Trace trace, bool blocking); + + +} // namespace v1 +} // namespace tt::tt_metal diff --git a/tt_metal/include/tt_metal/types.hpp b/tt_metal/include/tt_metal/types.hpp new file mode 100644 index 00000000000..b962a06a111 --- /dev/null +++ b/tt_metal/include/tt_metal/types.hpp @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "device/tt_cluster_descriptor_types.h" +#include "hostdevcommon/common_values.hpp" +#include "tt_metal/impl/dispatch/dispatch_core_manager.hpp" + +namespace tt::tt_metal{ + +namespace v1 { + +// Opaque classes +class Program; +class Device; +class CommandQueue; +class Trace; + +// Ideally these would be opaque but this work requires +// completion of the prototype of the runtime args. +class CircularBuffer; +class Buffer; + +// Not likely going to be opaque, but pending review of +// completion of the prototype of the runtime args. +class Event; +class RuntimeArgs; +class RuntimeArgsData; + +} + + + +} diff --git a/tt_metal/jit_build/build.cpp b/tt_metal/jit_build/build.cpp index 083a30f2377..7043fdde0ee 100644 --- a/tt_metal/jit_build/build.cpp +++ b/tt_metal/jit_build/build.cpp @@ -196,11 +196,6 @@ void JitBuildState::finish_init() { // Note the preceding slash which defies convention as this gets appended to // the kernel name used as a path which doesn't have a slash this->target_full_path_ = "/" + this->target_name_ + "/" + this->target_name_ + ".elf"; - - if (not this->is_fw_) { - // Emit relocations, so we can relocate the resulting binary - this->lflags_ += "-Wl,--emit-relocs "; - } } JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) : diff --git a/tt_metal/llrt/llrt.cpp b/tt_metal/llrt/llrt.cpp index 183df6abb24..28521cabc72 100644 --- a/tt_metal/llrt/llrt.cpp +++ b/tt_metal/llrt/llrt.cpp @@ -25,8 +25,7 @@ using std::uint16_t; using std::uint32_t; using std::uint64_t; -ll_api::memory get_risc_binary(string const &path, uint32_t riscv_id, - ll_api::memory::PackSpans span_type, ll_api::memory::Relocate relo_type) { +ll_api::memory get_risc_binary(string const &path, uint32_t riscv_id, PackSpans pack_spans) { static const uint32_t processor_to_fw_base_addr[] = { MEM_BRISC_FIRMWARE_BASE, @@ -49,15 +48,13 @@ ll_api::memory get_risc_binary(string const &path, uint32_t riscv_id, if (inserted) { // We're the first with PATH. Create and insert. lock.unlock(); - auto *ptr = new ll_api::memory(path, relo_type); + auto *ptr = new ll_api::memory(path); // TODO: pass pack_spans into reader, generate text/data sizes // from segment sizes and pack there - if (span_type == ll_api::memory::PackSpans::PACK) { + if (pack_spans == PackSpans::PACK) { uint64_t data_start = MEM_LOCAL_BASE; - uint64_t text_start = (relo_type == ll_api::memory::Relocate::XIP) ? - 0 : - processor_to_fw_base_addr[riscv_id]; + uint64_t text_start = processor_to_fw_base_addr[riscv_id]; ptr->pack_data_into_text(text_start, data_start); } @@ -205,14 +202,6 @@ bool test_load_write_read_trisc_binary(ll_api::memory &mem, chip_id_t chip_id, c return test_load_write_read_risc_binary(mem, chip_id, core, triscv_id + 2); } -void write_binary_to_address(ll_api::memory &mem, chip_id_t chip_id, const CoreCoord &core, uint32_t address) { - - log_debug(tt::LogLLRuntime, "vec size = {}, size_in_bytes = {}", mem.size(), mem.size() * sizeof(uint32_t)); - mem.process_spans([&](std::vector::const_iterator mem_ptr, uint64_t addr, uint32_t len_words) { - tt::Cluster::instance().write_core(&*mem_ptr, len_words * sizeof(uint32_t), tt_cxy_pair(chip_id, core), address); - }); -} - CoreCoord get_core_for_dram_channel(int dram_channel_id, chip_id_t chip_id) { return tt::Cluster::instance().get_soc_desc(chip_id).get_preferred_worker_core_for_dram_channel(dram_channel_id); } diff --git a/tt_metal/llrt/llrt.hpp b/tt_metal/llrt/llrt.hpp index d690aebc144..391622a2f67 100644 --- a/tt_metal/llrt/llrt.hpp +++ b/tt_metal/llrt/llrt.hpp @@ -52,10 +52,8 @@ using NUM_REPETITIONS = std::uint32_t; using WorkerCore = tt_cxy_pair; using WorkerCores = std::vector; -ll_api::memory get_risc_binary(string const &path, uint32_t riscv_id = 0, - ll_api::memory::PackSpans span_type = ll_api::memory::PackSpans::NO_PACK, - ll_api::memory::Relocate relo_type = ll_api::memory::Relocate::NONE); - +enum class PackSpans { PACK, NO_PACK }; +ll_api::memory get_risc_binary(string const &path, uint32_t riscv_id = 0, PackSpans pack_spans = PackSpans::NO_PACK); // TODO: try using "stop" method from device instead, it's the proper way of asserting reset @@ -95,8 +93,8 @@ uint32_t generate_risc_startup_addr(bool is_eth_core); void program_risc_startup_addr(chip_id_t chip_id, const CoreCoord &core); bool test_load_write_read_risc_binary(ll_api::memory &mem, chip_id_t chip_id, const CoreCoord &core, int riscv_id); + bool test_load_write_read_trisc_binary(ll_api::memory &mem, chip_id_t chip_id, const CoreCoord &core, int triscv_id); -void write_binary_to_address(ll_api::memory &mem, chip_id_t chip_id, const CoreCoord &core, uint32_t address); // subchannel hard-coded to 0 for now CoreCoord get_core_for_dram_channel(int dram_channel_id, chip_id_t chip_id = 0); diff --git a/tt_metal/llrt/tt_elffile.cpp b/tt_metal/llrt/tt_elffile.cpp index 0f5b8145c08..6007219430c 100644 --- a/tt_metal/llrt/tt_elffile.cpp +++ b/tt_metal/llrt/tt_elffile.cpp @@ -8,8 +8,6 @@ #include #include "common/assert.hpp" -// C++ -#include // C #include // OS @@ -46,19 +44,6 @@ enum { #define EM_RISCV_WORMHOLE 0x5151 #define EM_RISCV_BLACKHOLE 0x6151 -// We have to translate these two instructions -static constexpr uint32_t insn_opc_auipc = 0x00000017; -static constexpr uint32_t insn_opc_lui = 0x00000037; -static constexpr uint32_t insn_mask_u = 0x0000007f; -static constexpr uint32_t mask_hi20 = 0x00000fff; -static constexpr unsigned mask_hi20_shift = 12; -static constexpr uint32_t mask_lo12_i = 0x000fffff; -static constexpr unsigned mask_lo12_i_shift = 20; -static constexpr uint32_t mask_lo12_s = 0x01fff07f; -static constexpr unsigned mask_lo12_s_split = 5; -static constexpr unsigned mask_lo12_s_shift_1 = 7; -static constexpr unsigned mask_lo12_s_shift_2 = 25; - using namespace ll_api; class ElfFile::Impl { @@ -78,7 +63,6 @@ class ElfFile::Impl { public: void LoadImage(); void WeakenDataSymbols(std::span strong_names); - void XIPify(); private: [[nodiscard]] auto GetHeader() const -> Elf32_Ehdr const & { return *ByteOffset(GetContents().data()); } @@ -143,13 +127,6 @@ class ElfFile::Impl { [[nodiscard]] static T const *ByteOffset(std::byte const *base, size_t offset = 0) { return reinterpret_cast(base + offset); } - - uint32_t Read32(Elf32_Shdr const &shdr, address_t addr) { - return *ByteOffset(GetContents(shdr).data(), addr - shdr.sh_addr); - } - void Write32(Elf32_Shdr const &shdr, address_t addr, uint32_t value) { - *ByteOffset(GetContents(shdr).data(), addr - shdr.sh_addr) = value; - } }; ElfFile::~ElfFile() { @@ -198,8 +175,6 @@ void ElfFile::WriteImage(std::string const &path) { void ElfFile::WeakenDataSymbols(std::span strong) { pimpl_->WeakenDataSymbols(strong); } -void ElfFile::MakeExecuteInPlace() { pimpl_->XIPify(); } - void ElfFile::Impl::LoadImage() { auto &hdr = GetHeader(); @@ -389,309 +364,3 @@ void ElfFile::Impl::WeakenDataSymbols(std::span strong) weakener.RewriteSymbols(); } } - -void ElfFile::Impl::XIPify() { - // In general there can be several lo12 relocs for a hi20 - // reloc. This is particularly true for lui/{addi,lw,sw,etc} - // pairs -- a load and a store might share a single lui, as - // the compiler now emits those insns separately. Thus we have - // to build a work list and then process it. Furthermore, - // although auipc/lo12 pairings are clear because the lo12 - // part directly points at the auipc, that is not true of - // lui/lo12 pairings. We have to use heuristics to locate the - // matching relocs and that could get arbitrarily hard. We - // presume (a) the compiler doesn't duplicate lui insns, and - // (b) the lui preceeds the lo12 in program counter - // order. Thus we look for a hi20 reloc matching the symbol at - // a lower offset than the lo12 in question. Fortunately we - // only need to do this for relocs that need translating, and - // those happen to be rare when all data-like sections are in - // the data segment (so putting .rodata in text is - // problematic). If that proves insufficient here are some - // ideas: - - // * Insert fn boundaries from symbols of FNtype -- you'll - // need to tweak the fn address to not cause collisions in - // the reloc map. this might fail with hot/cold block - // splitting. - - // * Construct the CFG by examining R_RISCV_BRANCH - // relocs. Then walk it (backwards) from each lo12 to find - // the reachable hi20. This would be able to deal with - // hot/cold splitting, if one constructed the complete - // section CFG, not as a per-fn entity. One might get away - // with not disasembling to discover ret instructions that - // terminate the CFG. - - struct ComposedReloc { - std::vector lo_relocs; - Elf32_Rela *hi_reloc = nullptr; // the high part - - ComposedReloc(Elf32_Rela *hi) : hi_reloc(hi) {} - }; - - enum { ABS, PCREL, HWM }; - static char const *const r_names[][2] = { - {"R_RISCV_HI20", "R_RISCV_LO12"}, {"R_RISCV_PCREL_HI20", "R_RISCV_PCREL_LO12"}}; - - auto check_relaxed = [&](Elf32_Rela const &reloc) { - // If RELOC is the final reloc, this will - // be out of bounds (and probably fail), - // but we kind of want that anyway - if (ELF32_R_TYPE((&reloc)[1].r_info) != R_RISCV_RELAX) - log_debug(tt::LogLLRuntime, "{}: Relocation at {x} is not relaxed", path_, reloc.r_offset); - }; - - unsigned num_reloc_sections = 0; - for (auto const &relocHdr : GetShdrs()) { - if (relocHdr.sh_type != SHT_RELA) - continue; - - // Is this relocating a section of interest? - unsigned section_ix = relocHdr.sh_info; - auto §ion = GetShdr(section_ix); - if (!(section.sh_flags & SHF_ALLOC && section.sh_type != SHT_NOBITS)) - continue; - - int segment_ix = GetSegmentIx(section); - if (segment_ix < 0) - continue; - - num_reloc_sections++; - std::map composed[HWM]; - std::vector lo[HWM]; - - auto symbols = GetSymbols(GetShdr(relocHdr.sh_link)); - auto relocs = GetRelocations(relocHdr); - bool is_from_text = !segment_ix; - - // ADD32/SUB32 pairs are used for switch tables. Make sure - // they're consistent. - Elf32_Rela const *sub_reloc = nullptr; // Active sub reloc. - for (auto ix = relocs.size(); ix--;) { - auto &reloc = relocs[ix]; - if (reloc.r_offset & 3 || reloc.r_offset - section.sh_addr >= section.sh_size) - TT_THROW( - "{}: relocation @ {x} is {} section {}", - path_, - reloc.r_offset, - reloc.r_offset & 3 ? "misaligned in" : "outside of", - GetName(section)); - - auto type = ELF32_R_TYPE(reloc.r_info); - auto sym_ix = ELF32_R_SYM(reloc.r_info); - auto const *symbol = &symbols[sym_ix]; - bool is_to_text = IsTextSymbol(*symbol); - - // Check add/sub relocs are paired and do not cross text/non-text boundary. - if (bool(sub_reloc) != (type == R_RISCV_ADD32) || (sub_reloc && sub_reloc->r_offset != reloc.r_offset)) - unpaired_sub: - TT_THROW( - "{}: unpaired {} reloc at {x}", - path_, - sub_reloc ? "sub32" : "add32", - (sub_reloc ? sub_reloc : &reloc)->r_offset); - if (type == R_RISCV_ADD32) { - auto const *sub_symbol = &symbols[ELF32_R_SYM(sub_reloc->r_info)]; - bool sub_is_to_text = IsTextSymbol(*sub_symbol); - if (is_to_text != sub_is_to_text) - TT_THROW( - "{}: mismatched add32/sub32 relocs at {x} & {x}", path_, reloc.r_offset, sub_reloc->r_offset); - } - sub_reloc = nullptr; - if (type == R_RISCV_SUB32) { - sub_reloc = &reloc; - if (!ix) - goto unpaired_sub; - } - - unsigned kind = PCREL; - switch (type) { - // Abs relocs to text will need fixing up - case R_RISCV_LO12_I: - case R_RISCV_LO12_S: - if (!is_to_text) - break; - kind = ABS; - [[fallthrough]]; - - // PCrel relocs not to text will need fixing up. At - // this point we don't know the symbol from the LO12 - // relocs, as that points at the hi20 reloc. - case R_RISCV_PCREL_LO12_I: - case R_RISCV_PCREL_LO12_S: lo[kind].push_back(&reloc); break; - - case R_RISCV_HI20: kind = ABS; [[fallthrough]]; - - case R_RISCV_PCREL_HI20: - if (is_to_text && !is_from_text) - TT_THROW( - "{}: segment-crossing {} relocation found at {x}", path_, r_names[kind][0], reloc.r_offset); - - if (!is_to_text && kind == ABS) - break; - composed[kind].emplace(reloc.r_offset, ComposedReloc(&reloc)); - break; - - case R_RISCV_32: { - if (!is_to_text) - break; - // Emit dynamic reloc - log_debug( - tt::LogLLRuntime, "{}: emitting dynamic R_RISCV_32 relocation at {x}", path_, reloc.r_offset); - address_t value = - (symbol->st_value + reloc.r_addend - GetSegments().front().address); - Write32(section, reloc.r_offset, value); - auto &seg = GetSegments()[segment_ix]; - seg.relocs.push_back(reloc.r_offset - seg.address); - } break; - - case R_RISCV_JAL: - if (is_from_text != is_to_text) - TT_THROW("{}: segment-crossing R_RISCV_JAL relocation found at {x}", path_, reloc.r_offset); - break; - - case R_RISCV_CALL: - case R_RISCV_CALL_PLT: - TT_THROW("{}: R_RISCV_CALL{,_PLT} relocation found at {x}", path_, reloc.r_offset); - break; - - case R_RISCV_32_PCREL: - TT_THROW("{}: R_RISCV_32_PCREL relocation found at {x}", path_, reloc.r_offset); - break; - } - } - - // Combine hi/lo relocs - - // We can't do abs ones in general with complete accuracy, - // because there could be multiple possible matching hi - // relocs. If we construct the CFG then it becomes more - // accurate, but it's always going to be somewhat - // heuristic. Let's hope CFG construction is unnecessary. A - // first step in that direction might be to insert function - // boundaries, to stop the search. - for (unsigned kind = HWM; kind--;) { - for (auto *lo_reloc : lo[kind]) { - // Find the matching hi-reloc by searching backwards. This - // presumes block reordering hasn't done something to - // break that. - unsigned sym_ix = ELF32_R_SYM(lo_reloc->r_info); - auto hi_reloc = composed[kind].begin(); - - if (kind == ABS) { - hi_reloc = composed[kind].lower_bound(lo_reloc->r_offset); - while (hi_reloc != composed[kind].begin()) { - --hi_reloc; - if (ELF32_R_SYM(hi_reloc->second.hi_reloc->r_info) == sym_ix) - goto found; - } - } else { - uint32_t hi_offset = symbols[sym_ix].st_value + lo_reloc->r_addend; - hi_reloc = composed[kind].find(hi_offset); - if (hi_reloc != composed[kind].end()) - goto found; - } - TT_THROW( - "{}: {} relocation at {x} has no matching {}", - path_, - r_names[kind][true], - lo_reloc->r_offset, - r_names[kind][false]); - found: - hi_reloc->second.lo_relocs.push_back(lo_reloc); - } - } - - // Process composed relocations - for (unsigned kind = HWM; kind--;) { - for (auto &slot : composed[kind]) { - if (slot.second.lo_relocs.empty()) - TT_THROW( - "{}: R_RISCV_{}HI20 relocation at {x} has no matching R_RISCV_{}LO12", - path_, - r_names[kind][false], - r_names[kind][true], - slot.first); - - auto hi_reloc = slot.second.hi_reloc; - unsigned sym_ix = ELF32_R_SYM(hi_reloc->r_info); - auto const &symbol = symbols[sym_ix]; - bool is_to_text = IsTextSymbol(symbol); - if (is_to_text == is_from_text) - continue; - - address_t value = symbol.st_value + hi_reloc->r_addend; - if (kind == ABS) { - value -= slot.first; - sym_ix = 0; - } - - // translate hi - check_relaxed(*hi_reloc); - uint32_t insn = Read32(section, hi_reloc->r_offset); - log_debug( - tt::LogLLRuntime, - "{}: translating {} at {x} to {}", - path_, - r_names[kind][false], - hi_reloc->r_offset, - r_names[HWM - 1 - kind][false]); - if ((insn & insn_mask_u) != (kind == ABS ? insn_opc_lui : insn_opc_auipc)) - TT_THROW( - "{}: translating instruction at {x} is not `{}'", - path_, - hi_reloc->r_offset, - kind == ABS ? "lui" : "auipc"); - insn &= mask_hi20; // Remove old immediate - insn ^= insn_opc_auipc ^ insn_opc_lui; // Convert opcode - // Insert new immediate - insn |= ((value + (1 << 11)) >> 12) << mask_hi20_shift; - Write32(section, hi_reloc->r_offset, insn); - hi_reloc->r_info ^= ELF32_R_INFO(0, R_RISCV_HI20 ^ R_RISCV_PCREL_HI20); - - // translate lo - for (auto *lo_reloc : slot.second.lo_relocs) { - unsigned type = ELF32_R_TYPE(lo_reloc->r_info); - bool is_form_i = type == (kind == PCREL ? R_RISCV_PCREL_LO12_I : R_RISCV_LO12_I); - check_relaxed(*lo_reloc); - uint32_t insn = Read32(section, lo_reloc->r_offset); - log_debug( - tt::LogLLRuntime, - "{}: translating R_RISCV{}_LO12 at {x} to R_RISCV{}_LO12", - path_, - r_names[kind][true], - lo_reloc->r_offset, - r_names[HWM - 1 - kind][true]); - if (is_form_i) { - insn &= mask_lo12_i; - insn |= (value & 0x0fff) << mask_lo12_i_shift; - } else { - // S form splits the immediate - insn &= mask_lo12_s; - insn |= (value & ((1 << mask_lo12_s_split) - 1)) << mask_lo12_s_shift_1; - insn |= ((value & 0x0fff) >> mask_lo12_s_split) << mask_lo12_s_shift_2; - } - Write32(section, lo_reloc->r_offset, insn); - - // We can't convert to PCREL with fidelity, as - // that involves adding a symbol. Instead, let's - // use a null symbol and an addend. - lo_reloc->r_info = ELF32_R_INFO( - sym_ix, - type ^ (is_form_i ? (R_RISCV_LO12_I ^ R_RISCV_PCREL_LO12_I) - : (R_RISCV_LO12_S ^ R_RISCV_PCREL_LO12_S))); - lo_reloc->r_addend = kind == PCREL ? slot.second.hi_reloc->r_addend - : slot.second.hi_reloc->r_offset - lo_reloc->r_offset; - } - } - } - } - - if (!num_reloc_sections) - // Hm, that's suspicious - TT_THROW("{}: there are no relocation sections", path_); - - // The text segment is now XIP - GetSegments().front().address = 0; -} diff --git a/tt_metal/llrt/tt_elffile.hpp b/tt_metal/llrt/tt_elffile.hpp index 7c1b09ff034..b93bfd8d0f5 100644 --- a/tt_metal/llrt/tt_elffile.hpp +++ b/tt_metal/llrt/tt_elffile.hpp @@ -25,7 +25,6 @@ class ElfFile { using word_t = std::uint32_t; // Contents struct Segment { - std::vector relocs; // 32-bit relocs to apply std::span contents; // Non-owning span address_t address = 0; // byte address or 0 for XIP offset_t bss = 0; // words of BSS @@ -76,9 +75,6 @@ class ElfFile { // globs ending in '*'. void WeakenDataSymbols(std::span strong_names); - // XIPify - void MakeExecuteInPlace(); - private: class Impl; // We can't use unique_ptr here, because the above move semantics diff --git a/tt_metal/llrt/tt_memory.cpp b/tt_metal/llrt/tt_memory.cpp index 376b2213038..b8933f60379 100644 --- a/tt_metal/llrt/tt_memory.cpp +++ b/tt_metal/llrt/tt_memory.cpp @@ -23,20 +23,16 @@ memory::memory() { packed_size_ = 0; } -memory::memory(std::string const &path, Relocate relo_type) : memory() { +memory::memory(std::string const &path) : memory() { ElfFile elf; elf.ReadImage(path); - if (relo_type == Relocate::XIP) { - elf.MakeExecuteInPlace(); - } // The ELF file puts the text segment first, but memory wants // ordered spans. // FIXME: Perhaps we can relax that? uint32_t total_size = 0; auto emit_segment = [&](ElfFile::Segment const& segment) { - TT_ASSERT(segment.relocs.empty(), "Unexpected dynamic relocations"); link_spans_.emplace_back( segment.address, segment.contents.size()); data_.insert(data_.end(), segment.contents.begin(), segment.contents.end()); @@ -53,7 +49,7 @@ memory::memory(std::string const &path, Relocate relo_type) : memory() { if (text) emit_segment(*text); - set_text_size(elf.GetSegments()[0].contents.size() * sizeof(word_t)); + set_text_size(elf.GetSegments()[0].contents.size() * sizeof(uint32_t)); set_packed_size(total_size * sizeof(uint32_t)); } @@ -148,7 +144,6 @@ void memory::pack_data_into_text(std::uint64_t text_start, std::uint64_t data_st this->link_spans_.resize(1); this->link_spans_[0] = new_span; this->data_ = new_data; - this->text_addr_ = new_span.addr; } } // namespace ll_api diff --git a/tt_metal/llrt/tt_memory.h b/tt_metal/llrt/tt_memory.h index 98eda2331c8..b39e899e0a3 100644 --- a/tt_metal/llrt/tt_memory.h +++ b/tt_metal/llrt/tt_memory.h @@ -20,8 +20,6 @@ class memory { public: typedef std::uint64_t address_t; typedef std::uint32_t word_t; - enum class PackSpans { PACK, NO_PACK }; - enum class Relocate { XIP, NONE }; private: static constexpr uint32_t initial_data_space_ = 0x400; @@ -39,11 +37,10 @@ class memory { std::vector link_spans_; uint32_t text_size_; uint32_t packed_size_; - uint32_t text_addr_; public: memory(); - memory(std::string const &path, Relocate relo_type); + memory(std::string const &path); public: const std::vector& data() const { return this->data_; } @@ -55,7 +52,6 @@ class memory { void set_packed_size(uint32_t size) { this->packed_size_ = size; } uint32_t get_text_size() const { return this->text_size_; } uint32_t get_packed_size() const { return this->packed_size_; } - uint32_t get_text_addr() const { return this->text_addr_; } size_t size() const { return data_.size(); } diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index fca037ad2a1..f71c5c49302 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -113,18 +113,11 @@ DataMovementConfigStatus CheckDataMovementConfig(Program &program, const CoreRan } void ConfigureKernelGroup( - Program &program, - uint32_t programmable_core_type_index, - const KernelGroup *kernel_group, - Device *device, - const CoreCoord &logical_core) { + const Program &program, const KernelGroup *kernel_group, Device *device, const CoreCoord &logical_core) { - uint32_t kernel_config_base = hal.get_dev_addr(programmable_core_type_index, HalL1MemAddrType::KERNEL_CONFIG); for (auto& optional_id : kernel_group->kernel_ids) { if (optional_id) { - // Need the individual offsets of each bin - detail::GetKernel(program, optional_id.value())->configure(device, logical_core, - kernel_config_base, kernel_group->kernel_text_offsets); + detail::GetKernel(program, optional_id.value())->configure(device, logical_core); } } } @@ -703,7 +696,7 @@ bool ConfigureDeviceWithProgram(Device *device, Program &program, bool fd_bootlo KernelGroup *kernel_group = program.kernels_on_core(logical_core, index); CoreCoord physical_core = device->physical_core_from_logical_core(logical_core, core_type); - ConfigureKernelGroup(program, index, kernel_group, device, logical_core); + ConfigureKernelGroup(program, kernel_group, device, logical_core); // TODO: add support for CB for ethernet cores if (core_type == CoreType::WORKER) { // CircularBufferConfigVec -- common across all kernels, so written once to the core diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 5f1d7d0ff1d..873296beece 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -44,6 +44,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/kv_cache.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/kv_cache_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/device/update_cache_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/common/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/concat.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp @@ -348,9 +349,6 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa_pybind.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode_gqa/device/sdpa_decode_gqa_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp diff --git a/ttnn/cpp/pybind11/operations/creation.hpp b/ttnn/cpp/pybind11/operations/creation.hpp index 8872df553cb..ce6f2a44d55 100644 --- a/ttnn/cpp/pybind11/operations/creation.hpp +++ b/ttnn/cpp/pybind11/operations/creation.hpp @@ -96,7 +96,7 @@ void bind_full_operation(py::module& module, const creation_operation_t& operati } template -void bind_full_operation_with_hard_coded_value(py::module& module, const creation_operation_t& operation, const std::string& value_string) { +void bind_full_operation_with_hard_coded_value(py::module& module, const creation_operation_t& operation, const std::string& value_string, const std::string& info_doc = "") { auto doc = fmt::format( R"doc( Creates a tensor with the specified shape and fills it with the value of {1}. @@ -115,6 +115,9 @@ void bind_full_operation_with_hard_coded_value(py::module& module, const creatio Returns: ttnn.Tensor: A tensor filled with {1}. + Note: + {2} + Example: >>> tensor = ttnn.{0}(shape=[1, 2, 2, 2], dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT) >>> print(tensor) @@ -124,7 +127,8 @@ void bind_full_operation_with_hard_coded_value(py::module& module, const creatio [{1}, {1}]]]]], shape=Shape([1, 2, 2, 2]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR) )doc", operation.base_name(), - value_string); + value_string, + info_doc); bind_registered_operation( module, @@ -221,7 +225,7 @@ void bind_full_like_operation(py::module& module, const creation_operation_t& op } template -void bind_full_like_operation_with_hard_coded_value(py::module& module, const creation_operation_t& operation, const std::string& value_string) { +void bind_full_like_operation_with_hard_coded_value(py::module& module, const creation_operation_t& operation, const std::string& value_string, const std::string& info_doc = "") { auto doc = fmt::format( R"doc( Creates a tensor of the same shape as the input tensor and fills it with the value of {1}. The data type, layout, device, and memory configuration of the resulting tensor can be specified. @@ -238,6 +242,9 @@ void bind_full_like_operation_with_hard_coded_value(py::module& module, const cr Returns: ttnn.Tensor: A tensor filled with {1}. + Note: + {2} + Example: >>> tensor = ttnn.{0}(ttnn.from_torch(torch.randn(1, 2, 2, 2), ttnn.bfloat16, ttnn.TILE_LAYOUT) >>> output_tensor = ttnn.{0}(tensor=input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) @@ -248,7 +255,8 @@ void bind_full_like_operation_with_hard_coded_value(py::module& module, const cr [{1}, {1}]]]]], shape=Shape([1, 2, 2, 2]), dtype=DataType::BFLOAT16, layout=Layout::TILE_LAYOUT) )doc", operation.base_name(), - value_string); + value_string, + info_doc); bind_registered_operation( module, @@ -320,7 +328,7 @@ void bind_arange_operation(py::module& module, const creation_operation_t& opera py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}); } -void bind_empty_operation(py::module& module) { +void bind_empty_operation(py::module& module, const std::string& info_doc = "") { auto doc = fmt::format( R"doc( Creates a device tensor with uninitialized values of the specified shape, data type, layout, and memory configuration. @@ -335,12 +343,16 @@ void bind_empty_operation(py::module& module) { Returns: ttnn.Tensor: The output uninitialized tensor. + Note: + {1} + Example: >>> tensor = ttnn.empty(shape=[2, 3], device=device) >>> print(tensor) ttnn.Tensor([[[[0.9, 0.21, 0.5], [0.67, 0.11, 0.30]]]], shape=Shape([2, 3]), dtype=DataType::BFLOAT16, layout=Layout::TILE) )doc", - ttnn::empty.base_name()); + ttnn::empty.base_name(), + info_doc); using EmptyType = decltype(ttnn::empty); bind_registered_operation( @@ -414,16 +426,40 @@ void bind_empty_like_operation(py::module& module) { void py_module(py::module& module) { detail::bind_full_operation(module, ttnn::full); - detail::bind_full_operation_with_hard_coded_value(module, ttnn::zeros, "0.0"); + detail::bind_full_operation_with_hard_coded_value(module, ttnn::zeros, "0.0", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16, FLOAT32 | ROW_MAJOR, TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+)doc"); + detail::bind_full_operation_with_hard_coded_value(module, ttnn::ones, "1.0"); detail::bind_full_like_operation(module, ttnn::full_like); - detail::bind_full_like_operation_with_hard_coded_value(module, ttnn::zeros_like, "0.0"); + detail::bind_full_like_operation_with_hard_coded_value(module, ttnn::zeros_like, "0.0", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16, FLOAT32 | ROW_MAJOR, TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+)doc"); detail::bind_full_like_operation_with_hard_coded_value(module, ttnn::ones_like, "1.0"); detail::bind_arange_operation(module, ttnn::arange); - detail::bind_empty_operation(module); + detail::bind_empty_operation(module, + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16, FLOAT32 | ROW_MAJOR, TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT_8 | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+)doc"); detail::bind_empty_like_operation(module); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index e49a602f839..17c9c0cd91c 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -73,7 +73,6 @@ Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, u return tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(conv_weight_tensor, num_groups, output_dtype); } -template ParallelConfig determine_parallel_config( const TensorMemoryLayout shard_layout, uint32_t batch_size, @@ -81,7 +80,7 @@ ParallelConfig determine_parallel_config( uint32_t output_height, uint32_t output_width, uint32_t output_channels, - T * device, + const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, bool is_out_tiled) { @@ -91,21 +90,20 @@ ParallelConfig determine_parallel_config( uint32_t out_c_ntiles = tt::round_up(output_channels, effective_tile_width) / effective_tile_width; // calculate num_core_nhw and the grid - auto grid_size = device->compute_with_storage_grid_size(); - uint32_t max_num_cores = grid_size.x * grid_size.y; + uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; uint32_t num_cores_nhw = 0; CoreRangeSet grid; if (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) { num_cores_nhw = find_closest_largest_divisor(out_nhw_ntiles, max_num_cores); - if (num_cores_nhw < grid_size.x && out_nhw_ntiles > grid_size.x) { - num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, grid_size.x); + if (num_cores_nhw < compute_grid_size.x && out_nhw_ntiles > compute_grid_size.x) { + num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, compute_grid_size.x); } - grid = num_cores_to_corerange_set(num_cores_nhw, grid_size, true); + grid = num_cores_to_corerange_set(num_cores_nhw, compute_grid_size, true); } else if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) { uint32_t start_divisor = - block_shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.x : grid_size.y; + block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y; num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); - uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), block_shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y : grid_size.x); + uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x); uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c; uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw; CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1})); @@ -113,7 +111,7 @@ ParallelConfig determine_parallel_config( } else if (shard_layout == TensorMemoryLayout::WIDTH_SHARDED) { num_cores_nhw = 1; uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), max_num_cores); - grid = num_cores_to_corerange_set(num_cores_c, grid_size, true); + grid = num_cores_to_corerange_set(num_cores_c, compute_grid_size, true); } else { TT_THROW("Conv2d supports Height, Block or Width Sharded Layouts but got {}", shard_layout); } @@ -252,8 +250,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( uint32_t act_block_w_div, uint32_t window_h, uint32_t window_w, - bool fp32_accum, - bool use_shallow_conv_variant) { + bool fp32_accum) { if (act_block_h_override > 0) { TT_ASSERT( @@ -263,7 +260,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( auto grid_size = parallel_config.grid.bounding_box().grid_size(); uint32_t act_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; if (parallel_config.shard_scheme != TensorMemoryLayout::WIDTH_SHARDED && act_block_h_override > 0 ) { - log_info(LogOp, "act_block_h_override is set, but ignored when Width Sharding is used"); + log_debug(LogOp, "act_block_h_override is set, but ignored when Width Sharding is used"); act_block_h_ntiles = act_block_h_override / constants::TILE_HEIGHT; } uint32_t act_block_w = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED @@ -281,13 +278,6 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( uint32_t weight_block_w_ntiles = conv_op_parallel_config.per_core_out_matrix_width_ntiles; auto [out_subblock_h_ntiles, out_subblock_w_ntiles] = determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum); - if (use_shallow_conv_variant && ((act_block_h_ntiles / out_subblock_h_ntiles) % 2 != 0)) { - TT_ASSERT(parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED); - // TODO: do a proper fix and remove this temporary hack for shallow conv - TT_ASSERT(act_block_h_ntiles % 2 == 0); - out_subblock_h_ntiles = act_block_h_ntiles / 2; - TT_ASSERT((out_subblock_h_ntiles * out_subblock_w_ntiles) <= 8); - } return { .act_block_h_ntiles = act_block_h_ntiles, .act_block_w_ntiles = act_block_w_ntiles, @@ -307,8 +297,8 @@ static bool use_matmul_for_1x1_conv( // Implements a heuristic for selecting shard layout based on how many tenix cores are available // for each shard. -template static TensorMemoryLayout select_shard_spec( + bool is_mm_conv, uint32_t batch_size, uint32_t in_channels, uint32_t out_channels, @@ -316,13 +306,10 @@ static TensorMemoryLayout select_shard_spec( uint32_t output_width, uint32_t weights_width, uint32_t input_width, - uint32_t groups, ShardOrientation shard_orientation, const std::array& kernel_size, const std::array& stride, - const std::array& padding, - const std::array& dilation, - T const * device) { + const CoreCoord& compute_grid_size) { auto get_core_count_for_sharding = [&](TensorMemoryLayout shard_layout) { return determine_parallel_config( shard_layout, @@ -331,7 +318,7 @@ static TensorMemoryLayout select_shard_spec( output_height, output_width, out_channels, - device, + compute_grid_size, shard_orientation) .grid.num_cores(); }; @@ -340,7 +327,6 @@ static TensorMemoryLayout select_shard_spec( const bool is_block_sharding_valid = (kernel_size[0] == 3 && kernel_size[1] == 3 && (stride[0] == 1 || stride[0] == 2)) || (kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == 2); - const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); // 1d convs support only height sharding const bool is_conv1d = weights_width == 1 && input_width == 1; @@ -348,18 +334,16 @@ static TensorMemoryLayout select_shard_spec( const uint32_t cc_height = get_core_count_for_sharding(TensorMemoryLayout::HEIGHT_SHARDED); // matmul doesn't support width sharding const uint32_t cc_width = - !mm_conv && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::WIDTH_SHARDED) : 0; + !is_mm_conv && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::WIDTH_SHARDED) : 0; const uint32_t cc_block = - (is_block_sharding_valid || mm_conv) && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::BLOCK_SHARDED) : 0; + (is_block_sharding_valid || is_mm_conv) && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::BLOCK_SHARDED) : 0; uint32_t max_cc = cc_block; TensorMemoryLayout shard_layout = TensorMemoryLayout::BLOCK_SHARDED; - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - // Prefer block sharding over height sharding but make sure that we got at least // some blocking on width dimension as well. - if (cc_height > max_cc || (cc_height == max_cc && cc_height <= compute_with_storage_grid_size.x)) { + if (cc_height > max_cc || (cc_height == max_cc && cc_height <= compute_grid_size.x)) { shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; max_cc = cc_height; } @@ -394,14 +378,7 @@ std::tuple get_conv_padded_input_sh uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups) { + uint32_t out_channels) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); bool needs_shard_or_reshard = false; @@ -411,31 +388,14 @@ std::tuple get_conv_padded_input_sh "Incorrect config provided: reshard_if_not_optimal and override_sharding_config cannot both be set."); } + TT_FATAL( + (!input_tensor_on_device || input_tensor_.is_sharded()) || conv_config.shard_layout.has_value(), + "Tesor must be sharded or shard_layout must be set."); + TensorMemoryLayout shard_layout; if (conv_config.shard_layout.has_value()) { shard_layout = conv_config.shard_layout.value(); - } else { - ShardOrientation shard_orientation = - conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - shard_layout = select_shard_spec( - batch_size, - in_channels, - out_channels, - height, - width, - weights_width, - input_width, - groups, - shard_orientation, - kernel_size, - stride, - padding, - dilation, - device); } - bool use_non_tile_height = shard_layout == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && - conv_config.dtype == DataType::BFLOAT16 && conv_config.output_layout == Layout::ROW_MAJOR; - use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; //shalow conv varient ParallelConfig input_tensor_parallel_config; if (!input_tensor_on_device) { @@ -478,20 +438,17 @@ std::tuple get_conv_padded_input_sh } } } + + bool use_non_tile_height = shard_layout == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && + conv_config.dtype == DataType::BFLOAT16 && conv_config.output_layout == Layout::ROW_MAJOR; + use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; // shalow conv varient + ParallelConfig parallel_config = input_tensor_parallel_config; if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) { auto block_shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - const ParallelConfig& optimal_parallel_config = determine_parallel_config( - shard_layout, - batch_size, - in_channels, - height, - width, - out_channels, - device, - block_shard_orientation, - !use_non_tile_height); + ParallelConfig optimal_parallel_config = determine_parallel_config( + shard_layout, batch_size, in_channels, height, width, out_channels, device->compute_with_storage_grid_size(), block_shard_orientation, !use_non_tile_height); if (conv_config.override_sharding_config) { TT_FATAL(conv_config.core_grid.has_value(), "Error"); @@ -548,13 +505,7 @@ std::tuple shard_or_reshard_tensor_if_ uint32_t width, uint32_t in_channels, uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups) { + bool is_mm_conv) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); @@ -567,14 +518,7 @@ std::tuple shard_or_reshard_tensor_if_ height, width, in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - weights_width, - input_width, - groups); + out_channels); ParallelConfig parallel_config = { .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, .shard_scheme = input_tensor_sharded_memory_config.memory_layout, @@ -608,9 +552,8 @@ std::tuple shard_or_reshard_tensor_if_ } } - const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); if (input_tensor_on_device) { - if (mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && + if (is_mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) { // Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout input_tensor = @@ -624,7 +567,7 @@ std::tuple shard_or_reshard_tensor_if_ } input_tensor = resharded_input_tensor; } else { - if (mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && + if (is_mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) { // Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout input_tensor = ttnn::to_device(input_tensor, device, std::nullopt); @@ -801,11 +744,55 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co } } -template +static void adjust_conv_op_config_for_auto_shard( + bool is_mm_conv, + uint32_t batch_size, + uint32_t in_channels, + uint32_t out_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t weights_width, + uint32_t input_width, + const std::array& kernel_size, + const std::array& stride, + const CoreCoord& compute_grid_size, + Conv2dConfig& conv_config) { + ShardOrientation shard_orientation = + conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; + conv_config.shard_layout = select_shard_spec( + is_mm_conv, + batch_size, + in_channels, + out_channels, + output_height, + output_width, + weights_width, + input_width, + shard_orientation, + kernel_size, + stride, + compute_grid_size); + + if (conv_config.act_block_h_override == 0 && conv_config.shard_layout != TensorMemoryLayout::WIDTH_SHARDED) { + if (in_channels <= constants::TILE_WIDTH / 2 && conv_config.input_channels_alignment == constants::TILE_WIDTH && + !is_mm_conv && conv_config.shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + log_debug(LogOp, "Auto shard, enable shallow conv"); + // height sharded, non matmul conv, with input channels <= 16, and default setting for + // input_channels_alignment + conv_config.input_channels_alignment = constants::TILE_WIDTH / 2; + } + + // Set act_block_h_override to min value to + // be conservative with L1 memory usage. + conv_config.act_block_h_override = constants::TILE_HEIGHT; + } +} + +template std::tuple> conv2d( const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - T * device, + T* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -819,36 +806,31 @@ std::tuple bias_tensor, std::optional conv_config_, const std::optional memory_config) { + const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); + const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; + const uint32_t output_width = + ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - if (conv_config.act_block_h_override == 0 && !input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) { - // This is a path for auto_sharding, set act_block_h_override to min value to - // be conservative with L1 memory usage. - if (conv_config.input_channels_alignment == (constants::TILE_WIDTH / 2)) { - // shallow conv, requires at least two tiles - conv_config.act_block_h_override = constants::TILE_HEIGHT * 2; - } else { - conv_config.act_block_h_override = constants::TILE_HEIGHT; - } + if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) { + // In this case we deduce the shard layout. + adjust_conv_op_config_for_auto_shard( + mm_conv, + batch_size, + in_channels, + out_channels, + output_height, + output_width, + weight_tensor.get_shape()[3], + input_width, + kernel_size, + stride, + device->compute_with_storage_grid_size(), + conv_config); } - uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; - uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; + auto [input_tensor_post_tm, parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required( - device, - input_tensor, - conv_config, - batch_size, - output_height, - output_width, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - weight_tensor.get_shape()[3], - input_width, - groups); + device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels, mm_conv); if (tensor_manipulated) { if (conv_config.deallocate_activation) { ttnn::Tensor input_tensor_ = input_tensor; // TODO: allow in place modification of inputs to the op @@ -873,8 +855,7 @@ std::tuple bias_tensor_on_device = bias_tensor; @@ -894,7 +875,6 @@ std::tuple( - const TensorMemoryLayout shard_layout, - uint32_t batch_size, - uint32_t input_channels, - uint32_t output_height, - uint32_t output_width, - uint32_t output_channels, - Device * device, - ShardOrientation block_shard_orientation, - bool is_out_tiled); - -template ParallelConfig determine_parallel_config( - const TensorMemoryLayout shard_layout, - uint32_t batch_size, - uint32_t input_channels, - uint32_t output_height, - uint32_t output_width, - uint32_t output_channels, - MeshDevice * device, - ShardOrientation block_shard_orientation, - bool is_out_tiled); - template std::tuple get_conv_padded_input_shape_and_mem_config( Device* device, const ttnn::Tensor& input_tensor_, @@ -1080,14 +1038,7 @@ template std::tuple get_conv_padded uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups); + uint32_t out_channels); template std::tuple get_conv_padded_input_shape_and_mem_config( MeshDevice * device, @@ -1097,14 +1048,7 @@ template std::tuple get_conv_padded uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups); + uint32_t out_channels); template std::tuple shard_or_reshard_tensor_if_required( Device* device, @@ -1115,13 +1059,7 @@ template std::tuple shard_or_reshard_t uint32_t width, uint32_t in_channels, uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups); + bool is_mm_conv); template std::tuple shard_or_reshard_tensor_if_required( MeshDevice * device, @@ -1131,14 +1069,8 @@ template std::tuple shard_or_reshard_t uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups); + uint32_t out_channel, + bool is_mm_conv); template std::pair> prepare_conv_weights_biases_and_move_to_device( const ttnn::Tensor& weight_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index e13658a4135..efcbcd46693 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -102,7 +102,6 @@ uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t st uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor); -template sliding_window::ParallelConfig determine_parallel_config( const TensorMemoryLayout shard_layout, uint32_t batch_size, @@ -110,7 +109,7 @@ sliding_window::ParallelConfig determine_parallel_config( uint32_t output_height, uint32_t output_width, uint32_t output_channels, - T * device, + const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, bool is_out_tiled=true); @@ -124,7 +123,13 @@ OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_o std::pair determine_largest_subblock_size(uint32_t block_height, uint32_t block_width, bool fp32_accum); -OptimizedConvBlockConfig determine_per_core_conv_block_config(const sliding_window::ParallelConfig& parallel_config, const OptimizedConvParallelizationConfig& conv_op_parallel_config, uint32_t padded_in_channels, uint32_t act_block_h_override, uint32_t window_w, bool fp32_accum, bool use_shallow_conv_variant); +OptimizedConvBlockConfig determine_per_core_conv_block_config( + const sliding_window::ParallelConfig& parallel_config, + const OptimizedConvParallelizationConfig& conv_op_parallel_config, + uint32_t padded_in_channels, + uint32_t act_block_h_override, + uint32_t window_w, + bool fp32_accum); template std::tuple get_conv_padded_input_shape_and_mem_config( @@ -135,18 +140,11 @@ std::tuple get_conv_padded_input_sh uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups); + uint32_t out_channels); -template +template std::tuple shard_or_reshard_tensor_if_required( - T * device, + T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, uint32_t batch_size, @@ -154,13 +152,7 @@ std::tuple shard_or_re uint32_t width, uint32_t in_channels, uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups); + bool is_mm_conv); void validate_weight_and_bias_tensors(const ttnn::Tensor& weight_tensor, std::optional& bias_tensor); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index 02f3f1ee7ae..0163c3d43a0 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -129,14 +129,7 @@ void py_bind_conv2d(py::module& module) { uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t weights_width, - uint32_t input_width, - uint32_t groups) -> std::tuple { + uint32_t out_channels) -> std::tuple { return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( device, input_tensor, @@ -145,14 +138,7 @@ void py_bind_conv2d(py::module& module) { height, width, in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - weights_width, - input_width, - groups); + out_channels); }, py::kw_only(), py::arg("device"), @@ -162,14 +148,7 @@ void py_bind_conv2d(py::module& module) { py::arg("height"), py::arg("width"), py::arg("in_channels"), - py::arg("out_channels"), - py::arg("kernel_size"), - py::arg("stride"), - py::arg("padding"), - py::arg("dilation"), - py::arg("weights_width"), - py::arg("input_width"), - py::arg("groups")); + py::arg("out_channels")); module.def( "get_conv_padded_input_shape_and_mem_config", @@ -196,14 +175,7 @@ void py_bind_conv2d(py::module& module) { height, width, in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - weights_width, - input_width, - groups); + out_channels); }, py::kw_only(), py::arg("device"), @@ -253,35 +225,11 @@ void py_bind_conv2d(py::module& module) { uint32_t output_height, uint32_t output_width, uint32_t output_channels, - ttnn::Device* device, - ShardOrientation block_shard_orientation, - bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { - return ttnn::operations::conv::conv2d::determine_parallel_config( - shard_layout, batch_size, input_channels, output_height, output_width, output_channels, device, block_shard_orientation, is_out_tiled); - }, - py::arg("shard_layout"), - py::arg("batch_size"), - py::arg("input_channels"), - py::arg("output_height"), - py::arg("output_width"), - py::arg("output_channels"), - py::arg("device"), - py::arg("block_shard_orientation"), - py::arg("is_out_tiled") = true); - - module.def( - "determine_parallel_config", - [](const ttnn::TensorMemoryLayout& shard_layout, - uint32_t batch_size, - uint32_t input_channels, - uint32_t output_height, - uint32_t output_width, - uint32_t output_channels, - ttnn::MeshDevice* device, + const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { - return ttnn::operations::conv::conv2d::determine_parallel_config( - shard_layout, batch_size, input_channels, output_height, output_width, output_channels, device, block_shard_orientation, is_out_tiled); + return ttnn::operations::conv::conv2d::determine_parallel_config( + shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, is_out_tiled); }, py::arg("shard_layout"), py::arg("batch_size"), @@ -289,7 +237,7 @@ void py_bind_conv2d(py::module& module) { py::arg("output_height"), py::arg("output_width"), py::arg("output_channels"), - py::arg("device"), + py::arg("compute_grid_size"), py::arg("block_shard_orientation"), py::arg("is_out_tiled") = true); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index eac1d5ad840..8bd6bd51a0d 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -485,7 +485,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( TT_FATAL(not weight_width_sliced, "split reader does not work with 2d conv"); TT_FATAL((act_block_h_ntiles / block_config.out_subblock_h_ntiles) >= 2, "split reader needs to have at leaset two subblocks"); } - bool split_reader = use_shallow_conv_variant or enable_split_reader; + bool split_reader = enable_split_reader; if (split_reader) { TT_FATAL( block_config.act_block_h_ntiles % block_config.out_subblock_h_ntiles == 0, diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp new file mode 100644 index 00000000000..b1dbec7c2c7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/data_movement/common/common.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" + +namespace ttnn { +namespace operations { +namespace data_movement { + ttnn::Tensor pad_to_tile_vol(uint8_t queue_id, + const ttnn::Tensor& tensor, + const float value, + const bool use_multicore, + const std::optional& memory_config) { + auto logical_shape = tensor.get_logical_shape(); + auto padded_shape = tensor.get_padded_shape(); + auto rank = tensor.get_shape().rank(); + if (padded_shape.volume() % tt::constants::TILE_HW != 0) { + TT_ASSERT(rank >= 2, "rank of tensor to pad to tile must be at least 2."); + + auto padded_height = tt::round_up(padded_shape[-2], tt::constants::TILE_HEIGHT); + auto padded_width = tt::round_up(padded_shape[-1], tt::constants::TILE_WIDTH); + uint32_t num_non_hw_dims = rank - 2u; + auto padding_vec = std::vector>(num_non_hw_dims, {0,0}); + padding_vec.reserve(rank); + padding_vec.emplace_back(0, padded_height - padded_shape[-2]); + padding_vec.emplace_back(0, padded_width - padded_shape[-1]); + + constexpr bool pad_use_multicore = true; + auto padded_output = ttnn::pad(queue_id, + tensor, + padding_vec, + value, + use_multicore, + memory_config); + return padded_output; + } + return tensor; + } + uint32_t wrap_index(int index, int size) { + return index < 0 ? size + index : index; + } +} +} +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp index 2280dc608db..f82ef63ccf6 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp @@ -2,6 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttnn/cpp/ttnn/tensor/types.hpp" +#include "ttnn/cpp/ttnn/tensor/tensor.hpp" + namespace ttnn { namespace operations { namespace data_movement { @@ -9,32 +12,9 @@ namespace data_movement { const ttnn::Tensor& tensor, const float value, const bool use_multicore, - const std::optional& memory_config) { - auto logical_shape = tensor.get_logical_shape(); - auto padded_shape = tensor.get_padded_shape(); - auto rank = tensor.get_shape().rank(); - if (padded_shape.volume() % tt::constants::TILE_HW != 0) { - TT_ASSERT(rank >= 2, "rank of tensor to pad to tile must be at least 2."); - - auto padded_height = tt::round_up(padded_shape[-2], tt::constants::TILE_HEIGHT); - auto padded_width = tt::round_up(padded_shape[-1], tt::constants::TILE_WIDTH); - uint32_t num_non_hw_dims = rank - 2u; - auto padding_vec = std::vector>(num_non_hw_dims, {0,0}); - padding_vec.reserve(rank); - padding_vec.emplace_back(0, padded_height - padded_shape[-2]); - padding_vec.emplace_back(0, padded_width - padded_shape[-1]); - - constexpr bool pad_use_multicore = true; - auto padded_output = ttnn::pad(queue_id, - tensor, - padding_vec, - value, - use_multicore, - memory_config); - return padded_output; - } - return tensor; - } + const std::optional& memory_config); + + uint32_t wrap_index(int index, int size); template struct MassagedOperationParams { diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp new file mode 100644 index 00000000000..bf7062ab92b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// This file contains common kernel functions used in data movement device kernels +// It's best to copy and paste the functions in rather than include the header as code size will likely explode +// Best to separate in to cpp/hpp at some point to avoid the code size explosion but need to figure out the linking issues + +namespace tt::data_movement::common { + + // this function is useful for converting bfloat16 values to float32 + float bfloat16_to_float32(uint16_t bfloat16_data) { + uint32_t bits = static_cast(bfloat16_data) << 16; + + // Extract the sign bit + uint32_t sign = bits & 0x80000000; + + // Extract the exponent + uint32_t exponent = bits & 0x7F800000; + + // Extract the mantissa + uint32_t mantissa = bits & 0x007FFFFF; + + // Handle special cases + if (exponent == 0 && mantissa == 0) { + // Zero + return sign ? -0.0f : 0.0f; + } else if (exponent == 0x7F800000) { + if (mantissa == 0) { + // Infinity + return sign ? -__builtin_huge_valf() : __builtin_huge_valf(); + } else { + // NaN + return __builtin_nanf(""); + } + } + + // Assemble the float + union { + uint32_t u; + float f; + } ieee_float; + + ieee_float.u = sign | exponent | mantissa; + return ieee_float.f; + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp deleted file mode 100644 index ac3f9e2ac01..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp +++ /dev/null @@ -1,118 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" - -#ifdef DEBUG_PRINT -// this function is useful for printing bfloat16 values -#include "dprint.h" - -float bfloat16_to_float32(uint16_t bfloat16_data) { - uint32_t bits = static_cast(bfloat16_data) << 16; - - // Extract the sign bit - uint32_t sign = bits & 0x80000000; - - // Extract the exponent - uint32_t exponent = bits & 0x7F800000; - - // Extract the mantissa - uint32_t mantissa = bits & 0x007FFFFF; - - // Handle special cases - if (exponent == 0 && mantissa == 0) { - // Zero - return sign ? -0.0f : 0.0f; - } else if (exponent == 0x7F800000) { - if (mantissa == 0) { - // Infinity - return sign ? -__builtin_huge_valf() : __builtin_huge_valf(); - } else { - // NaN - return __builtin_nanf(""); - } - } - - // Assemble the float - union { - uint32_t u; - float f; - } ieee_float; - - ieee_float.u = sign | exponent | mantissa; - return ieee_float.f; -} -#endif - - -void kernel_main() { - - constexpr bool src0_is_dram = (bool) get_compile_time_arg_val(0); - constexpr uint32_t W = get_compile_time_arg_val(1); - constexpr uint32_t H = get_compile_time_arg_val(2); - constexpr uint32_t C = get_compile_time_arg_val(3); - constexpr uint32_t N = get_compile_time_arg_val(4); - - constexpr uint32_t stride_W = get_compile_time_arg_val(5); - constexpr uint32_t stride_H = get_compile_time_arg_val(6); - constexpr uint32_t stride_C = get_compile_time_arg_val(7); - constexpr uint32_t stride_N = get_compile_time_arg_val(8); - constexpr uint32_t page_size = get_compile_time_arg_val(9); - - const uint32_t src_addr = get_arg_val(0); - const uint32_t start_W = get_arg_val(1); - const uint32_t start_H = get_arg_val(2); - const uint32_t start_C = get_arg_val(3); - const uint32_t start_N = get_arg_val(4); - - const uint32_t end_W = get_arg_val(5); - const uint32_t end_H = get_arg_val(6); - const uint32_t end_C = get_arg_val(7); - const uint32_t end_N = get_arg_val(8); - - const InterleavedAddrGen s0 = { - .bank_base_address = src_addr, - .page_size = page_size - }; - - constexpr uint32_t cb_id_in0 = 0; - constexpr uint32_t cb_id_out0 = 24; - uint32_t src_buffer_l1_addr = get_write_ptr(cb_id_in0); - volatile tt_l1_ptr uint16_t* in_stick = reinterpret_cast(src_buffer_l1_addr); - constexpr uint32_t CH = C*H; - // TODO: optimize this kernel to read in multiple sticks at once - // TODO: add support for negative strides - // TODO: add axis support - for (uint32_t i = start_N; i < end_N; i+=stride_N) { - uint32_t iCH = i*CH; - for (uint32_t j = start_C; j < end_C; j+=stride_C) { - uint32_t jHplusiCH = j*H + iCH; - for (uint32_t k = start_H; k < end_H; k+=stride_H) { - - // relevant page/stick id - uint32_t src_stick_id = k + jHplusiCH; - - // read in entire stick and wait - we may want to allocate a CB and max out our reads before waiting - noc_async_read_page(src_stick_id, s0, src_buffer_l1_addr); - noc_async_read_barrier(); - - - // TODO: optimize when there's no slice or stride along W. In that case, we can just do a single read and write. - // reserve space in output buffer - cb_reserve_back(cb_id_out0, 1); - // write out element by element into output buffer - volatile tt_l1_ptr uint16_t* out_stick = reinterpret_cast(get_write_ptr(cb_id_out0)); - uint32_t out_stick_id = 0; - for (uint32_t l = start_W; l < end_W; l+=stride_W) { - out_stick[out_stick_id] = in_stick[l]; - out_stick_id++; - } - cb_push_back(cb_id_out0, 1); - } - } - } - - -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp new file mode 100644 index 00000000000..792b9e1ee91 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + + constexpr bool src0_is_dram = (bool) get_compile_time_arg_val(0); + constexpr uint32_t page_size = get_compile_time_arg_val(1); + constexpr uint32_t dims = get_compile_time_arg_val(2); + + const uint32_t src_addr = get_arg_val(0); + + // Initialize shape, starts, ends, strides + uint32_t shape[dims], starts[dims], ends[dims], strides[dims]; + for (uint32_t i = 1; i <= dims; i++) { + shape[i - 1] = get_arg_val(i); + starts[i - 1] = get_arg_val(i + dims); + ends[i - 1] = get_arg_val(i + 2*dims); + strides[i - 1] = get_arg_val(i + 3*dims); + } + + // Calculate the product array, excluding the last dimension + uint32_t prod[dims]; + for (uint32_t i = 0; i < dims - 1; i++) { + prod[i] = 1; + for (uint32_t j = i + 1; j < dims - 1; j++) { // Exclude the last dimension + prod[i] *= shape[j]; + } + } + prod[dims - 1] = 1; // Not used, but set to 1 for completeness + + const InterleavedAddrGen s0 = { + .bank_base_address = src_addr, + .page_size = page_size + }; + + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_out0 = 24; + uint32_t src_buffer_l1_addr = get_write_ptr(cb_id_in0); + volatile tt_l1_ptr uint16_t* in_stick = reinterpret_cast(src_buffer_l1_addr); + + + uint32_t index[dims - 1]; // To hold current index in each of the first dims-1 dimensions + for (uint32_t i = 0; i < dims - 1; i++) { + index[i] = starts[i]; // Initialize the index with the start values + } + + // Flag to determine when to terminate the loop + bool done = false; + + while (!done) { + // Calculate the base linear index based on the first dims-1 indices + uint32_t base_linear_index = 0; + for (uint32_t i = 0; i < dims - 1; i++) { + base_linear_index += index[i] * prod[i]; + } + + // Now iterate over the last dimension + uint32_t out_stick_id = 0; + // Perform the read operation + noc_async_read_page(base_linear_index, s0, src_buffer_l1_addr); + // Reserve space in the output buffer + cb_reserve_back(cb_id_out0, 1); + noc_async_read_barrier(); + for (uint32_t l = starts[dims - 1]; l < ends[dims - 1]; l += strides[dims - 1]) { + // Write the element into the output buffer + volatile tt_l1_ptr uint16_t* out_stick = reinterpret_cast(get_write_ptr(cb_id_out0)); + out_stick[out_stick_id] = in_stick[l]; // Assuming you write one element at a time + out_stick_id++; + } + cb_push_back(cb_id_out0, 1); + + // Increment the indices for the first dims-1 dimensions + for (int32_t i = dims - 2; i >= 0; i--) { // Start from the last of the first dims-1 + index[i] += strides[i]; + if (index[i] < ends[i]) { + break; // Successfully incremented this dimension, no carry over + } else { + index[i] = starts[i]; // Reset this dimension and carry over to the next + if (i == 0) { + done = true; // If the first dimension is reset, we've completed all iterations + } + } + } + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp index 2585f3561bc..e7215850ea1 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp @@ -59,7 +59,7 @@ inline std::vector, std::vector>> get_ accumulated_total_per_dim[i] = num_total_dim * accumulated_total_per_dim[i - 1]; } - uint32_t unpadded_row_size_bytes_offset = tt::round_up(unpadded_row_size_bytes, TILE_WIDTH / 2); + uint32_t unpadded_row_size_bytes_offset = output_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? tt::round_up(unpadded_row_size_bytes, TILE_WIDTH) : tt::round_up(unpadded_row_size_bytes, TILE_WIDTH / 2); vector common_reader_kernel_args = { input_tensor.buffer()->address() + output_tensor_start[-1] * output_tensor.element_size(), @@ -261,7 +261,7 @@ operation::ProgramWithCallbacks slice_rm_multi_core( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } -operation::ProgramWithCallbacks slice_rm_strided_single_core(const Tensor& a, Tensor& output, const tt::tt_metal::LegacyShape& output_tensor_start, const tt::tt_metal::LegacyShape& output_tensor_end, const tt::tt_metal::LegacyShape& step) { +operation::ProgramWithCallbacks slice_rm_strided_single_core_n_dims(const Tensor& a, Tensor& output, const tt::tt_metal::LegacyShape& output_tensor_start, const tt::tt_metal::LegacyShape& output_tensor_end, const tt::tt_metal::LegacyShape& step) { // TODO: multi core implementation - work division is not trivial as we need to determine the N/C/H/W start and end points for each split, and base that off stride tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); const tt::tt_metal::LegacyShape output_shape = output.get_legacy_shape(); @@ -291,20 +291,13 @@ operation::ProgramWithCallbacks slice_rm_strided_single_core(const Tensor& a, Te tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp", + "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp", core, tt::tt_metal::ReaderDataMovementConfig( { src_is_dram, - input_shape[3], - input_shape[2], - input_shape[1], - input_shape[0], - step[3], - step[2], - step[1], - step[0], (uint32_t) page_size_input, + (uint32_t) input_shape.rank(), } )); @@ -320,26 +313,24 @@ operation::ProgramWithCallbacks slice_rm_strided_single_core(const Tensor& a, Te } )); - tt::tt_metal::SetRuntimeArgs( - program, unary_reader_kernel_id, core, - { - a.buffer()->address(), - output_tensor_start[3], - output_tensor_start[2], - output_tensor_start[1], - output_tensor_start[0], - output_tensor_end[3], - output_tensor_end[2], - output_tensor_end[1], - output_tensor_end[0], + std::vector reader_runtime_args; + reader_runtime_args.reserve(1 + (4*input_shape.rank())); + reader_runtime_args.push_back(a.buffer()->address()); - }); + reader_runtime_args.insert(reader_runtime_args.end(), input_shape.begin(), input_shape.end()); + reader_runtime_args.insert(reader_runtime_args.end(), output_tensor_start.begin(), output_tensor_start.end()); + reader_runtime_args.insert(reader_runtime_args.end(), output_tensor_end.begin(), output_tensor_end.end()); + reader_runtime_args.insert(reader_runtime_args.end(), step.begin(), step.end()); + + tt::tt_metal::SetRuntimeArgs( + program, unary_reader_kernel_id, core, reader_runtime_args); + uint32_t pages = output.volume() / output_shape[-1]; tt::tt_metal::SetRuntimeArgs( program, unary_writer_kernel_id, core, { output.buffer()->address(), - output_shape[0]*output_shape[1]*output_shape[2], + pages, }); auto override_address_callback = [unary_reader_kernel_id, unary_writer_kernel_id]( @@ -962,7 +953,7 @@ operation::ProgramWithCallbacks slice_multi_core( case Layout::ROW_MAJOR: return a.is_sharded() ? slice_rm_multi_core_sharded(a, output, output_tensor_start, output_tensor_end) : (has_step ? - slice_rm_strided_single_core(a, output, output_tensor_start, output_tensor_end, step) : + slice_rm_strided_single_core_n_dims(a, output, output_tensor_start, output_tensor_end, step) : slice_rm_multi_core(a, output, output_tensor_start, output_tensor_end)); case Layout::TILE: return slice_tile_multi_core(a, output, output_tensor_start, output_tensor_end); default: TT_ASSERT(false, "Unsupported Layout"); diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp index cff6bf540ec..e65f1bba9ce 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp @@ -10,17 +10,10 @@ #include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/common/constants.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/copy/copy.hpp" - +#include "ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/common/common.hpp" namespace ttnn::operations::data_movement { -namespace detail { - static inline uint32_t wrap_index(int index, int size) { - return index < 0 ? size + index : index; - } - static inline uint32_t round_up_to_multiple_of_32(uint32_t value) { - return value == 0 ? 32u : ((value + 31u) & ~31); - } -} template ttnn::Tensor SliceOperation::invoke( @@ -33,8 +26,10 @@ ttnn::Tensor SliceOperation::invoke( const std::optional& optional_output_tensor) { // Ensure start and end vectors have matching sizes and correct tensor rank - uint32_t input_rank = input_tensor.get_shape().rank(); - const auto &input_shape = input_tensor.get_shape(); + + const auto &input_shape = input_tensor.get_logical_shape(); + uint32_t input_rank = input_shape.rank(); + bool no_step = std::ranges::all_of(step, [](uint32_t s) { return s == 1; }); bool starts_zero = std::ranges::all_of(begins, [](uint32_t s) { return s == 0; }); bool ends_max = true; @@ -44,6 +39,7 @@ ttnn::Tensor SliceOperation::invoke( break; } } + if (no_step && starts_zero && ends_max) { if (input_tensor.storage_type() == StorageType::DEVICE) { auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); @@ -56,57 +52,80 @@ ttnn::Tensor SliceOperation::invoke( TT_FATAL(begins.size() == ends.size(), "Start {} and end {} must have the same size", begins.size(), ends.size()); TT_FATAL(step.size() == begins.size(), "Step {} must have the same size as start {} and end", step.size(), begins.size()); - // Create modified vectors with appropriate size (max rank 4) and wrap indices - Tensor input_4d = (input_rank < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor; - auto padded_4d_shape = input_4d.get_legacy_shape(); - std::array modified_begins = {0, 0, 0, 0}; - std::array modified_ends = {padded_4d_shape[0], padded_4d_shape[1], padded_4d_shape[2], padded_4d_shape[3]}; - std::array modified_step = {1, 1, 1, 1}; - uint32_t rank_diff = 4 - input_rank; - - // Ideally we would call the 4D array implementation of slice here and then handle reshapes and padding outside of it but it's not ready yet - // Insert values for start, step and end, wrapping indices using detail::wrap_index - // should be able to skip wrap_index if T is uint32_t + bool rm_only = !no_step && input_tensor.get_layout() == Layout::TILE; + Tensor input = input_tensor; + if (rm_only) { + TT_FATAL(input.get_dtype() == DataType::BFLOAT16, "Strided slice is not supported for BFLOAT8 tensors"); + input = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + } + + // Unsqueeze tensor to 4D if necessary + if (input_rank < 4) { + input = ttnn::unsqueeze_to_4D(input); + } + + auto padded_shape = input.get_padded_shape(); + size_t adjusted_rank = padded_shape.rank(); // Now adjusted to 4 after unsqueeze + + // Create modified vectors with wrapped indices and adjust them to match the tensor's rank + std::vector modified_begins(adjusted_rank, 0); + std::vector modified_ends = padded_shape.as_vector(); + std::vector modified_step(adjusted_rank, 1); + + size_t rank_diff = adjusted_rank - input_rank; + + // Wrap indices and adjust begins, ends, and step for (size_t i = 0; i < begins.size(); ++i) { - modified_begins[i + rank_diff] = detail::wrap_index(begins[i], input_tensor.get_shape()[i]); - modified_ends[i + rank_diff] = detail::wrap_index(ends[i], input_tensor.get_shape()[i]); - modified_step[i + rank_diff] = step[i]; + size_t idx = i + rank_diff; + + if constexpr (std::is_signed_v) { + modified_begins[idx] = wrap_index(begins[i], input_shape[i]); + modified_ends[idx] = wrap_index(ends[i], input_shape[i]); + modified_step[idx] = static_cast(step[i]); + } else { + modified_begins[idx] = begins[i]; + modified_ends[idx] = ends[i]; + modified_step[idx] = step[i]; + } } - auto output_dim_i = [&modified_begins, &modified_step] (size_t i, const std::array &modified_ends) { + auto output_dim_i = [&modified_begins, &modified_step](size_t i, const std::vector &modified_ends) { return (modified_ends[i] - modified_begins[i] + modified_step[i] - 1) / modified_step[i]; }; - std::array padded_ends = modified_ends; - if (input_tensor.layout() == Layout::TILE) { - padded_ends[2] = detail::round_up_to_multiple_of_32(padded_ends[2]); - padded_ends[3] = detail::round_up_to_multiple_of_32(padded_ends[3]); + std::vector padded_ends = modified_ends; + if (input.layout() == Layout::TILE) { + padded_ends[adjusted_rank - 2] = std::max(tt::round_up(padded_ends[adjusted_rank - 2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT); + padded_ends[adjusted_rank - 1] = std::max(tt::round_up(padded_ends[adjusted_rank - 1], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH); } - std::vector actual_shape, padded_shape; + + std::vector actual_shape, final_padded_shape; actual_shape.reserve(input_rank); - padded_shape.reserve(input_rank); + final_padded_shape.reserve(input_rank); bool empty = false; - for (int i = 0; i < input_rank; ++i) { - // Check that end indices are greater than or equal to start indices (empty tensor where end=start is supported) - TT_FATAL(modified_ends[i + rank_diff] >= modified_begins[i + rank_diff], "End {} must be greater than or equal to start {}", modified_ends[i + rank_diff], modified_begins[i + rank_diff]); - auto val = output_dim_i(i + rank_diff, modified_ends); + + // Compute actual and padded shapes for the original input rank + for (size_t i = 0; i < input_rank; ++i) { + size_t idx = i + rank_diff; + TT_FATAL(modified_ends[idx] >= modified_begins[idx], "End {} must be greater than or equal to start {}", modified_ends[idx], modified_begins[idx]); + auto val = output_dim_i(idx, modified_ends); if (val == 0) { empty = true; } actual_shape.push_back(val); - padded_shape.push_back(std::max(output_dim_i(i + rank_diff, padded_ends), (uint32_t)1)); + final_padded_shape.push_back(std::max(output_dim_i(idx, padded_ends), static_cast(1))); } - ttnn::Shape output_shape(actual_shape, padded_shape); - // PyTorch supports final dimension = 0 (start = end, where end is inclusive) so >= is okay, just return an empty tensor + ttnn::Shape output_shape(actual_shape, final_padded_shape); + if (empty) { TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Host tensor slice cannot return a scalar or empty tensor"); return ttnn::empty(output_shape, input_tensor.dtype(), input_tensor.layout(), input_tensor.device(), memory_config_arg.value_or(input_tensor.memory_config())); } - // Early exit if slice is a no-op (ends = padding ends and step = 1 for all dimensions) - if (tt::tt_metal::LegacyShape(padded_shape) == input_tensor.get_legacy_shape() and no_step) { + // Early exit if slice is a no-op + if (ttnn::SimpleShape(final_padded_shape) == input.get_padded_shape() && no_step) { if (input_tensor.storage_type() == StorageType::DEVICE) { auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); auto res = ttnn::to_memory_config(input_tensor, memory_config, std::nullopt); @@ -117,33 +136,28 @@ ttnn::Tensor SliceOperation::invoke( if (input_tensor.storage_type() != StorageType::DEVICE) { TT_FATAL(no_step, "Host tensor slice does not support strides"); - // if we support negative strides, we can't do this early exit - if (input_tensor.get_legacy_shape() == actual_shape) { + if (input_tensor.get_padded_shape() == actual_shape) { return input_tensor; } else { - auto input_4d_rm = ttnn::to_layout(input_4d, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); - auto output_4d = input_4d_rm.unpad(ttnn::SimpleShape(modified_begins), ttnn::SimpleShape(modified_ends)); - auto output_4d_rm = ttnn::to_layout(output_4d, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); - return ttnn::reshape(output_4d_rm, output_shape); + input = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + input = input.unpad(ttnn::SimpleShape(modified_begins), ttnn::SimpleShape(modified_ends)); + input = ttnn::to_layout(input, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); + return ttnn::reshape(input, output_shape); } - } - else { - // TODO: Generalize this early exit of slice for other cases - auto& input_tensor_shape = input_4d.get_legacy_shape(); + } else { + const auto& input_tensor_shape = input.get_padded_shape(); auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); - if (input_4d.is_sharded() && input_4d.memory_config() == memory_config && - input_tensor_shape.rank() > 1) { + + if (input.is_sharded() && input.memory_config() == memory_config && input_tensor_shape.rank() > 1) { TT_FATAL(no_step, "Sharded tensor slice implementation does not support striding"); uint32_t i; - // Require all leading dims to be 1 (TODO: This can be relaxed to support outermost non-1 dim unpadding) bool in_place_unpad = true; - for (i = 0; i < input_4d.get_legacy_shape().rank() - 2; ++i) { - in_place_unpad &= - modified_begins[i] == 0 && modified_ends[i] == 1 && input_tensor_shape[i] == 1; + for (i = 0; i < input_tensor_shape.rank() - 2; ++i) { + in_place_unpad &= modified_begins[i] == 0 && modified_ends[i] == 1 && input_tensor_shape[i] == 1; } in_place_unpad &= modified_begins[i] == 0 && - tt::div_up(modified_ends[i], input_4d.shard_spec().value().shape[0]) == - tt::div_up(input_tensor_shape[i], input_4d.shard_spec().value().shape[0]); + tt::div_up(modified_ends[i], input.shard_spec().value().shape[0]) == + tt::div_up(input_tensor_shape[i], input.shard_spec().value().shape[0]); i++; in_place_unpad &= modified_begins[i] == 0 && modified_ends[i] == input_tensor_shape[i]; if (in_place_unpad) { @@ -152,16 +166,18 @@ ttnn::Tensor SliceOperation::invoke( } auto res = operation::run( - SliceDeviceOperation{ - tt::tt_metal::LegacyShape(modified_begins), - tt::tt_metal::LegacyShape(padded_ends), - modified_step, - memory_config}, - {input_4d}, {}, {optional_output_tensor}, queue_id) + SliceDeviceOperation{ + tt::tt_metal::LegacyShape(modified_begins), + tt::tt_metal::LegacyShape(padded_ends), + tt::tt_metal::LegacyShape(modified_step), + memory_config}, + {input}, {}, {optional_output_tensor}, queue_id) .at(0); - return ttnn::reshape(res, output_shape); + res = ttnn::reshape(res, output_shape); + return rm_only ? ttnn::to_layout(res, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr) : res; } } + template ttnn::Tensor SliceOperation::invoke( const ttnn::Tensor& input_tensor, @@ -184,7 +200,7 @@ ttnn::Tensor SliceOperation::invoke( const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { - const auto& padded_input_shape = input_tensor.get_shape().with_tile_padding(); + const auto& padded_input_shape = input_tensor.get_padded_shape(); TT_FATAL(padded_input_shape.rank() == 4, "Input tensor must have rank 4"); bool no_step = step[0] == 1 && step[1] == 1 && step[2] == 1 && step[3] == 1; @@ -198,13 +214,18 @@ ttnn::Tensor SliceOperation::invoke( } return input_tensor; } + bool rm_only = !no_step && input_tensor.get_layout() == Layout::TILE; + ttnn::Tensor input = input_tensor; + if (rm_only) { + input = ttnn::to_layout(input_tensor, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + } - const bool tiled = input_tensor.get_layout() == Layout::TILE; - bool on_device = input_tensor.storage_type() == StorageType::DEVICE; + const bool tiled = input.get_layout() == Layout::TILE; + bool on_device = input.storage_type() == StorageType::DEVICE; std::array actual_shape; std::array padded_shape; - const std::array padded_ends = tiled ? std::array({ends[0], ends[1], detail::round_up_to_multiple_of_32(ends[2]), detail::round_up_to_multiple_of_32(ends[3])}) : ends; + const std::array padded_ends = tiled ? std::array({ends[0], ends[1], std::max(tt::round_up(ends[2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT), std::max(tt::round_up(ends[3], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH)}) : ends; bool empty = false; for (int i = 0; i < 4; ++i) { TT_FATAL(ends[i] >= begins[i], "End {} must be greater than or equal to start {}", ends[i], begins[i]); @@ -219,58 +240,59 @@ ttnn::Tensor SliceOperation::invoke( if (empty) { TT_FATAL(on_device, "Host tensor slice cannot return a scalar or empty tensor"); - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); - return ttnn::empty(output_shape, input_tensor.dtype(), input_tensor.layout(), - input_tensor.device(), memory_config); + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); + return ttnn::empty(output_shape, input.dtype(), input_tensor.layout(), + input.device(), memory_config); } // Early exit if slice is a no-op if (ttnn::Shape(padded_shape) == padded_input_shape && no_step) { - if (input_tensor.storage_type() == StorageType::DEVICE) { - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); - auto res = ttnn::to_memory_config(input_tensor, memory_config, std::nullopt); + if (input.storage_type() == StorageType::DEVICE) { + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); + auto res = ttnn::to_memory_config(input, memory_config, std::nullopt); return ttnn::reshape(res, output_shape); } - return ttnn::reshape(input_tensor, output_shape); // change to view + return ttnn::reshape(input, output_shape); // change to view } if (on_device) { - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); // Check for in-place unpad optimization - if (input_tensor.is_sharded() && input_tensor.memory_config() == memory_config && padded_input_shape.rank() > 1) { + if (input.is_sharded() && input.memory_config() == memory_config && padded_input_shape.rank() > 1) { TT_FATAL(no_step, "Sharded tensor slice implementation does not support striding"); bool in_place_unpad = true; for (int i = 0; i < 2; ++i) { in_place_unpad &= begins[i] == 0 && ends[i] == 1 && padded_input_shape[i] == 1; } in_place_unpad &= begins[2] == 0 && - tt::div_up(ends[2], input_tensor.shard_spec().value().shape[0]) == - tt::div_up(padded_input_shape[2], input_tensor.shard_spec().value().shape[0]); + tt::div_up(ends[2], input.shard_spec().value().shape[0]) == + tt::div_up(padded_input_shape[2], input.shard_spec().value().shape[0]); in_place_unpad &= begins[3] == 0 && ends[3] == padded_input_shape[3]; if (in_place_unpad) { - return ttnn::reshape(input_tensor, output_shape); + return ttnn::reshape(input, output_shape); } } - auto res = operation::run( + input = operation::run( SliceDeviceOperation{ begins, padded_ends, step, memory_config}, - {input_tensor}, {}, {optional_output_tensor}, queue_id)[0]; - return ttnn::reshape(res, output_shape); + {input}, {}, {optional_output_tensor}, queue_id)[0]; + input = ttnn::reshape(input, output_shape); + return rm_only ? ttnn::to_layout(input, input.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr) : input; } TT_FATAL(no_step, "Host tensor slice does not support strides"); - if (input_tensor.get_legacy_shape() == actual_shape) { - return input_tensor; + if (input.get_padded_shape() == actual_shape) { + return input; } else { - auto input_4d_rm = ttnn::to_layout(input_tensor, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + auto input_4d_rm = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); auto output_4d = input_4d_rm.unpad(ttnn::SimpleShape(begins), ttnn::SimpleShape(ends)); - auto output_4d_rm = ttnn::to_layout(output_4d, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); + auto output_4d_rm = ttnn::to_layout(output_4d, input.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); return ttnn::reshape(output_4d_rm, output_shape); } } @@ -301,7 +323,6 @@ ttnn::Tensor SliceOperation::invoke( return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg); } - template ttnn::Tensor SliceOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp index 342b282ec96..1aeae87eefd 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp @@ -17,26 +17,31 @@ namespace py = pybind11; void bind_slice(py::module& module) { auto doc = R"doc( - slice(input_tensor: ttnn.Tensor, slice_start: List[int[tensor rank], slice_end: List[int[tensor rank], value: Union[int, float], *, Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor - Returns a sliced tensor. If the input tensor is on host, the slice will be performed on host, and if its on device it will be performed on device. - Equivalent pytorch code: + Args: + input_tensor: Input Tensor. + slice_start: Start indices of input tensor. Values along each dim must be < input_tensor_shape[i]. + slice_end: End indices of input tensor. Values along each dim must be < input_tensor_shape[i]. + slice_step: (Optional[List[int[tensor rank]]) Step size for each dim. Default is None, which works out be 1 for each dimension. - .. code-block:: python + Keyword Args: + memory_config Memory Config of the output tensor + queue_id (uint8, optional) command queue id - output_tensor = input_tensor[output_start: output_end] + Returns: + ttnn.Tensor: the output tensor. - Args: - * :attr:`input_tensor`: Input Tensor. - * :attr:`slice_start`: Start indices of input tensor. Values along each dim must be < input_tensor_shape[i]. - * :attr:`slice_end`: End indices of input tensor. Values along each dim must be < input_tensor_shape[i]. - * :attr:`step` (Optional[List[int[tensor rank]]): Step size for each dim. Default is None, which works out be 1 for each dimension. + Example: + >>> tensor = ttnn.slice(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device), [0, 0, 0, 0], [1, 1, 64, 16], [1, 1, 2, 1]) + >>> print(tensor.shape) + [1, 1, 32, 16] + >>> input = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device) + >>> output = ttnn.slice(input, [0, 0, 0, 0], [1, 1, 32, 32]) + >>> print(output.shape) + [1, 1, 32, 32] + )doc"; - Keyword Args: - * :attr:`memory_config`: Memory Config of the output tensor - * :attr:`queue_id` (Optional[uint8]): command queue id - )doc"; // TODO: implementing the array version and overloading the pybind with all the possible array sizes is better than a vector with a fixed size default value using OperationType = decltype(ttnn::slice); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index cf2ba6de4b0..28b63cd3c80 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -67,9 +67,9 @@ void bind_unary_composite_optional_floats_with_default(py::module& module, const return self(input_tensor, parameter_a, parameter_b, memory_config); }, py::arg("input_tensor"), - py::kw_only(), py::arg(parameter_name_a.c_str()) = parameter_a_value, py::arg(parameter_name_b.c_str()) = parameter_b_value, + py::kw_only(), py::arg("memory_config") = std::nullopt}); } @@ -1403,7 +1403,17 @@ void py_module(py::module& module) { detail::bind_unary_operation(module, ttnn::acos, R"doc(\mathrm{{output\_tensor}}_i = acos(\mathrm{{input\_tensor}}_i))doc"); detail::bind_unary_operation(module, ttnn::asin, R"doc(\mathrm{{output\_tensor}}_i = asin(\mathrm{{input\_tensor}}_i))doc"); - detail::bind_unary_operation(module, ttnn::atan, R"doc(\mathrm{{output\_tensor}}_i = atan(\mathrm{{input\_tensor}}_i))doc"); + detail::bind_unary_operation(module, ttnn::atan, R"doc(\mathrm{{output\_tensor}}_i = atan(\mathrm{{input\_tensor}}_i))doc", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + )doc"); + detail::bind_unary_operation(module, ttnn::cos, R"doc(\mathrm{{output\_tensor}}_i = cos(\mathrm{{input\_tensor}}_i))doc"); detail::bind_unary_operation(module, ttnn::erfinv, R"doc(\mathrm{{output\_tensor}}_i = erfinv(\mathrm{{input\_tensor}}_i))doc", R"doc(Supported dtypes, layouts, and ranks: @@ -1544,7 +1554,16 @@ void py_module(py::module& module) { detail::bind_unary_operation(module, ttnn::sqrt, R"doc(\mathrm{{output\_tensor}}_i = sqrt(\mathrm{{input\_tensor}}_i))doc"); detail::bind_unary_operation(module, ttnn::square, R"doc(\mathrm{{output\_tensor}}_i = square(\mathrm{{input\_tensor}}_i))doc"); detail::bind_unary_operation(module, ttnn::tan, R"doc(\mathrm{{output\_tensor}}_i = tan(\mathrm{{input\_tensor}}_i))doc"); - detail::bind_unary_operation(module, ttnn::tanh, R"doc(\mathrm{{output\_tensor}}_i = tanh(\mathrm{{input\_tensor}}_i))doc"); + detail::bind_unary_operation(module, ttnn::tanh, R"doc(\mathrm{{output\_tensor}}_i = tanh(\mathrm{{input\_tensor}}_i))doc", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + )doc"); detail::bind_unary_operation(module, ttnn::log_sigmoid, R"doc(\mathrm{{output\_tensor}}_i = \verb|log_sigmoid|(\mathrm{{input\_tensor}}_i))doc", R"doc(Supported dtypes, layouts, and ranks: @@ -1803,7 +1822,17 @@ void py_module(py::module& module) { )doc"); detail::bind_unary_composite(module, ttnn::cbrt, R"doc(Performs cbrt function on :attr:`input_tensor`.)doc"); - detail::bind_unary_composite(module, ttnn::cosh, R"doc(Performs cosh function on :attr:`input_tensor`.)doc", "[supported range -9 to 9]"); + detail::bind_unary_composite(module, ttnn::cosh, R"doc(Performs cosh function on :attr:`input_tensor`.)doc", "[supported range -9 to 9]", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16 | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + )doc"); + detail::bind_unary_composite(module, ttnn::digamma, R"doc(Performs digamma function on :attr:`input_tensor`.)doc", "[supported for value greater than 0]"); detail::bind_unary_composite(module, ttnn::lgamma, R"doc(Performs lgamma function on :attr:`input_tensor`.)doc", "[supported for value greater than 0]"); detail::bind_unary_composite(module, ttnn::log1p, R"doc(Performs log1p function on :attr:`input_tensor`.)doc", "[supported range -1 to 1]", @@ -1818,7 +1847,16 @@ void py_module(py::module& module) { )doc"); detail::bind_unary_composite(module, ttnn::mish, R"doc(Performs mish function on :attr:`input_tensor`, not supported for grayskull.)doc"); detail::bind_unary_composite(module, ttnn::multigammaln, R"doc(Performs multigammaln function on :attr:`input_tensor`.)doc", "[supported range 1.6 to inf]"); - detail::bind_unary_composite(module, ttnn::sinh, R"doc(Performs sinh function on :attr:`input_tensor`.)doc", "[supported range -88 to 88]"); + detail::bind_unary_composite(module, ttnn::sinh, R"doc(Performs sinh function on :attr:`input_tensor`.)doc", "[supported range -88 to 88]", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16 | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + )doc"); detail::bind_unary_composite(module, ttnn::softsign, R"doc(Performs softsign function on :attr:`input_tensor`.)doc"); detail::bind_unary_composite(module, ttnn::swish, R"doc(Performs swish function on :attr:`input_tensor`.)doc"); detail::bind_unary_composite(module, ttnn::var_hw, R"doc(Performs var_hw function on :attr:`input_tensor`.)doc"); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index e9ac44f7244..afcaae4f8d3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -177,7 +177,7 @@ void bind_unary_backward_rsqrt( template void bind_unary_backward_op_reciprocal( - py::module& module, const unary_backward_operation_t& operation, const std::string& description) { + py::module& module, const unary_backward_operation_t& operation, const std::string& description, const std::string_view supported_dtype = "") { auto doc = fmt::format( R"doc( {2} @@ -192,6 +192,9 @@ void bind_unary_backward_op_reciprocal( Returns: List of ttnn.Tensor: the output tensor. + Note: + {3} + Example: >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device) @@ -200,7 +203,8 @@ void bind_unary_backward_op_reciprocal( )doc", operation.base_name(), operation.python_fully_qualified_name(), - description); + description, + supported_dtype); bind_registered_operation( module, @@ -347,7 +351,8 @@ void bind_unary_backward_two_float_with_default( const std::string& parameter_name_b, const std::string& parameter_b_doc, float parameter_b_value, - const std::string_view description) { + const std::string_view description, + const std::string_view supported_dtype = "") { auto doc = fmt::format( R"doc( {8} @@ -357,18 +362,21 @@ void bind_unary_backward_two_float_with_default( input_tensor (ComplexTensor or ttnn.Tensor): the input tensor. Keyword args: - {2} (float, optional): {3} , Default to {4} - {5} (float, optional): {6} , Default to {7} + {2} (float, optional): {3} , Defaults to {4}. + {5} (float, optional): {6} , Defaults to {7}. memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. Returns: List of ttnn.Tensor: the output tensor. + Note: + {9} + Example: >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device) >>> input = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device) - >>> output = {1}(grad_tensor, input, {2} = {3}, {5} = {6}) + >>> output = {1}(grad_tensor, input, {2} = {4}, {5} = {7}) )doc", operation.base_name(), operation.python_fully_qualified_name(), @@ -378,7 +386,8 @@ void bind_unary_backward_two_float_with_default( parameter_name_b, parameter_b_doc, parameter_b_value, - description); + description, + supported_dtype); bind_registered_operation( module, @@ -535,7 +544,8 @@ void bind_unary_backward_float_string_default( const std::string& parameter_name_b, const std::string& parameter_b_doc, string parameter_b_value, - const std::string_view description) { + const std::string_view description, + const std::string_view supported_dtype = "") { auto doc = fmt::format( R"doc( {7} @@ -552,6 +562,9 @@ void bind_unary_backward_float_string_default( Returns: List of ttnn.Tensor: the output tensor. + Note: + {8} + Example: >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) @@ -565,7 +578,8 @@ void bind_unary_backward_float_string_default( parameter_name_b, parameter_b_doc, parameter_b_value, - description); + description, + supported_dtype); bind_registered_operation( module, @@ -1012,7 +1026,8 @@ void bind_unary_backward_gelu( const std::string& parameter_name_a, const std::string& parameter_a_doc, string parameter_a_value, - const std::string_view description) { + const std::string_view description, + const std::string_view supported_dtype = "") { auto doc = fmt::format( R"doc( {5} @@ -1027,6 +1042,9 @@ void bind_unary_backward_gelu( output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`. queue_id (uint8, optional): command queue id. Defaults to `0`. + Note: + {6} + Example: @@ -1039,7 +1057,8 @@ void bind_unary_backward_gelu( parameter_name_a, parameter_a_doc, parameter_a_value, - description); + description, + supported_dtype); bind_registered_operation( module, operation, @@ -1086,7 +1105,14 @@ void py_module(py::module& module) { "max", "Maximum value", 1.0, - R"doc(Performs backward operations for hardtanh activation function on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`.)doc"); + R"doc(Performs backward operations for hardtanh activation function on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`.)doc", + R"doc(Supported dtypes, layouts, and ranks: + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16 | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + )doc"); detail::bind_unary_backward_float_with_default( @@ -1214,7 +1240,16 @@ void py_module(py::module& module) { "threshold", "Threshold value", 20.0, - R"doc(Performs backward operations for softplus on :attr:`input_tensor`, :attr:`beta`, :attr:`threshold` with given :attr:`grad_tensor`.)doc"); + R"doc(Performs backward operations for softplus on :attr:`input_tensor`, :attr:`beta`, :attr:`threshold` with given :attr:`grad_tensor`.)doc", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16 | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + )doc"); detail::bind_unary_backward_float_string_default( module, @@ -1225,7 +1260,16 @@ void py_module(py::module& module) { "Mode of Rounding", "None", R"doc(Performs backward operations for Unary rdiv on :attr:`input_tensor`, :attr:`scalar` with given :attr:`grad_tensor` using given :attr:`round_mode`. - :attr:`round_mode` can be 'None', 'trunc', or 'floor'.)doc"); + :attr:`round_mode` can be 'None', 'trunc', or 'floor'.)doc", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16 | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + )doc"); detail::bind_unary_backward_shape( module, @@ -1241,7 +1285,16 @@ void py_module(py::module& module) { "Approximation type", "none", R"doc(Performs backward operations for gelu on :attr:`input_tensor_a` or :attr:`input_tensor`, with given :attr:`grad_tensor` using given :attr:`approximate` mode. - :attr:`approximate` mode can be 'none', 'tanh'.)doc"); + :attr:`approximate` mode can be 'none', 'tanh'.)doc", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16 | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + )doc"); detail::bind_unary_backward_unary_optional_float( module, @@ -1897,7 +1950,16 @@ void py_module(py::module& module) { detail::bind_unary_backward_op_reciprocal( module, ttnn::reciprocal_bw, - R"doc(Performs backward operations for reciprocal on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + R"doc(Performs backward operations for reciprocal on :attr:`input_tensor` with given :attr:`grad_tensor`)doc", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16 | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + )doc"); detail::bind_unary_backward_op( module, diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp index 3fe422e7786..bf2cf007f15 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp @@ -61,7 +61,7 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, output_shape[1], output_shape[2], channels, - input_tensor.device(), + input_tensor.device()->compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR, false); num_cores_nhw = conv::conv2d::get_num_cores_nhw_from_parallel_config(parallel_config); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp index a5d9f2df35c..d21dc41a84f 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp @@ -375,6 +375,8 @@ void MAIN { constexpr uint32_t k_chunk_size = get_compile_time_arg_val(17); constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(18); constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(19); + constexpr bool is_causal = get_compile_time_arg_val(20) == 1; + constexpr bool use_attention_mask = get_compile_time_arg_val(21) == 1; constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt; constexpr uint32_t k_chunk_tiles = Sk_chunk_t * DHt; @@ -423,23 +425,26 @@ void MAIN { } // Get cur_pos - uint32_t cur_pos = 0; - // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list - if (cur_pos_arg != UINT32_MAX){ - cur_pos = cur_pos_arg; - } - else { - constexpr uint32_t cb_index_id = tt::CB::dataflow0; - cb_wait_front(cb_index_id, 1); - volatile uint32_t *index_addr_ptr; - cb_get_tile(cb_index_id, 0, &index_addr_ptr); - cur_pos = index_addr_ptr[4+cur_batch]; - cb_release_tile(cb_index_id); - } + constexpr uint32_t cur_pos_base = St*32-1; + uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position + if constexpr(is_causal) { + // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list + if (cur_pos_arg != UINT32_MAX){ + cur_pos = cur_pos_arg; + } + else { + constexpr uint32_t cb_index_id = tt::CB::dataflow0; + cb_wait_front(cb_index_id, 1); + volatile uint32_t *index_addr_ptr; + cb_get_tile(cb_index_id, 0, &index_addr_ptr); + cur_pos = index_addr_ptr[4+cur_batch]; + cb_release_tile(cb_index_id); + } - if (cur_pos == UINT32_MAX) { - // cur_pos of -1 indicates that the user should be skipped - return; + if (cur_pos == UINT32_MAX) { + // cur_pos of -1 indicates that the user should be skipped + return; + } } // Sequence length assignment auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = get_runtime_args(cur_pos, cur_batch, core_num_in_reduce, num_cores_per_head, k_chunk_size); @@ -464,11 +469,19 @@ void MAIN { /* QK *= SCALE */ mul_block_bcast_scalar_inplace(cb_qk_im, cb_scale_in, qk_chunk_tiles); - // For decode, we only apply mask at the last chunk on reducer cor - if (k_chunk == k_chunk_end - 1 && do_reduce) { - /* QK += MASK */ - reconfig_data_format(cb_qk_im, cb_mask_in); - add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); + if constexpr(is_causal){ + // For decode, we only apply mask at the last chunk on reducer core for causal mode + if (k_chunk == k_chunk_end - 1 && do_reduce) { + /* QK += MASK */ + reconfig_data_format(cb_qk_im, cb_mask_in); + add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); + } + } + else { + if constexpr(use_attention_mask){ + reconfig_data_format(cb_qk_im, cb_mask_in); + add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); + } } reconfig_data_format(cb_qk_im, cb_identity_scale_in); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp index 2e973fbcb8b..5e82ebcfa62 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp @@ -27,6 +27,31 @@ uint32_t virtual_seq_tile_id_to_physical_tile_id(uint32_t seq_tile_idx, uint32_t return physical_block * block_stride + head_offset + block_offset; } +template +uint32_t read_mask_chunk(uint32_t PSt, uint32_t mask_start_tile_id, const InterleavedAddrGenFast mask_reader) { + // Read mask chunk + cb_reserve_back(cb_mask_in, mask_chunk_tiles); + uint32_t mask_write_ptr = get_write_ptr(cb_mask_in); + uint32_t barrier_count = 0; + for (uint32_t row = 0; row < PNHt; ++row) { + uint32_t mask_tile_id = mask_start_tile_id + row * PSt; + for (uint32_t col = 0; col < Sk_chunk_t; ++col) { + noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr); + mask_tile_id++; + mask_write_ptr += mask_tile_bytes; + + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } + } + } + noc_async_read_barrier(); + cb_push_back(cb_mask_in, mask_chunk_tiles); + mask_start_tile_id += mask_chunk_tiles; + return mask_start_tile_id; +} + void kernel_main() { /* In DRAM, Q is (B, PNHt, DHt), K is (B, St, DHt), V is (B, St, DHt), mask is (B, PNHt, PSt) @@ -50,6 +75,8 @@ void kernel_main() { constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(14); constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(15); constexpr uint32_t num_output_cores = get_compile_time_arg_val(16); + constexpr bool is_causal = get_compile_time_arg_val(17) == 1; + constexpr bool use_attention_mask = get_compile_time_arg_val(18) == 1; uint32_t arg_idx = 0; const uint32_t q_addr = get_arg_val(arg_idx++); @@ -57,6 +84,7 @@ void kernel_main() { const uint32_t v_addr = get_arg_val(arg_idx++); const uint32_t pos_addr = get_arg_val(arg_idx++); const uint32_t page_table_addr = get_arg_val(arg_idx++); + const uint32_t mask_addr = get_arg_val(arg_idx++); const uint32_t page_table_page_size = get_arg_val(arg_idx++); const bool is_worker = get_arg_val(arg_idx++) == 0; const bool is_output_core = get_arg_val(arg_idx++) == 1; @@ -71,32 +99,35 @@ void kernel_main() { return; } // Get cur_pos - uint32_t cur_pos = 0; - // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list - if (cur_pos_arg != UINT32_MAX){ - cur_pos = cur_pos_arg; - } - else { - constexpr uint32_t cb_index_id = tt::CB::dataflow0; - const InterleavedAddrGen addrg = { - .bank_base_address = pos_addr, - .page_size = index_stick_size_B - }; - - cb_reserve_back(cb_index_id, 1); - uint32_t index_cb_wr_ptr = get_write_ptr(cb_index_id); - // index_tensor has one page to read - uint64_t tensor_index_noc_addr = get_noc_addr(0, addrg); - noc_async_read(tensor_index_noc_addr, index_cb_wr_ptr, index_stick_size_B); - noc_async_read_barrier(); - cb_push_back(cb_index_id, 1); - volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast(index_cb_wr_ptr); - cur_pos = index_ptr[cur_batch]; - } + constexpr uint32_t cur_pos_base = St*32-1; + uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position + if constexpr(is_causal) { + // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list + if (cur_pos_arg != UINT32_MAX){ + cur_pos = cur_pos_arg; + } + else { + constexpr uint32_t cb_index_id = tt::CB::dataflow0; + const InterleavedAddrGen addrg = { + .bank_base_address = pos_addr, + .page_size = index_stick_size_B + }; + + cb_reserve_back(cb_index_id, 1); + uint32_t index_cb_wr_ptr = get_write_ptr(cb_index_id); + // index_tensor has one page to read + uint64_t tensor_index_noc_addr = get_noc_addr(0, addrg); + noc_async_read(tensor_index_noc_addr, index_cb_wr_ptr, index_stick_size_B); + noc_async_read_barrier(); + cb_push_back(cb_index_id, 1); + volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast(index_cb_wr_ptr); + cur_pos = index_ptr[cur_batch]; + } - if (cur_pos == UINT32_MAX) { - // cur_pos of -1 indicates that the user should be skipped - return; + if (cur_pos == UINT32_MAX) { + // cur_pos of -1 indicates that the user should be skipped + return; + } } const uint32_t valid_seq_len_tiles = (cur_pos + 1 + 32 - 1) / 32; @@ -137,6 +168,7 @@ void kernel_main() { constexpr uint32_t cb_q_in = tt::CB::c_in0; constexpr uint32_t cb_k_in = tt::CB::c_in1; constexpr uint32_t cb_v_in = tt::CB::c_in2; + constexpr uint32_t cb_mask_in = tt::CB::c_in3; constexpr uint32_t onetile = 1; @@ -146,6 +178,8 @@ void kernel_main() { constexpr DataFormat k_data_format = get_dataformat(cb_k_in); constexpr uint32_t v_tile_bytes = get_tile_size(cb_v_in); constexpr DataFormat v_data_format = get_dataformat(cb_v_in); + constexpr uint32_t mask_tile_bytes = get_tile_size(cb_mask_in); + constexpr DataFormat mask_data_format = get_dataformat(cb_mask_in); constexpr uint32_t barrier_threshold = get_barrier_read_threshold(); uint32_t barrier_count = 0; @@ -202,7 +236,16 @@ void kernel_main() { .data_format = v_data_format }; + const InterleavedAddrGenFast mask_reader = { + .bank_base_address = mask_addr, + .page_size = mask_tile_bytes, + .data_format = mask_data_format + }; + for (uint32_t cur_head = cur_head_group*num_heads_per_core; cur_head < cur_head_group*num_heads_per_core + num_heads_per_core; ++cur_head) { + const uint32_t mask_batch_offset = (cur_batch % Bkv) * PNHt * St; + const uint32_t mask_chunk_offset = k_chunk_start * Sk_chunk_t; + uint32_t mask_start_tile_id = mask_batch_offset + mask_chunk_offset; if constexpr (is_paged_attention) { for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { @@ -229,6 +272,10 @@ void kernel_main() { noc_async_read_barrier(); cb_push_back(cb_k_in, k_chunk_tiles); + if constexpr(use_attention_mask){ + mask_start_tile_id = read_mask_chunk(PSt, mask_start_tile_id, mask_reader); + } + // Read V chunk in row major order, write in row-major order cb_reserve_back(cb_v_in, k_chunk_tiles); uint32_t v_write_ptr = get_write_ptr(cb_v_in); @@ -289,6 +336,10 @@ void kernel_main() { cb_push_back(cb_k_in, k_chunk_tiles); k_start_tile_id += k_chunk_tiles; + if constexpr(use_attention_mask){ + mask_start_tile_id = read_mask_chunk(PSt, mask_start_tile_id, mask_reader); + } + // Read V chunk cb_reserve_back(cb_v_in, k_chunk_tiles); uint32_t v_write_ptr = get_write_ptr(cb_v_in); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp index 7374a7594f6..4059ad84736 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp @@ -244,6 +244,7 @@ void kernel_main() { constexpr uint32_t num_reducer_cores = get_compile_time_arg_val(16); constexpr uint32_t num_output_cores = get_compile_time_arg_val(17); constexpr uint32_t ELEMENT_SIZE = get_compile_time_arg_val(18); + constexpr bool is_causal = get_compile_time_arg_val(19) == 1; uint32_t arg_idx = 0; const uint32_t out_addr = get_arg_val(arg_idx++); @@ -262,22 +263,25 @@ void kernel_main() { return; } // Get cur_pos - uint32_t cur_pos = 0; + constexpr uint32_t cur_pos_base = St*32-1; + uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list - if (cur_pos_arg != UINT32_MAX){ - cur_pos = cur_pos_arg; - } - else { - constexpr uint32_t cb_index_id = tt::CB::dataflow0; - cb_wait_front(cb_index_id, 1); - uint32_t index_cb_ptr = get_read_ptr(cb_index_id); - volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast(index_cb_ptr); - cur_pos = index_ptr[cur_batch]; - } + if constexpr(is_causal) { + if (cur_pos_arg != UINT32_MAX){ + cur_pos = cur_pos_arg; + } + else { + constexpr uint32_t cb_index_id = tt::CB::dataflow0; + cb_wait_front(cb_index_id, 1); + uint32_t index_cb_ptr = get_read_ptr(cb_index_id); + volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast(index_cb_ptr); + cur_pos = index_ptr[cur_batch]; + } - if (cur_pos == UINT32_MAX) { - // cur_pos of -1 indicates that the user should be skipped - return; + if (cur_pos == UINT32_MAX) { + // cur_pos of -1 indicates that the user should be skipped + return; + } } // Sequence length assignment auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = get_runtime_args(cur_pos, cur_batch, core_num_in_reduce, num_cores_per_head, k_chunk_size); @@ -347,8 +351,8 @@ void kernel_main() { constexpr uint32_t barrier_threshold = get_barrier_read_threshold(); uint32_t barrier_count = 0; - // generate and send mask to compute - generate_mask(k_num_chunks, PSt, cur_pos); + // generate and send mask to compute if causal + if constexpr(is_causal) generate_mask(k_num_chunks, PSt, cur_pos); for (uint32_t cur_head = cur_head_group*num_heads_per_core; cur_head < cur_head_group*num_heads_per_core + num_heads_per_core; ++cur_head) { if (k_chunk_end - k_chunk_start < k_num_chunks){ diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp index 476e0bbd57e..b192917fe09 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp @@ -23,10 +23,10 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ input_tensor.get_dtype()); } - const auto q_shape = input_tensors.at(0).get_legacy_shape(); - const auto q_shape_unpadded = input_tensors.at(0).get_shape(); - const auto k_shape = input_tensors.at(1).get_legacy_shape(); - const auto v_shape = input_tensors.at(2).get_legacy_shape(); + const auto q_shape = input_tensors.at(0).get_padded_shape(); + const auto q_shape_unpadded = input_tensors.at(0).get_logical_shape(); + const auto k_shape = input_tensors.at(1).get_padded_shape(); + const auto v_shape = input_tensors.at(2).get_padded_shape(); // Input 0 must be sharded by height or DRAM interleaved. All other inputs must be in DRAM. const auto Q_memcfg = input_tensors.at(0).memory_config(); @@ -44,30 +44,60 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); } + if (!this->is_causal) { + if (optional_input_tensors.at(2).has_value()){ + // Causal attention verification + const auto& mask_tensor = optional_input_tensors.at(2).value(); + const auto mask_shape = mask_tensor.get_padded_shape(); + const auto mask_shape_unpadded = mask_tensor.get_logical_shape(); + + TT_FATAL(mask_shape[2] == q_shape[2], "Expect same number of padded heads in mask as in Q, got {} and {}", mask_shape[2], q_shape[2]); + TT_FATAL(mask_shape_unpadded[2] == q_shape_unpadded[2], "Expect same number of heads in mask as in Q, got {} and {}", mask_shape_unpadded[3], q_shape_unpadded[2]); + if (! this->paged_attention) TT_FATAL(mask_shape[3] == k_shape[2], "Expect same sequence length in mask as in K, got {} and {}", mask_shape[3], k_shape[2]); + TT_FATAL(mask_shape[3] % k_chunk_size == 0, "Mask sequence length must be multiple of chunk size, got: {} and {}", mask_shape[3], k_chunk_size); + + TT_FATAL( + mask_tensor.get_dtype() == DataType::BFLOAT16 || mask_tensor.get_dtype() == DataType::BFLOAT8_B || + mask_tensor.get_dtype() == DataType::BFLOAT4_B, + "Unsupported data type for mask tensor: {}.", + mask_tensor.get_dtype()); + } + } else { + // Uncausal attention verification + TT_FATAL(not optional_input_tensors.at(2).has_value(), "Must not have attn_mask tensor for non-causal attention"); + } + if (this->paged_attention) { // Paged attention verification TT_FATAL(! this->share_cache.value_or(false), "Share cache feature not supported for paged attention"); - TT_FATAL(optional_input_tensors.at(0).has_value(), "Must have cur_pos tensor for paged attention"); - TT_FATAL(optional_input_tensors.at(1).has_value(), "Must have page_table tensor for paged attention"); + const auto B = q_shape[1]; - const auto& cur_pos_tensor = optional_input_tensors.at(0).value(); - const auto& page_table_tensor = optional_input_tensors.at(1).value(); + if (this->is_causal) { + // Check cur pos tensor for causal mode + TT_FATAL(optional_input_tensors.at(0).has_value(), "Must have cur_pos tensor for paged attention in causal mode"); + const auto& cur_pos_tensor = optional_input_tensors.at(0).value(); + TT_FATAL(cur_pos_tensor.get_dtype() == DataType::INT32, "Expect cur_pos to be INT32, got {}", cur_pos_tensor.get_dtype()); + TT_FATAL(cur_pos_tensor.get_layout() == Layout::ROW_MAJOR, "Expect cur_pos to be ROW_MAJOR, got {}", cur_pos_tensor.get_layout()); + const auto cur_pos_shape = cur_pos_tensor.get_padded_shape(); + TT_FATAL(cur_pos_shape[0] == B, "cur_pos must have batch size equal to Q, got {} and {}", cur_pos_shape[0], B); + } - TT_FATAL(cur_pos_tensor.get_dtype() == DataType::INT32, "Error"); - TT_FATAL(cur_pos_tensor.get_layout() == Layout::ROW_MAJOR, "Error"); + TT_FATAL(optional_input_tensors.at(1).has_value(), "Must have page_table tensor for paged attention"); + const auto& page_table_tensor = optional_input_tensors.at(1).value(); TT_FATAL(page_table_tensor.get_dtype() == DataType::INT32, "Error"); TT_FATAL(page_table_tensor.get_layout() == Layout::ROW_MAJOR, "Error"); - const auto cur_pos_shape = cur_pos_tensor.get_legacy_shape(); - const auto page_table_shape = page_table_tensor.get_legacy_shape(); + const auto page_table_shape = page_table_tensor.get_padded_shape(); - const auto B = q_shape[1]; - TT_FATAL(cur_pos_shape[0] == B, "cur_pos must have batch size equal to Q"); TT_FATAL(page_table_shape[0] == B, "page_table must have hidden size equal to Q"); TT_FATAL(k_shape[2] == v_shape[2], "K and V must have same block size"); TT_FATAL(k_shape[3] == v_shape[3] && k_shape[3] == q_shape[3], "Q, K, V must have same hidden size"); + + // Validate chunk size for paged version + TT_FATAL(k_chunk_size % 32 == 0, "Chunk size must be multiple of 32, got: {}", k_chunk_size); + if (! this->is_causal) TT_FATAL((page_table_shape[1]*k_shape[2]) % k_chunk_size == 0, "K sequence length must be multiple of chunk size, got: {} and {}", page_table_shape[1]*k_shape[2], k_chunk_size); } else { // Unpaged attention verification TT_FATAL(not optional_input_tensors.at(1).has_value(), "Must not have page_table tensor for unpaged attention"); @@ -93,6 +123,10 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ // Check sequence lengths TT_FATAL(k_shape[-2] == v_shape[-2], "Error"); + // Validate chunk size for unpaged version + TT_FATAL(k_chunk_size % 32 == 0, "Chunk size must be multiple of 32, got: {}", k_chunk_size); + TT_FATAL(k_shape[2] % k_chunk_size == 0, "K sequence length must be multiple of chunk size, got: {} and {}", k_shape[2], k_chunk_size); + // Check hidden size const auto D = q_shape[-1]; TT_FATAL(k_shape[-1] == D, "Error"); @@ -115,9 +149,9 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ } } -std::vector ScaledDotProductAttentionDecode::compute_output_shapes( +std::vector ScaledDotProductAttentionDecode::compute_output_shapes( const std::vector& input_tensors) const { - return {input_tensors.at(0).get_legacy_shape()}; + return {input_tensors.at(0).get_padded_shape()}; } std::vector ScaledDotProductAttentionDecode::create_output_tensors( @@ -134,12 +168,13 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( auto& cur_pos_tensor = optional_input_tensors.at(0); auto& page_table_tensor = optional_input_tensors.at(1); + auto& attn_mask = optional_input_tensors.at(2); auto& output_tensor = output_tensors.at(0); auto scale = this->scale; if (not scale.has_value()) { - scale = 1.0f / std::sqrt(static_cast(input_tensor_q.get_legacy_shape()[-1])); + scale = 1.0f / std::sqrt(static_cast(input_tensor_q.get_padded_shape()[-1])); } return detail::sdpa_decode_multi_core( @@ -148,7 +183,9 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( input_tensor_v, cur_pos_tensor, page_table_tensor, + attn_mask, output_tensor, + this->is_causal, this->cur_pos, scale, this->compute_kernel_config, @@ -158,6 +195,8 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( } operation::Hash ScaledDotProductAttentionDecode::compute_program_hash(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { + bool has_cur_pos = optional_input_tensors.at(0).has_value(); + bool has_attn_mask = optional_input_tensors.at(2).has_value(); return operation::hash_operation( this->scale, this->output_mem_config, @@ -165,6 +204,9 @@ operation::Hash ScaledDotProductAttentionDecode::compute_program_hash(const std: this->compute_kernel_config, this->k_chunk_size, this->paged_attention, + this->is_causal, + has_attn_mask, + has_cur_pos, input_tensors); } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp index 626f10f133d..c055bbb77b9 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp @@ -14,6 +14,7 @@ namespace ttnn::operations::transformer { struct ScaledDotProductAttentionDecode { + const bool is_causal; std::vector cur_pos; const std::optional scale; const MemoryConfig output_mem_config; @@ -26,7 +27,7 @@ struct ScaledDotProductAttentionDecode { void validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_shapes(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index 2c0a99f0389..3d5ba1ae74f 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -25,7 +25,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( const Tensor& input_tensor_v, std::optional cur_pos_tensor, std::optional page_table_tensor, + std::optional attn_mask, const Tensor& output_tensor, + bool is_causal, const std::vector& cur_pos_ids, std::optional scale, DeviceComputeKernelConfig compute_kernel_config, @@ -46,9 +48,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( const bool is_paged_attention = page_table_tensor.has_value(); - const auto q_shape = input_tensor_q.get_legacy_shape(); - const auto q_shape_unpadded = input_tensor_q.get_shape(); - const auto k_shape = input_tensor_k.get_legacy_shape(); + const auto q_shape = input_tensor_q.get_padded_shape(); + const auto q_shape_unpadded = input_tensor_q.get_logical_shape(); + const auto k_shape = input_tensor_k.get_padded_shape(); // Use k_shape for S and DH since Q might be different for decode uint32_t B = q_shape[1], PNH = q_shape[2], S = k_shape[2], DH = k_shape[3]; @@ -59,6 +61,8 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( if (is_paged_attention) { uint32_t block_size = k_shape[2]; page_block_size_t = block_size / TILE_HEIGHT; + // get real S using the page_table_tensor + S = page_table_tensor.value().get_padded_shape()[-1]*S; } uint32_t Bkv = k_shape[0]; uint32_t St = S/TILE_HEIGHT; @@ -102,6 +106,10 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( auto out0_buffer = output_tensor.buffer(); bool use_cur_pos_tensor = cur_pos_tensor.has_value(); + bool use_attention_mask = attn_mask.has_value(); + + log_debug("use_cur_pos_tensor: {}", use_cur_pos_tensor); + log_debug("use_attention_mask: {}", use_attention_mask); // Parallelization scheme // We will assign cores to batches @@ -262,6 +270,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( tt::DataFormat q_df = tt_metal::datatype_to_dataformat_converter(input_tensor_q.get_dtype()); tt::DataFormat k_df = tt_metal::datatype_to_dataformat_converter(input_tensor_k.get_dtype()); tt::DataFormat v_df = tt_metal::datatype_to_dataformat_converter(input_tensor_v.get_dtype()); + tt::DataFormat mask_df = use_attention_mask ? tt_metal::datatype_to_dataformat_converter(attn_mask.value().get_dtype()) : tt::DataFormat::Float16_b; tt::DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); tt::DataFormat scalar_df = tt::DataFormat::Float16_b; tt::DataFormat im_df = tt::DataFormat::Float16_b; @@ -271,6 +280,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t q_tile_size = tt_metal::detail::TileSize(q_df); uint32_t k_tile_size = tt_metal::detail::TileSize(k_df); uint32_t v_tile_size = tt_metal::detail::TileSize(v_df); + uint32_t mask_tile_size = tt_metal::detail::TileSize(mask_df); uint32_t out_tile_size = tt_metal::detail::TileSize(out_df); uint32_t scalar_tile_size = tt_metal::detail::TileSize(scalar_df); uint32_t im_tile_size = tt_metal::detail::TileSize(im_df); @@ -329,7 +339,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( auto cb_in2_id = CreateCircularBuffer(program, core_grid, c_in2_config); // attn_mask input - auto c_in3_config = CircularBufferConfig(qk_tiles * stats_tile_size, {{CB::c_in3, stats_df}}).set_page_size(CB::c_in3, stats_tile_size); + auto c_in3_config = CircularBufferConfig(qk_tiles * mask_tile_size, {{CB::c_in3, mask_df}}).set_page_size(CB::c_in3, mask_tile_size); auto cb_in3_id = CreateCircularBuffer(program, core_grid, c_in3_config); // scale input @@ -486,7 +496,8 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( B, PNHt, St, DHt, Sk_chunk_t, num_active_cores, is_q_sharded, num_cores_per_batch, k_chunk_size, index_stick_size, (uint32_t)is_paged_attention, num_kv_heads, page_block_size_t, - Bkv, num_cores_per_head, num_heads_per_core, num_output_cores + Bkv, num_cores_per_head, num_heads_per_core, num_output_cores, + is_causal, use_attention_mask, }; std::vector writer_compile_time_args_common = { @@ -505,14 +516,15 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( num_heads_per_core, num_reducer_cores, num_output_cores, - output_tensor.element_size() + output_tensor.element_size(), + is_causal, }; std::vector compute_compile_time_args_common = { St, DHt, PNHt, Sk_chunk_t, qk_in0_block_w, qk_out_subblock_w, qk_out_subblock_h, qk_in0_num_subblocks, qk_in1_num_subblocks, qk_num_blocks, out_in0_block_w, out_out_subblock_w, out_out_subblock_h, out_in0_num_subblocks, out_in1_num_subblocks, out_num_blocks, - num_cores_per_batch, k_chunk_size, num_cores_per_head, num_heads_per_core + num_cores_per_batch, k_chunk_size, num_cores_per_head, num_heads_per_core, is_causal, use_attention_mask, }; std::map defines; @@ -562,6 +574,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t v_addr = v_buffer->address(); uint32_t pos_addr = use_cur_pos_tensor ? cur_pos_tensor.value().buffer()->address() : 0; uint32_t page_table_addr = is_paged_attention ? page_table_tensor.value().buffer()->address() : 0; + uint32_t attn_mask_addr = use_attention_mask ? attn_mask.value().buffer()->address() : 0; uint32_t out_addr = out0_buffer->address(); // Set rt args @@ -577,7 +590,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t core_num_in_reduce = i % num_cores_per_head; uint32_t core_num_in_output = i % num_cores_per_batch; - uint32_t cur_pos = use_cur_pos_tensor ? -1 : cur_pos_ids.at(cur_batch); + uint32_t cur_pos = (use_cur_pos_tensor || ! is_causal) ? -1 : cur_pos_ids.at(cur_batch); log_debug("---- core_id: {}, coord: {} ----", i, core); log_debug("worker_id_for_reduce: {}", worker_id_for_reduce); @@ -591,7 +604,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( log_debug("cur_pos: {}", cur_pos); // reader runtime args - std::vector reader_rt_args = { q_addr, k_addr, v_addr, pos_addr, page_table_addr, page_table_stick_size, do_reduce, do_output, cur_head, cur_batch, core_num_in_reduce, core_num_in_output, cur_pos}; + std::vector reader_rt_args = { q_addr, k_addr, v_addr, pos_addr, page_table_addr, attn_mask_addr, page_table_stick_size, do_reduce, do_output, cur_head, cur_batch, core_num_in_reduce, core_num_in_output, cur_pos}; reader_rt_args.insert(reader_rt_args.end(), output_core_physical_xs.begin(), output_core_physical_xs.end()); reader_rt_args.insert(reader_rt_args.end(), output_core_physical_ys.begin(), output_core_physical_ys.end()); @@ -640,7 +653,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( cb_out4_id, B, use_cur_pos_tensor, - is_paged_attention + use_attention_mask, + is_paged_attention, + is_causal ] ( const void* operation, @@ -662,6 +677,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t v_addr = v_buffer->address(); uint32_t pos_addr = use_cur_pos_tensor ? optional_input_tensors.at(0).value().buffer()->address() : 0; uint32_t page_table_addr = is_paged_attention ? optional_input_tensors.at(1).value().buffer()->address() : 0; + uint32_t attn_mask_addr = use_attention_mask ? optional_input_tensors.at(2).value().buffer()->address() : 0; auto page_table_buffer = is_paged_attention ? optional_input_tensors.at(1).value().buffer() : nullptr; uint32_t page_table_stick_size = is_paged_attention ? page_table_buffer->aligned_page_size() : 0; uint32_t out_addr = out0_buffer->address(); @@ -681,7 +697,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t cur_batch = i / num_cores_per_batch; uint32_t core_num_in_reduce = (num_cores_per_head == 0) ? 0 : i % num_cores_per_head; uint32_t core_num_in_output = i % num_cores_per_batch; - uint32_t cur_pos = use_cur_pos_tensor ? -1 : cur_pos_ids.at(cur_batch); + uint32_t cur_pos = (use_cur_pos_tensor || ! is_causal) ? -1 : cur_pos_ids.at(cur_batch); auto& reader_args = reader_args_by_core[core.x][core.y]; auto& writer_args = writer_args_by_core[core.x][core.y]; @@ -694,6 +710,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( reader_args[arg_idx++] = v_addr; reader_args[arg_idx++] = pos_addr; reader_args[arg_idx++] = page_table_addr; + reader_args[arg_idx++] = attn_mask_addr; reader_args[arg_idx++] = page_table_stick_size; reader_args[arg_idx++] = do_reduce; reader_args[arg_idx++] = do_output; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp index 9d619f9052d..ea25388791e 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp @@ -16,7 +16,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( const Tensor &input_tensor_v, std::optional cur_pos_tensor, std::optional page_table_tensor, + std::optional attn_mask, const Tensor &output_tensor, + bool is_causal, const std::vector& cur_pos_ids, std::optional scale, DeviceComputeKernelConfig compute_kernel_config, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp index cf5804d559a..d1da3bc4c99 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp @@ -9,14 +9,20 @@ #include "ttnn/run_operation.hpp" namespace { -uint32_t get_chunk_size(uint32_t s) { - if (s <= 128) { - return 32; +inline uint32_t get_chunk_size(uint32_t s) { + /* + # find maximum power of 2 divisor of s + for i in range(1, s): + if s % (2**(i+1)) != 0: + break + */ + uint32_t i = 1; + for (; i < s; i++) { + if (s % (1 << (i + 1)) != 0) { + break; + } } - if (s <= 256) { - return 256; - } - return 512; + return std::min(512, 1 << i); } } // namespace @@ -27,6 +33,8 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, + const bool is_causal, + const std::optional attn_mask, const std::vector cur_pos, const std::optional cur_pos_tensor, std::optional scale, @@ -35,13 +43,15 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( std::optional compute_kernel_config) { auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - //uint32_t max_cur_pos = *std::max_element(cur_pos.begin(), cur_pos.end()); - uint32_t k_chunk_size = 512; //get_chunk_size(max_cur_pos + 1); + uint32_t s = input_tensor_k.get_logical_shape()[-2]; + uint32_t k_chunk_size = get_chunk_size(s); if (program_config.has_value() && program_config.value().k_chunk_size > 0) { k_chunk_size = program_config.value().k_chunk_size; // assert chunk size must be power of 2 and multiple of 32 TT_FATAL((k_chunk_size & (k_chunk_size - 1)) == 0, "User provided k_chunk_size must be power of 2, got: {}", k_chunk_size); TT_FATAL(k_chunk_size % 32 == 0, "User provided k_chunk_size must be multiple of 32, got: {}", k_chunk_size); + } else { + TT_FATAL(k_chunk_size % 32 == 0, "Chunk size must be multiple of 32, but the maximum calculated k_chunk_size is: {}", k_chunk_size); } // get chunk size and then pass to sdpa decode as an attribute for prgm cache @@ -50,6 +60,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( return operation::run( ScaledDotProductAttentionDecode{ + .is_causal = is_causal, .cur_pos = cur_pos, .scale = scale, .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), @@ -58,7 +69,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( .k_chunk_size = k_chunk_size, .paged_attention = false}, {input_tensor_q, input_tensor_k, input_tensor_v}, - {cur_pos_tensor, std::nullopt}, + {cur_pos_tensor, std::nullopt, attn_mask}, {}, queue_id) .at(0); @@ -68,6 +79,8 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, + const bool is_causal, + const std::optional attn_mask, const std::vector cur_pos, const std::optional cur_pos_tensor, std::optional scale, @@ -79,6 +92,8 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( input_tensor_q, input_tensor_k, input_tensor_v, + is_causal, + attn_mask, cur_pos, cur_pos_tensor, scale, @@ -93,21 +108,25 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal, + const std::optional attn_mask, + const std::optional &cur_pos_tensor, std::optional scale, const std::optional &memory_config, std::optional program_config, std::optional compute_kernel_config) { auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - //uint32_t max_cur_pos = *std::max_element(cur_pos.begin(), cur_pos.end()); - uint32_t k_chunk_size = 512; //get_chunk_size(max_cur_pos + 1); + uint32_t s = input_tensor_k.get_logical_shape()[-2]; + uint32_t k_chunk_size = get_chunk_size(s); if (program_config.has_value() && program_config.value().k_chunk_size > 0) { k_chunk_size = program_config.value().k_chunk_size; // assert chunk size must be power of 2 and multiple of 32 TT_FATAL((k_chunk_size & (k_chunk_size - 1)) == 0, "User provided k_chunk_size must be power of 2, got: {}", k_chunk_size); TT_FATAL(k_chunk_size % 32 == 0, "User provided k_chunk_size must be multiple of 32, got: {}", k_chunk_size); + } else { + TT_FATAL(k_chunk_size % 32 == 0, "Chunk size must be multiple of 32, but the maximum calculated k_chunk_size is: {}", k_chunk_size); } // get chunk size and then pass to sdpa decode as an attribute for prgm cache @@ -116,6 +135,7 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( return operation::run( ScaledDotProductAttentionDecode{ + .is_causal = is_causal, .cur_pos = std::vector(), .scale = scale, .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), @@ -124,7 +144,7 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( .k_chunk_size = k_chunk_size, .paged_attention = true}, {input_tensor_q, input_tensor_k, input_tensor_v}, - {cur_pos_tensor, page_table_tensor}, + {cur_pos_tensor, page_table_tensor, attn_mask}, {}, queue_id) .at(0); @@ -134,8 +154,10 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal, + const std::optional attn_mask, + const std::optional &cur_pos_tensor, std::optional scale, const std::optional &memory_config, std::optional program_config, @@ -145,8 +167,10 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( input_tensor_q, input_tensor_k, input_tensor_v, - cur_pos_tensor, page_table_tensor, + is_causal, + attn_mask, + cur_pos_tensor, scale, memory_config, program_config, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp index b86a7696288..7b8eee9ce16 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp @@ -17,7 +17,9 @@ struct ExecuteScaledDotProductAttentionDecode { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, + const bool is_causal= true, + const std::optional attn_mask= std::nullopt, + const std::vector cur_pos= std::vector(), const std::optional cur_pos_tensor= std::nullopt, std::optional scale = std::nullopt, const std::optional &memory_config = std::nullopt, @@ -28,7 +30,9 @@ struct ExecuteScaledDotProductAttentionDecode { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, + const bool is_causal= true, + const std::optional attn_mask= std::nullopt, + const std::vector cur_pos= std::vector(), const std::optional cur_pos_tensor= std::nullopt, std::optional scale = std::nullopt, const std::optional &memory_config = std::nullopt, @@ -42,8 +46,10 @@ struct ExecutePagedScaledDotProductAttentionDecode { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal= true, + const std::optional attn_mask= std::nullopt, + const std::optional &cur_pos_tensor= std::nullopt, std::optional scale = std::nullopt, const std::optional &memory_config = std::nullopt, std::optional program_config = std::nullopt, @@ -53,8 +59,10 @@ struct ExecutePagedScaledDotProductAttentionDecode { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal= true, + const std::optional attn_mask= std::nullopt, + const std::optional &cur_pos_tensor= std::nullopt, std::optional scale = std::nullopt, const std::optional &memory_config = std::nullopt, std::optional program_config = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp index 1215563ff4a..8026d56385b 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp @@ -26,11 +26,12 @@ void py_bind_sdpa_decode(py::module &module) { input_tensor_q (ttnn.Tensor): the input tensor [1 x b x nh x dh] input_tensor_k (ttnn.Tensor): the input tensor [b x nkv x s x dh] input_tensor_v (ttnn.Tensor): the input tensor [b x nkv x s x dh] - cur_pos (List of int): list of integers of length b. - Keyword args: + is_causal (bool): whether the attention is is_causal. Defaults to `True`. + attn_mask (ttnn.Tensor, optional): the input tensor [b x 1 x s x s]. Defaults to `None`. + cur_pos (List of int, optional): list of integers of length b. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. queue_id (int, optional): command queue id. Defaults to `0`. cur_pos_tensor (ttnn.Tensor, optional): [b] tensor of integers of length b. Defaults to `None`. @@ -57,6 +58,8 @@ void py_bind_sdpa_decode(py::module &module) { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, + const bool is_causal, + const std::optional attn_mask, const std::vector cur_pos, const std::optional cur_pos_tensor, std::optional scale, @@ -69,6 +72,8 @@ void py_bind_sdpa_decode(py::module &module) { input_tensor_q, input_tensor_k, input_tensor_v, + is_causal, + attn_mask, cur_pos, cur_pos_tensor, scale, @@ -79,8 +84,10 @@ void py_bind_sdpa_decode(py::module &module) { py::arg("input_tensor_q").noconvert(), py::arg("input_tensor_k").noconvert(), py::arg("input_tensor_v").noconvert(), - py::arg("cur_pos").noconvert() = std::vector(), py::kw_only(), + py::arg("is_causal").noconvert() = true, + py::arg("attn_mask").noconvert() = std::nullopt, + py::arg("cur_pos").noconvert() = std::vector(), py::arg("cur_pos_tensor").noconvert() = std::nullopt, py::arg("scale").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt, @@ -100,8 +107,10 @@ void py_bind_sdpa_decode(py::module &module) { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal, + const std::optional attn_mask, + const std::optional &cur_pos_tensor, std::optional scale, const std::optional &memory_config, std::optional program_config, @@ -112,8 +121,10 @@ void py_bind_sdpa_decode(py::module &module) { input_tensor_q, input_tensor_k, input_tensor_v, - cur_pos_tensor, page_table_tensor, + is_causal, + attn_mask, + cur_pos_tensor, scale, memory_config, program_config, @@ -122,9 +133,11 @@ void py_bind_sdpa_decode(py::module &module) { py::arg("input_tensor_q").noconvert(), py::arg("input_tensor_k").noconvert(), py::arg("input_tensor_v").noconvert(), - py::arg("cur_pos_tensor").noconvert(), py::arg("page_table_tensor").noconvert(), py::kw_only(), + py::arg("is_causal").noconvert() = true, + py::arg("attn_mask").noconvert() = std::nullopt, + py::arg("cur_pos_tensor").noconvert() = std::nullopt, py::arg("scale").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt, py::arg("program_config").noconvert() = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/device/sdpa_decode_gqa_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/device/sdpa_decode_gqa_op.cpp deleted file mode 100644 index 18bed2eeace..00000000000 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/device/sdpa_decode_gqa_op.cpp +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "sdpa_decode_gqa_op.hpp" - -#include "ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp" -#include "ttnn/run_operation.hpp" - -namespace ttnn::operations::transformer { - -void ScaledDotProductAttentionGQADecode::validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { - TT_FATAL(input_tensors.size() == 3, "Must have 3 input tensors and mask"); - for (auto& input_tensor : input_tensors) { - TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device!"); - TT_FATAL(input_tensor.buffer() != nullptr, "Operands to SDPA need to be allocated in buffers on device!"); - TT_FATAL((input_tensor.get_layout() == Layout::TILE), "Inputs to SDPA must be tilized"); - } - - const auto q_shape = input_tensors.at(0).get_legacy_shape(); - const auto k_shape = input_tensors.at(1).get_legacy_shape(); - const auto v_shape = input_tensors.at(2).get_legacy_shape(); - - if (optional_input_tensors.at(0).has_value()){ - const auto& cur_pos_tensor = optional_input_tensors.at(0).value(); - - TT_FATAL(cur_pos_tensor.get_dtype() == DataType::INT32, "Error"); - TT_FATAL(cur_pos_tensor.get_layout() == Layout::ROW_MAJOR, "Error"); - } - - // All other inputs must be in DRAM. - for (std::size_t i = 0; i < input_tensors.size(); i++) { - TT_FATAL(input_tensors.at(i).buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM, "Error"); - } - - // Check dtype - for (std::size_t i = 1; i < input_tensors.size(); i++) { - TT_FATAL(input_tensors.at(i).get_dtype() == DataType::BFLOAT8_B, "Error"); - } - TT_FATAL(input_tensors.at(0).get_dtype() == DataType::BFLOAT16, "Error"); - - // Check sequence lengths - TT_FATAL(k_shape[-2] == v_shape[-2], "Error"); - - // Check hidden size - const auto D = q_shape[-1]; - TT_FATAL(k_shape[-1] == D, "Error"); - TT_FATAL(v_shape[-1] == D, "Error"); - - // Check num_heads - TT_FATAL(k_shape[1] == v_shape[1], "Error"); - TT_FATAL(q_shape[1] % k_shape[1] == 0, "Error"); - TT_FATAL(q_shape[1] <= 32, "Error"); - - // Check batch size - TT_FATAL(k_shape[0] == v_shape[0], "Error"); - - // Check valid seqlen - for (int i = 0; i < this->cur_pos.size(); i++) { - TT_FATAL(this->cur_pos[i] < k_shape[-2], "cur_pos must be <= K sequence dim"); - } - - // Check compute kernel config - std::visit( - [&](auto&& compute_kernel_config) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - TT_FATAL( - compute_kernel_config.fp32_dest_acc_en == false, - "FP32 dest acc disabled due to nd pcc and unpacker hang issue."); - } - }, - this->compute_kernel_config); -} - -std::vector ScaledDotProductAttentionGQADecode::compute_output_shapes( - const std::vector& input_tensors) const { - auto tt_q_shape = input_tensors.at(0).get_legacy_shape(); - auto tt_k_shape = input_tensors.at(1).get_legacy_shape(); - uint32_t n_groups = tt_q_shape[2] / tt_k_shape[1]; - return {input_tensors.at(0).get_legacy_shape()}; -} - -std::vector ScaledDotProductAttentionGQADecode::create_output_tensors( - const std::vector& input_tensors) const { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); -} - -operation::ProgramWithCallbacks ScaledDotProductAttentionGQADecode::create_program( - const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector& output_tensors) const { - auto& input_tensor_q = input_tensors.at(0); - auto& input_tensor_k = input_tensors.at(1); - auto& input_tensor_v = input_tensors.at(2); - - auto& cur_pos_tensor = optional_input_tensors.at(0); - - auto& output_tensor = output_tensors.at(0); - - auto scale = this->scale; - if (not scale.has_value()) { - scale = 1.0f / std::sqrt(static_cast(input_tensor_q.get_legacy_shape()[-1])); - } - - // TODO: get this from program_config - // std::size_t q_chunk_size; - // std::size_t k_chunk_size; - - // std::visit( - // [&](const auto& program_config) { - // using ProgramConfigType = std::decay_t; - // if constexpr (std::is_same_v< - // ProgramConfigType, - // tt::operations::primary::transformers::SDPAMultiCoreProgramConfig>) { - // q_chunk_size = program_config.q_chunk_size; - // k_chunk_size = program_config.k_chunk_size; - // } else { - // q_chunk_size = k_chunk_size = 32; - // } - // }, - // this->program_config); - - return detail::sdpa_decode_multi_core( - input_tensor_q, - input_tensor_k, - input_tensor_v, - cur_pos_tensor, - std::nullopt, - output_tensor, - this->cur_pos, - scale, - this->compute_kernel_config, - this->program_config, - this->k_chunk_size, - this->share_cache); -} - -operation::Hash ScaledDotProductAttentionGQADecode::compute_program_hash( - const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { - return operation::hash_operation( - this->scale, - this->output_mem_config, - this->program_config, - this->compute_kernel_config, - this->k_chunk_size, - input_tensors, - optional_input_tensors); -} - -} // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/device/sdpa_decode_gqa_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/device/sdpa_decode_gqa_op.hpp deleted file mode 100644 index 0ecb6a0c052..00000000000 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/device/sdpa_decode_gqa_op.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include - -#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" -#include "ttnn/operation.hpp" -#include "ttnn/operations/transformer/sdpa_config.hpp" -#include "ttnn/tensor/tensor.hpp" - -namespace ttnn::operations::transformer { - -struct ScaledDotProductAttentionGQADecode { - std::vector cur_pos; - const std::optional share_cache; - const std::optional scale; - const MemoryConfig output_mem_config; - const std::optional program_config; - const DeviceComputeKernelConfig compute_kernel_config; - const uint32_t k_chunk_size; - - void validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - - std::vector compute_output_shapes(const std::vector& input_tensors) const; - - std::vector create_output_tensors(const std::vector& input_tensors) const; - - operation::ProgramWithCallbacks create_program( - const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector& output_tensors) const; - - operation::Hash compute_program_hash(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; -}; - -} // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.cpp deleted file mode 100644 index 490f4ba68dd..00000000000 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.cpp +++ /dev/null @@ -1,128 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 -#include "sdpa_decode_gqa.hpp" - -#include "ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp" -#include "ttnn/common/constants.hpp" -#include "ttnn/operations/core/core.hpp" -#include "ttnn/operations/data_movement/transpose/transpose.hpp" -#include "ttnn/run_operation.hpp" - -namespace { -uint32_t get_chunk_size(uint32_t s) { - if (s <= 128) { - return 32; - } - if (s <= 256) { - return 256; - } - return 512; -} -} // namespace - -namespace ttnn::operations::transformer { - -ttnn::Tensor ExecuteScaledDotProductAttentionGQADecode::invoke( - uint8_t queue_id, - const ttnn::Tensor &input_tensor_q, - const ttnn::Tensor &input_tensor_k, - const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, - const std::optional cur_pos_tensor, - std::optional transpose_q, - std::optional share_cache, - std::optional scale, - const std::optional &memory_config, - std::optional program_config, - std::optional compute_kernel_config) { - - // default transpose_q to true and share_cache to false - if (!transpose_q.has_value()) { - transpose_q = true; - } - if (!share_cache.has_value()) { - share_cache = false; - } - - auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch() - : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - - // Q (if transpose q): 1, heads, batch, dim -> 1, batch, heads, dim - auto input_tensor_q_gqa = input_tensor_q; - if (transpose_q.value()) { - // formatting input tensors - auto q_shape = input_tensor_q.get_shape(); - uint32_t Bq = transpose_q.value() ? q_shape[2] : q_shape[1]; - uint32_t NQH = transpose_q.value() ? q_shape[1] : q_shape[2]; - uint32_t D = q_shape[3]; - - input_tensor_q_gqa = - ttnn::to_layout(input_tensor_q, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr); - input_tensor_q_gqa = ttnn::transpose(input_tensor_q_gqa, 1, 2); - input_tensor_q_gqa = ttnn::reshape(input_tensor_q_gqa, ttnn::SimpleShape{std::array{1, Bq, NQH, D}}); - input_tensor_q_gqa = - ttnn::to_layout(input_tensor_q_gqa, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr); - } - - uint32_t k_chunk_size; - // since we can't get the max cur_pos value from the tensor, we default to 512 - if (cur_pos_tensor.has_value()) { - k_chunk_size = 512; - } else{ - uint32_t max_cur_pos = *std::max_element(cur_pos.begin(), cur_pos.end()); - k_chunk_size = get_chunk_size(max_cur_pos + 1); - } - - // get chunk size and then pass to sdpa decode as an attribute for prgm cache - auto kernel_config_val = init_device_compute_kernel_config( - input_tensor_q.device()->arch(), compute_kernel_config, MathFidelity::HiFi2, true, false, false); - - auto output_tensors = operation::run( - ScaledDotProductAttentionDecode{ - .cur_pos = cur_pos, - .scale = scale, - .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), - .program_config = program_config, - .compute_kernel_config = kernel_config_val, - .k_chunk_size = k_chunk_size, - .paged_attention = false, - .share_cache = share_cache}, - {input_tensor_q_gqa, input_tensor_k, input_tensor_v}, - {cur_pos_tensor, std::nullopt}, - {}, - queue_id); - - // formatting output tensor - auto output_tensor = output_tensors.at(0); - return output_tensor; -} - -ttnn::Tensor ExecuteScaledDotProductAttentionGQADecode::invoke( - const ttnn::Tensor &input_tensor_q, - const ttnn::Tensor &input_tensor_k, - const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, - const std::optional cur_pos_tensor, - std::optional transpose_q, - std::optional share_cache, - std::optional scale, - const std::optional &memory_config, - std::optional program_config, - std::optional compute_kernel_config) { - return invoke( - DefaultQueueId, - input_tensor_q, - input_tensor_k, - input_tensor_v, - cur_pos, - cur_pos_tensor, - transpose_q, - share_cache, - scale, - memory_config, - program_config, - compute_kernel_config); -} - -} // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.hpp deleted file mode 100644 index 7cdb1a486ce..00000000000 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.hpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" -#include "ttnn/decorators.hpp" -#include "ttnn/operations/transformer/sdpa_config.hpp" - -namespace ttnn { -namespace operations::transformer { - -struct ExecuteScaledDotProductAttentionGQADecode { - static ttnn::Tensor invoke( - uint8_t queue_id, - const ttnn::Tensor &input_tensor_q, - const ttnn::Tensor &input_tensor_k, - const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, - const std::optional cur_pos_tensor = std::nullopt, - std::optional transpose_q = true, - std::optional share_cache = false, - std::optional scale = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional program_config = std::nullopt, - std::optional compute_kernel_config = std::nullopt); - - static ttnn::Tensor invoke( - const ttnn::Tensor &input_tensor_q, - const ttnn::Tensor &input_tensor_k, - const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, - const std::optional cur_pos_tensor= std::nullopt, - std::optional transpose_q = true, - std::optional share_cache = false, - std::optional scale = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional program_config = std::nullopt, - std::optional compute_kernel_config = std::nullopt); -}; - -} // namespace operations::transformer - -namespace transformer { - -constexpr auto scaled_dot_product_attention_decode_gqa = ttnn::register_operation_with_auto_launch_op< - "ttnn::transformer::scaled_dot_product_attention_decode_gqa", - ttnn::operations::transformer::ExecuteScaledDotProductAttentionGQADecode>(); - -} // namespace transformer - -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa_pybind.cpp deleted file mode 100644 index 92624e51e85..00000000000 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa_pybind.cpp +++ /dev/null @@ -1,101 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "sdpa_decode_gqa_pybind.hpp" - -#include -#include - -#include "sdpa_decode_gqa.hpp" -#include "ttnn/cpp/pybind11/decorators.hpp" - -namespace ttnn::operations::transformer { - -void py_bind_sdpa_gqa_decode(py::module &module) { - auto doc = - R"doc( - A version of scaled dot product attention specifically for GQA decode. - - - Accepts a `SDPAMultiCoreProgramConfig` which specifies the grid size and chunk tiles in the K/V/Mask sequence lengths (Q chunk tiles is not used). The op parallelizes over `b` and K/V/Mask's `s` dimension. - - - Args: - input_tensor_q (ttnn.Tensor): the input tensor [1 x qh x b x dh] - input_tensor_k (ttnn.Tensor): the input tensor [b x kh x s x dh] - input_tensor_v (ttnn.Tensor): the input tensor [b x kh x s x dh] - cur_pos (List of int): list of integers of length b. - - - - Keyword args: - memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. - queue_id (int, optional): command queue id. Defaults to `0`. - scale (float, optional): Defaults to `None`. - program_config (SDPAProgramConfig, optional): Defaults to `None`. - compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional): Defaults to `None`. - - - Returns: - ttnn.Tensor: the output tensor [1 x b x qh x dh]. - - "Q: [1 x qh x b x dh] or [1 x b x qh x dh]" - "K: [b x kh x s x dh] or [1 x kh x s x dh]" - "V: [b x kh x s x dh] or [1 x kh x s x dh]" - "cur_pos: list of integers of length b" - "transpose_q: bool default true. If true, expects Q in [1 x qh x b x dh] format" - "share_cache: bool default false. If true, shares cache across all batch, so K and V are [1 x kh x s x dh]" - "output: [1 x b x qh x dh]" - - )doc"; - - using OperationType = decltype(ttnn::transformer::scaled_dot_product_attention_decode_gqa); - ttnn::bind_registered_operation( - module, - ttnn::transformer::scaled_dot_product_attention_decode_gqa, - doc, - ttnn::pybind_overload_t{ - [](const OperationType &self, - const ttnn::Tensor &input_tensor_q, - const ttnn::Tensor &input_tensor_k, - const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, - const std::optional cur_pos_tensor, - std::optional transpose_q, - std::optional share_cache, - std::optional scale, - const std::optional &memory_config, - std::optional program_config, - std::optional compute_kernel_config, - uint8_t queue_id) { - return self( - queue_id, - input_tensor_q, - input_tensor_k, - input_tensor_v, - cur_pos, - cur_pos_tensor, - transpose_q, - share_cache, - scale, - memory_config, - program_config, - compute_kernel_config); - }, - py::arg("input_tensor_q").noconvert(), - py::arg("input_tensor_k").noconvert(), - py::arg("input_tensor_v").noconvert(), - py::arg("cur_pos").noconvert() = std::vector(), - py::kw_only(), - py::arg("cur_pos_tensor").noconvert() = std::nullopt, - py::arg("transpose_q") = true, - py::arg("share_cache") = false, - py::arg("scale").noconvert() = std::nullopt, - py::arg("memory_config").noconvert() = std::nullopt, - py::arg("program_config").noconvert() = std::nullopt, - py::arg("compute_kernel_config").noconvert() = std::nullopt, - py::arg("queue_id") = 0, - }); -} -} // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa_pybind.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa_pybind.hpp deleted file mode 100644 index 8b0e896ad6e..00000000000 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa_pybind.hpp +++ /dev/null @@ -1,12 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "pybind11/pybind_fwd.hpp" - -namespace ttnn::operations::transformer { - -void py_bind_sdpa_gqa_decode(pybind11::module &module); -} // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp index da34a5182f8..a1c1129cea6 100644 --- a/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp @@ -12,7 +12,6 @@ #include "sdpa/sdpa_pybind.hpp" #include "sdpa_config.hpp" #include "sdpa_decode/sdpa_decode_pybind.hpp" -#include "sdpa_decode_gqa/sdpa_decode_gqa_pybind.hpp" #include "split_query_key_value_and_split_heads/split_query_key_value_and_split_heads_pybind.hpp" namespace ttnn::operations::transformer { @@ -39,7 +38,6 @@ void py_module(py::module& module) { py_bind_sdpa(module); py_bind_sdpa_decode(module); - py_bind_sdpa_gqa_decode(module); } } // namespace ttnn::operations::transformer diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 45e024ac82f..2c820effa02 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -64,18 +64,13 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor: if len(slices) > input_rank: raise RuntimeError(f"Too many slices for tensor of rank {input_rank}") - if input_rank <= 4: - slice_start = [_slice.start if _slice.start is not None else 0 for _slice in slices] - slice_end = [ - _slice.stop if _slice.stop is not None else input_tensor.shape[i] for i, _slice in enumerate(slices) - ] - slice_step = [_slice.step if _slice.step is not None else 1 for _slice in slices] + slice_start = [_slice.start if _slice.start is not None else 0 for _slice in slices] + slice_end = [_slice.stop if _slice.stop is not None else input_tensor.shape[i] for i, _slice in enumerate(slices)] + slice_step = [_slice.step if _slice.step is not None else 1 for _slice in slices] - output = ttnn.slice(input_tensor, slice_start, slice_end, slice_step) + output = ttnn.slice(input_tensor, slice_start, slice_end, slice_step) - return output - - raise NotImplementedError + return output def _preprocess_shape(input_shape, shape): diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index b7d3aa79a53..aa3b2e048d7 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -289,7 +289,7 @@ def _golden_function_polygamma(input_tensor_a, k, *args, **kwargs): def _golden_function_clamp(input_tensor_a, min=None, max=None, *args, **kwargs): import torch - return torch.clamp(input=input_tensor_a, min=min, max=max) + return torch.clamp(input_tensor_a, min, max) ttnn.attach_golden_function(ttnn.clamp, golden_function=_golden_function_clamp) @@ -298,7 +298,7 @@ def _golden_function_clamp(input_tensor_a, min=None, max=None, *args, **kwargs): def _golden_function_clip(input_tensor_a, min=None, max=None, *args, **kwargs): import torch - return torch.clip(input=input_tensor_a, min=min, max=max) + return torch.clip(input_tensor_a, min, max) ttnn.attach_golden_function(ttnn.clip, golden_function=_golden_function_clip) diff --git a/ttnn/ttnn/operations/unary_backward.py b/ttnn/ttnn/operations/unary_backward.py index f0f05460952..d09da6a4bae 100644 --- a/ttnn/ttnn/operations/unary_backward.py +++ b/ttnn/ttnn/operations/unary_backward.py @@ -33,27 +33,8 @@ def _golden_function_div_no_nan(torch_op, grad_tensor, input_tensor, alpha, *arg return golden_tensor -def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, alpha=None, *args, **kwargs): - if torch_op == "leaky_relu": - if alpha != None: - pyt_y = torch.nn.functional.leaky_relu(input_tensor, negative_slope=alpha) - else: - pyt_y = torch.nn.functional.leaky_relu(input_tensor) - elif torch_op == "elu": - if alpha != None: - pyt_y = torch.nn.functional.elu(input_tensor, alpha=alpha) - else: - pyt_y = torch.nn.functional.elu(input_tensor) - elif torch_op == "celu": - if alpha != None: - pyt_y = torch.nn.functional.celu(input_tensor, alpha) - else: - pyt_y = torch.nn.functional.celu(input_tensor) - else: - if alpha != None: - pyt_y = torch_op(input_tensor, alpha) - else: - pyt_y = torch_op(input_tensor) +def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, *args, **kwargs): + pyt_y = torch_op(input_tensor, *args, **kwargs) input_tensor.retain_grad() pyt_y.backward(gradient=grad_tensor) golden_tensor = [input_tensor.grad] @@ -165,36 +146,36 @@ def _golden_function_backward_with_reverse_string( ttnn.attach_golden_function( ttnn.hardshrink_bw, - golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( - torch.hardshrink, grad, input, alpha, *args, **kwargs + golden_function=lambda grad, input, alpha=0.5, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.hardshrink, grad, input, lambd=alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.softshrink_bw, - golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( - torch.softshrink, grad, input, alpha, *args, **kwargs + golden_function=lambda grad, input, alpha=0.5, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.nn.functional.softshrink, grad, input, lambd=alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.leaky_relu_bw, - golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( - "leaky_relu", grad, input, alpha, *args, **kwargs + golden_function=lambda grad, input, alpha=1e-2, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.nn.functional.leaky_relu, grad, input, negative_slope=alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.elu_bw, - golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( - "elu", grad, input, alpha, *args, **kwargs + golden_function=lambda grad, input, alpha=1.0, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.nn.functional.elu, grad, input, alpha=alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.celu_bw, - golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( - "celu", grad, input, alpha, *args, **kwargs + golden_function=lambda grad, input, alpha=1.0, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.nn.functional.celu, grad, input, alpha=alpha, *args, **kwargs ), ) @@ -208,7 +189,7 @@ def _golden_function_backward_with_reverse_string( ttnn.attach_golden_function( ttnn.logiteps_bw, golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( - torch.logit, grad, input, alpha, *args, **kwargs + torch.logit, grad, input, eps=alpha, *args, **kwargs ), )