Skip to content

Commit

Permalink
#14257: ci fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Nov 25, 2024
1 parent 66b25f2 commit 884aff5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 15 deletions.
2 changes: 1 addition & 1 deletion tests/scripts/run_python_model_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ run_python_model_tests_wormhole_b0() {
# higher sequence lengths and different formats trigger memory issues
pytest models/demos/falcon7b_common/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_128 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM"
pytest tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50_new.py -k "pretrained_weight_false"
pytest models/demos/yolov4/demo/demo.py -k "pretrained_weight_false"
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/demo/demo.py -k "pretrained_weight_false"

# Unet Shallow
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -svv models/experimental/functional_unet/tests/test_unet_model.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ConcatOpParallelizationStrategy ConcatDeviceOperation::get_parallelization_strat

void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) const {
const auto &first_input = input_tensors[0];
ttnn::SimpleShape shape_first = first_input.get_logical_shape();
tt::tt_metal::LegacyShape shape_first = first_input.get_legacy_shape();
TT_FATAL(this->dim < shape_first.rank(), "ConcatDeviceOperation dim specified is larger than input tensor rank.");
shape_first[this->dim] = 0;
bool shard_first = input_tensors[0].is_sharded();
Expand All @@ -38,24 +38,12 @@ void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) c
TT_FATAL(in_ref.device() == first_input.device(), "Operands to concat need to be on the same device.");
TT_FATAL(in_ref.get_layout() == first_input.get_layout(), "All Tensors should have same layouts.");
TT_FATAL(in_ref.get_dtype() == first_input.get_dtype(), "All Tensors should have same dtypes.");
ttnn::SimpleShape curr_shape = in_ref.get_logical_shape();

tt::tt_metal::LegacyShape curr_shape = in_ref.get_legacy_shape();
TT_FATAL(curr_shape.rank() == shape_first.rank(), "Input tensor ranks must be equal");
curr_shape[this->dim] = 0;
// last tensor can support without any kernel changes
if(in_ref.get_layout() == Layout::TILE and in_ref.get_shape().has_tile_padding(this->dim)) {
warn_about_alignment = true;
/* // last tensor can support without any kernel changes
TT_FATAL(
!in_ref.get_shape().has_tile_padding(this->dim),
"Tile padding along concatenated dim ({}) not supported for concat yet (tensor: {}).",
this->dim,
i);
TT_FATAL(curr_shape == shape_first, "concat tensors differ in shape across non-concat dimensions.");
if (in_ref.get_layout() == Layout::ROW_MAJOR && this->dim == shape_first.rank() - 1) {
TT_FATAL(
(in_ref.get_logical_shape()[this->dim] * in_ref.element_size()) % in_ref.buffer()->alignment() == 0,
"Current concat implementation requires aligned last dim when concatting on last dim");*/
}
TT_FATAL(curr_shape == shape_first, "concat tensors differ in shape across non-concat dimensions.");
TT_FATAL(in_ref.is_sharded() == shard_first, "All tensors must be sharded or all must be interleaved");
Expand Down

0 comments on commit 884aff5

Please sign in to comment.