Skip to content

Commit

Permalink
#0: testing
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Nov 24, 2024
1 parent dc668ab commit f1b228d
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 4 deletions.
13 changes: 13 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,19 @@ Result conv2d(
conv_config,
weight_config
);

// tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights_biases_and_move_to_device(
// weight_tensor,
// bias_tensor,
// conv_config.input_channels_alignment,
// conv_config.weights_dtype,
// opt_conv_op_block_config.act_block_w_ntiles,
// opt_conv_op_block_config.out_subblock_w_ntiles,
// parallel_config,
// device,
// groups,
// opt_conv_op_block_config.act_block_h_ntiles,
// input_width);
weight_tensor_on_device = ttnn::operations::core::to_device(weight_tensor_on_device, device, std::nullopt);
if(bias_tensor.has_value()){
bias_tensor_on_device = ttnn::operations::core::to_device(bias_tensor_on_device.value(), device, std::nullopt);
Expand Down
140 changes: 136 additions & 4 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ using sliding_window::ParallelConfig;

namespace conv2d {

void validate_weight_and_bias_tensors(
const ttnn::Tensor& weight_tensor, std::optional<const ttnn::Tensor>& bias_tensor) {
TT_ASSERT(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE));
TT_ASSERT(weight_tensor.get_layout() == Layout::ROW_MAJOR);
TT_ASSERT(weight_tensor.get_shape().rank() == 4);
// TODO: enable this assert
// TT_ASSERT(weight_tensor.get_shape() == weight_tensor.get_legacy_shape());
if (bias_tensor.has_value()) {
TT_ASSERT(!ttnn::has_storage_type_of(bias_tensor.value(), ttnn::DEVICE_STORAGE_TYPE));
TT_ASSERT(bias_tensor.value().get_shape().rank() == 4);
TT_ASSERT(bias_tensor.value().get_layout() == Layout::ROW_MAJOR);
// TODO: enable this assert
// TT_ASSERT(bias_tensor.value().get_shape() == bias_tensor.value().get_legacy_shape());
}
}


void validate_weight_tensor(const ttnn::Tensor& weight_tensor) {
TT_ASSERT(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE));
TT_ASSERT(weight_tensor.get_layout() == Layout::ROW_MAJOR);
Expand Down Expand Up @@ -117,6 +134,116 @@ OptimizedConvBlockConfig get_opt_block_config(
conv_config.enable_split_reader);
}



template <typename T>
std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases_and_move_to_device(
const ttnn::Tensor& weight_tensor,
std::optional<const ttnn::Tensor>& bias_tensor,
uint32_t input_channels_alignment,
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const ParallelConfig& parallel_config,
T * device,
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width) {

validate_weight_and_bias_tensors(weight_tensor, bias_tensor);
ttnn::Tensor weight_tensor_; // tensor to return
ttnn::Tensor bias_tensor_;

auto original_weights_shape = weight_tensor.get_shape();
uint32_t original_weights_out_channels = original_weights_shape[0];
uint32_t original_weights_in_channels = original_weights_shape[1];
uint32_t original_weights_window_h = original_weights_shape[2];
uint32_t original_weights_window_w = original_weights_shape[3];

bool is_conv1d = original_weights_window_w == 1 && input_width == 1;
bool is_depthwise_conv = groups == original_weights_out_channels && original_weights_in_channels == 1;

weight_tensor_ = weight_tensor;

// Convert weight tensor to 0 padded shape if groups > 1
if (!is_conv1d and groups > 1) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype);
}
else if (is_conv1d and groups > 1) {
if (is_depthwise_conv) {
weight_tensor_ = convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, weights_bias_dtype);
weight_block_h_ntiles = act_block_h_ntiles;
}
else{
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype);
}
}

