Skip to content

Commit

Permalink
Adding reshape op (#489)
Browse files Browse the repository at this point in the history
* Adding reshape op

* Fixup test...

* Addresing comments
  • Loading branch information
mtopalovicTT authored Aug 28, 2024
1 parent 0fb294e commit 1f92bc6
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 2 deletions.
20 changes: 20 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,26 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
let hasVerifier = 1;
}

def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> {
let summary = "Reshape op.";
let description = [{
Reshape tensor.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I32ArrayAttr:$shape,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
19 changes: 19 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,25 @@ def TTNN_ConcatOp : TTNN_NamedDPSOp<"concat"> {
let hasVerifier = 1;
}

def TTNN_ReshapeOp : TTNN_NamedDPSOp<"reshape"> {
let summary = "Reshape op.";
let description = [{
Reshape tensor.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I32ArrayAttr:$shape);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttnn
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> {
let arguments = (ins AnyRankedTensor:$a,
Expand Down
9 changes: 8 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ table ConcatOp {
dim: int32;
}

table ReshapeOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
shape: [int32];
}

// ANCHOR: adding_an_op_matmul_fbs
table MatmulOp {
in0: tt.target.TensorRef;
Expand All @@ -105,7 +111,8 @@ union OpType {
EmbeddingOp,
SoftmaxOp,
TransposeOp,
ConcatOp
ConcatOp,
ReshapeOp
}

table Operation {
Expand Down
15 changes: 15 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ class ConcatOpConversionPattern : public OpConversionPattern<ttir::ConcatOp> {
}
};

class ReshapeOpConversionPattern : public OpConversionPattern<ttir::ReshapeOp> {
public:
using OpConversionPattern<ttir::ReshapeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), adaptor.getShape());
return success();
}
};

} // namespace

// ANCHOR: adding_an_op_matmul_op_rewriter
Expand Down Expand Up @@ -214,6 +228,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
ConcatOpConversionPattern,
ReshapeOpConversionPattern,
MatmulOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
patterns.add<DefaultOpConversionPattern<ttnn::TransposeOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ConcatOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ReshapeOp>>(typeConverter, ctx);

// Matmul ops
//
Expand Down
69 changes: 68 additions & 1 deletion lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,74 @@ ::mlir::LogicalResult mlir::tt::ttir::ConcatOp::verify() {
}
}

return mlir::success();
return success();
}

::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
auto shape = getShape();
int64_t shape_size = static_cast<int64_t>(shape.size());

// Check that the shape size matches the rank of the output tensor
if (shape_size != static_cast<int64_t>(outputType.getRank())) {
return emitOpError("Shape attribute size must match output tensor rank");
}

// Check that the shape attribute is non-empty
if (shape_size == 0) {
return emitOpError("Shape attribute must be non-empty");
}

// Check that the shape attribute has at most 5 elements
if (shape_size > 5) {
return emitOpError("Shape attribute must have at most 5 elements");
}

// Cardinality of the input and output tensors must be the same
if (inputType.getNumElements() != outputType.getNumElements()) {
return emitOpError(
"Input and output tensors must have the same number of elements");
}

bool has_negative = false;
int64_t known_dim_product = 1;
auto outputShape = outputType.getShape();

// Check that all dimensions are positive except for at most one -1
// Check that the non-negative dimensions match the output tensor shape
// Calculate the product of the known dimensions
for (int64_t i = 0; i < shape_size; i++) {
int64_t dim_value = mlir::cast<IntegerAttr>(shape[i]).getInt();

if (dim_value == -1) {
if (has_negative) {
return emitOpError("Shape attribute must have at most one -1 element");
}
has_negative = true;
} else {
if (dim_value <= 0) {
return emitOpError(
"All dimensions must be positive except the one with -1");
}

// Ensure that the non-negative dimensions match the output tensor shape
if (dim_value != outputShape[i]) {
return emitOpError("Shape attribute must match the output tensor shape "
"for dimensions that are not -1");
}

known_dim_product *= dim_value;
}
}

// If there's a -1, ensure that it can be inferred correctly
if (has_negative && inputType.getNumElements() % known_dim_product != 0) {
return emitOpError("Invalid shape: the dimensions do not multiply to the "
"total number of elements in the tensor");
}

return success();
}

// ANCHOR: adding_an_op_matmul_ttir_verify
Expand Down
67 changes: 67 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,73 @@ ::mlir::LogicalResult mlir::tt::ttnn::ConcatOp::verify() {
return mlir::success();
}

::mlir::LogicalResult mlir::tt::ttnn::ReshapeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
auto shape = getShape();
int64_t shape_size = static_cast<int64_t>(shape.size());

// Check that the shape size matches the rank of the output tensor
if (shape_size != static_cast<int64_t>(outputType.getRank())) {
return emitOpError("Shape attribute size must match output tensor rank");
}

// Check that the shape attribute is non-empty
if (shape_size == 0) {
return emitOpError("Shape attribute must be non-empty");
}

// Check that the shape attribute has at most 5 elements
if (shape_size > 5) {
return emitOpError("Shape attribute must have at most 5 elements");
}

