Skip to content

Commit

Permalink
Bringup cache updates on ttir level, add silicon test (#1437)
Browse files Browse the repository at this point in the history
Model memory effects of cache fill/update ops

Create TTNN_InplaceOp with MemWrite trait
  • Loading branch information
LPanosTT authored Dec 5, 2024
1 parent 0a172a2 commit 0a33667
Show file tree
Hide file tree
Showing 31 changed files with 626 additions and 39 deletions.
41 changes: 41 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,47 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
let hasVerifier = 1;
}

def TTIR_UpdateCacheOp : TTIR_DPSOp<"update_cache"> {
let summary = "Update static cache tensor.";
let description = [{
Updates the `cache` tensor in-place with values from `input` at `update_index` and `batch_offset`.
}];

let arguments = (ins AnyRankedTensor:$cache,
AnyRankedTensor:$input,
AnyRankedTensor:$update_index,
I32Attr:$batch_offset,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getCacheMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_FillCacheOp : TTIR_DPSOp<"fill_cache"> {
let summary = "Fill static cache tensor.";
let description = [{
Fills the `cache` tensor in-place with values from `input` at `batch_offset`.
}];

let arguments = (ins AnyRankedTensor:$cache,
AnyRankedTensor:$input,
I32Attr:$batch_offset,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getCacheMutable(); }
}];
}

def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
let summary = "Broadcast operation.";
let description = [{
Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,7 @@ def TTNN_Dialect : Dialect {
class TTNN_Op<string mnemonic, list<Trait> traits = []> :
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [Pure, TTNN_OpModelInterface, TTNN_WorkaroundInterface])>;

class TTNN_InplaceOp<string mnemonic, list<Trait> traits = []> :
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [MemoryEffects<[MemWrite]>, TTNN_OpModelInterface, TTNN_WorkaroundInterface])>;

#endif
27 changes: 27 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,33 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
let hasVerifier = 1;
}

def TTNN_UpdateCacheOp : TTNN_InplaceOp<"update_cache"> {
let summary = "Update static cache tensor.";
let description = [{
Updates the `cache` tensor in-place with values from `input` at `update_index` and `batch_offset`.
}];

let arguments = (ins Arg<AnyRankedTensor, "cache tensor", [MemWrite]>:$cache,
AnyRankedTensor:$input,
AnyRankedTensor:$update_index,
I32Attr:$batch_offset);

let hasVerifier = 1;
}

def TTNN_FillCacheOp : TTNN_InplaceOp<"fill_cache"> {
let summary = "Fill static cache tensor.";
let description = [{
Fills the `cache` tensor in-place with values from `input` at `batch_offset`.
}];

let arguments = (ins Arg<AnyRankedTensor, "cache tensor", [MemWrite]>:$cache,
AnyRankedTensor:$input,
I32Attr:$batch_offset);

let hasVerifier = 1;
}

