Skip to content

Commit

Permalink
#0: add pybinds for softmax apis, change python tests to ttnn
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Jun 28, 2024
1 parent db25a35 commit 20f82a7
Show file tree
Hide file tree
Showing 59 changed files with 3,017 additions and 663 deletions.
6 changes: 1 addition & 5 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,6 @@ autofunction:: tt_lib.operations.primary.matmul

.. autofunction:: tt_lib.operations.primary.add_layernorm

.. autofunction:: tt_lib.operations.primary.softmax_in_place

.. autofunction:: tt_lib.operations.primary.moreh_softmax

.. autofunction:: tt_lib.operations.primary.moreh_softmax_backward
Expand All @@ -269,8 +267,6 @@ autofunction:: tt_lib.operations.primary.matmul

.. autofunction:: tt_lib.operations.primary.moreh_logsoftmax_backward

.. autofunction:: tt_lib.operations.primary.transformers.scale_mask_softmax_in_place

.. autofunction:: tt_lib.operations.primary.moreh_mean

.. autofunction:: tt_lib.operations.primary.moreh_mean_backward
Expand Down Expand Up @@ -414,7 +410,7 @@ Tensor elementwise operations
.. autofunction:: tt_lib.tensor.unary_remainder

.. autofunction:: tt_lib.tensor.remainder

.. autofunction:: tt_lib.tensor.unary_fmod

