diff --git a/ttnn/ttnn/device.py b/ttnn/ttnn/device.py index 5ce73f487049..1de8adf1380b 100644 --- a/ttnn/ttnn/device.py +++ b/ttnn/ttnn/device.py @@ -157,7 +157,7 @@ def GetComputeKernelConfig( dst_full_sync_en=False, ): device = ttnn.GetDefaultDevice() - if is_wormhole_b0(device) or is_blackhole(device): + if is_wormhole_b0(device): return ttnn.WormholeComputeKernelConfig( math_fidelity=math_fidelity, fp32_dest_acc_en=fp32_dest_acc_en, @@ -165,6 +165,14 @@ def GetComputeKernelConfig( math_approx_mode=math_approx_mode, dst_full_sync_en=dst_full_sync_en, ) + elif is_blackhole(device): + return ttnn.BlackholeComputeKernelConfig( + math_fidelity=math_fidelity, + fp32_dest_acc_en=fp32_dest_acc_en, + packer_l1_acc=packer_l1_acc, + math_approx_mode=math_approx_mode, + dst_full_sync_en=dst_full_sync_en, + ) else: return ttnn.GrayskullComputeKernelConfig( math_fidelity=math_fidelity, math_approx_mode=math_approx_mode, dst_full_sync_en=dst_full_sync_en