Skip to content

Commit

Permalink
#0: Fix Squeezebert
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt committed Dec 9, 2024
1 parent 28c8e15 commit 0bb687c
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions models/demos/squeezebert/tt/ttnn_functional_squeezebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def ttnn_conv1d(
conv_config = ttnn.Conv1dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat8_b,
math_approx_mode_enabled=math_approx,
fp32_dest_acc_enabled=fp32_accum,
packer_l1_accum_enabled=packer_l1_acc,
activation=activation,
input_channels_alignment=(16 if use_shallow_conv_variant else 32),
deallocate_activation=deallocate_activation,
Expand All @@ -86,6 +83,13 @@ def ttnn_conv1d(
core_grid=get_shard_grid_from_num_cores(56, device),
math_fidelity=math_fidelity,
)
compute_config = ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=math_fidelity,
math_approx_mode=math_approx,
fp32_dest_acc_en=fp32_accum,
packer_l1_acc=packer_l1_acc,
)

[tt_output_tensor_on_device, out_length, weights_device, bias_device] = ttnn.Conv1d(
input_tensor=tt_input_tensor,
Expand All @@ -100,6 +104,7 @@ def ttnn_conv1d(
batch_size=tt_input_tensor.shape[0],
input_length=tt_input_tensor.shape[1],
conv_config=conv_config,
compute_config=compute_config,
conv_op_cache={},
debug=debug,
groups=groups,
Expand Down

0 comments on commit 0bb687c

Please sign in to comment.