From be01f8b2d2c8f61f8106ff38e04c631401d809b2 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 22 Sep 2023 22:17:16 -0700 Subject: [PATCH] lower poly sub, to_tensor, from_tensor ops --- .../PolyToStandard/PolyToStandard.cpp | 71 ++++++++++++++++++- tests/poly_to_standard.mlir | 32 +++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.cpp b/lib/Conversion/PolyToStandard/PolyToStandard.cpp index 7db9b25..e182660 100644 --- a/lib/Conversion/PolyToStandard/PolyToStandard.cpp +++ b/lib/Conversion/PolyToStandard/PolyToStandard.cpp @@ -2,8 +2,10 @@ #include "lib/Dialect/Poly/PolyOps.h" #include "lib/Dialect/Poly/PolyTypes.h" +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -50,6 +52,72 @@ struct ConvertAdd : public OpConversionPattern { } }; +struct ConvertSub : public OpConversionPattern { + ConvertSub(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + arith::SubIOp subOp = rewriter.create( + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op.getOperation(), {subOp}); + return success(); + } +}; + +struct ConvertFromTensor : public OpConversionPattern { + ConvertFromTensor(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + FromTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTensorTy = cast( + typeConverter->convertType(op->getResultTypes()[0])); + auto resultShape = resultTensorTy.getShape()[0]; + auto resultEltTy = resultTensorTy.getElementType(); + + auto inputTensorTy = op.getInput().getType(); + auto inputShape = inputTensorTy.getShape()[0]; + + // Zero pad the tensor if the coefficients' size is less than the polynomial + // degree. + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto coeffValue = adaptor.getInput(); + if (inputShape < resultShape) { + SmallVector low, high; + low.push_back(rewriter.getIndexAttr(0)); + high.push_back(rewriter.getIndexAttr(resultShape - inputShape)); + coeffValue = b.create( + resultTensorTy, coeffValue, low, high, + b.create(rewriter.getIntegerAttr(resultEltTy, 0)), + /*nofold=*/false); + } + + rewriter.replaceOp(op, coeffValue); + return success(); + } +}; + +struct ConvertToTensor : public OpConversionPattern { + ConvertToTensor(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ToTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getInput()); + return success(); + } +}; + struct ConvertConstant : public OpConversionPattern { ConvertConstant(mlir::MLIRContext *context) : OpConversionPattern(context) {} @@ -79,7 +147,8 @@ struct PolyToStandard : impl::PolyToStandardBase { RewritePatternSet patterns(context); PolyToStandardTypeConverter typeConverter(context); - patterns.add(typeConverter, context); + patterns.add(typeConverter, context); populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter); diff --git a/tests/poly_to_standard.mlir b/tests/poly_to_standard.mlir index b9f4ad4..4edcc4c 100644 --- a/tests/poly_to_standard.mlir +++ b/tests/poly_to_standard.mlir @@ -18,3 +18,35 @@ func.func @test_lower_add_and_fold() { %2 = poly.add %0, %1: !poly.poly<10> return } + +// CHECK-LABEL: test_lower_sub +func.func @test_lower_sub(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> { + // CHECK: arith.subi + %2 = poly.sub %0, %1: !poly.poly<10> + return %2 : !poly.poly<10> +} + +// CHECK-LABEL: test_lower_to_tensor( +// CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T]] { +// CHECK-NEXT: return %[[V0]] : [[T]] +func.func @test_lower_to_tensor(%0: !poly.poly<10>) -> tensor<10xi32> { + %2 = poly.to_tensor %0: !poly.poly<10> -> tensor<10xi32> + return %2 : tensor<10xi32> +} + +// CHECK-LABEL: test_lower_from_tensor( +// CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T]] { +// CHECK-NEXT: return %[[V0]] : [[T]] +func.func @test_lower_from_tensor(%0 : tensor<10xi32>) -> !poly.poly<10> { + %2 = poly.from_tensor %0: tensor<10xi32> -> !poly.poly<10> + return %2 : !poly.poly<10> +} + +// CHECK-LABEL: test_lower_from_tensor_extend( +// CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T2:tensor<20xi32>]] { +// CHECK: %[[V1:.*]] = tensor.pad %[[V0]] low[0] high[10] +// CHECK: return %[[V1]] : [[T2]] +func.func @test_lower_from_tensor_extend(%0 : tensor<10xi32>) -> !poly.poly<20> { + %2 = poly.from_tensor %0: tensor<10xi32> -> !poly.poly<20> + return %2 : !poly.poly<20> +}