Skip to content

Commit

Permalink
#4999: Adding validation checks
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jan 31, 2024
1 parent 27f13bc commit 8d1e89a
Show file tree
Hide file tree
Showing 19 changed files with 616 additions and 226 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,25 @@ def dropout(hidden_states, p, training):

def calculate_key_values(config, key_value_states, *, parameters):
bsz, tgt_len, hidden_size = key_value_states.shape
bsz, tgt_len_padded, _ = key_value_states.shape.padded()
head_size = hidden_size // config.encoder_attention_heads

fused_qkv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias
fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT)
fused_qkv = ttnn.to_layout(fused_qkv, ttnn.ROW_MAJOR_LAYOUT)
fused_qkv = ttnn.reshape(fused_qkv, shape=(bsz, tgt_len, 2, config.encoder_attention_heads, head_size))
key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :]

key_states = ttnn.reshape(key_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size))
key_states = ttnn.permute(key_states, (0, 2, 1, 3))

value_states = ttnn.reshape(value_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size))
value_states = ttnn.permute(value_states, (0, 2, 1, 3))
key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT)
value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT)

desired_shape = ttnn.Shape(
[bsz, config.encoder_attention_heads, tgt_len, head_size],
[bsz, config.encoder_attention_heads, tgt_len_padded, head_size],
)
key_states = ttnn.reshape(key_states, shape=desired_shape)
value_states = ttnn.reshape(value_states, shape=desired_shape)

return key_states, value_states

Expand All @@ -42,23 +49,29 @@ def split_query_key_value_and_split_heads(
config, fused_qkv: ttnn.Tensor
) -> Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor]:
head_size = config.d_model // config.encoder_attention_heads
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
batch_size, *_, seq_length, three_times_hidden_size = fused_qkv.shape
batch_size, *_, padded_seq_length, three_times_hidden_size = fused_qkv.shape.padded()
hidden_size = three_times_hidden_size // 3
encoder_attention_heads = hidden_size // head_size

fused_qkv = ttnn.to_layout(fused_qkv, ttnn.ROW_MAJOR_LAYOUT)
fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT)
fused_qkv = ttnn.reshape(fused_qkv, shape=(batch_size, seq_length, 3, encoder_attention_heads, head_size))
query_states, key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :], fused_qkv[..., 2, :, :]

query_states = ttnn.reshape(query_states, shape=(batch_size, seq_length, encoder_attention_heads, head_size))
query_states = ttnn.permute(query_states, (0, 2, 1, 3))

key_states = ttnn.reshape(key_states, shape=(batch_size, seq_length, encoder_attention_heads, head_size))
key_states = ttnn.permute(key_states, (0, 2, 1, 3))

value_states = ttnn.reshape(value_states, shape=(batch_size, seq_length, encoder_attention_heads, head_size))
value_states = ttnn.permute(value_states, (0, 2, 1, 3))

query_states = ttnn.to_layout(query_states, ttnn.TILE_LAYOUT)
key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT)
value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT)

desired_shape = ttnn.Shape(
[batch_size, encoder_attention_heads, seq_length, head_size],
[batch_size, encoder_attention_heads, padded_seq_length, head_size],
)
query_states = ttnn.reshape(query_states, shape=desired_shape)
key_states = ttnn.reshape(key_states, shape=desired_shape)
value_states = ttnn.reshape(value_states, shape=desired_shape)
return query_states, key_states, value_states


Expand All @@ -70,26 +83,37 @@ def calculate_query_key_values(config, hidden_states, *, parameters):
def whisper_attention(config, hidden_states, attention_mask, key_value_states=None, *, parameters):
head_size = config.d_model // config.encoder_attention_heads
scaling = head_size**-0.5
bsz, tgt_len, _ = hidden_states.shape
bsz, *_, padded_tgt_len, _ = hidden_states.shape.padded()
bsz, *_, tgt_len, _ = hidden_states.shape

is_cross_attention = key_value_states is not None
if is_cross_attention:
query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias
query_states = ttnn.to_layout(query_states, layout=ttnn.ROW_MAJOR_LAYOUT)
query_states = ttnn.to_layout(query_states, ttnn.ROW_MAJOR_LAYOUT)
query_states = ttnn.reshape(query_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size))
query_states = ttnn.to_layout(query_states, ttnn.TILE_LAYOUT)
query_states = ttnn.permute(query_states, (0, 2, 1, 3))
key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters)
padded_key_value_tgt_len = key_states.shape.padded()[2]
key_value_tgt_len = key_states.shape[2]
else:
query_states, key_states, value_states = calculate_query_key_values(
config, hidden_states, parameters=parameters
)
padded_key_value_tgt_len = padded_tgt_len
key_value_tgt_len = tgt_len

query_states = ttnn.to_layout(query_states, layout=ttnn.TILE_LAYOUT)
query_states *= scaling
query_states = ttnn.to_layout(query_states, layout=ttnn.ROW_MAJOR_LAYOUT)

