Skip to content

Commit

Permalink
#0: testing
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Nov 24, 2024
1 parent dc668ab commit e9de010
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 18 deletions.
49 changes: 31 additions & 18 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,26 +112,39 @@ Result conv2d(
.act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles,
.shard_layout = parallel_config.shard_scheme,
};
tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights(
// tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights(
// weight_tensor,
// input_tensor_post_tm.memory_config(),
// input_tensor_post_tm.layout(),
// "OIHW",
// in_channels,
// out_channels,
// batch_size,
// input_height,
// input_width,
// kernel_size,
// stride,
// padding,
// dilation,
// groups,
// device,
// bias_tensor,
// conv_config,
// weight_config
// );

tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights_biases_and_move_to_device(
weight_tensor,
input_tensor_post_tm.memory_config(),
input_tensor_post_tm.layout(),
"OIHW",
in_channels,
out_channels,
batch_size,
input_height,
input_width,
kernel_size,
stride,
padding,
dilation,
groups,
device,
bias_tensor,
conv_config,
weight_config
);
conv_config.input_channels_alignment,
conv_config.weights_dtype,
opt_conv_op_block_config.act_block_w_ntiles,
opt_conv_op_block_config.out_subblock_w_ntiles,
parallel_config,
device,
groups,
opt_conv_op_block_config.act_block_h_ntiles,
input_width);
weight_tensor_on_device = ttnn::operations::core::to_device(weight_tensor_on_device, device, std::nullopt);
if(bias_tensor.has_value()){
bias_tensor_on_device = ttnn::operations::core::to_device(bias_tensor_on_device.value(), device, std::nullopt);
Expand Down
113 changes: 113 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,116 @@ OptimizedConvBlockConfig get_opt_block_config(
conv_config.enable_split_reader);
}



template <typename T>
std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases_and_move_to_device(
const ttnn::Tensor& weight_tensor,
std::optional<const ttnn::Tensor>& bias_tensor,
uint32_t input_channels_alignment,
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const ParallelConfig& parallel_config,
T * device,
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width) {

validate_weight_and_bias_tensors(weight_tensor, bias_tensor);
ttnn::Tensor weight_tensor_; // tensor to return
ttnn::Tensor bias_tensor_;

auto original_weights_shape = weight_tensor.get_shape();
uint32_t original_weights_out_channels = original_weights_shape[0];
uint32_t original_weights_in_channels = original_weights_shape[1];
uint32_t original_weights_window_h = original_weights_shape[2];
uint32_t original_weights_window_w = original_weights_shape[3];

bool is_conv1d = original_weights_window_w == 1 && input_width == 1;
bool is_depthwise_conv = groups == original_weights_out_channels && original_weights_in_channels == 1;

weight_tensor_ = weight_tensor;

// Convert weight tensor to 0 padded shape if groups > 1
if (!is_conv1d and groups > 1) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype);
}
else if (is_conv1d and groups > 1) {
if (is_depthwise_conv) {
weight_tensor_ = convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, weights_bias_dtype);
weight_block_h_ntiles = act_block_h_ntiles;
}
else{
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype);
}
}

auto weights_shape = weight_tensor_.get_shape();
uint32_t out_channels = weights_shape[0];
uint32_t in_channels = weights_shape[1];
uint32_t window_h = weights_shape[2];
uint32_t window_w = weights_shape[3];
uint32_t out_channel_padding = tt::round_up(out_channels, 32) - out_channels;
tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array<uint32_t, 4>(
{tt::round_up(out_channels, 32), tt::round_up(in_channels, input_channels_alignment), window_h, window_w}));
if (weights_bias_dtype == DataType::BFLOAT8_B) {
TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32);
if (bias_tensor.has_value()) {
TT_ASSERT(bias_tensor.value().get_dtype() == DataType::FLOAT32);
}
} else {
// TODO: fix the need to check this. We should be able to accept any datatype and convert
TT_ASSERT(weight_tensor_.get_dtype() == weights_bias_dtype);
if (bias_tensor.has_value()) {
TT_ASSERT(bias_tensor.value().get_dtype() == weights_bias_dtype);
}
}
weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0);

// for conv op, pad the weights to block shape
if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
} else {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
}

uint32_t weight_matrix_height = in_channels * window_h * window_w;
int32_t weight_matrix_height_padding = weight_tensor_.shape()[2] - weight_matrix_height;
TT_FATAL(weight_matrix_height_padding >= 0," Matrix Height Padding can't be negative");

// convert_conv_weight_tensor adds the padding to the base shape.
// Reshape the weights to remove padding from the base shape.
weight_tensor_.set_shape(
ttnn::Shape(std::array<uint32_t,4>{1, 1, weight_matrix_height, out_channels},
std::array<std::array<uint32_t, 2>, 4>{
std::array<uint32_t, 2>{0, 0},
std::array<uint32_t, 2>{0, 0},
std::array<uint32_t, 2>{0, weight_matrix_height_padding},
std::array<uint32_t, 2>{0, out_channel_padding}
}));

weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt);
if (bias_tensor.has_value()) {
bias_tensor_ = bias_tensor.value();
auto bias_shape = bias_tensor_.get_shape();
TT_ASSERT(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1);
tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape(
std::array<uint32_t, 4>({1, 1, 32, tt::round_up(out_channels, weight_block_w_ntiles * 32)}));
bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0);
bias_tensor_ = ttnn::to_layout(
bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr);
if (bias_tensor_.get_dtype() != weights_bias_dtype) {
bias_tensor_ = ttnn::to_dtype(bias_tensor_, weights_bias_dtype);
}
bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt);
}

return {weight_tensor_, bias_tensor.has_value() ? bias_tensor_ : std::optional<ttnn::Tensor>()};
}

template <typename T>
std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
const ttnn::Tensor& weight_tensor,
Expand Down Expand Up @@ -175,6 +285,9 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights(
tensor_layout = conv_config.shard_layout.value();
}

// std::cout << "Weight Block H TILES: " << weight_block_h_ntiles << " weight_block_w_ntiles: " << weight_block_w_ntiles << "act_block_h_ntiles: " << act_block_h_ntiles
// << " act_block_h_ntiles: " << act_block_h_ntiles << " tensor_layout: "<< (int)tensor_layout << std::endl;

validate_weight_tensor(weight_tensor);
ttnn::Tensor weight_tensor_ = weight_tensor; // tensor to return

Expand Down
14 changes: 14 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ ttnn::Tensor prepare_conv_bias(
T *device,
std::optional<const Conv2dConfig> conv_config_);

template <typename T>
std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases_and_move_to_device(
const ttnn::Tensor& weight_tensor,
std::optional<const ttnn::Tensor>& bias_tensor,
uint32_t input_channels_alignment,
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const sliding_window::ParallelConfig& parallel_config,
T * device,
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width);

} // namespace conv2d
} // namespace operations::conv
} // namespace ttnn

0 comments on commit e9de010

Please sign in to comment.