auto weights_shape = weight_tensor_.get_shape();
uint32_t out_channels = weights_shape[0];
uint32_t in_channels = weights_shape[1];
uint32_t window_h = weights_shape[2];
uint32_t window_w = weights_shape[3];
uint32_t out_channel_padding = tt::round_up(out_channels, 32) - out_channels;
tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array<uint32_t, 4>(
{tt::round_up(out_channels, 32), tt::round_up(in_channels, input_channels_alignment), window_h, window_w}));
if (weights_bias_dtype == DataType::BFLOAT8_B) {
TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32);
if (bias_tensor.has_value()) {
TT_ASSERT(bias_tensor.value().get_dtype() == DataType::FLOAT32);
}
} else {
// TODO: fix the need to check this. We should be able to accept any datatype and convert
TT_ASSERT(weight_tensor_.get_dtype() == weights_bias_dtype);
if (bias_tensor.has_value()) {
TT_ASSERT(bias_tensor.value().get_dtype() == weights_bias_dtype);
}
}
weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0);

// for conv op, pad the weights to block shape
if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
} else {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
}

uint32_t weight_matrix_height = in_channels * window_h * window_w;
int32_t weight_matrix_height_padding = weight_tensor_.shape()[2] - weight_matrix_height;
TT_FATAL(weight_matrix_height_padding >= 0," Matrix Height Padding can't be negative");

// convert_conv_weight_tensor adds the padding to the base shape.
// Reshape the weights to remove padding from the base shape.
weight_tensor_.set_shape(
ttnn::Shape(std::array<uint32_t,4>{1, 1, weight_matrix_height, out_channels},
std::array<std::array<uint32_t, 2>, 4>{
std::array<uint32_t, 2>{0, 0},
std::array<uint32_t, 2>{0, 0},
std::array<uint32_t, 2>{0, weight_matrix_height_padding},
std::array<uint32_t, 2>{0, out_channel_padding}
}));

weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt);
if (bias_tensor.has_value()) {
bias_tensor_ = bias_tensor.value();
auto bias_shape = bias_tensor_.get_shape();
TT_ASSERT(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1);
tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape(
std::array<uint32_t, 4>({1, 1, 32, tt::round_up(out_channels, weight_block_w_ntiles * 32)}));
bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0);
bias_tensor_ = ttnn::to_layout(
bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr);
if (bias_tensor_.get_dtype() != weights_bias_dtype) {
bias_tensor_ = ttnn::to_dtype(bias_tensor_, weights_bias_dtype);
}
bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt);
}

return {weight_tensor_, bias_tensor.has_value() ? bias_tensor_ : std::optional<ttnn::Tensor>()};
}

template <typename T>
std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
const ttnn::Tensor& weight_tensor,
Expand All @@ -141,9 +268,7 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(weight_tensor), "Error: weight tensor must be on host for preparation.");

const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups);
const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1;
const uint32_t output_width =
((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;

uint32_t weight_block_h_ntiles=0, weight_block_w_ntiles=0, act_block_h_ntiles=0;
Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
TensorMemoryLayout tensor_layout;
Expand All @@ -154,6 +279,9 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
act_block_h_ntiles = weight_config.act_block_h_ntiles;
tensor_layout = weight_config.shard_layout;
}else{
const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1;
const uint32_t output_width =
((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;
auto opt_conv_op_block_config = get_opt_block_config(
mm_conv,
in_channels,
Expand All @@ -175,7 +303,11 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
tensor_layout = conv_config.shard_layout.value();
}

validate_weight_tensor(weight_tensor);
// std::cout << "Weight Block H TILES: " << weight_block_h_ntiles << " weight_block_w_ntiles: " << weight_block_w_ntiles << "act_block_h_ntiles: " << act_block_h_ntiles
// << " act_block_h_ntiles: " << act_block_h_ntiles << " tensor_layout: "<< (int)tensor_layout << std::endl;

// validate_weight_tensor(weight_tensor);
validate_weight_and_bias_tensors(weight_tensor, bias_tensor);
ttnn::Tensor weight_tensor_ = weight_tensor; // tensor to return

// Permute to OIHW layout as thats what the preparation expects
Expand Down
14 changes: 14 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ ttnn::Tensor prepare_conv_bias(
T *device,
std::optional<const Conv2dConfig> conv_config_);

template <typename T>
std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases_and_move_to_device(
const ttnn::Tensor& weight_tensor,
std::optional<const ttnn::Tensor>& bias_tensor,
uint32_t input_channels_alignment,
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const sliding_window::ParallelConfig& parallel_config,
T * device,
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width);

} // namespace conv2d
} // namespace operations::conv
} // namespace ttnn

0 comments on commit f1b228d

Please sign in to comment.