proj_shape = (bsz * config.encoder_attention_heads, -1, head_size)
proj_shape = ttnn.Shape(
[bsz * config.encoder_attention_heads, tgt_len, head_size],
[bsz * config.encoder_attention_heads, padded_tgt_len, head_size],
)
query_states = ttnn.reshape(query_states, shape=proj_shape)
proj_shape = ttnn.Shape(
[bsz * config.encoder_attention_heads, key_value_tgt_len, head_size],
[bsz * config.encoder_attention_heads, padded_key_value_tgt_len, head_size],
)
key_states = ttnn.reshape(key_states, shape=proj_shape)
value_states = ttnn.reshape(value_states, shape=proj_shape)

Expand All @@ -116,26 +140,29 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states=No
attn_output = ttnn.to_layout(attn_output, layout=ttnn.ROW_MAJOR_LAYOUT)
attn_output = ttnn.reshape(attn_output, shape=(bsz, config.encoder_attention_heads, tgt_len, head_size))
attn_output = ttnn.permute(attn_output, (0, 2, 1, 3))
attn_output = ttnn.reshape(attn_output, shape=(bsz, tgt_len, config.d_model))
attn_output = ttnn.to_layout(attn_output, layout=ttnn.TILE_LAYOUT)
attn_output = ttnn.to_layout(attn_output, ttnn.ROW_MAJOR_LAYOUT)
attn_output = ttnn.reshape(
attn_output,
shape=ttnn.Shape([bsz, tgt_len, config.d_model], [bsz, tgt_len, config.d_model]),
)
attn_output = ttnn.to_layout(attn_output, ttnn.TILE_LAYOUT)
attn_output = attn_output @ parameters.out_proj.weight + parameters.out_proj.bias
return attn_output


def encoder_layer(config, hidden_states, *, parameters):
residual = hidden_states

hidden_states = ttnn.layer_norm(
hidden_states,
weight=parameters.self_attn_layer_norm.weight,
bias=parameters.self_attn_layer_norm.bias,
)

hidden_states = whisper_attention(config, hidden_states, attention_mask=None, parameters=parameters.self_attn)
hidden_states = dropout(hidden_states, p=0, training=False)
hidden_states = residual + hidden_states

residual = hidden_states

hidden_states = ttnn.layer_norm(
hidden_states,
weight=parameters.final_layer_norm.weight,
Expand All @@ -156,7 +183,7 @@ def encoder_layer(config, hidden_states, *, parameters):


def encoder(config, inputs_embeds, *, parameters):
hidden_states = inputs_embeds + ttnn.to_layout(parameters.embed_positions.weight, layout=ttnn.TILE_LAYOUT)
hidden_states = inputs_embeds + parameters.embed_positions.weight
hidden_states = dropout(hidden_states, p=0, training=False)

for encoder_layer_parameter in parameters.layers:
Expand Down Expand Up @@ -358,13 +385,14 @@ def preprocess_inputs(

def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters):
encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder)
return decoder(
last_hidden_state = decoder(
config,
decoder_hidden_states,
decoder_attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
parameters=parameters.decoder,
)
return last_hidden_state


def custom_preprocessor(torch_model, name):
Expand Down Expand Up @@ -393,4 +421,8 @@ def custom_preprocessor(torch_model, name):

parameters["out_proj"]["weight"] = preprocess_linear_weight(torch_model.out_proj.weight, dtype=ttnn.bfloat16)
parameters["out_proj"]["bias"] = preprocess_linear_bias(torch_model.out_proj.bias, dtype=ttnn.bfloat16)
elif name == "encoder.embed_positions" and isinstance(torch_model, torch.nn.Embedding):
embeddings = ttnn.from_torch(torch_model.weight, dtype=ttnn.bfloat16)
embeddings = ttnn.to_layout(embeddings, ttnn.TILE_LAYOUT)
parameters["weight"] = embeddings
return parameters
27 changes: 7 additions & 20 deletions tests/ttnn/unit_tests/operations/test_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def test_group_norm(device, h, w, num_groups):
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.nn.functional.group_norm(torch_input_tensor, num_groups)

input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_device(input_tensor, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.group_norm(input_tensor, num_groups=num_groups)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
Expand All @@ -46,13 +45,9 @@ def test_group_norm_with_weight_and_bias(device, h, w, num_groups):
torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias
)

input_tensor = ttnn.from_torch(torch_input_tensor)
weight = ttnn.from_torch(torch_weight)
bias = ttnn.from_torch(torch_bias)

input_tensor = ttnn.to_device(input_tensor, device)
weight = ttnn.to_device(weight, device)
bias = ttnn.to_device(bias, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device)
bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.group_norm(input_tensor, num_groups=num_groups, weight=weight, bias=bias)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
Expand All @@ -79,17 +74,9 @@ def test_group_norm_with_tile_layout(device, h, w, num_groups):
torch_bias,
)