// Cardinality of the input and output tensors must be the same
if (inputType.getNumElements() != outputType.getNumElements()) {
return emitOpError(
"Input and output tensors must have the same number of elements");
}

bool has_negative = false;
int64_t known_dim_product = 1;
auto outputShape = outputType.getShape();

// Check that all dimensions are positive except for at most one -1
// Check that the non-negative dimensions match the output tensor shape
// Calculate the product of the known dimensions
for (int64_t i = 0; i < shape_size; i++) {
int64_t dim_value = mlir::cast<IntegerAttr>(shape[i]).getInt();

if (dim_value == -1) {
if (has_negative) {
return emitOpError("Shape attribute must have at most one -1 element");
}
has_negative = true;
} else {
if (dim_value <= 0) {
return emitOpError(
"All dimensions must be positive except the one with -1");
}

// Ensure that the non-negative dimensions match the output tensor shape
if (dim_value != outputShape[i]) {
return emitOpError("Shape attribute must match the output tensor shape "
"for dimensions that are not -1");
}

known_dim_product *= dim_value;
}
}

// If there's a -1, ensure that it can be inferred correctly
if (has_negative && inputType.getNumElements() % known_dim_product != 0) {
return emitOpError("Invalid shape: the dimensions do not multiply to the "
"total number of elements in the tensor");
}

return success();
}

// ANCHOR: adding_an_op_matmul_ttnn_verify
::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
Expand Down
18 changes: 18 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,19 @@ createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) {
return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output);
}

template <typename ReshapeOp>
::flatbuffers::Offset<::tt::target::ttnn::ReshapeOp>
createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) {
auto in =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto out = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
auto shape =
arrayAttrToFlatbuffer<mlir::IntegerAttr, int>(cache, op.getShape());

return ::tt::target::ttnn::CreateReshapeOp(*cache.fbb, in, out, shape);
}

template <typename SoftmaxOp>
::flatbuffers::Offset<::tt::target::ttnn::SoftmaxOp>
createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) {
Expand Down Expand Up @@ -309,6 +322,11 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto concatOp = dyn_cast<ConcatOp>(op); concatOp) {
return createOperation(cache, createConcatOp(cache, concatOp), debugString);
}
if (auto reshapeOp = dyn_cast<ReshapeOp>(op); reshapeOp) {
return createOperation(cache, createReshapeOp(cache, reshapeOp),
debugString);
}

llvm_unreachable("unhandled op in emitTTNNOperation");
}

Expand Down
57 changes: 57 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,60 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device,
}
}

template <int32_t Rank>
static std::array<int32_t, Rank>
vectorToArray(const std::vector<int32_t> &vec) {
if (vec.size() != Rank) {
throw std::invalid_argument("Vector size does not match array size");
}
std::array<int32_t, Rank> arr;
std::copy(vec.begin(), vec.end(), arr.begin());
return arr;
}

template <int32_t Rank>
static ::ttnn::Tensor invoke_reshape(const ::ttnn::Tensor &tensor,
const std::vector<int32_t> &shape) {
return ::ttnn::reshape(tensor, vectorToArray<Rank>(shape));
}

static void
run(::tt::target::ttnn::ReshapeOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
auto &in = *liveTensors.at(op->in()->global_id());
const auto *fbShape = op->shape();
std::vector<int32_t> shape(fbShape->begin(), fbShape->end());

constexpr int32_t Rank1 = 1;
constexpr int32_t Rank2 = 2;
constexpr int32_t Rank3 = 3;
constexpr int32_t Rank4 = 4;
constexpr int32_t Rank5 = 5;

switch (fbShape->size()) {
case Rank1:
tensorPool.push_back(invoke_reshape<Rank1>(in, shape));
break;
case Rank2:
tensorPool.push_back(invoke_reshape<Rank2>(in, shape));
break;
case Rank3:
tensorPool.push_back(invoke_reshape<Rank3>(in, shape));
break;
case Rank4:
tensorPool.push_back(invoke_reshape<Rank4>(in, shape));
break;
case Rank5:
tensorPool.push_back(invoke_reshape<Rank5>(in, shape));
break;
default:
throw std::invalid_argument("Unsupported rank for reshape");
}

liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}

static void
run(::tt::target::ttnn::EmbeddingOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
Expand Down Expand Up @@ -456,10 +510,13 @@ run(::tt::target::ttnn::Operation const *op, ::ttnn::Device &device,
}
case ::tt::target::ttnn::OpType::ConcatOp: {
return run(op->type_as_ConcatOp(), device, liveTensors, tensorPool);
case ::tt::target::ttnn::OpType::ReshapeOp: {
return run(op->type_as_ReshapeOp(), device, liveTensors, tensorPool);
}
default:
throw std::runtime_error("Unsupported operation type");
}
}
}

// Nop is single input, output tensor where input is returned as output.
Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_reshape.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> {
%0 = tensor.empty() : tensor<2x4x32x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]]
%1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16>
return %1 : tensor<2x4x32x32xbf16>
}
}

0 comments on commit 1f92bc6

Please sign in to comment.