def TTNN_SoftmaxOp : TTNN_Op<"softmax"> {
let summary = "Softmax op.";
let description = [{
Expand Down
15 changes: 15 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ table ToDeviceOp {
out: tt.target.TensorRef;
}

table UpdateCacheOp {
cache: tt.target.TensorRef;
input: tt.target.TensorRef;
update_index: tt.target.TensorRef;
batch_offset: uint32;
}

table FillCacheOp {
cache: tt.target.TensorRef;
input: tt.target.TensorRef;
batch_offset: uint32;
}

table FromDeviceOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
Expand Down Expand Up @@ -283,6 +296,8 @@ union OpType {
DeallocateOp,
AllGatherOp,
ArangeOp,
UpdateCacheOp,
FillCacheOp,
}

table Operation {
Expand Down
90 changes: 81 additions & 9 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <mlir/IR/Operation.h>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -334,6 +335,78 @@ class ClampOpConversionPattern : public OpConversionPattern<ttir::ClampOp> {
}
};

class UpdateCacheOpConversionPattern
: public OpConversionPattern<ttir::UpdateCacheOp> {
public:
using OpConversionPattern<ttir::UpdateCacheOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::UpdateCacheOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// The TTIR version of this op is pure. In TTNN this op is in-place.
// We need to replace uses of the result ot the TTIR op with uses
// of the cache argument.
//
// The presence of the MemWrite trait of this op should preserve
// the order of this op relative to the cache arguments uses, preserving
// program correctness.

// This op can only work if it is the final use of the cache tensor in the
// order of execution. For now, checking that there is only one user (this
// op) of the cache tensor will suffice.
std::vector<mlir::Operation *> users(op.getCache().getUsers().begin(),
op.getCache().getUsers().end());
if (users.size() != 1) {
return rewriter.notifyMatchFailure(
op, "UpdateCacheOp must have exactly one user");
}

rewriter.create<ttnn::UpdateCacheOp>(
op.getLoc(), adaptor.getCache(), adaptor.getInput(),
adaptor.getUpdateIndex(), adaptor.getBatchOffset());

rewriter.replaceOp(op, adaptor.getCache());
return success();
}
};

class FillCacheOpConversionPattern
: public OpConversionPattern<ttir::FillCacheOp> {
public:
using OpConversionPattern<ttir::FillCacheOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::FillCacheOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// The TTIR version of this op is pure. In TTNN this op is in-place.
// We need to replace uses of the result ot the TTIR op with uses
// of the cache argument.
//
// The presence of the MemWrite trait of this op should preserve
// the order of this op relative to the cache arguments uses, preserving
// program correctness.

// This op can only work if it is the final use of the cache tensor in the
// order of execution. For now, checking that there is only one user (this
// op) of the cache tensor will suffice.
std::vector<mlir::Operation *> users(op.getCache().getUsers().begin(),
op.getCache().getUsers().end());
if (users.size() != 1) {
return rewriter.notifyMatchFailure(
op, "FillCacheOp must have exactly one user");
}

rewriter.create<ttnn::FillCacheOp>(op.getLoc(), adaptor.getCache(),
adaptor.getInput(),
adaptor.getBatchOffset());

rewriter.replaceOp(op, adaptor.getCache());
return success();
}
};

template <typename TTIROpTy, typename TTNNOpTy,
typename OpAdaptor = typename TTIROpTy::Adaptor>
class ElementwiseUnaryWithFloatParameterOpConversionPattern
Expand Down Expand Up @@ -506,15 +579,12 @@ class ConstantOpConversionPattern
valueAttr.getElementType().isInteger()
? getIntegerValue(valueAttr)
: valueAttr.getSplatValue<mlir::APFloat>().convertToFloat();
if (fillValue == 0) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device);
} else {
::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue);
rewriter.replaceOpWithNewOp<ttnn::FullOp>(
op, this->getTypeConverter()->convertType(op.getType()), device,
fillValueAttr);
}

::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue);
rewriter.replaceOpWithNewOp<ttnn::FullOp>(
op, this->getTypeConverter()->convertType(op.getType()), device,
fillValueAttr);

} else {
return rewriter.notifyMatchFailure(
op, "TTNN doesn't currently support tensor creation from multiple "
Expand Down Expand Up @@ -980,6 +1050,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
SubtractOpConversionPattern,
AllGatherOpConversionPattern,
ArangeOpConversionPattern,
UpdateCacheOpConversionPattern,
FillCacheOpConversionPattern,
ScatterOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,13 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Module op
//
patterns.add<ModuleOpConversionPattern>(typeConverter, ctx);

// KV Cache ops
//
patterns.add<DefaultOpConversionPattern<ttnn::UpdateCacheOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::FillCacheOp>>(typeConverter,
ctx);
}

} // namespace mlir::tt
121 changes: 121 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,127 @@ ::mlir::LogicalResult mlir::tt::ttir::ScatterOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// UpdateCacheOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult mlir::tt::ttir::UpdateCacheOp::verify() {
if (getBatchOffset() != 0) {
return emitOpError(
"Only single-batch is supported. Batch offset must be 0");
}

const ::mlir::RankedTensorType cacheType = getCache().getType();
const ::mlir::RankedTensorType inputType = getInput().getType();