input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_layout(input_tensor, ttnn.TILE_LAYOUT)
input_tensor = ttnn.to_device(input_tensor, device)

weight = ttnn.from_torch(torch_weight)
weight = ttnn.to_layout(weight, ttnn.TILE_LAYOUT)
weight = ttnn.to_device(weight, device)

bias = ttnn.from_torch(torch_bias)
bias = ttnn.to_layout(bias, ttnn.TILE_LAYOUT)
bias = ttnn.to_device(bias, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device)
bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.group_norm(
input_tensor,
Expand Down
26 changes: 8 additions & 18 deletions tests/ttnn/unit_tests/operations/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def test_layer_norm(device, h, w):
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.nn.functional.layer_norm(torch_input_tensor, normalized_shape=[w])

input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_device(input_tensor, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.layer_norm(input_tensor)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
Expand All @@ -45,13 +44,9 @@ def test_layer_norm_with_weight_and_bias(device, h, w):
torch_input_tensor, normalized_shape=[w], weight=torch_weight, bias=torch_bias
)

input_tensor = ttnn.from_torch(torch_input_tensor)
weight = ttnn.from_torch(torch_weight)
bias = ttnn.from_torch(torch_bias)

input_tensor = ttnn.to_device(input_tensor, device)
weight = ttnn.to_device(weight, device)
bias = ttnn.to_device(bias, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device)
bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.layer_norm(input_tensor, weight=weight, bias=bias)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
Expand All @@ -75,15 +70,10 @@ def test_layer_norm_with_weight_bias_and_residual_input(device, h, w):
torch_input_tensor + torch_residual_input_tensor, normalized_shape=[w], weight=torch_weight, bias=torch_bias
)

input_tensor = ttnn.from_torch(torch_input_tensor)
residual_input_tensor = ttnn.from_torch(torch_residual_input_tensor)
weight = ttnn.from_torch(torch_weight)
bias = ttnn.from_torch(torch_bias)

input_tensor = ttnn.to_device(input_tensor, device)
residual_input_tensor = ttnn.to_device(residual_input_tensor, device)
weight = ttnn.to_device(weight, device)
bias = ttnn.to_device(bias, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
residual_input_tensor = ttnn.from_torch(torch_residual_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device)
bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.layer_norm(input_tensor, residual_input_tensor=residual_input_tensor, weight=weight, bias=bias)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
Expand Down
7 changes: 2 additions & 5 deletions tests/ttnn/unit_tests/operations/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("dim", [-1, -2])
@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_mean(device, batch_size, h, w, dim, input_layout):
def test_mean(device, batch_size, h, w, dim):
torch.manual_seed(0)

torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16)
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=True, dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_layout(input_tensor, input_layout)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

input_tensor = ttnn.to_device(input_tensor, device)
output_tensor = ttnn.mean(input_tensor, dim=dim, keepdim=True)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
Expand Down
15 changes: 4 additions & 11 deletions tests/ttnn/unit_tests/operations/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,13 @@
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("dim", [-1, -2, -3])
@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_softmax(device, batch_size, h, w, dim, input_layout):
if dim != -1 and input_layout != ttnn.TILE_LAYOUT:
pytest.skip("Not supported yet")

def test_softmax(device, batch_size, h, w, dim):
torch.manual_seed(0)

torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16)
torch_output_tensor = F.softmax(torch_input_tensor, dim=dim, dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_layout(input_tensor, input_layout)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

input_tensor = ttnn.to_device(input_tensor, device)
output_tensor = ttnn.softmax(input_tensor, dim=dim)
Expand All @@ -45,8 +40,7 @@ def test_softmax_with_3D(device):
torch.manual_seed(0)
torch_input_tensor = torch_random((8, 1500, 1500), -10, 10, dtype=torch.bfloat16)
torch_output_tensor = F.softmax(torch_input_tensor, dim=-1, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_device(input_tensor, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.softmax(input_tensor, dim=-1)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
Expand Down Expand Up @@ -80,8 +74,7 @@ def test_specific_tensor_combination(device):

torch_output_tensor = torch.softmax(torch_input_tensor, -1)

input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_device(input_tensor, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output = ttnn.softmax(input_tensor, -1)
output = ttnn.from_device(output)
Expand Down
3 changes: 1 addition & 2 deletions tests/ttnn/unit_tests/operations/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def test_split(device, h, w, split_size, dim):
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensors = torch.split(torch_input_tensor, split_size, dim=dim)

input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_device(input_tensor, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensors = ttnn.split(input_tensor, split_size=split_size, dim=dim)

for torch_output_tensor, output_tensor in zip(torch_output_tensors, output_tensors):
Expand Down
3 changes: 1 addition & 2 deletions tests/ttnn/unit_tests/operations/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def run_unary_test(device, h, w, ttnn_function, torch_function, pcc=0.9999):
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch_function(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_device(input_tensor, device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn_function(input_tensor)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
Expand Down
Loading

0 comments on commit 8d1e89a

Please sign in to comment.