.. autofunction:: tt_lib.tensor.fmod
Expand Down
2 changes: 1 addition & 1 deletion models/demos/bert/tt/ttnn_optimized_sharded_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def update_model_config(config, batch_size):
block_w=4,
inplace=True,
),
"softmax_program_config": ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
"softmax_program_config": ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(core_grid.x, core_grid.y),
subblock_w=6,
block_h=24,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def test_falcon7b_attnention_sliced(
subblock_w = 1
if seq_len == 2048:
subblock_w = 8
softmax_program_config = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=grid_size,
subblock_w=subblock_w,
block_h=mm_output_height_shard_spec[0] // 32,
Expand Down Expand Up @@ -669,14 +669,14 @@ def test_falcon7b_attention_softmax_sequence(
subblock_w = 1
if seq_len == 2048:
subblock_w = 8
softmax_program_config = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=grid_size,
subblock_w=subblock_w,
block_h=mm_output_height_shard_spec[0] // 32,
block_w=mm_output_height_shard_spec[1] // 32,
)

mm_slice = ttnn.experimental.operations.primary.transformers.scale_causal_mask_hw_dims_softmax_in_place(
mm_slice = ttnn.scale_causal_mask_hw_dims_softmax_in_place(
mm_slice,
scalar_value,
attention_masks_per_slice[i],
Expand Down Expand Up @@ -725,11 +725,11 @@ def test_falcon7b_attention_softmax_sequence(
reference_query_layer, reference_key_layer_transposed, memory_config=dram_interleaved_memory_config
)

attn_weights = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
attn_weights = ttnn.scale_mask_softmax_in_place(
attn_weights,
scalar_value,
attention_mask_proper_dim,
program_config=ttnn.experimental.operations.primary.transformers.SoftmaxDefaultProgramConfig(),
program_config=ttnn.SoftmaxDefaultProgramConfig(),
is_causal_mask=True,
compute_kernel_config=compute_kernel_config,
)
Expand Down Expand Up @@ -876,14 +876,14 @@ def test_softmax(device, num_cores, seq_len):
ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
)

softmax_program_config = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=grid_size,
subblock_w=1,
block_h=height_shard_spec[0] // 32,
block_w=height_shard_spec[1] // 32,
)

input_slice = ttnn.experimental.operations.primary.transformers.scale_causal_mask_hw_dims_softmax_in_place(
input_slice = ttnn.scale_causal_mask_hw_dims_softmax_in_place(
input_slice,
scalar_value,
tt_attention_masks_per_slice[i],
Expand Down
14 changes: 5 additions & 9 deletions models/demos/falcon7b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,7 @@ def forward(
### SOFTMAX ###
###############
for i in range(self.num_devices):
attn_weights[i] = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
attn_weights[i]
)
attn_weights[i] = ttnn.scale_mask_softmax_in_place(attn_weights[i])

######################
### V CACHE UPDATE ###
Expand Down Expand Up @@ -499,7 +497,7 @@ def _optimized_forward(
)
### SOFTMAX ###
mm_slices = [
ttnn.experimental.operations.primary.transformers.scale_causal_mask_hw_dims_softmax_in_place(
ttnn.scale_causal_mask_hw_dims_softmax_in_place(
mm_slices[device_id],
self.scalar_for_optimized_prefill,
attention_mask[device_id][i],
Expand Down Expand Up @@ -842,19 +840,17 @@ def forward(
### SOFTMAX ###
###############
for i in range(self.num_devices):
attn_weights[i] = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
attn_weights[i]
)
attn_weights[i] = ttnn.scale_mask_softmax_in_place(attn_weights[i])
else:
###############
### SOFTMAX ###
###############
for i in range(self.num_devices):
attn_weights[i] = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
attn_weights[i] = ttnn.scale_mask_softmax_in_place(
attn_weights[i],
scale=self.scale,
mask=attention_mask[i],
program_config=ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
program_config=ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(8, 4),
subblock_w=1,
block_h=self.padded_local_heads // 32,
Expand Down
2 changes: 1 addition & 1 deletion models/demos/falcon7b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def set_prefill_config(model_config, seq_len, dram_memcfg):

model_config[
"SOFTMAX_OPTIMIZED_PROGCFG"
] = lambda grid_size, subblock_w, block_h, block_w: ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
] = lambda grid_size, subblock_w, block_h, block_w: ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=grid_size,
subblock_w=subblock_w,
block_h=block_h,
Expand Down
6 changes: 2 additions & 4 deletions models/demos/metal_BERT_large_11/tt/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def op3_bmm(Q_heads, K_T_heads):
)
return qkt

softmax_program_config = model_config.get(
"OP4_SOFTMAX_CONFIG", tt_lib.operations.primary.transformers.SoftmaxDefaultProgramConfig()
)
softmax_program_config = model_config.get("OP4_SOFTMAX_CONFIG", ttnn.SoftmaxDefaultProgramConfig())

def op4_scale_mask_softmax(qkt, attention_mask):
# Attention scores computation
Expand All @@ -95,7 +93,7 @@ def op4_scale_mask_softmax(qkt, attention_mask):
# No-op reshapes are handled within pre-softmax (op 7) and post-softmax bmms (op 9)
shape = qkt.get_legacy_shape()
qkt = qkt.reshape(shape[0], 1, shape[1] * shape[2], shape[3])
attention_scores = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place(
attention_scores = ttnn.scale_mask_softmax_in_place(
qkt, freciprocal_of_sqrt_hidden_dim, attention_mask, program_config=softmax_program_config
)
attention_scores = attention_scores.reshape(shape)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/metal_BERT_large_11/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def get_model_config(batch, device_grid_size, model_config_str):
transpose_mcast=False,
fused_activation=None,
),
"OP4_SOFTMAX_CONFIG": tt_lib.operations.primary.transformers.SoftmaxDefaultProgramConfig(),
"OP4_SOFTMAX_CONFIG": ttnn.SoftmaxDefaultProgramConfig(),
"OP8_LAYERNORM_CONFIG": tt_lib.operations.primary.LayerNormDefaultProgramConfig(),
"OP11_LAYERNORM_CONFIG": tt_lib.operations.primary.LayerNormDefaultProgramConfig(),
}
Expand Down Expand Up @@ -397,7 +397,7 @@ def get_model_config(batch, device_grid_size, model_config_str):
block_w=4,
inplace=True,
),
"OP4_SOFTMAX_CONFIG": tt_lib.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
"OP4_SOFTMAX_CONFIG": ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=grid_size,
subblock_w=6,
block_h=24,
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/falcon40b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def fwd_decode(
softmax_progcfg = self.model_config["SOFTMAX_PROGCFG"]
softmax_progcfg.block_w = padded_layer_past_len // 32

attn_weights = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
attn_weights = ttnn.scale_mask_softmax_in_place(
attn_weights,
self.scalar,
attention_mask,
Expand Down
8 changes: 2 additions & 6 deletions models/demos/t3000/falcon40b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,7 @@ def get_decode_model_config(model_config_str, input_shape, num_devices):
)
model_config["K_TRANSPOSED_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG
model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG
model_config[
"SOFTMAX_PROGCFG"
] = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(8, 2),
subblock_w=1,
block_h=row_height // 32,
Expand Down Expand Up @@ -835,9 +833,7 @@ def get_prefill_model_config(model_config_str, input_shape, num_devices):
fused_activation=None,
mcast_in0=False,
)
model_config[
"SOFTMAX_PROGCFG"
] = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=attention_mm_grid_size,
subblock_w=1,
block_h=attetnion_mm_M,
Expand Down
8 changes: 4 additions & 4 deletions models/demos/t3000/falcon40b/tt/ops/falcon_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,24 @@ def __call__(

if self.is_sharded:
# Subtract max value from activation before softmax
out = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
out = ttnn.scale_mask_softmax_in_place(
x,
self.scalar,
attention_mask,
program_config=softmax_progcfg,
# output_mem_config=self.model_config["DEFAULT_MEMCFG"],
# program_config=ttnn.experimental.operations.primary.transformers.SoftmaxDefaultProgramConfig(),
# program_config=ttnn.SoftmaxDefaultProgramConfig(),
is_causal_mask=True,
)
else:
# Subtract max value from activation before softmax
out = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
out = ttnn.scale_mask_softmax_in_place(
x,
self.scalar,
attention_mask,
# program_config=softmax_progcfg,
# output_mem_config=self.model_config["DEFAULT_MEMCFG"],
program_config=ttnn.experimental.operations.primary.transformers.SoftmaxDefaultProgramConfig(),
program_config=ttnn.SoftmaxDefaultProgramConfig(),
is_causal_mask=True,
)

Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/llama2_70b/scripts/model_config_n150.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,14 @@ def get_model_config(model_config_str, num_devices=1, all_gather=True):
model_config["K_TRANSPOSED_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG
model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG
if num_devices == 4:
model_config["SOFTMAX_PROGCFG"] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(8, 2),
subblock_w=1,
block_h=1,
block_w=1, # Dynamic
)
elif num_devices == 8:
model_config["SOFTMAX_PROGCFG"] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(8, 1),
subblock_w=1,
block_h=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask):
# attn_mask = torch2tt_tensor(attn_mask, self.device)

attn = ttnn.add(attn, attn_mask)
attn = tt_lib.tensor.softmax(attn)
attn = ttnn.softmax(attn)
return attn

def forward(self, xq, keys, values, attn_mask):
Expand All @@ -95,7 +95,7 @@ def forward(self, xq, keys, values, attn_mask):
# TODO: This op expects attn_mask to be sharded such that each core has 1 head
# This is illegal on single chip since we need 8x8 coregrid to shard
# 64 heads on. Until we fracture on multi-chip, we can't use this op.
# attn = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place(
# attn = ttnn.scale_mask_softmax_in_place(
# attn,
# 1 / math.sqrt(self.head_dim),
# attn_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def attn_mqa(
softmax_progcfg = self.model_config["BATCHED_SOFTMAX_PROGCFG"]
softmax_progcfg.block_w = padded_layer_past_len // 32

attn_weights = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place(
attn_weights = ttnn.scale_mask_softmax_in_place(
attn_weights,
self.scale,
attn_masks,
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/llama2_70b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def get_model_config(
if llm_mode == "decode":
model_config[
"BATCHED_SOFTMAX_PROGCFG"
] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
] = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(8, 4), # In-place softmax on 32 cores sharded on batch dim
subblock_w=1,
block_h=shard_height // 32,
Expand All @@ -551,7 +551,7 @@ def get_model_config(
else:
model_config[
"BATCHED_SOFTMAX_PROGCFG"
] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
] = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(8, 4 if seq_len == 128 else 8),
subblock_w=1,
block_h=32 // 32, # 128 * 8 // 32 cores // TILE_SIZE
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def forward(

# Softmax and scaling

attn_1B4P = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
attn_1B4P = ttnn.scale_mask_softmax_in_place(
attn_1B4P,
self.scale,
attn_mask_1B4P,
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/mixtral8x7b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self, device=None, instruct=False, dummy_weights=False):
)

self.model_config["ATTN_BATCHED_SOFTMAX_PROGCFG"] = cached_lambda(
lambda padded_layer_past_len: ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
lambda padded_layer_past_len: ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(8, 4), # In-place softmax on 32 cores sharded on batch dim
subblock_w=1,
block_h=1, # Shard_height // 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,7 @@ def __init__(self, device, parameters, seq_len):
mcast_in0=False,
)

self.program_configs[
"tsa_softmax"
] = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
self.program_configs["tsa_softmax"] = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=self.tsa_grid_size,
subblock_w=1,
block_h=mm_output_height_shard_spec[0] // 32,
Expand Down Expand Up @@ -393,7 +391,7 @@ def time_sharded_attention(self, query, t_key, value, head_size, attn_type):

use_mask = False
if use_mask:
mm_slice = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
mm_slice = ttnn.scale_mask_softmax_in_place(
mm_slice,
1 / math.sqrt(head_size),
attention_mask,
Expand Down Expand Up @@ -523,15 +521,15 @@ def sharded_attention(self, query, key, value, head_size, attn_type):
output_mem_config,
)
# attention_scores = ttnn.experimental.tensor.move_sharded(attention_scores)
softmax_program_config = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=compute_with_storage_grid_size,
subblock_w=1,
block_h=height_per_core // 32,
block_w=key_len // 32,
)
use_mask = attn_type == "cross"
if use_mask:
attention_scores = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place(
attention_scores = ttnn.scale_mask_softmax_in_place(
attention_scores,
1 / math.sqrt(head_size),
attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def run_softmax_tests(dev, test_id, batch, dtype, in0_mem_config):
logger.info("Running scale_mask_softmax")
torch_scale, tt_scale = generate_recip_tensor(dev, 0.5 + random.random())
torch_attn_mask, tt_attn_mask = generate_attn_mask(N, C, W, dev, -4.2 * 1, dtype, in0_mem_config)
t1_fused = ttl.operations.primary.transformers.scale_mask_softmax_in_place(t0, tt_scale, tt_attn_mask)
t1_fused = ttnn.scale_mask_softmax_in_place(t0, tt_scale, tt_attn_mask)
ref_sm = ref_scale_mask_softmax(torch_scale, torch_attn_mask, x)
elif test_id == 1:
logger.info("Running softmax")
t1_fused = ttl.operations.primary.softmax_in_place(t0)
t1_fused = ttnn.softmax_in_place(t0)
ref_sm = ref_stable_softmax(x)
else:
assert False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def update_model_config(config, batch_size):
# out_data_format=ttnn.bfloat8_b,
inplace=False,
),
"softmax_program_config": ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
"softmax_program_config": ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(core_grid.x, core_grid.y),
subblock_w=7,
block_h=7,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def update_model_config(config, batch_size):
# out_data_format=ttnn.bfloat8_b,
inplace=False,
),
"softmax_program_config": ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
"softmax_program_config": ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(core_grid.x, core_grid.y),
subblock_w=7,
block_h=7,
Expand Down
Loading

0 comments on commit 20f82a7

Please sign in to comment.