const DataType cacheDataType =
elementTypeToDataType(cacheType.getElementType());
const DataType inputDataType =
elementTypeToDataType(inputType.getElementType());

if (cacheDataType != inputDataType) {
return emitOpError(
"Cache and input tensors must have the same dtype. "
"Got cache dtype = " +
DataTypeEnumToString(cacheDataType) +
", input dtype = " + DataTypeEnumToString(inputDataType));
}

if (cacheType.getRank() != 4) {
return emitOpError("Cache tensor must be a 4D tensor");
}

if (inputType.getRank() != 4) {
return emitOpError("Input tensor must be a 4D tensor");
}

if (inputType.getShape()[2] != 1) {
return emitOpError("Input tensor requires that dim 2 have size 1, got "
"input dim 2 size = " +
std::to_string(inputType.getShape()[2]));
}

if (cacheType.getShape()[0] != inputType.getShape()[0] ||
cacheType.getShape()[1] != inputType.getShape()[1] ||
cacheType.getShape()[3] != inputType.getShape()[3]) {
return emitOpError("Cache tensor shape must match input tensor shape on "
"all dimensions except dim 2. Got cache shape (" +
std::to_string(cacheType.getShape()[0]) + ", " +
std::to_string(cacheType.getShape()[1]) + ", " +
std::to_string(cacheType.getShape()[2]) + ", " +
std::to_string(cacheType.getShape()[3]) +
"), input shape ()" +
std::to_string(inputType.getShape()[0]) + "x" +
std::to_string(inputType.getShape()[1]) + "x" +
std::to_string(inputType.getShape()[2]) + "x" +
std::to_string(inputType.getShape()[3]) + ")");
}

return success();
}

//===----------------------------------------------------------------------===//
// FillCacheOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult mlir::tt::ttir::FillCacheOp::verify() {
if (getBatchOffset() != 0) {
return emitOpError(
"Only single-batch is supported. Batch offset must be 0");
}

const ::mlir::RankedTensorType cacheType = getCache().getType();
const ::mlir::RankedTensorType inputType = getInput().getType();

const DataType cacheDataType =
elementTypeToDataType(cacheType.getElementType());
const DataType inputDataType =
elementTypeToDataType(inputType.getElementType());

if (cacheDataType != inputDataType) {
return emitOpError(
"Cache and input tensors must have the same dtype. "
"Got cache dtype = " +
DataTypeEnumToString(cacheDataType) +
", input dtype = " + DataTypeEnumToString(inputDataType));
}

if (cacheType.getRank() != 4) {
return emitOpError("Cache tensor must be a 4D tensor");
}

if (inputType.getRank() != 4) {
return emitOpError("Input tensor must be a 4D tensor");
}

if (inputType.getShape()[2] > cacheType.getShape()[2]) {
return emitOpError(
"Input tensor requires that dim 2 have a size which is less than or "
"equal to the size of dim 2 of the cache tensor. Got cache dim 2 size "
"= " +
std::to_string(cacheType.getShape()[2]) +
", input dim 2 size = " + std::to_string(inputType.getShape()[2]));
}

if (cacheType.getShape()[0] != inputType.getShape()[0] ||
cacheType.getShape()[1] != inputType.getShape()[1] ||
cacheType.getShape()[3] != inputType.getShape()[3]) {
return emitOpError("Cache tensor shape must match input tensor shape on "
"all dimensions except dim 2. Got cache shape (" +
std::to_string(cacheType.getShape()[0]) + ", " +
std::to_string(cacheType.getShape()[1]) + ", " +
std::to_string(cacheType.getShape()[2]) + ", " +
std::to_string(cacheType.getShape()[3]) +
"), input shape (" +
std::to_string(inputType.getShape()[0]) + ", " +
std::to_string(inputType.getShape()[1]) + ", " +
std::to_string(inputType.getShape()[2]) + ", " +
std::to_string(inputType.getShape()[3]) + ")");
}

return success();
}

//===----------------------------------------------------------------------===//
// GenericOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 0a33667

Please sign in to comment.