diff --git a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py b/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py index 17a3d5de322..3c541204f4c 100644 --- a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py +++ b/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py @@ -22,18 +22,25 @@ def dropout(hidden_states, p, training): def calculate_key_values(config, key_value_states, *, parameters): bsz, tgt_len, hidden_size = key_value_states.shape + bsz, tgt_len_padded, _ = key_value_states.shape.padded() head_size = hidden_size // config.encoder_attention_heads fused_qkv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias - fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT) + fused_qkv = ttnn.to_layout(fused_qkv, ttnn.ROW_MAJOR_LAYOUT) fused_qkv = ttnn.reshape(fused_qkv, shape=(bsz, tgt_len, 2, config.encoder_attention_heads, head_size)) key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :] - key_states = ttnn.reshape(key_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size)) key_states = ttnn.permute(key_states, (0, 2, 1, 3)) - - value_states = ttnn.reshape(value_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size)) value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + + desired_shape = ttnn.Shape( + [bsz, config.encoder_attention_heads, tgt_len, head_size], + [bsz, config.encoder_attention_heads, tgt_len_padded, head_size], + ) + key_states = ttnn.reshape(key_states, shape=desired_shape) + value_states = ttnn.reshape(value_states, shape=desired_shape) return key_states, value_states @@ -42,23 +49,29 @@ def split_query_key_value_and_split_heads( config, fused_qkv: ttnn.Tensor ) -> Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor]: head_size = config.d_model // config.encoder_attention_heads - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + batch_size, *_, seq_length, three_times_hidden_size = fused_qkv.shape + batch_size, *_, padded_seq_length, three_times_hidden_size = fused_qkv.shape.padded() hidden_size = three_times_hidden_size // 3 encoder_attention_heads = hidden_size // head_size - fused_qkv = ttnn.to_layout(fused_qkv, ttnn.ROW_MAJOR_LAYOUT) + fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT) fused_qkv = ttnn.reshape(fused_qkv, shape=(batch_size, seq_length, 3, encoder_attention_heads, head_size)) query_states, key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :], fused_qkv[..., 2, :, :] - - query_states = ttnn.reshape(query_states, shape=(batch_size, seq_length, encoder_attention_heads, head_size)) query_states = ttnn.permute(query_states, (0, 2, 1, 3)) - - key_states = ttnn.reshape(key_states, shape=(batch_size, seq_length, encoder_attention_heads, head_size)) key_states = ttnn.permute(key_states, (0, 2, 1, 3)) - - value_states = ttnn.reshape(value_states, shape=(batch_size, seq_length, encoder_attention_heads, head_size)) value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + query_states = ttnn.to_layout(query_states, ttnn.TILE_LAYOUT) + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + + desired_shape = ttnn.Shape( + [batch_size, encoder_attention_heads, seq_length, head_size], + [batch_size, encoder_attention_heads, padded_seq_length, head_size], + ) + query_states = ttnn.reshape(query_states, shape=desired_shape) + key_states = ttnn.reshape(key_states, shape=desired_shape) + value_states = ttnn.reshape(value_states, shape=desired_shape) return query_states, key_states, value_states @@ -70,26 +83,37 @@ def calculate_query_key_values(config, hidden_states, *, parameters): def whisper_attention(config, hidden_states, attention_mask, key_value_states=None, *, parameters): head_size = config.d_model // config.encoder_attention_heads scaling = head_size**-0.5 - bsz, tgt_len, _ = hidden_states.shape + bsz, *_, padded_tgt_len, _ = hidden_states.shape.padded() + bsz, *_, tgt_len, _ = hidden_states.shape is_cross_attention = key_value_states is not None if is_cross_attention: query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias - query_states = ttnn.to_layout(query_states, layout=ttnn.ROW_MAJOR_LAYOUT) + query_states = ttnn.to_layout(query_states, ttnn.ROW_MAJOR_LAYOUT) query_states = ttnn.reshape(query_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size)) + query_states = ttnn.to_layout(query_states, ttnn.TILE_LAYOUT) query_states = ttnn.permute(query_states, (0, 2, 1, 3)) key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters) + padded_key_value_tgt_len = key_states.shape.padded()[2] + key_value_tgt_len = key_states.shape[2] else: query_states, key_states, value_states = calculate_query_key_values( config, hidden_states, parameters=parameters ) + padded_key_value_tgt_len = padded_tgt_len + key_value_tgt_len = tgt_len - query_states = ttnn.to_layout(query_states, layout=ttnn.TILE_LAYOUT) query_states *= scaling - query_states = ttnn.to_layout(query_states, layout=ttnn.ROW_MAJOR_LAYOUT) - proj_shape = (bsz * config.encoder_attention_heads, -1, head_size) + proj_shape = ttnn.Shape( + [bsz * config.encoder_attention_heads, tgt_len, head_size], + [bsz * config.encoder_attention_heads, padded_tgt_len, head_size], + ) query_states = ttnn.reshape(query_states, shape=proj_shape) + proj_shape = ttnn.Shape( + [bsz * config.encoder_attention_heads, key_value_tgt_len, head_size], + [bsz * config.encoder_attention_heads, padded_key_value_tgt_len, head_size], + ) key_states = ttnn.reshape(key_states, shape=proj_shape) value_states = ttnn.reshape(value_states, shape=proj_shape) @@ -116,26 +140,29 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states=No attn_output = ttnn.to_layout(attn_output, layout=ttnn.ROW_MAJOR_LAYOUT) attn_output = ttnn.reshape(attn_output, shape=(bsz, config.encoder_attention_heads, tgt_len, head_size)) attn_output = ttnn.permute(attn_output, (0, 2, 1, 3)) - attn_output = ttnn.reshape(attn_output, shape=(bsz, tgt_len, config.d_model)) - attn_output = ttnn.to_layout(attn_output, layout=ttnn.TILE_LAYOUT) + attn_output = ttnn.to_layout(attn_output, ttnn.ROW_MAJOR_LAYOUT) + attn_output = ttnn.reshape( + attn_output, + shape=ttnn.Shape([bsz, tgt_len, config.d_model], [bsz, tgt_len, config.d_model]), + ) + attn_output = ttnn.to_layout(attn_output, ttnn.TILE_LAYOUT) attn_output = attn_output @ parameters.out_proj.weight + parameters.out_proj.bias return attn_output def encoder_layer(config, hidden_states, *, parameters): residual = hidden_states - hidden_states = ttnn.layer_norm( hidden_states, weight=parameters.self_attn_layer_norm.weight, bias=parameters.self_attn_layer_norm.bias, ) + hidden_states = whisper_attention(config, hidden_states, attention_mask=None, parameters=parameters.self_attn) hidden_states = dropout(hidden_states, p=0, training=False) hidden_states = residual + hidden_states residual = hidden_states - hidden_states = ttnn.layer_norm( hidden_states, weight=parameters.final_layer_norm.weight, @@ -156,7 +183,7 @@ def encoder_layer(config, hidden_states, *, parameters): def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + ttnn.to_layout(parameters.embed_positions.weight, layout=ttnn.TILE_LAYOUT) + hidden_states = inputs_embeds + parameters.embed_positions.weight hidden_states = dropout(hidden_states, p=0, training=False) for encoder_layer_parameter in parameters.layers: @@ -358,13 +385,14 @@ def preprocess_inputs( def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) - return decoder( + last_hidden_state = decoder( config, decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, parameters=parameters.decoder, ) + return last_hidden_state def custom_preprocessor(torch_model, name): @@ -393,4 +421,8 @@ def custom_preprocessor(torch_model, name): parameters["out_proj"]["weight"] = preprocess_linear_weight(torch_model.out_proj.weight, dtype=ttnn.bfloat16) parameters["out_proj"]["bias"] = preprocess_linear_bias(torch_model.out_proj.bias, dtype=ttnn.bfloat16) + elif name == "encoder.embed_positions" and isinstance(torch_model, torch.nn.Embedding): + embeddings = ttnn.from_torch(torch_model.weight, dtype=ttnn.bfloat16) + embeddings = ttnn.to_layout(embeddings, ttnn.TILE_LAYOUT) + parameters["weight"] = embeddings return parameters diff --git a/tests/ttnn/unit_tests/operations/test_group_norm.py b/tests/ttnn/unit_tests/operations/test_group_norm.py index 2bb17bcd4e1..1796d2fb8c7 100644 --- a/tests/ttnn/unit_tests/operations/test_group_norm.py +++ b/tests/ttnn/unit_tests/operations/test_group_norm.py @@ -22,8 +22,7 @@ def test_group_norm(device, h, w, num_groups): torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) torch_output_tensor = torch.nn.functional.group_norm(torch_input_tensor, num_groups) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_device(input_tensor, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.group_norm(input_tensor, num_groups=num_groups) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.from_device(output_tensor) @@ -46,13 +45,9 @@ def test_group_norm_with_weight_and_bias(device, h, w, num_groups): torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias ) - input_tensor = ttnn.from_torch(torch_input_tensor) - weight = ttnn.from_torch(torch_weight) - bias = ttnn.from_torch(torch_bias) - - input_tensor = ttnn.to_device(input_tensor, device) - weight = ttnn.to_device(weight, device) - bias = ttnn.to_device(bias, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device) + bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.group_norm(input_tensor, num_groups=num_groups, weight=weight, bias=bias) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) @@ -79,17 +74,9 @@ def test_group_norm_with_tile_layout(device, h, w, num_groups): torch_bias, ) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_layout(input_tensor, ttnn.TILE_LAYOUT) - input_tensor = ttnn.to_device(input_tensor, device) - - weight = ttnn.from_torch(torch_weight) - weight = ttnn.to_layout(weight, ttnn.TILE_LAYOUT) - weight = ttnn.to_device(weight, device) - - bias = ttnn.from_torch(torch_bias) - bias = ttnn.to_layout(bias, ttnn.TILE_LAYOUT) - bias = ttnn.to_device(bias, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device) + bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.group_norm( input_tensor, diff --git a/tests/ttnn/unit_tests/operations/test_layer_norm.py b/tests/ttnn/unit_tests/operations/test_layer_norm.py index 02a15bc9b9f..1bfba14d130 100644 --- a/tests/ttnn/unit_tests/operations/test_layer_norm.py +++ b/tests/ttnn/unit_tests/operations/test_layer_norm.py @@ -21,8 +21,7 @@ def test_layer_norm(device, h, w): torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) torch_output_tensor = torch.nn.functional.layer_norm(torch_input_tensor, normalized_shape=[w]) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_device(input_tensor, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.layer_norm(input_tensor) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.from_device(output_tensor) @@ -45,13 +44,9 @@ def test_layer_norm_with_weight_and_bias(device, h, w): torch_input_tensor, normalized_shape=[w], weight=torch_weight, bias=torch_bias ) - input_tensor = ttnn.from_torch(torch_input_tensor) - weight = ttnn.from_torch(torch_weight) - bias = ttnn.from_torch(torch_bias) - - input_tensor = ttnn.to_device(input_tensor, device) - weight = ttnn.to_device(weight, device) - bias = ttnn.to_device(bias, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device) + bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.layer_norm(input_tensor, weight=weight, bias=bias) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) @@ -75,15 +70,10 @@ def test_layer_norm_with_weight_bias_and_residual_input(device, h, w): torch_input_tensor + torch_residual_input_tensor, normalized_shape=[w], weight=torch_weight, bias=torch_bias ) - input_tensor = ttnn.from_torch(torch_input_tensor) - residual_input_tensor = ttnn.from_torch(torch_residual_input_tensor) - weight = ttnn.from_torch(torch_weight) - bias = ttnn.from_torch(torch_bias) - - input_tensor = ttnn.to_device(input_tensor, device) - residual_input_tensor = ttnn.to_device(residual_input_tensor, device) - weight = ttnn.to_device(weight, device) - bias = ttnn.to_device(bias, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + residual_input_tensor = ttnn.from_torch(torch_residual_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device) + bias = ttnn.from_torch(torch_bias, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.layer_norm(input_tensor, residual_input_tensor=residual_input_tensor, weight=weight, bias=bias) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) diff --git a/tests/ttnn/unit_tests/operations/test_mean.py b/tests/ttnn/unit_tests/operations/test_mean.py index e94d40d035a..4b2780081fe 100644 --- a/tests/ttnn/unit_tests/operations/test_mean.py +++ b/tests/ttnn/unit_tests/operations/test_mean.py @@ -17,17 +17,14 @@ @pytest.mark.parametrize("h", [32, 64]) @pytest.mark.parametrize("w", [32, 64]) @pytest.mark.parametrize("dim", [-1, -2]) -@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) -def test_mean(device, batch_size, h, w, dim, input_layout): +def test_mean(device, batch_size, h, w, dim): torch.manual_seed(0) torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16) torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=True, dtype=torch.bfloat16) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_layout(input_tensor, input_layout) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) - input_tensor = ttnn.to_device(input_tensor, device) output_tensor = ttnn.mean(input_tensor, dim=dim, keepdim=True) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.from_device(output_tensor) diff --git a/tests/ttnn/unit_tests/operations/test_softmax.py b/tests/ttnn/unit_tests/operations/test_softmax.py index 3622f3c9676..c175823927b 100644 --- a/tests/ttnn/unit_tests/operations/test_softmax.py +++ b/tests/ttnn/unit_tests/operations/test_softmax.py @@ -18,18 +18,13 @@ @pytest.mark.parametrize("h", [32, 64]) @pytest.mark.parametrize("w", [32, 64]) @pytest.mark.parametrize("dim", [-1, -2, -3]) -@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) -def test_softmax(device, batch_size, h, w, dim, input_layout): - if dim != -1 and input_layout != ttnn.TILE_LAYOUT: - pytest.skip("Not supported yet") - +def test_softmax(device, batch_size, h, w, dim): torch.manual_seed(0) torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16) torch_output_tensor = F.softmax(torch_input_tensor, dim=dim, dtype=torch.bfloat16) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_layout(input_tensor, input_layout) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = ttnn.to_device(input_tensor, device) output_tensor = ttnn.softmax(input_tensor, dim=dim) @@ -45,8 +40,7 @@ def test_softmax_with_3D(device): torch.manual_seed(0) torch_input_tensor = torch_random((8, 1500, 1500), -10, 10, dtype=torch.bfloat16) torch_output_tensor = F.softmax(torch_input_tensor, dim=-1, dtype=torch.bfloat16) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_device(input_tensor, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.softmax(input_tensor, dim=-1) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.from_device(output_tensor) @@ -80,8 +74,7 @@ def test_specific_tensor_combination(device): torch_output_tensor = torch.softmax(torch_input_tensor, -1) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_device(input_tensor, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output = ttnn.softmax(input_tensor, -1) output = ttnn.from_device(output) diff --git a/tests/ttnn/unit_tests/operations/test_split.py b/tests/ttnn/unit_tests/operations/test_split.py index 66249e9e461..8ab5fd93774 100644 --- a/tests/ttnn/unit_tests/operations/test_split.py +++ b/tests/ttnn/unit_tests/operations/test_split.py @@ -23,8 +23,7 @@ def test_split(device, h, w, split_size, dim): torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) torch_output_tensors = torch.split(torch_input_tensor, split_size, dim=dim) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_device(input_tensor, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output_tensors = ttnn.split(input_tensor, split_size=split_size, dim=dim) for torch_output_tensor, output_tensor in zip(torch_output_tensors, output_tensors): diff --git a/tests/ttnn/unit_tests/operations/test_unary.py b/tests/ttnn/unit_tests/operations/test_unary.py index d71669a710f..a0a073fe9db 100644 --- a/tests/ttnn/unit_tests/operations/test_unary.py +++ b/tests/ttnn/unit_tests/operations/test_unary.py @@ -17,8 +17,7 @@ def run_unary_test(device, h, w, ttnn_function, torch_function, pcc=0.9999): torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) torch_output_tensor = torch_function(torch_input_tensor) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_device(input_tensor, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn_function(input_tensor) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.from_device(output_tensor) diff --git a/tests/ttnn/unit_tests/test_debug_decorator.py b/tests/ttnn/unit_tests/test_debug_decorator.py index acefde69c56..eeb9ead5f57 100644 --- a/tests/ttnn/unit_tests/test_debug_decorator.py +++ b/tests/ttnn/unit_tests/test_debug_decorator.py @@ -18,18 +18,13 @@ @pytest.mark.parametrize("h", [32]) @pytest.mark.parametrize("w", [32]) @pytest.mark.parametrize("dim", [-1]) -@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) -def test_softmax(device, ttnn_enable_debug_decorator, batch_size, h, w, dim, input_layout): - if dim != -1 and input_layout != ttnn.TILE_LAYOUT: - pytest.skip("Not supported yet") - +def test_softmax(device, ttnn_enable_debug_decorator, batch_size, h, w, dim): torch.manual_seed(0) torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16) torch_output_tensor = torch.nn.functional.softmax(torch_input_tensor, dim=dim, dtype=torch.bfloat16) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_layout(input_tensor, input_layout) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = ttnn.to_device(input_tensor, device) with override_pearson_correlation_coefficient(0.99): @@ -43,17 +38,13 @@ def test_softmax(device, ttnn_enable_debug_decorator, batch_size, h, w, dim, inp @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("h", [32]) @pytest.mark.parametrize("w", [32]) -@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) -def test_exp(device, ttnn_enable_debug_decorator, batch_size, h, w, input_layout): +def test_exp(device, ttnn_enable_debug_decorator, batch_size, h, w): torch.manual_seed(0) torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16) torch_output_tensor = torch.exp(torch_input_tensor) - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_layout(input_tensor, input_layout) - - input_tensor = ttnn.to_device(input_tensor, device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.exp(input_tensor) output_tensor = ttnn.to_torch(output_tensor) diff --git a/tests/ttnn/unit_tests/test_preprocess_model_parameters.py b/tests/ttnn/unit_tests/test_preprocess_model_parameters.py index 6155f3d83b3..73ab2b60988 100644 --- a/tests/ttnn/unit_tests/test_preprocess_model_parameters.py +++ b/tests/ttnn/unit_tests/test_preprocess_model_parameters.py @@ -248,7 +248,9 @@ def custom_preprocessor(model, name): output_tensor = conv1(output_tensor) output_tensor = conv1.copy_output_from_device(output_tensor) output_tensor = ttnn.to_device(output_tensor, device=device) + output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT) output_tensor = ttnn.relu(output_tensor) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = conv2.copy_input_to_device(output_tensor) output_tensor = conv2(output_tensor) output_tensor = conv2.copy_output_from_device(output_tensor) diff --git a/ttnn/README.md b/ttnn/README.md index 59708165ad5..1988f47e97d 100644 --- a/ttnn/README.md +++ b/ttnn/README.md @@ -37,15 +37,36 @@ followed the instructions for [installing and building the software](https://git * When using the ttnn library, the operation requires the first parameter to be the tensor and the next to be the new order of dimensions within a parenthesis. #### Frequently asked questions +* Where are the tests for ttnn? + * All tests can be found under tests/ttnn +* Tell me the differences between each kind of test under tests/ttnn? + * tests/ttnn/integration_tests + * Demonstrates the inference models built with ttnn + * tests/ttnn/sweep_tests + * Used to check coverage of what is and what is NOT supported + * Tests can be added well before the actuall implementation is finished + * These tests do not block the continuous integration pipeline + * They are built so that their results can be uniformly reported + * Can be run with pytest although they follow a strict format when built + * tests/ttnn/unit_test + * These are traditional unit tests written with pytest + * Failures on these tests will cause alarms if code is merged into main +* Why do the sweep tests use a dictionary for all the combinations of input and then use a special run method? Could you not have done this with a traditional pytest instead? + * The primary reason was because we needed a way to create a consolidated report per operation in the form of a csv file. The idea was that each operation would get its own python file where all the test combinations are handled by a single run method. Each permuation of the input combinations would become the header for the resulting csv which is then uploaded and reported on. +* How do I run sweep tests with pytest? + * To run all of the sweep tests for a given python operation file: + * pytest /tt-metal/tests/ttnn/sweep_tests/test_all_sweep_tests.py::test_ + * Example for matmul: pytest /home/ubuntu/git/tt-metal/tests/ttnn/sweep_tests/test_all_sweep_tests.py::test_matmul + * To run just one sample combination for an operation: + * pytest /tt-metal/tests/ttnn/sweep_tests/test_all_sweep_tests.py::test_[.py-] + * Example for matmul: pytest /home/ubuntu/git/tt-metal/tests/ttnn/sweep_tests/test_all_sweep_tests.py::test_matmul[matmul.py-0] * What if my device hangs? - * Try resetting the board on the command line with: `tt-smi -tr all` + * Be sure that you have a clean build with the latest code update. Updating code without rebuilding is often the source of the issue. + * If you have a clean build, you can reset the board on the command line using the tt-smi command and the device id with: `tt-smi -tr 0` where 0 represents the device id. * What types are supported on device? * We currently support ttnn.bfloat16, ttnn.bfloat8 and ttnn.uint32. * What shapes are supported on device? - * The last dimension of the shape multiplied by the number of bytes of the sizeof the dataype must be a multiple of four. For example, ttnn.bloat16 would need to have the last dimension be even. - * TODO : address ttnn.bfloat8_b and how mantissa is stored per tile - * TODO : address converting from int in data type for torch tensors to ttnn.uint32 - * TODO : mention how ttnn.blfloat32 is not supported on device + * The last dimension of the shape multiplied by the number of bytes of the sizeof the dataype must be a multiple of four. For example, ttnn.bloat16 would need to have the last dimension be even for a tensor using ttnn.ROW_MAJOR_LAYOUT. For ttnn.TILE_LAYOUT the to_layout operation will automatically do padding to ensure the the last two dimensions (height and width) are multiples of 32. * Is slicing available? * Slicing is supported. At the moment this feature falls back to using PyTorch slicing on the host. * Example: @@ -63,7 +84,55 @@ followed the instructions for [installing and building the software](https://git * `export TT_METAL_LOGGER_LEVEL=DEBUG` * For the location of the operations use the following environment variable * `export OPERATION_HISTORY_CSV=` +* What is the format for git commit messages? + * As mentioned in other documenation, the use of the '#' symbol to identify an issue request number is expected on each commit message. + * For example your git commit message might be: "#4003: Your message here" for github issue 4003. + * Consider using: `git config --global core.commentChar '>'` +#### Steps to setup ttnn tests from vscode +* Add the Makefile Tools extension to vscode +* Add the Python extension to vscode +* Update settings.json to make vscode aware of the pytests + * ``` + "python.testing.pytestEnabled": true, + "python.testing.pytestArgs": [ + "tests/ttnn" + ], + "python.autoComplete.extraPaths": [ + "${workspaceFolder}/tests/ttnn" + ], + "python.analysis.extraPaths": [ + "${workspaceFolder}/tests/ttnn" + ], + ``` + +#### Steps to launch the tt-eager example code from within vscode +* Add the Makefile Tools extension +* Be sure to build with `make tests/tt_eager` +* Update launch.json to debug the code sample you want to run. For example if you want to run test_bert, your update to launch.json might look like like: + ``` + { + "name": "test_bert", + "type": "cppdbg", + "request": "launch", + "args": [], + "stopAtEntry": false, + "externalConsole": false, + "cwd": "${workspaceFolder}/build", + "program": "${workspaceFolder}/build/test/tt_eager/integration_tests/test_bert", + "MIMode": "gdb", + "miDebuggerPath": "gdb", + "setupCommands": [ + { + "description": "Enable pretty-printing for gdb", + "text": "-enable-pretty-printing", + "ignoreFailures": true + } + ] + }, + + ``` + * Debug with vscode by launching it from the "Run and Debug" #### How to debug from python and C++ at the same time within vscode * `export CONFIG=debug` within your virtual environment and run `make build` diff --git a/ttnn/ttnn/core.py b/ttnn/ttnn/core.py index a1afbd714e3..5f0fc05d1e4 100644 --- a/ttnn/ttnn/core.py +++ b/ttnn/ttnn/core.py @@ -52,6 +52,21 @@ def device(self: "Tensor") -> DataType: else: return Cpu() + def _getitem_validate_input_tensors(operation_name, input_tensor, padding, *args, **kwargs): + validate_input_tensor( + operation_name, + input_tensor, + ranks=(1, 2, 3, 4, 5, 6, 7, 8), + dtypes=(bfloat16, bfloat8_b, uint16, uint32), + layouts=(ROW_MAJOR_LAYOUT, TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=True, + ) + + @register_operation( + name="ttnn.pad", + validate_input_tensors=_getitem_validate_input_tensors, + ) @register_operation(name="ttnn.Tensor.__getitem__") def __getitem__(self: "Tensor", slices) -> "Tensor": if self.layout != ROW_MAJOR_LAYOUT: @@ -361,7 +376,23 @@ def has_padding(tensor): return False -@register_operation(name="ttnn.from_torch") +def _from_torch_validate_input_tensors(operation_name, tensor, *args, **kwargs): + import torch + + ranks = (1, 2, 3, 4, 5, 6, 7, 8) + if len(tensor.shape) not in ranks: + raise RuntimeError(f"{operation_name}: Tensor must be of rank {ranks}, but got {len(tensor.shape)}") + dtypes = (torch.bfloat16, torch.float32, torch.int16, torch.int32, torch.int64) + if tensor.dtype not in dtypes: + raise RuntimeError(f"{operation_name}: Tensor must be of type {dtypes}, but got {tensor.dtype}") + # if not tensor.is_contiguous(): + # raise RuntimeError(f"{operation_name}: Tensor must be contiguous") + + +@register_operation( + name="ttnn.from_torch", + validate_input_tensors=_from_torch_validate_input_tensors, +) def from_torch( tensor: "torch.Tensor", dtype: Optional[DataType] = None, @@ -407,7 +438,22 @@ def impl(tensor, dtype): return tensor -@register_operation(name="ttnn.to_torch") +def _to_torch_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + validate_input_tensor( + operation_name, + input_tensor, + ranks=(1, 2, 3, 4, 5, 6, 7, 8), + dtypes=(bfloat16, bfloat8_b, uint16, uint32), + layouts=(ROW_MAJOR_LAYOUT, TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=True, + ) + + +@register_operation( + name="ttnn.to_torch", + validate_input_tensors=_to_torch_validate_input_tensors, +) def to_torch(tensor: Tensor, *, torch_rank: Optional[int] = None) -> "torch.Tensor": """ to_torch(tensor: ttnn.Tensor) -> torch.Tensor @@ -573,7 +619,22 @@ def impl(tensor): ttl.tensor.decorate_external_operation(impl, function_name="ttnn.deallocate")(tensor) -@register_operation(name="ttnn.to_memory_config") +def _to_memory_config_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + validate_input_tensor( + operation_name, + input_tensor, + ranks=(1, 2, 3, 4), + dtypes=(bfloat16, bfloat8_b, uint16, uint32), + layouts=(ROW_MAJOR_LAYOUT, TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@register_operation( + name="ttnn.to_memory_config", + validate_input_tensors=_to_memory_config_validate_input_tensors, +) def to_memory_config(tensor, memory_config: MemoryConfig): """ to_memory_config(tensor: ttnn.Tensor, memory_config: MemoryConfig) -> ttnn.Tensor @@ -826,7 +887,23 @@ def _torch_identity(input_tensor): return input_tensor.clone() -@register_operation(name="ttnn.reallocate", torch_function=_torch_identity) +def _reallocate_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + validate_input_tensor( + operation_name, + input_tensor, + ranks=(1, 2, 3, 4), + dtypes=(bfloat16, bfloat8_b, uint16, uint32), + layouts=(ROW_MAJOR_LAYOUT, TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@register_operation( + name="ttnn.reallocate", + validate_input_tensors=_reallocate_validate_input_tensors, + torch_function=_torch_identity, +) def reallocate(input_tensor: Tensor) -> Tensor: def impl(input_tensor): ttl_input_tensor = input_tensor.value @@ -836,7 +913,17 @@ def impl(input_tensor): return ttl.tensor.decorate_external_operation(impl, function_name="ttnn.reallocate")(input_tensor) -@register_operation(name="ttnn.load_tensor") +def _load_tensor_validate_input_tensors(operation_name, file_name, *args, **kwargs): + if not isinstance(file_name, str) and not isinstance(file_name, pathlib.Path): + raise RuntimeError( + f"Unable to dump the tensor to the type {type(file_name)}. The file_name must either be a str or pathlib.Path." + ) + + +@register_operation( + name="ttnn.load_tensor", + validate_input_tensors=_load_tensor_validate_input_tensors, +) def load_tensor(file_name: Union[str, pathlib.Path]) -> Tensor: def impl(file_name): return Tensor(ttl.tensor.load_tensor(str(file_name))) @@ -844,7 +931,26 @@ def impl(file_name): return ttl.tensor.decorate_external_operation(impl, function_name="ttnn.load_tensor")(file_name) -@register_operation(name="ttnn.dump_tensor") +def _dump_tensor_validate_input_tensors(operation_name, file_name, tensor, *args, **kwargs): + if not isinstance(file_name, str) and not isinstance(file_name, pathlib.Path): + raise RuntimeError( + f"Unable to dump the tensor to the type {type(file_name)}. The file_name must either be a str or pathlib.Path." + ) + validate_input_tensor( + operation_name, + tensor, + ranks=(1, 2, 3, 4, 5, 6, 7, 8), + dtypes=(bfloat16, bfloat8_b, uint16, uint32), + layouts=(ROW_MAJOR_LAYOUT, TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=True, + ) + + +@register_operation( + name="ttnn.dump_tensor", + validate_input_tensors=_dump_tensor_validate_input_tensors, +) def dump_tensor(file_name: Union[str, pathlib.Path], tensor: Tensor) -> None: def impl(file_name, tensor): ttl_tensor = tensor.value diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index e4501c7fcdc..9639dec23b0 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -19,7 +19,7 @@ def register_ttl_binary_function(name, ttl_binary_function, doc): - def _torch_unary(input_tensor: ttnn.Tensor, parameter, **_): + def _torch_binary(input_tensor: ttnn.Tensor, parameter, **_): import torch name_to_torch_function = {"pow": torch.pow} @@ -27,7 +27,22 @@ def _torch_unary(input_tensor: ttnn.Tensor, parameter, **_): input_tensor = ttnn.to_torch(input_tensor) return torch_function(input_tensor, parameter) - @register_operation(torch_function=_torch_unary, name=f"ttnn.{name}") + def _binary_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + @ttnn.register_operation( + name=f"ttnn.{name}", + validate_input_tensors=_binary_validate_input_tensors, + torch_function=_torch_binary, + ) def binary_function( input_tensor: ttnn.Tensor, parameter: float, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG ) -> ttnn.Tensor: @@ -35,13 +50,6 @@ def binary_function( input_tensor = ttnn.unsqueeze_to_4D(input_tensor) ttl_input_tensor = input_tensor.value - if not isinstance(input_tensor, ttnn.Tensor): - raise TypeError("Expected first argument to be a ttnn.Tensor") - - if not ttnn.has_storage_type_of(input_tensor, ttnn.DEVICE_STORAGE_TYPE): - raise RuntimeError("input_tensor must be on device!") - ttl_input_tensor = input_tensor.value - ttl_output_tensor = ttl_binary_function(ttl_input_tensor, parameter, output_mem_config=memory_config) output_tensor = ttnn.Tensor(ttl_output_tensor) diff --git a/ttnn/ttnn/operations/data_movement.py b/ttnn/ttnn/operations/data_movement.py index c16bf04145f..039d9a7521a 100644 --- a/ttnn/ttnn/operations/data_movement.py +++ b/ttnn/ttnn/operations/data_movement.py @@ -7,7 +7,6 @@ import tt_lib as ttl import ttnn.core as ttnn -from ttnn.decorators import register_operation def _torch_pad(input_tensor: ttnn.Tensor, padding, value): @@ -25,7 +24,23 @@ def _torch_pad(input_tensor: ttnn.Tensor, padding, value): return torch.nn.functional.pad(input_tensor, pad=torch_padding, mode="constant", value=value) -@register_operation(torch_function=_torch_pad, name="ttnn.pad") +def _pad_validate_input_tensors(operation_name, input_tensor, padding, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.pad", + validate_input_tensors=_pad_validate_input_tensors, + torch_function=_torch_pad, +) def pad(input_tensor: ttnn.Tensor, padding: Tuple[Tuple[int, int], ...], value: Union[int, float]) -> ttnn.Tensor: r""" @@ -42,9 +57,6 @@ def pad(input_tensor: ttnn.Tensor, padding: Tuple[Tuple[int, int], ...], value: """ - if not ttnn.has_storage_type_of(input_tensor, ttnn.DEVICE_STORAGE_TYPE): - raise RuntimeError("pad expects input tensor to be on device!") - output_tensor = _torch_pad(input_tensor, padding, value) output_tensor = ttnn.from_torch( output_tensor, dtype=input_tensor.dtype, device=input_tensor.device, layout=input_tensor.layout @@ -64,7 +76,31 @@ def _torch_permute(input_tensor: ttnn.Tensor, order: Tuple[int, ...], **_): return torch.permute(input_tensor, order).contiguous().clone() -@register_operation(torch_function=_torch_permute, name="ttnn.permute") +def _permute_validate_input_tensors(operation_name, input_tensor, order, *args, **kwargs): + if not isinstance(order, tuple): + raise RuntimeError("order must be a tuple") + + if len(input_tensor.shape) != len(order): + raise RuntimeError( + "The number of dimensions in the tensor input does not match the length of the desired ordering" + ) + + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(1, 2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.permute", + validate_input_tensors=_permute_validate_input_tensors, + torch_function=_torch_permute, +) def permute(input_tensor: ttnn.Tensor, order: Tuple[int, ...]) -> ttnn.Tensor: r""" permute(input_tensor: ttnn.Tensor, order: Tuple[int, ...]) -> ttnn.Tensor @@ -84,17 +120,6 @@ def permute(input_tensor: ttnn.Tensor, order: Tuple[int, ...]) -> ttnn.Tensor: """ - if not isinstance(order, tuple): - raise RuntimeError("order must be a tuple") - - if not ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE): - RuntimeError("input_tensor must be on device!") - - if len(input_tensor.shape) != len(order): - raise RuntimeError( - "The number of dimensions in the tensor input does not match the length of the desired ordering" - ) - on_device = ttnn.has_storage_type_of(input_tensor, ttnn.DEVICE_STORAGE_TYPE) device = input_tensor.device layout = input_tensor.layout @@ -144,7 +169,46 @@ def _torch_concat(tensors, dim=0, **_): return torch.concat(torch_tensors, dim) -@register_operation(torch_function=_torch_concat, name="ttnn.concat") +def _concat_validate_input_tensors(operation_name, tensors, dim, *args, **kwargs): + if len(tensors) < 2: + raise RuntimeError("You must have at least two tensors to concat!") + + first_tensor = tensors[0] + first_tensor_shape = first_tensor.shape + for input_tensor in tensors: + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + for tensor in tensors: + shape = tensor.shape + if len(shape) != len(first_tensor_shape) or any( + shape[i] != first_tensor_shape[i] for i in range(len(shape)) if i != dim + ): + raise ValueError( + "All dimensions must be the same size except for the dimension along which the contenation is taking place." + ) + + rank = len(tensors[0].shape) + original_dim = dim + if dim < 0: + dim = rank + dim + if dim < 0 or dim >= rank: + raise RuntimeError( + f"ttnn: Dimension out of range: dim {original_dim} cannot be used for tensors of rank {rank}" + ) + + +@ttnn.register_operation( + name="ttnn.concat", + validate_input_tensors=_concat_validate_input_tensors, + torch_function=_torch_concat, +) def concat( tensors: List[ttnn.Tensor], dim: int = 0, @@ -171,38 +235,13 @@ def concat( """ - if len(tensors) < 2: - raise RuntimeError("You must have at least two tensors to concat!") - rank = len(tensors[0].shape) - original_dim = dim - if dim < 0: - dim = rank + dim - if dim < 0 or dim >= rank: - raise RuntimeError( - f"ttnn: Dimension out of range: dim {original_dim} cannot be used for tensors of rank {rank}" - ) - - for input_tensor in tensors: - if not ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE): - raise RuntimeError("ttnn: All tensors must be on device!") dtype = tensors[0].dtype device = tensors[0].device layout = tensors[0].layout rank = len(tensors[0].shape) - first_tensor = tensors[0] - first_tensor_shape = first_tensor.shape - for tensor in tensors: - shape = tensor.shape - if len(shape) != len(first_tensor_shape) or any( - shape[i] != first_tensor_shape[i] for i in range(len(shape)) if i != dim - ): - raise ValueError( - "All dimensions must be the same size except for the dimension along which the contenation is taking place." - ) - all_tensors_are_tile_layout_without_padding = not any( tensor.layout != ttnn.TILE_LAYOUT or ttnn.has_padding(tensor) for tensor in tensors ) @@ -241,7 +280,23 @@ def _torch_split(input_tensor: ttnn.Tensor, split_size, dim): return torch.split(input_tensor, split_size, dim=dim) -@register_operation(torch_function=_torch_split, name="ttnn.split") +def _split_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.split", + validate_input_tensors=_split_validate_input_tensors, + torch_function=_torch_split, +) def split(input_tensor: ttnn.Tensor, split_size: int, dim: int) -> ttnn.Tensor: r""" split(input_tensor: ttnn.Tensor, split_size: int, dim: int) -> Tuple[ttnn.Tensor, ...] @@ -254,9 +309,6 @@ def split(input_tensor: ttnn.Tensor, split_size: int, dim: int) -> ttnn.Tensor: * :attr:`dim`: dimension along which to split the tensor. """ - if not ttnn.has_storage_type_of(input_tensor, ttnn.DEVICE_STORAGE_TYPE): - raise RuntimeError("pad expects input tensor to be on device!") - output_tensors = _torch_split(input_tensor, split_size, dim) output_tensors = tuple( ttnn.from_torch(output_tensor, device=input_tensor.device, dtype=input_tensor.dtype, layout=input_tensor.layout) @@ -274,17 +326,33 @@ def _torch_repeat_interleave(tensor, repeats, dim=0, **_): return torch.repeat_interleave(ttnn.to_torch(tensor), repeats, dim=dim) -@register_operation(torch_function=_torch_repeat_interleave, name="ttnn.repeat_interleave") -def repeat_interleave(tensor: ttnn.Tensor, repeats: Union[ttnn.Tensor, int], dim: int = 0) -> ttnn.Tensor: +def _repeat_interleave_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=True, + ) + + +@ttnn.register_operation( + name="ttnn.repeat_interleave", + validate_input_tensors=_repeat_interleave_validate_input_tensors, + torch_function=_torch_repeat_interleave, +) +def repeat_interleave(input_tensor: ttnn.Tensor, repeats: Union[ttnn.Tensor, int], dim: int = 0) -> ttnn.Tensor: r""" - repeat_interleave(tensors: ttnn.Tensor, repeats : Union[ttnn.Tensor,int], dim: int = 0) -> ttnn.Tensor + repeat_interleave(input_tensor: ttnn.Tensor, repeats : Union[ttnn.Tensor,int], dim: int = 0) -> ttnn.Tensor Repeats elements of a :attr:`tensor` in the given :attr:`dim`. Args: - * :attr:`tensors`: the tensors to be concatenated. + * :attr:`input_tensor`: the input_tensor to apply the repeate interleave operation. * :attr:`repeats`: The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. - * :attr:`dim`: the concatenating dimension. + * :attr:`dim`: the dimension to expand with the repetitions. Example:: @@ -297,20 +365,10 @@ def repeat_interleave(tensor: ttnn.Tensor, repeats: Union[ttnn.Tensor, int], dim """ - if not isinstance(tensor, ttnn.Tensor): - raise RuntimeError("ttnn: Expected tensor argument to be a ttnn.Tensor") - - if not ttnn.has_storage_type_of(tensor, ttl.tensor.StorageType.DEVICE): - raise RuntimeError("ttnn: Tensor must be on device!") - if not isinstance(repeats, int) and not isinstance(repeats, ttnn.Tensor): raise RuntimeError("ttnn: Expected repeat to either be an int or a ttnn.Tensor") - # For now, don't require the repeat tensor to be on device. - # if type(repeats) == type(tensor) and not ttnn.has_storage_type_of(repeats, ttl.tensor.StorageType.DEVICE): - # raise RuntimeError("Repeats tensor must be on device!") - - rank_of_tensor = len(tensor.shape) + rank_of_tensor = len(input_tensor.shape) if dim >= rank_of_tensor: dimension_range = f"[{-rank_of_tensor}, {rank_of_tensor - 1}]" raise RuntimeError( @@ -324,21 +382,21 @@ def custom_numel(tensor): return total_elements if isinstance(repeats, ttnn.Tensor): - if tensor.shape[dim] != custom_numel(repeats): + if input_tensor.shape[dim] != custom_numel(repeats): raise RuntimeError("ttnn: repeats must have the same size as input along dim") elif len(repeats.shape) != 1: raise RuntimeError("ttnn: repeats must be 0-dim or 1-dim tensor") - dtype = tensor.dtype - device = tensor.device - layout = tensor.layout - rank = len(tensor.shape) + dtype = input_tensor.dtype + device = input_tensor.device + layout = input_tensor.layout + rank = len(input_tensor.shape) if dtype == ttnn.bfloat16 and rank == 4 and dim != 2 and dim != 3: - ttl_input_tensor = tensor.value + ttl_input_tensor = input_tensor.value output_tensor = ttnn.Tensor(ttl.tensor.repeat_interleave(ttl_input_tensor, repeats, dim=dim)) *batch, _, _ = output_tensor.shape - *_, h, w = tensor.shape - *_, padded_h, padded_w = tensor.shape.padded() + *_, h, w = input_tensor.shape + *_, padded_h, padded_w = input_tensor.shape.padded() if dim == 2: *_, h, _ = output_tensor.shape *_, padded_h, _ = output_tensor.shape.padded() @@ -354,7 +412,7 @@ def torch_repeat_interleave(tensor, repeats, dim=dim): output_tensor = ttl.tensor.decorate_external_operation( torch_repeat_interleave, function_name="torch_repeat_interleave" - )(tensor, repeats, dim=dim) + )(input_tensor, repeats, dim=dim) return ttnn.from_torch(output_tensor, device=device, dtype=dtype, layout=layout) diff --git a/ttnn/ttnn/operations/matmul.py b/ttnn/ttnn/operations/matmul.py index 50af5366bc2..8e4859fc689 100644 --- a/ttnn/ttnn/operations/matmul.py +++ b/ttnn/ttnn/operations/matmul.py @@ -474,17 +474,6 @@ def linear( padded_output_shape_list.append(input_shape_b.padded()[-1]) output_shape = ttnn.Shape(output_shape_list, padded_output_shape_list) - if not isinstance(input_tensor_a, ttnn.Tensor): - raise RuntimeError("Expected first argument to be a ttnn.Tensor") - if not isinstance(input_tensor_b, ttnn.Tensor): - raise RuntimeError("Expected second argument to be a ttnn.Tensor") - - if input_tensor_a.value.storage_type() != ttl.tensor.StorageType.DEVICE: - raise RuntimeError("input_tensor_a must be on device!") - - if input_tensor_b.value.storage_type() != ttl.tensor.StorageType.DEVICE: - raise RuntimeError("input_tensor_b must be on device!") - # The idea is to make the shapes "possibly" broadcastable. if len(input_tensor_a.shape) > 4: raise RuntimeError("There is currently no support for ranks greater than 4.") diff --git a/ttnn/ttnn/operations/normalization.py b/ttnn/ttnn/operations/normalization.py index 05b435ae85e..af1cf0a6d10 100644 --- a/ttnn/ttnn/operations/normalization.py +++ b/ttnn/ttnn/operations/normalization.py @@ -42,7 +42,23 @@ def _torch_layer_norm( return torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), weight, bias, eps=epsilon) -@ttnn.register_operation(torch_function=_torch_layer_norm, name="ttnn.layer_norm") +def _layer_norm_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.layer_norm", + validate_input_tensors=_layer_norm_validate_input_tensors, + torch_function=_torch_layer_norm, +) def layer_norm( input_tensor: ttnn.Tensor, *, @@ -59,9 +75,6 @@ def layer_norm( """ - if not ttnn.has_storage_type_of(input_tensor, ttnn.DEVICE_STORAGE_TYPE): - raise RuntimeError("layer_norm only supports device storage type") - original_shape = input_tensor.shape input_tensor = ttnn.unsqueeze_to_4D(input_tensor) if residual_input_tensor is not None: @@ -139,7 +152,23 @@ def _torch_group_norm(input_tensor: ttnn.Tensor, *, num_groups, epsilon=1e-05, w return torch.nn.functional.group_norm(input_tensor, num_groups, weight, bias, eps=epsilon) -@ttnn.register_operation(torch_function=_torch_group_norm, name="ttnn.group_norm") +def _group_norm_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.group_norm", + validate_input_tensors=_group_norm_validate_input_tensors, + torch_function=_torch_group_norm, +) def group_norm( input_tensor: ttnn.Tensor, *, @@ -154,10 +183,6 @@ def group_norm( Compute group_norm over :attr:`input_tensor`. """ - - if not ttnn.has_storage_type_of(input_tensor, ttnn.DEVICE_STORAGE_TYPE): - raise RuntimeError("group_norm expects input tensor to be on device!") - output = _torch_group_norm(input_tensor, num_groups=num_groups, epsilon=epsilon, weight=weight, bias=bias) return ttnn.from_torch(output, dtype=input_tensor.dtype, layout=input_tensor.layout, device=input_tensor.device) diff --git a/ttnn/ttnn/operations/others.py b/ttnn/ttnn/operations/others.py index 51eeacc7df1..acdd59093bd 100644 --- a/ttnn/ttnn/operations/others.py +++ b/ttnn/ttnn/operations/others.py @@ -22,6 +22,23 @@ def _torch_pad_to_tile(padded_tensor: ttnn.Tensor): return output_tensor +def _pad_to_tile_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(1, 2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.ROW_MAJOR_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=True, + ) + + +@ttnn.register_operation( + name="ttnn.pad_to_tile", + validate_input_tensors=_pad_to_tile_validate_input_tensors, + torch_function=_torch_pad_to_tile, +) @register_operation(torch_function=_torch_pad_to_tile, name="ttnn.pad_to_tile") def pad_to_tile(input_tensor: ttnn.Tensor) -> ttnn.Tensor: r""" @@ -107,7 +124,23 @@ def _torch_unpad_from_tile(padded_tensor: ttnn.Tensor): return output_tensor -@register_operation(torch_function=_torch_unpad_from_tile, name="ttnn.unpad_from_tile") +def _unpad_from_tile_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=True, + ) + + +@ttnn.register_operation( + name="ttnn.unpad_from_tile", + validate_input_tensors=_unpad_from_tile_validate_input_tensors, + torch_function=_torch_unpad_from_tile, +) def unpad_from_tile(input_tensor: ttnn.Tensor) -> ttnn.Tensor: r""" unpad(input_tensor: ttnn.Tensor) -> ttnn.Tensor @@ -171,6 +204,23 @@ def _torch_embedding(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, **_): return output_tensor +def _embedding_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(1, 2, 3, 4), + dtypes=(ttnn.uint16, ttnn.uint32), + layouts=(ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.embedding", + validate_input_tensors=_embedding_validate_input_tensors, + torch_function=_torch_embedding, +) @register_operation(torch_function=_torch_embedding, name="ttnn.embedding") def embedding( input_tensor: ttnn.Tensor, @@ -236,7 +286,23 @@ def _torch_softmax(input_tensor: ttnn.Tensor, dim: int, **_): return torch.softmax(input_tensor, dim) -@register_operation(torch_function=_torch_softmax, name="ttnn.softmax") +def _softmax_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.softmax", + validate_input_tensors=_softmax_validate_input_tensors, + torch_function=_torch_softmax, +) def softmax( input_tensor: ttnn.Tensor, dim: int, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG ) -> ttnn.Tensor: @@ -297,7 +363,23 @@ def _torch_mean(input_tensor: ttnn.Tensor, dim: int, keepdim=False, **_): return torch.mean(input_tensor, dim=dim, keepdim=keepdim) -@register_operation(torch_function=_torch_mean, name="ttnn.mean") +def _mean_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.mean", + validate_input_tensors=_mean_validate_input_tensors, + torch_function=_torch_mean, +) def mean(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int]], keepdim: bool = False) -> ttnn.Tensor: input_shape = tuple(input_tensor.shape) rank = len(input_shape) diff --git a/ttnn/ttnn/operations/pooling.py b/ttnn/ttnn/operations/pooling.py index 0030389b8f8..839400d0216 100644 --- a/ttnn/ttnn/operations/pooling.py +++ b/ttnn/ttnn/operations/pooling.py @@ -96,7 +96,27 @@ def _torch_average_pool2d(input_tensor: ttnn.Tensor): return torch.nn.AdaptiveAvgPool2d(output_size)(input_tensor) -@ttnn.register_operation(torch_function=_torch_average_pool2d, name="ttnn.average_pool2d") +def _average_pool2d_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(4,), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + b, s, *_ = input_tensor.shape + b, s_padded, *_ = input_tensor.shape.padded() + if s != s_padded: + raise RuntimeError("There can be no padding on the second dimension.") + + +@ttnn.register_operation( + name="ttnn.average_pool2d", + validate_input_tensors=_average_pool2d_validate_input_tensors, + torch_function=_torch_average_pool2d, +) def average_pool2d(input_tensor: ttnn.Tensor) -> ttnn.Tensor: output = ttl.tensor.average_pool_2d(input_tensor.value) diff --git a/ttnn/ttnn/operations/transformer.py b/ttnn/ttnn/operations/transformer.py index dee9cef861e..ebe30fd578f 100644 --- a/ttnn/ttnn/operations/transformer.py +++ b/ttnn/ttnn/operations/transformer.py @@ -243,7 +243,27 @@ def _fallback_attention_softmax(input_tensor: ttnn.Tensor, *, head_size: int, at return _torch_attention_softmax(input_tensor, head_size=head_size, attention_mask=attention_mask) -@register_operation(torch_function=_fallback_attention_softmax, name="ttnn.transformer.attention_softmax") +def _attention_softmax_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(4,), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + b, s, *_ = input_tensor.shape + b, s_padded, *_ = input_tensor.shape.padded() + if s != s_padded: + raise RuntimeError("There can be no padding on the second dimension.") + + +@ttnn.register_operation( + name="ttnn.transformer.attention_softmax", + validate_input_tensors=_attention_softmax_validate_input_tensors, + torch_function=_fallback_attention_softmax, +) def attention_softmax( input_tensor: ttnn.Tensor, *, @@ -262,12 +282,6 @@ def attention_softmax( * :attr:`memory_config`: Memory Config of the output tensor """ - if len(input_tensor.shape) != 4: - raise RuntimeError("Input Tensor must have strictly 4 dimensions!") - - if input_tensor.layout != ttnn.TILE_LAYOUT: - raise RuntimeError("Input Tensor must be in a TILE_LAYOUT!") - if head_size is not None: scaler = 1 / (head_size**0.5) else: @@ -285,7 +299,11 @@ def attention_softmax( return ttnn.Tensor(ttl_output_tensor) -@register_operation(torch_function=_torch_attention_softmax, name="ttnn.transformer.attention_softmax_") +@ttnn.register_operation( + name="ttnn.transformer.attention_softmax_", + validate_input_tensors=_attention_softmax_validate_input_tensors, + torch_function=_torch_attention_softmax, +) def attention_softmax_( input_tensor: ttnn.Tensor, *, @@ -340,7 +358,23 @@ def _fallback_concatenate_heads(input_tensor: ttnn.Tensor, **_): return _torch_concatenate_heads(input_tensor) -@register_operation(torch_function=_fallback_concatenate_heads, name="ttnn.transformer.concatenate_heads") +def _concatenate_heads_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(4,), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation( + name="ttnn.transformer.concatenate_heads", + validate_input_tensors=_concatenate_heads_validate_input_tensors, + torch_function=_fallback_concatenate_heads, +) def concatenate_heads( input_tensor: ttnn.Tensor, *, @@ -356,12 +390,6 @@ def concatenate_heads( * :attr:`memory_config`: Memory Config of the output tensor """ - if len(input_tensor.shape) != 4: - raise RuntimeError("Input Tensor must have strictly 4 dimensions!") - - if input_tensor.layout != ttnn.TILE_LAYOUT: - raise RuntimeError("Input Tensor must be in a TILE_LAYOUT!") - batch_size, num_heads, sequence_size, head_size = input_tensor.shape ttl_input_tensor = input_tensor.value diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index a31ca81341c..583d7d2a3f3 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -32,7 +32,22 @@ def _torch_unary(input_tensor: ttnn.Tensor, **_): input_tensor = ttnn.to_torch(input_tensor) return torch_function(input_tensor) - @register_operation(torch_function=_torch_unary, name=f"ttnn.{name}") + def _unary_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(2, 3, 4), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), + layouts=(ttnn.TILE_LAYOUT,), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + @ttnn.register_operation( + name=f"ttnn.{name}", + validate_input_tensors=_unary_validate_input_tensors, + torch_function=_torch_unary, + ) def unary_function( input_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG ) -> ttnn.Tensor: