diff --git a/docs/spec.md b/docs/spec.md index 2503e7d59ce..e92b06d9bd3 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -2179,16 +2179,18 @@ For quantized types, performs `dequantize_op_quantize( // "i" is input feature dimension, "o" is output feature dimension, // "0/1/etc" are spatial dimensions. dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, - feature_group_count = 1 : i64, batch_group_count = 1 : i64, + feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] -} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32> +} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> // %result: [[ // [[10], [26]], // [[46], [62]] // ]] ``` + [More Examples](../stablehlo/tests/interpret_convolution.mlir) + ### cosine #### Semantics diff --git a/docs/status.md b/docs/status.md index d845d723099..e4943bce62e 100644 --- a/docs/status.md +++ b/docs/status.md @@ -68,7 +68,7 @@ one of the following tracking labels. | concatenate | yes | yes | yes | yes | yes | | constant | yes | yes | yes | yes | yes | | convert | yes | yes | infeasible | yes | yes | -| convolution | yes | yes | infeasible | revisit | no | +| convolution | yes | yes | infeasible | revisit | yes | | cosine | yes | yes | yes | yes | yes | | count_leading_zeros | yes | yes | yes | yes | yes | | create_token | no | yes\* | yes\* | yes | revisit | diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 776a599a926..9e98bb4909a 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2066,24 +2066,32 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [Pure]> { Example: ```mlir - %result = "stablehlo.convolution"(%lhs, %rhs) { - window_strides = dense<4> : tensor<2xi64>, - padding = dense<0> : tensor<2x2xi64>, - lhs_dilation = dense<2> : tensor<2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_reversal = dense : tensor<2xi1>, - dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, - feature_group_count = 1 : i64, - batch_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32> + %result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + pad = [[0, 0], [0, 0]], + lhs_dilate = [2, 2], + rhs_dilate = [1, 1], + reverse = [0, 0] + } { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> ``` }]; let arguments = !con( (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs), - StableHLO_ConvolutionAttributes.attributes); + HLO_Tensor:$lhs, /*convolution_i1*/ + HLO_Tensor:$rhs), /*convolution_i2*/ + StableHLO_ConvolutionAttributes.attributes /*convolution_i3, convolution_i4, + convolution_i5, convolution_i6, convolution_i7, convolution_i8, + convolution_i9, convolution_i10, convolution_i11, convolution_i12, + convolution_i13, convolution_i14, convolution_i15, convolution_i16, + convolution_i17, convolution_i18, convolution_i19*/ + ); let results = (outs HLO_Tensor); let hasVerifier = 1; diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 038e670625b..e0c1faaef54 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -318,25 +318,30 @@ verifyWindowAttributesAndInferWindowDimensions( " to have same dimension-size as size of window dimensions (", windowDimensions.size(), "), but got: ", attrSize, "."); }; - // reduce_window_c6, select_and_scatter_c6 + // convolution_c2, reduce_window_c6, select_and_scatter_c6 if (failed(verifySize(windowStrides.size(), "window-strides"))) return failure(); - // reduce_window_c8 + + // convolution_c5, reduce_window_c8 if (failed(verifySize(lhsDilation.size(), "base-dilation factors"))) return failure(); - // reduce_window_c10 + + // convolution_c7, reduce_window_c10 if (failed(verifySize(rhsDilation.size(), "window-dilation factors"))) return failure(); - // reduce_window_c12 + + // convolution_c4, reduce_window_c12 if (failed(verifySize(padding.size(), "padding-entries"))) return failure(); + + // convolution_c9 if (failed(verifySize(windowReversal.size(), "window-reversal"))) return failure(); SmallVector window(windowDimensions.size()); for (size_t i = 0; i < windowDimensions.size(); i++) { WindowDimension& dim = window[i]; - dim.size = windowDimensions[i]; + // reduce_window_c5, select_and_scatter_c5 if (!isDynamicDimSize(dim.size) && dim.size <= 0) return emitOptionalError(loc, @@ -344,21 +349,24 @@ verifyWindowAttributesAndInferWindowDimensions( "-th window dimension, but got ", dim.size, "."); if (!windowStrides.empty()) dim.stride = windowStrides[i]; - // reduce_window_c7, select_and_scatter_c7 + + // convolution_c3, reduce_window_c7, select_and_scatter_c7 if (dim.stride <= 0) return emitOptionalError( loc, "expects window to have positive stride for ", i, "-th window dimension, but got ", dim.stride, "."); if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i]; - // reduce_window_c9 + + // convolution_c6, reduce_window_c9 if (dim.baseDilation <= 0) return emitOptionalError( loc, "expects window to have positive base dilation factor for ", i, "-th window dimension, but got ", dim.baseDilation, "."); if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i]; - // reduce_window_c11 + + // convolution_c8, reduce_window_c11 if (dim.windowDilation <= 0) return emitOptionalError( loc, "expects window to have positive window dilation factor for ", i, @@ -745,12 +753,6 @@ LogicalResult verifyRegionNotEmpty(std::optional location, return success(); } -// Checks: -// P1. Same sizes for input, kernel and output spatialDims. -// P2. Spatial and non-spatial dimensions (for input,kernel, &output) should -// be unique and in range [0, num_dims), where num_dims = rank of input -// (lhs/rhs) tensors. -// // Note that the spatial + non-spatial dimensions may not cover all the // dimensions in the range [0,num) because of the presence of 'unknown' // dimensions (ref. `printConvolutionDimensions()`) @@ -762,7 +764,7 @@ LogicalResult isSpatialDimensionsValid( int64_t outputFeatureDimension, ArrayRef outputSpatialDimensions, std::optional location) { uint64_t spatialDimNum = inputSpatialDimensions.size(); - // P1. + // convolution_c17, convolution_c19 if ((spatialDimNum != kernelSpatialDimensions.size()) || (spatialDimNum != outputSpatialDimensions.size())) return emitOptionalError(location, @@ -772,7 +774,6 @@ LogicalResult isSpatialDimensionsValid( kernelSpatialDimensions.size(), ", and ", outputSpatialDimensions.size(), " resp."); - // P2. SmallVector inputDimNums(spatialDimNum + 2); inputDimNums[0] = inputBatchDimension; inputDimNums[1] = inputFeatureDimension; @@ -785,37 +786,40 @@ LogicalResult isSpatialDimensionsValid( std::copy(kernelSpatialDimensions.begin(), kernelSpatialDimensions.end(), windowDimNums.begin() + 2); - SmallVector OutputDimNums(spatialDimNum + 2); - OutputDimNums[0] = outputBatchDimension; - OutputDimNums[1] = outputFeatureDimension; + SmallVector outputDimNums(spatialDimNum + 2); + outputDimNums[0] = outputBatchDimension; + outputDimNums[1] = outputFeatureDimension; std::copy(outputSpatialDimensions.begin(), outputSpatialDimensions.end(), - OutputDimNums.begin() + 2); + outputDimNums.begin() + 2); auto numDims = lhsType.cast().getRank(); const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; }; - + // convolution_c13, convolution_c18, convolution_c20 if (!llvm::all_of(inputDimNums, inRange) || !llvm::all_of(windowDimNums, inRange) || - !llvm::all_of(OutputDimNums, inRange)) + !llvm::all_of(outputDimNums, inRange)) return emitOptionalError(location, "expects input, kernel, and output " "dimension-numbers to be in-range [0, ", numDims, ")."); + // convolution_c13 if (hasDuplicates(inputDimNums)) return emitOptionalError( location, "expects input dimension-numbers to be unique, got {", inputDimNums, "}."); + // convolution_c18 if (hasDuplicates(windowDimNums)) return emitOptionalError( location, "expects kernel dimension-numbers to be unique, got {", windowDimNums, "}."); - if (hasDuplicates(OutputDimNums)) + // convolution_c20 + if (hasDuplicates(outputDimNums)) return emitOptionalError( location, "expects output dimension-numbers to be unique, got {", - OutputDimNums, "}."); + outputDimNums, "}."); return success(); } @@ -833,27 +837,6 @@ LogicalResult verifyPrecisionConfig(std::optional loc, "<= 2 elements."); } -// Verifies the following properties: -// P1. The input, kernel, and output spatial-dimensions are valid. -// P2. Given, -// input-dimensions: b * input-spatial-dims * f -// kernel-dimensions: kernel-spatial-dims * i * o -// output-dimensions: b' * out-spatial-dims * f' -// where b = input-batch-dim -// where f = input-feature-dim -// where i = kernel-input-feature-dim -// where o = kernel-output-feature-dim -// where b' = output-batch-dim -// where f' = output-feature-dim -// Check the following properties w.r.t feature_group_count (fgc) and -// batch_group_count (bgc). -// * fgc > 0, bgc > 0 and !(fgc > 1 && bgc > 1) -// * dim(lhs, b) % bgc == 0 -// * dim(lhs, f) % fgc == 0 and -// dim(lhs, f) / fgc = dim(rhs, i) -// * dim(rhs, o) (or dim(output, f')) % bgc == 0 and -// dim(rhs, o) (or dim(output, f')) % fgc == 0 -// P3. Precision config is null, of size 0 or of size 2. LogicalResult verifyConvolutionAttributes( std::optional location, Type lhsType, Type rhsType, int64_t inputBatchDimension, int64_t inputFeatureDimension, @@ -863,7 +846,6 @@ LogicalResult verifyConvolutionAttributes( int64_t outputFeatureDimension, ArrayRef outputSpatialDimensions, int64_t featureGroupCount, int64_t batchGroupCount, std::optional precisionConfig) { - // P1. if (failed(isSpatialDimensionsValid( lhsType, inputBatchDimension, inputFeatureDimension, inputSpatialDimensions, kernelInputFeatureDimension, @@ -872,17 +854,19 @@ LogicalResult verifyConvolutionAttributes( location))) return failure(); - // P2. + // convolution_c21 if (featureGroupCount <= 0) return emitOptionalError( location, "expects feature_group_count to be a positive number, got ", featureGroupCount, "."); + // convolution_c22 if (batchGroupCount <= 0) return emitOptionalError( location, "expects batch_group_count to be a positive number, got ", batchGroupCount, "."); + // convolution_c23 if (batchGroupCount > 1 && featureGroupCount > 1) return emitOptionalError( location, @@ -900,24 +884,16 @@ LogicalResult verifyConvolutionAttributes( const int64_t kernelOutputFeatures = rankedRhsType.getShape()[kernelOutputFeatureDimension]; - if (!isDynamicDimSize(kernelOutputFeatures)) { - if (kernelOutputFeatures % batchGroupCount != 0) - return emitOptionalError( - location, "expects output feature dimension size (", - kernelOutputFeatures, - ") to be a multiple of batch_group_count. Got batch_group_count = ", - batchGroupCount, "."); - - if (kernelOutputFeatures % featureGroupCount != 0) - return emitOptionalError(location, - "expects kernel output feature dimension (", - kernelOutputFeatures, - ") to be divisible by feature_group_count. For " - "feature_group_count = ", - featureGroupCount, "."); - } + // convolution_c10 + if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0) + return emitOptionalError(location, "expects input batch dimension (", + inputBatch, + ") to be divisible by " + "batch_group_count. Got batch_group_count = ", + batchGroupCount, "."); if (!isDynamicDimSize(inputFeatures)) { + // convolution_c11 if (inputFeatures % featureGroupCount != 0) return emitOptionalError(location, "expects input feature dimension (", inputFeatures, @@ -925,6 +901,7 @@ LogicalResult verifyConvolutionAttributes( "feature_group_count = ", featureGroupCount, "."); + // convolution_c14 if (!isDynamicDimSize(kernelInputFeatures) && inputFeatures / featureGroupCount != kernelInputFeatures) return emitOptionalError( @@ -935,14 +912,26 @@ LogicalResult verifyConvolutionAttributes( "). Got feature_group_count = ", featureGroupCount, "."); } - if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0) - return emitOptionalError(location, "expects input batch dimension (", - inputBatch, - ") to be divisible by " - "batch_group_count. Got batch_group_count = ", - batchGroupCount, "."); + if (!isDynamicDimSize(kernelOutputFeatures)) { + // convolution_c15 + if (kernelOutputFeatures % batchGroupCount != 0) + return emitOptionalError( + location, "expects output feature dimension size (", + kernelOutputFeatures, + ") to be a multiple of batch_group_count. Got batch_group_count = ", + batchGroupCount, "."); - // P3. + // convolution_c16 + if (kernelOutputFeatures % featureGroupCount != 0) + return emitOptionalError(location, + "expects kernel output feature dimension (", + kernelOutputFeatures, + ") to be divisible by feature_group_count. For " + "feature_group_count = ", + featureGroupCount, "."); + } + + // convolution_c24 if (failed(verifyPrecisionConfig(location, precisionConfig))) return failure(); @@ -1719,15 +1708,6 @@ LogicalResult inferConvertOp( return success(); } -/* - * We intend to verify the following properties - * P1. Verify the input, kernel types. - * P2. Verify the convolution atributes. - * P3. Verify and collect the window atributes. - * P4. Verify precision_config attribute. - * P5. Verify the return shape. - * TODO(b/232574102): Verify the element-type of return-value. - */ LogicalResult inferConvolutionOp( std::optional location, Type lhsType, Type rhsType, std::optional windowStrides, @@ -1750,19 +1730,29 @@ LogicalResult inferConvolutionOp( return success(); } - // P1. + // convolution_c13 int numDims = rankedLhsType.getRank(); + if (numDims < 2) + return emitOptionalError( + location, + "expects convolution arguments to have >= 2 dimensions. Got: ", + rankedLhsType, " and ", rankedRhsType, "."); + + // convolution_c1 if (numDims != rankedRhsType.getRank()) return emitOptionalError(location, "expects convolution arguments to have same " "number of dimensions. Got: ", rankedLhsType, " and ", rankedRhsType, "."); - if (numDims < 2) + + // convolution_c27 + if (!isCompatibleForHloTypeInference(rankedLhsType.getElementType(), + rankedRhsType.getElementType())) return emitOptionalError( - location, - "expects convolution arguments to have >= 2 dimensions. Got: ", - rankedLhsType, " and ", rankedRhsType, "."); - // P2. + location, "expects lhs and rhs to have compatible element type. Got: ", + rankedLhsType.getElementType(), " and ", + rankedRhsType.getElementType()); + if (failed(verifyConvolutionAttributes( location, lhsType, rhsType, inputBatchDimension, inputFeatureDimension, inputSpatialDimensions, @@ -1772,29 +1762,36 @@ LogicalResult inferConvolutionOp( precisionConfig))) return failure(); + // convolution_c12 if ((size_t)numDims != inputSpatialDimensions.size() + 2) return emitOptionalError(location, "expects convolution arguments to have ", inputSpatialDimensions.size() + 2, " dimensions. Got: ", numDims); - // P3. SmallVector windowDimensions(kernelSpatialDimensions.size()); for (size_t i = 0; i < windowDimensions.size(); i++) windowDimensions[i] = rankedRhsType.getShape()[kernelSpatialDimensions[i]]; + // convolution_c4, convolution_i4 auto paddingOrErr = convertPaddingAttribute(padding, location); if (failed(paddingOrErr)) return failure(); - // TODO: add missing tests for ConvolutionOp. + // convolution_i3 auto windowStridesOrErr = convert1DAttribute(windowStrides, location, "window_strides"); if (failed(windowStridesOrErr)) return failure(); + + // convolution_i5 auto lhsDilationOrErr = convert1DAttribute(lhsDilation, location, "lhs_dilation"); if (failed(lhsDilationOrErr)) return failure(); + + // convolution_i6 auto rhsDilationOrErr = convert1DAttribute(rhsDilation, location, "rhs_dilation"); if (failed(rhsDilationOrErr)) return failure(); + + // convolution_i7 auto windowReversalOrErr = convertWindowReversalAttribute( windowReversal, location, "window_reversal"); if (failed(windowReversalOrErr)) return failure(); @@ -1804,15 +1801,9 @@ LogicalResult inferConvolutionOp( *rhsDilationOrErr, *windowReversalOrErr, location); if (failed(windowOrErr)) return failure(); - // P3. - if (failed(verifyPrecisionConfig(location, precisionConfig))) - return failure(); - - // P5. + // convolution_c25, convolution_c26 SmallVector outputDimensions(rankedLhsType.getShape().size(), ShapedType::kDynamic); - - // Infer the output spatial dimensions. auto numSpatialDims = inputSpatialDimensions.size(); SmallVector inputSpatialDimVals(numSpatialDims); for (int64_t i = 0; i < static_cast(numSpatialDims); ++i) @@ -1824,7 +1815,6 @@ LogicalResult inferConvolutionOp( for (int64_t i = 0; i < static_cast(windowOrErr->size()); ++i) outputDimensions[outputSpatialDimensions[i]] = windowOutputShape[i]; - // Infer the output-batch-dimension and output-feature-dimension. const int64_t inputBatch = rankedLhsType.getShape()[inputBatchDimension]; const int64_t kernelOutputFeatures = rankedRhsType.getShape()[kernelOutputFeatureDimension]; @@ -3358,13 +3348,14 @@ LogicalResult verifyConvolutionOp( auto inferredShape = inferredReturnShapes[0]; auto shapedResultType = resultType.cast(); + // convolution_c25 if (inferredShape.hasRank() && shapedResultType.hasRank() && failed(verifyCompatibleShape(inferredShape.getDims(), shapedResultType.getShape()))) return emitOptionalError(location, "inferred shape '", dimSizesToString(inferredShape.getDims()), "' ", "is incompatible with return type of operation ", - shapedResultType, ""); + shapedResultType); return success(); } diff --git a/stablehlo/reference/Element.h b/stablehlo/reference/Element.h index da98940aedf..8ea5c58022c 100644 --- a/stablehlo/reference/Element.h +++ b/stablehlo/reference/Element.h @@ -51,6 +51,7 @@ class Element { Element(Type type, std::complex value); Element(const Element &other) = default; + Element() = default; /// @} /// Assignment operator. diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index f6311171e06..a29ad084735 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -52,6 +52,28 @@ Index evalIndex(Tensor tensor) { return result; } +Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs, + const Axes &lhsBatchingDimensions, + const Axes &rhsBatchingDimensions, + const Axes &lhsContractingDimensions, + const Axes &rhsContractingDimensions) { + SmallVector inferredDotGeneralType; + auto dotGeneralStatus = hlo::inferDotGeneralOp( + /*location=*/{}, lhs.getType(), rhs.getType(), + /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {}, + lhsContractingDimensions, rhsContractingDimensions, + /*precisionConfig=*/{}, inferredDotGeneralType); + if (failed(dotGeneralStatus)) + report_fatal_error( + invalidArgument("Could not infer DotGeneralOp's return type")); + + return evalDotGeneralOp( + lhs, rhs, lhsBatchingDimensions, rhsBatchingDimensions, + lhsContractingDimensions, rhsContractingDimensions, + RankedTensorType::get(inferredDotGeneralType[0].getDims(), + lhs.getElementType())); +} + Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue, const Sizes &edgePaddingLow, const Sizes &edgePaddingHigh, const Sizes &interiorPadding) { @@ -132,6 +154,12 @@ Tensor evalSliceOp(const Tensor &operand, const Index &index) { return evalSliceOp(operand, start, limit, strides); } +Sizes extractElements(ArrayRef arr, ArrayRef indices) { + Sizes elements; + for (auto index : indices) elements.push_back(arr[index]); + return elements; +} + void failOnDecomposableOp(Operation &op) { report_fatal_error(invalidArgument( "Operation %s is unsupported at the moment. " @@ -142,6 +170,18 @@ void failOnDecomposableOp(Operation &op) { op.getName().getStringRef().str().c_str())); } +template +DenseIntElementsAttr getDenseIntElementsAttr( + Type elementType, T values, + std::optional> valuesShape) { + SmallVector shape = + valuesShape.has_value() + ? *valuesShape + : SmallVector({static_cast(values.size())}); + return DenseIntElementsAttr::get(RankedTensorType::get(shape, elementType), + values); +} + SmallVector> getReplicaGroups( DenseIntElementsAttr replicaGroupsAttr) { auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape(); @@ -157,6 +197,74 @@ SmallVector> getReplicaGroups( return replicaGroups; } +Tensor evalConvolutionOp( + const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, + std::optional precisionConfig, ShapedType resultType) { + auto i64Type = IntegerType::get(lhs.getType().getContext(), 64); + auto i1Type = IntegerType::get(lhs.getType().getContext(), 1); + + SmallVector paddingVector; + for (auto pair : padding) { + paddingVector.push_back(pair.first); + paddingVector.push_back(pair.second); + } + + SmallVector paddingShape{static_cast(padding.size()), 2}; + SmallVector inferredConvolutionType; + auto convolutionStatus = hlo::inferConvolutionOp( + /*location=*/{}, lhs.getType(), rhs.getType(), + getDenseIntElementsAttr(i64Type, windowStrides, {}), + getDenseIntElementsAttr(i64Type, paddingVector, paddingShape), + getDenseIntElementsAttr(i64Type, lhsDilation, {}), + getDenseIntElementsAttr(i64Type, rhsDilation, {}), + getDenseIntElementsAttr(i1Type, windowReversal, {}), + static_cast(inputBatchDimension), + static_cast(inputFeatureDimension), + ArrayRef(inputSpatialDimensions), + static_cast(kernelInputFeatureDimension), + static_cast(kernelOutputFeatureDimension), + ArrayRef(kernelSpatialDimensions), + static_cast(outputBatchDimension), + static_cast(outputFeatureDimension), + ArrayRef(outputSpatialDimensions), + /*featureGroupCount=*/1, batchGroupCount, + /*precisionConfig=*/{}, inferredConvolutionType); + if (failed(convolutionStatus)) + report_fatal_error( + invalidArgument("Could not infer ConvolutionOp's return type")); + return evalConvolutionOp( + lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, + featureGroupCount, batchGroupCount, + RankedTensorType::get(inferredConvolutionType[0].getDims(), + resultType.getElementType())); +} + +// Returns `result` with the effect of applying `permutation` +// (= [dimA] + dimsB + [dimC]) to `input` (= [n] + hw + [c]) such that +// result[permutation[i]] = input[i]. +template +SmallVector concatAndPermute(T n, SmallVector hw, T c, + const Axes &permutation) { + SmallVector result(permutation.size()); + result[permutation[0]] = n; + result[permutation[permutation.size() - 1]] = c; + for (uint64_t i = 1; i < permutation.size() - 1; ++i) + result[permutation[i]] = hw[i - 1]; + return result; +} + Tensor makeScalar(const Element &initValue) { Tensor result(RankedTensorType::get({}, initValue.getType())); result.set({}, initValue); @@ -375,6 +483,56 @@ SmallVector eval( auto operand = scope.findTensor(convertOp.getOperand()); auto result = evalConvertOp(operand, convertOp.getType()); scope.add(convertOp.getResult(), result); + } else if (auto convolutionOp = dyn_cast(op)) { + auto lhs = scope.findTensor(convolutionOp.getLhs()); + auto rhs = scope.findTensor(convolutionOp.getRhs()); + auto rank = lhs.getRank(); + + SmallVector windowStrides(rank - 2, 1); + if (auto windowStridesAttr = convolutionOp.getWindowStridesAttr()) + windowStrides.assign(windowStridesAttr.value_begin(), + windowStridesAttr.value_end()); + + SmallVector> padding(rank - 2, {0, 0}); + if (auto paddingAttr = convolutionOp.getPaddingAttr()) { + auto paddingOrErr = hlo::convertPaddingAttribute(paddingAttr, {}); + if (failed(paddingOrErr)) + report_fatal_error(invalidArgument("Invalid padding format found.")); + padding = *paddingOrErr; + } + + SmallVector lhsDilation(rank - 2, 1); + if (auto lhsDilationAttr = convolutionOp.getLhsDilationAttr()) + lhsDilation.assign(lhsDilationAttr.value_begin(), + lhsDilationAttr.value_end()); + + SmallVector rhsDilation(rank - 2, 1); + if (auto rhsDilationAttr = convolutionOp.getRhsDilationAttr()) + rhsDilation.assign(rhsDilationAttr.value_begin(), + rhsDilationAttr.value_end()); + + SmallVector windowReversal(rank - 2, false); + if (auto windowReversalAttr = convolutionOp.getWindowReversalAttr()) + windowReversal.assign(windowReversalAttr.value_begin(), + windowReversalAttr.value_end()); + + auto result = evalConvolutionOp( + lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, + convolutionOp.getDimensionNumbers().getInputBatchDimension(), + convolutionOp.getDimensionNumbers().getInputFeatureDimension(), + Axes(convolutionOp.getDimensionNumbers().getInputSpatialDimensions()), + convolutionOp.getDimensionNumbers().getKernelInputFeatureDimension(), + convolutionOp.getDimensionNumbers().getKernelOutputFeatureDimension(), + Axes( + convolutionOp.getDimensionNumbers().getKernelSpatialDimensions()), + convolutionOp.getDimensionNumbers().getOutputBatchDimension(), + convolutionOp.getDimensionNumbers().getOutputFeatureDimension(), + Axes( + convolutionOp.getDimensionNumbers().getOutputSpatialDimensions()), + convolutionOp.getFeatureGroupCount(), + convolutionOp.getBatchGroupCount(), convolutionOp.getType()); + scope.add(convolutionOp.getResult(), result); } else if (auto cosineOp = dyn_cast(op)) { auto operand = scope.findTensor(cosineOp.getOperand()); auto result = evalCosineOp(operand, cosineOp.getType()); @@ -1173,6 +1331,155 @@ Tensor evalConvertOp(const Tensor &operand, ShapedType resultType) { return result; } +Tensor evalConvolutionOp( + const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType) { + Tensor result(resultType); + + if (featureGroupCount > 1) { + auto lhses = split(lhs, featureGroupCount, inputFeatureDimension, + resultType.getContext()); + auto rhses = split(rhs, featureGroupCount, kernelOutputFeatureDimension, + resultType.getContext()); + SmallVector results; + for (auto [left, right] : llvm::zip(lhses, rhses)) { + auto convolutionResult = evalConvolutionOp( + left, right, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, + /*featureGroupCount=*/1, batchGroupCount, /*precisionConfig=*/{}, + resultType); + results.push_back(convolutionResult); + } + return evalConcatenateOp(results, outputFeatureDimension, result.getType()); + } + + if (batchGroupCount > 1) { + auto lhses = split(lhs, batchGroupCount, inputBatchDimension, + resultType.getContext()); + auto rhses = split(rhs, batchGroupCount, kernelOutputFeatureDimension, + resultType.getContext()); + SmallVector results; + for (auto [left, right] : llvm::zip(lhses, rhses)) { + auto convolutionResult = evalConvolutionOp( + left, right, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, + featureGroupCount, /*batchGroupCount=*/1, /*precisionConfig=*/{}, + resultType); + results.push_back(convolutionResult); + } + return evalConcatenateOp(results, outputFeatureDimension, result.getType()); + } + + Axes lhsPermutation; + lhsPermutation.push_back(inputBatchDimension); + lhsPermutation.append(inputSpatialDimensions.begin(), + inputSpatialDimensions.end()); + lhsPermutation.push_back(inputFeatureDimension); + + auto lhsWindowDimensions = + concatAndPermute(lhs.getShape()[inputBatchDimension], + extractElements(rhs.getShape(), kernelSpatialDimensions), + lhs.getShape()[inputFeatureDimension], lhsPermutation); + + auto lhsWindowStrides = + concatAndPermute(1L, llvm::to_vector(windowStrides), 1L, lhsPermutation); + + auto lhsBaseDilations = + concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation); + + auto lhsWindowDilations = + concatAndPermute(1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation); + + Sizes lhsPaddingLow; + Sizes lhsPaddingHigh; + for (auto paddingPair : concatAndPermute({0, 0}, llvm::to_vector(padding), + {0, 0}, lhsPermutation)) { + lhsPaddingLow.push_back(paddingPair.first); + lhsPaddingHigh.push_back(paddingPair.second); + } + + auto paddingValue = makeScalar(convert(result.getElementType(), 0.0)); + auto paddedLhs = evalPadOp(lhs, paddingValue, lhsPaddingLow, lhsPaddingHigh, + Sizes(lhsBaseDilations)); + + IndexSpaceIterator outputSpatialIndexIt( + extractElements(result.getShape(), outputSpatialDimensions), + Index(outputSpatialDimensions.size())); + IndexSpaceIterator outputSpatialIndexItEnd( + extractElements(result.getShape(), outputSpatialDimensions), + std::nullopt); + for (; outputSpatialIndexIt != outputSpatialIndexItEnd; + ++outputSpatialIndexIt) { + Sizes lhsWindowStart; + for (auto [i, offset] : llvm::enumerate(concatAndPermute( + 0L, llvm::to_vector(*outputSpatialIndexIt), 0L, lhsPermutation))) + lhsWindowStart.push_back(lhsWindowStrides[i] * offset); + + Sizes limitIndices; + for (size_t i = 0; i < lhsWindowStart.size(); ++i) + limitIndices.push_back(std::min( + lhsWindowStart[i] + lhsWindowDimensions[i] * lhsWindowDilations[i], + paddedLhs.getShape()[i])); + + auto lhsWindow = evalSliceOp(paddedLhs, lhsWindowStart, limitIndices, + Sizes(lhsWindowDilations)); + + Axes reverseDims; + for (auto [i, isReverse] : llvm::enumerate(windowReversal)) + if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]); + auto reversedLhsWindow = + evalReverseOp(lhsWindow, reverseDims, lhsWindow.getType()); + + Axes lhsContractingDimensions(inputSpatialDimensions); + lhsContractingDimensions.push_back(inputFeatureDimension); + + Axes rhsContractingDimensions(kernelSpatialDimensions); + rhsContractingDimensions.push_back(kernelInputFeatureDimension); + + auto dotProduct = + evalDotGeneralOp(reversedLhsWindow, rhs, /*lhsBatchingDimensions=*/{}, + /*rhsBatchingDimensions=*/{}, lhsContractingDimensions, + rhsContractingDimensions); + + Sizes resultNonSpatialDims; + for (auto i = 0; i < result.getRank(); ++i) + if (llvm::find(outputSpatialDimensions, i) == + outputSpatialDimensions.end()) + resultNonSpatialDims.push_back(result.getShape()[i]); + + Axes resultPermutation; + resultPermutation.push_back(outputBatchDimension); + resultPermutation.append(outputSpatialDimensions.begin(), + outputSpatialDimensions.end()); + resultPermutation.push_back(outputFeatureDimension); + + IndexSpaceIterator resultNonSpatialIt(resultNonSpatialDims, + Index(resultNonSpatialDims.size())); + for (auto dotProductIt = dotProduct.index_begin(); + dotProductIt != dotProduct.index_end(); + ++dotProductIt, ++resultNonSpatialIt) { + Index resultIndex( + concatAndPermute((*resultNonSpatialIt)[0], *outputSpatialIndexIt, + (*resultNonSpatialIt)[1], resultPermutation)); + result.set(resultIndex, dotProduct.get(*dotProductIt)); + } + } + return result; +} + Tensor evalCosineOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index ef1a77a5fd0..6f17c320cb1 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -73,6 +73,16 @@ Tensor evalConcatenateOp(ArrayRef inputs, Axis dimension, ShapedType resultType); Tensor evalConstantOp(ElementsAttr value); Tensor evalConvertOp(const Tensor &operand, ShapedType resultType); +Tensor evalConvolutionOp( + const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType); Tensor evalCosineOp(const Tensor &operand, ShapedType resultType); Tensor evalDivideOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); diff --git a/stablehlo/testdata/conv_general_dilated_conv1d_lhs_float32_2_3_10__rhs_float32_3_3_5__windowstrides__1___padding___0_0_3336387849681708224.mlir b/stablehlo/testdata/conv_general_dilated_conv1d_lhs_float32_2_3_10__rhs_float32_3_3_5__windowstrides__1___padding___0_0_3336387849681708224.mlir index 821a9027e65..5ff1d73cc4b 100644 --- a/stablehlo/testdata/conv_general_dilated_conv1d_lhs_float32_2_3_10__rhs_float32_3_3_5__windowstrides__1___padding___0_0_3336387849681708224.mlir +++ b/stablehlo/testdata/conv_general_dilated_conv1d_lhs_float32_2_3_10__rhs_float32_3_3_5__windowstrides__1___padding___0_0_3336387849681708224.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_conv_tranpose1d_same_padding_lhs_float32_1_16_2__rhs_float32_3_2_2__windowstrid-2801039169040378295.mlir b/stablehlo/testdata/conv_general_dilated_conv_tranpose1d_same_padding_lhs_float32_1_16_2__rhs_float32_3_2_2__windowstrid-2801039169040378295.mlir index f2a3657c03c..0b337a48e41 100644 --- a/stablehlo/testdata/conv_general_dilated_conv_tranpose1d_same_padding_lhs_float32_1_16_2__rhs_float32_3_2_2__windowstrid-2801039169040378295.mlir +++ b/stablehlo/testdata/conv_general_dilated_conv_tranpose1d_same_padding_lhs_float32_1_16_2__rhs_float32_3_2_2__windowstrid-2801039169040378295.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_conv_tranpose1d_valid_padding_lhs_float32_1_16_2__rhs_float32_3_2_2__windowstri-8415303397313720110.mlir b/stablehlo/testdata/conv_general_dilated_conv_tranpose1d_valid_padding_lhs_float32_1_16_2__rhs_float32_3_2_2__windowstri-8415303397313720110.mlir index 70d532a153a..218a118d0e3 100644 --- a/stablehlo/testdata/conv_general_dilated_conv_tranpose1d_valid_padding_lhs_float32_1_16_2__rhs_float32_3_2_2__windowstri-8415303397313720110.mlir +++ b/stablehlo/testdata/conv_general_dilated_conv_tranpose1d_valid_padding_lhs_float32_1_16_2__rhs_float32_3_2_2__windowstri-8415303397313720110.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_conv_tranpose2d_same_padding_lhs_float32_1_16_16_2__rhs_float32_2_3_2_2__window-296817466026258822.mlir b/stablehlo/testdata/conv_general_dilated_conv_tranpose2d_same_padding_lhs_float32_1_16_16_2__rhs_float32_2_3_2_2__window-296817466026258822.mlir index a0ac2d15119..1d5e40510f7 100644 --- a/stablehlo/testdata/conv_general_dilated_conv_tranpose2d_same_padding_lhs_float32_1_16_16_2__rhs_float32_2_3_2_2__window-296817466026258822.mlir +++ b/stablehlo/testdata/conv_general_dilated_conv_tranpose2d_same_padding_lhs_float32_1_16_16_2__rhs_float32_2_3_2_2__window-296817466026258822.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_conv_tranpose2d_valid_padding_lhs_float32_1_16_16_2__rhs_float32_2_3_2_2__windo216554503587864094.mlir b/stablehlo/testdata/conv_general_dilated_conv_tranpose2d_valid_padding_lhs_float32_1_16_16_2__rhs_float32_2_3_2_2__windo216554503587864094.mlir index b8bb41dbf6a..b02e01dccdf 100644 --- a/stablehlo/testdata/conv_general_dilated_conv_tranpose2d_valid_padding_lhs_float32_1_16_16_2__rhs_float32_2_3_2_2__windo216554503587864094.mlir +++ b/stablehlo/testdata/conv_general_dilated_conv_tranpose2d_valid_padding_lhs_float32_1_16_16_2__rhs_float32_2_3_2_2__windo216554503587864094.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_depthwise1d_dilated_lhs_float32_2_3_9__rhs_float32_12_1_3__windowstrides__1___p4726467912149865720.mlir b/stablehlo/testdata/conv_general_dilated_depthwise1d_dilated_lhs_float32_2_3_9__rhs_float32_12_1_3__windowstrides__1___p4726467912149865720.mlir index 6f0b7b11a21..c18edc08f7e 100644 --- a/stablehlo/testdata/conv_general_dilated_depthwise1d_dilated_lhs_float32_2_3_9__rhs_float32_12_1_3__windowstrides__1___p4726467912149865720.mlir +++ b/stablehlo/testdata/conv_general_dilated_depthwise1d_dilated_lhs_float32_2_3_9__rhs_float32_12_1_3__windowstrides__1___p4726467912149865720.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_depthwise1d_lhs_float32_2_3_9__rhs_float32_12_1_3__windowstrides__1___padding__2025442285785167829.mlir b/stablehlo/testdata/conv_general_dilated_depthwise1d_lhs_float32_2_3_9__rhs_float32_12_1_3__windowstrides__1___padding__2025442285785167829.mlir index e596dba0b62..91fcbc658f7 100644 --- a/stablehlo/testdata/conv_general_dilated_depthwise1d_lhs_float32_2_3_9__rhs_float32_12_1_3__windowstrides__1___padding__2025442285785167829.mlir +++ b/stablehlo/testdata/conv_general_dilated_depthwise1d_lhs_float32_2_3_9__rhs_float32_12_1_3__windowstrides__1___padding__2025442285785167829.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_depthwise2d_dilated_lhs_float32_2_3_9_9__rhs_float32_12_1_3_3__windowstrides__1-5299255563318388077.mlir b/stablehlo/testdata/conv_general_dilated_depthwise2d_dilated_lhs_float32_2_3_9_9__rhs_float32_12_1_3_3__windowstrides__1-5299255563318388077.mlir index efff86848f9..0bbedd84429 100644 --- a/stablehlo/testdata/conv_general_dilated_depthwise2d_dilated_lhs_float32_2_3_9_9__rhs_float32_12_1_3_3__windowstrides__1-5299255563318388077.mlir +++ b/stablehlo/testdata/conv_general_dilated_depthwise2d_dilated_lhs_float32_2_3_9_9__rhs_float32_12_1_3_3__windowstrides__1-5299255563318388077.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_depthwise2d_lhs_float32_2_3_9_9__rhs_float32_12_1_3_3__windowstrides__1_1__padd-5494580149111736437.mlir b/stablehlo/testdata/conv_general_dilated_depthwise2d_lhs_float32_2_3_9_9__rhs_float32_12_1_3_3__windowstrides__1_1__padd-5494580149111736437.mlir index 5b1660e8553..45dbb8c8e77 100644 --- a/stablehlo/testdata/conv_general_dilated_depthwise2d_lhs_float32_2_3_9_9__rhs_float32_12_1_3_3__windowstrides__1_1__padd-5494580149111736437.mlir +++ b/stablehlo/testdata/conv_general_dilated_depthwise2d_lhs_float32_2_3_9_9__rhs_float32_12_1_3_3__windowstrides__1_1__padd-5494580149111736437.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin-940783895638600378.mlir b/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin-940783895638600378.mlir index e7bd941b85b..998526a2faf 100644 --- a/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin-940783895638600378.mlir +++ b/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin-940783895638600378.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin7803468828575064957.mlir b/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin7803468828575064957.mlir index 1e02dff7465..fb066ad8a84 100644 --- a/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin7803468828575064957.mlir +++ b/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin7803468828575064957.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin8362069580368565836.mlir b/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin8362069580368565836.mlir index ecb98b81f6d..e393b3726fa 100644 --- a/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin8362069580368565836.mlir +++ b/stablehlo/testdata/conv_general_dilated_dilations_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin8362069580368565836.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dimension_numbers_lhs_float32_2_3_9_10__rhs_float32_4_5_3_3__windowstrides__1_1-8327739261064575227.mlir b/stablehlo/testdata/conv_general_dilated_dimension_numbers_lhs_float32_2_3_9_10__rhs_float32_4_5_3_3__windowstrides__1_1-8327739261064575227.mlir index 359b385ca6a..c4bff2a7f39 100644 --- a/stablehlo/testdata/conv_general_dilated_dimension_numbers_lhs_float32_2_3_9_10__rhs_float32_4_5_3_3__windowstrides__1_1-8327739261064575227.mlir +++ b/stablehlo/testdata/conv_general_dilated_dimension_numbers_lhs_float32_2_3_9_10__rhs_float32_4_5_3_3__windowstrides__1_1-8327739261064575227.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dimension_numbers_lhs_float32_2_9_10_3__rhs_float32_4_5_3_3__windowstrides__1_1-6479296361709234045.mlir b/stablehlo/testdata/conv_general_dilated_dimension_numbers_lhs_float32_2_9_10_3__rhs_float32_4_5_3_3__windowstrides__1_1-6479296361709234045.mlir index cd7cddd36c9..ef5cd0e6657 100644 --- a/stablehlo/testdata/conv_general_dilated_dimension_numbers_lhs_float32_2_9_10_3__rhs_float32_4_5_3_3__windowstrides__1_1-6479296361709234045.mlir +++ b/stablehlo/testdata/conv_general_dilated_dimension_numbers_lhs_float32_2_9_10_3__rhs_float32_4_5_3_3__windowstrides__1_1-6479296361709234045.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-1572008590406220780.mlir b/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-1572008590406220780.mlir index 26a7e0b707e..db5547cf6fb 100644 --- a/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-1572008590406220780.mlir +++ b/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-1572008590406220780.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-4848375839726769119.mlir b/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-4848375839726769119.mlir index b0aac0f7d4f..dfe65f590aa 100644 --- a/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-4848375839726769119.mlir +++ b/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-4848375839726769119.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-6463715315589295392.mlir b/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-6463715315589295392.mlir index 3d87f197e07..21da95d51c6 100644 --- a/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-6463715315589295392.mlir +++ b/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-6463715315589295392.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-7128650963837464321.mlir b/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-7128650963837464321.mlir index bf3ce2751e6..d508654eaec 100644 --- a/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-7128650963837464321.mlir +++ b/stablehlo/testdata/conv_general_dilated_dtype_precision_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__-7128650963837464321.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_group_counts_lhs_float32_2_6_9_10__rhs_float32_6_3_4_5__windowstrides__1_1__pad-2212136815070163245.mlir b/stablehlo/testdata/conv_general_dilated_group_counts_lhs_float32_2_6_9_10__rhs_float32_6_3_4_5__windowstrides__1_1__pad-2212136815070163245.mlir index 46c0dded01a..7a1d1285bb8 100644 --- a/stablehlo/testdata/conv_general_dilated_group_counts_lhs_float32_2_6_9_10__rhs_float32_6_3_4_5__windowstrides__1_1__pad-2212136815070163245.mlir +++ b/stablehlo/testdata/conv_general_dilated_group_counts_lhs_float32_2_6_9_10__rhs_float32_6_3_4_5__windowstrides__1_1__pad-2212136815070163245.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_group_counts_lhs_float32_4_3_9_10__rhs_float32_6_3_4_5__windowstrides__1_1__pad2141518546713332984.mlir b/stablehlo/testdata/conv_general_dilated_group_counts_lhs_float32_4_3_9_10__rhs_float32_6_3_4_5__windowstrides__1_1__pad2141518546713332984.mlir index 4c40261e415..77f634836ae 100644 --- a/stablehlo/testdata/conv_general_dilated_group_counts_lhs_float32_4_3_9_10__rhs_float32_6_3_4_5__windowstrides__1_1__pad2141518546713332984.mlir +++ b/stablehlo/testdata/conv_general_dilated_group_counts_lhs_float32_4_3_9_10__rhs_float32_6_3_4_5__windowstrides__1_1__pad2141518546713332984.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_padding_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__padding_-5318452038191342786.mlir b/stablehlo/testdata/conv_general_dilated_padding_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__padding_-5318452038191342786.mlir index 0c79f3fabd2..ffdd2c9e42f 100644 --- a/stablehlo/testdata/conv_general_dilated_padding_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__padding_-5318452038191342786.mlir +++ b/stablehlo/testdata/conv_general_dilated_padding_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__padding_-5318452038191342786.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_padding_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__padding_5050050053143756869.mlir b/stablehlo/testdata/conv_general_dilated_padding_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__padding_5050050053143756869.mlir index 57002dae224..ec0370c9d18 100644 --- a/stablehlo/testdata/conv_general_dilated_padding_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__padding_5050050053143756869.mlir +++ b/stablehlo/testdata/conv_general_dilated_padding_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__padding_5050050053143756869.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_preferred_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin-8230831430534381289.mlir b/stablehlo/testdata/conv_general_dilated_preferred_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin-8230831430534381289.mlir index bec43fb444f..05b6173672f 100644 --- a/stablehlo/testdata/conv_general_dilated_preferred_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin-8230831430534381289.mlir +++ b/stablehlo/testdata/conv_general_dilated_preferred_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin-8230831430534381289.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_preferred_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin2430556623423240398.mlir b/stablehlo/testdata/conv_general_dilated_preferred_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin2430556623423240398.mlir index da13920da69..cf4c9c26677 100644 --- a/stablehlo/testdata/conv_general_dilated_preferred_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin2430556623423240398.mlir +++ b/stablehlo/testdata/conv_general_dilated_preferred_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__1_1__paddin2430556623423240398.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_preferred_lhs_int32_2_3_9_10__rhs_int32_3_3_4_5__windowstrides__1_1__padding___-7077764382968112509.mlir b/stablehlo/testdata/conv_general_dilated_preferred_lhs_int32_2_3_9_10__rhs_int32_3_3_4_5__windowstrides__1_1__padding___-7077764382968112509.mlir index 3b1fba42a0f..6b97083a2d6 100644 --- a/stablehlo/testdata/conv_general_dilated_preferred_lhs_int32_2_3_9_10__rhs_int32_3_3_4_5__windowstrides__1_1__padding___-7077764382968112509.mlir +++ b/stablehlo/testdata/conv_general_dilated_preferred_lhs_int32_2_3_9_10__rhs_int32_3_3_4_5__windowstrides__1_1__padding___-7077764382968112509.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_preferred_lhs_int32_2_3_9_10__rhs_int32_3_3_4_5__windowstrides__1_1__padding___8425439145703732762.mlir b/stablehlo/testdata/conv_general_dilated_preferred_lhs_int32_2_3_9_10__rhs_int32_3_3_4_5__windowstrides__1_1__padding___8425439145703732762.mlir index cb1bc57353e..8343bdba7c8 100644 --- a/stablehlo/testdata/conv_general_dilated_preferred_lhs_int32_2_3_9_10__rhs_int32_3_3_4_5__windowstrides__1_1__padding___8425439145703732762.mlir +++ b/stablehlo/testdata/conv_general_dilated_preferred_lhs_int32_2_3_9_10__rhs_int32_3_3_4_5__windowstrides__1_1__padding___8425439145703732762.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_rhs_oob_after_dilation_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides2604570134889024577.mlir b/stablehlo/testdata/conv_general_dilated_rhs_oob_after_dilation_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides2604570134889024577.mlir index 7492f2dfe79..edfdb78b0a1 100644 --- a/stablehlo/testdata/conv_general_dilated_rhs_oob_after_dilation_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides2604570134889024577.mlir +++ b/stablehlo/testdata/conv_general_dilated_rhs_oob_after_dilation_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides2604570134889024577.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_rhs_oob_after_pading_lhs_float32_1_3_2_2__rhs_float32_64_3_7_7__windowstrides__5476883256254674781.mlir b/stablehlo/testdata/conv_general_dilated_rhs_oob_after_pading_lhs_float32_1_3_2_2__rhs_float32_64_3_7_7__windowstrides__5476883256254674781.mlir index 581ff61120e..9ea173cd60c 100644 --- a/stablehlo/testdata/conv_general_dilated_rhs_oob_after_pading_lhs_float32_1_3_2_2__rhs_float32_64_3_7_7__windowstrides__5476883256254674781.mlir +++ b/stablehlo/testdata/conv_general_dilated_rhs_oob_after_pading_lhs_float32_1_3_2_2__rhs_float32_64_3_7_7__windowstrides__5476883256254674781.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_rhs_oob_lhs_float32_2_3_9_10__rhs_float32_3_3_10_5__windowstrides__1_1__padding-7028470013130116023.mlir b/stablehlo/testdata/conv_general_dilated_rhs_oob_lhs_float32_2_3_9_10__rhs_float32_3_3_10_5__windowstrides__1_1__padding-7028470013130116023.mlir index 254e189b338..07048e9fc1a 100644 --- a/stablehlo/testdata/conv_general_dilated_rhs_oob_lhs_float32_2_3_9_10__rhs_float32_3_3_10_5__windowstrides__1_1__padding-7028470013130116023.mlir +++ b/stablehlo/testdata/conv_general_dilated_rhs_oob_lhs_float32_2_3_9_10__rhs_float32_3_3_10_5__windowstrides__1_1__padding-7028470013130116023.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_rhs_oob_same_padding_lhs_float32_1_1_16_1__rhs_float32_4_1_1_2__windowstrides__-5378230060959849233.mlir b/stablehlo/testdata/conv_general_dilated_rhs_oob_same_padding_lhs_float32_1_1_16_1__rhs_float32_4_1_1_2__windowstrides__-5378230060959849233.mlir index cf4088ae8f2..90668369bda 100644 --- a/stablehlo/testdata/conv_general_dilated_rhs_oob_same_padding_lhs_float32_1_1_16_1__rhs_float32_4_1_1_2__windowstrides__-5378230060959849233.mlir +++ b/stablehlo/testdata/conv_general_dilated_rhs_oob_same_padding_lhs_float32_1_1_16_1__rhs_float32_4_1_1_2__windowstrides__-5378230060959849233.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-3327288011231887622.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-3327288011231887622.mlir index 10ab074e42f..e3413c0d640 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-3327288011231887622.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-3327288011231887622.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-3622729120410096886.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-3622729120410096886.mlir index 04676ee3dec..4b664e8e57d 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-3622729120410096886.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-3622729120410096886.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-6986955039250405512.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-6986955039250405512.mlir index 3ccc19eb461..33dbaccfe98 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-6986955039250405512.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-6986955039250405512.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-9029857704645127306.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-9029857704645127306.mlir index 6d3367cf10a..fae21627d18 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-9029857704645127306.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_1d_lhs_float32_1_28_1__rhs_float32_3_1_16__windowstrides__1_-9029857704645127306.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-2949258448720886117.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-2949258448720886117.mlir index b5ba24eb51a..6d7e72f4648 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-2949258448720886117.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-2949258448720886117.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-3150713558304940791.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-3150713558304940791.mlir index 4b83166fc27..1a4faa5cdb8 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-3150713558304940791.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-3150713558304940791.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-4268834939663085007.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-4268834939663085007.mlir index f93ed584a56..936bf849c96 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-4268834939663085007.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-4268834939663085007.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-4823547473556881342.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-4823547473556881342.mlir index ecc053b5522..bb45f7c5f56 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-4823547473556881342.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride-4823547473556881342.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride6887804375295982821.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride6887804375295982821.mlir index 218ee435eb8..44d44858772 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride6887804375295982821.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_1_28_28__rhs_float32_3_3_1_16__windowstride6887804375295982821.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride-5370080226275098061.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride-5370080226275098061.mlir index 8da8ecc13a6..f92256414b9 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride-5370080226275098061.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride-5370080226275098061.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride3500806609840728678.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride3500806609840728678.mlir index e7d49f7d5cc..0ad7355f21f 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride3500806609840728678.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride3500806609840728678.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride5787494791845602941.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride5787494791845602941.mlir index 287a8e66859..ac9d98a054f 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride5787494791845602941.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride5787494791845602941.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride6312536791243051545.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride6312536791243051545.mlir index 6cb4d4f8613..6803383c279 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride6312536791243051545.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride6312536791243051545.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride8472832706066243053.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride8472832706066243053.mlir index 074d1d9f20e..998983ad095 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride8472832706066243053.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_2d_lhs_float32_1_28_28_1__rhs_float32_3_3_1_16__windowstride8472832706066243053.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst-8282934958224398022.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst-8282934958224398022.mlir index 02be425b6ae..0cc45bbd789 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst-8282934958224398022.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst-8282934958224398022.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst3167308923230843499.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst3167308923230843499.mlir index cd7aff7bfa2..03a7a5644e2 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst3167308923230843499.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst3167308923230843499.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst6072174321640638276.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst6072174321640638276.mlir index ed1a6c7fc6d..1dc18b55f4c 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst6072174321640638276.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst6072174321640638276.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst8524424959485066052.mlir b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst8524424959485066052.mlir index 47ff8fde3ad..7ecf8022d8e 100644 --- a/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst8524424959485066052.mlir +++ b/stablehlo/testdata/conv_general_dilated_tf_conversion_path_3d_lhs_float32_1_4_28_28_1__rhs_float32_2_3_3_1_16__windowst8524424959485066052.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/conv_general_dilated_window_strides_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__2_3__p-3051008093469972974.mlir b/stablehlo/testdata/conv_general_dilated_window_strides_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__2_3__p-3051008093469972974.mlir index 56884a7dc10..f6e85ac504a 100644 --- a/stablehlo/testdata/conv_general_dilated_window_strides_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__2_3__p-3051008093469972974.mlir +++ b/stablehlo/testdata/conv_general_dilated_window_strides_lhs_float32_2_3_9_10__rhs_float32_3_3_4_5__windowstrides__2_3__p-3051008093469972974.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/tests/infer_stablehlo.mlir b/stablehlo/tests/infer_stablehlo.mlir index 2b1b30d6ce7..0d985afb7fa 100644 --- a/stablehlo/tests/infer_stablehlo.mlir +++ b/stablehlo/tests/infer_stablehlo.mlir @@ -142,6 +142,19 @@ func.func @abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xindex> { // ----- +// CHECK-LABEL: func @collective_permute_c5 +func.func @collective_permute_c5(%arg0: tensor<2x2xi64>) -> tensor<2x2xindex> { + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: types0 = tensor<2x2xi64> + %1 = "hlo_test_infer.get_return_types"(%0) : (tensor<2x2xi64>) -> tensor<2x2xindex> + func.return %1 : tensor<2x2xindex> +} + +// ----- + // CHECK-LABEL: @concat func.func @concat(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xindex> { %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> @@ -152,15 +165,182 @@ func.func @concat(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xindex // ----- -// CHECK-LABEL: func @collective_permute_c5 -func.func @collective_permute_c5(%arg0: tensor<2x2xi64>) -> tensor<2x2xindex> { - %0 = "stablehlo.collective_permute"(%arg0) { - source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, - channel_handle = #stablehlo.channel_handle - } : (tensor<2x2xi64>) -> tensor<2x2xi64> - // CHECK: types0 = tensor<2x2xi64> - %1 = "hlo_test_infer.get_return_types"(%0) : (tensor<2x2xi64>) -> tensor<2x2xindex> - func.return %1 : tensor<2x2xindex> +// Invalid rank of output-type. + +func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> { + // expected-error @+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x8x16xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> + func.return %0 : tensor<1x8x16xf32> +} + +// ----- + +// Invalid batch dimension in output-type. Should be equal to +// input-batch-dimension / batch_group_count. + +func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> { + // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<2x8x8x16xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> + func.return %0 : tensor<2x8x8x16xf32> +} + +// ----- + +// Invalid feature dimension in output-type. Should be equal to +// kernel_output_feature_dimension. + +func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> { + // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x8x8x32xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> + func.return %0 : tensor<1x8x8x32xf32> +} + +// ----- + +// The following tests checks the inferred output-type of ConvolutionOp. We +// deliberately put an invalid output-type in these tests so that the +// inffered-type can be highlighted in the error message. + +// Dynamic input-batch-dimension +func.func @invalid_conv_dynamic_shapes(%arg0: tensor, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{inferred shape '[?, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic input-feature-dimension: No effect on output dimensions. +func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x?xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x?xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic input-spatial-dimension +func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x?x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{inferred shape '[1, ?, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x?x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic kernel-input-feature-dimension: No effect on output dimensions. +func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic kernel-output-feature-dimension +func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{inferred shape '[1, 8, 8, ?]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// Dynamic kernel-spatial-dimension +func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> { + // expected-error@+1 {{inferred shape '[1, 8, ?, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x207xf32>, tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> + func.return %0 : tensor<1x1x1x1xf32> } // ----- diff --git a/stablehlo/tests/interpret_convolution.mlir b/stablehlo/tests/interpret_convolution.mlir new file mode 100644 index 00000000000..b5d0e327358 --- /dev/null +++ b/stablehlo/tests/interpret_convolution.mlir @@ -0,0 +1,79 @@ +// RUN: stablehlo-translate --interpret -split-input-file %s + +func.func @convolution_op_test_si64() { + %lhs = stablehlo.constant dense<[[ + [[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[10], [11], [14], [15]], + [[12], [13], [16], [17]] + ]]> : tensor<1x4x4x1xi64> + %rhs = stablehlo.constant dense<1> : tensor<3x3x1x1xi64> + %result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + lhs_dilate = [2, 2] + } { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } + : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + check.expect_eq_const %result, dense<[[ + [[10], [26]], + [[46], [62]] + ]]> : tensor<1x2x2x1xi64> + func.return +} + +// ----- + +func.func @convolution_batch_group_count_4() { + %lhs = stablehlo.constant dense<[[ + [[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[10], [11], [14], [15]], + [[12], [13], [16], [17]] + ]]> : tensor<1x4x4x1xi64> + %rhs = stablehlo.constant dense<1> : tensor<1x2x1x4xi64> + %result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [0, b, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + lhs_dilate = [2, 2] + } { + batch_group_count = 4 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } + : (tensor<1x4x4x1xi64>, tensor<1x2x1x4xi64>) -> tensor<1x1x2x4xi64> + check.expect_eq_const %result, dense<[[[[1, 3, 10, 12], + [5, 7, 14, 16]]]]> : tensor<1x1x2x4xi64> + func.return +} + +// ----- + +func.func @convolution_feature_group_count_4() { + %lhs = stablehlo.constant dense<[[ + [[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[10], [11], [14], [15]], + [[12], [13], [16], [17]] + ]]> : tensor<1x4x4x1xi64> + %rhs = stablehlo.constant dense<1> : tensor<1x2x1x4xi64> + %result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [b, 0, f, 1]x[0, i, 1, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + lhs_dilate = [2, 2] + } { + batch_group_count = 1 : i64, + feature_group_count = 2 : i64, + precision_config = [#stablehlo, #stablehlo] + } + : (tensor<1x4x4x1xi64>, tensor<1x2x1x4xi64>) -> tensor<1x2x1x4xi64> + check.expect_eq_const %result, dense<[[[[3, 3, 11, 11]], + [[21, 21, 29, 29]]]]> : tensor<1x2x1x4xi64> + func.return +} diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 8b71d01646a..3e206dd28d1 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -4967,11 +4967,31 @@ func.func @eltwise_static_and_dynamic_type(%arg0: tensor<10x10xf32>, %arg1: tens // ----- -// CHECK: func @quantized_conv2d +// CHECK-LABEL: func @convolution_operand_element_type_i4 +func.func @convolution_operand_element_type_i4(%arg0: tensor<64x8x8x8xi4>, %arg1: tensor<4x4x8x32xi4>) -> tensor<64x3x3x32xi8> { + // Note: This has been lowered and adapted from: + // %0 = "tf.Conv2D"(%arg0, %arg1) { + // data_format = "NHWC", + // dilations = [1, 2, 2, 1], + // explicit_paddings = [0, 0, 0, 1, 0, 1, 0, 0], + // padding = "EXPLICIT", + // strides = [1, 1, 1, 1]} : + // (tensor<64x8x8x8xf32>, tensor<4x4x8x32xf32>) -> tensor<64x3x3x32xf32> + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : + (tensor<64x8x8x8xi4>, tensor<4x4x8x32xi4>) -> tensor<64x3x3x32xi8> + func.return %0 : tensor<64x3x3x32xi8> +} + +// ----- + +// CHECK: func @convolution_quantized_conv2d // CHECK: stablehlo.convolution // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} -func.func @quantized_conv2d(%arg0: tensor<1x8x8x207x!quant.uniform>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { +func.func @convolution_quantized_conv2d(%arg0: tensor<1x8x8x207x!quant.uniform>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} @@ -5116,26 +5136,6 @@ func.func @einsum_i8xi8_i16(%arg0: tensor<1x2xi8>, %arg1: tensor<2x1xi8>) -> ten // ----- -// CHECK-LABEL: func @conv_i4 -func.func @conv_i4(%arg0: tensor<64x8x8x8xi4>, %arg1: tensor<4x4x8x32xi4>) -> tensor<64x3x3x32xi8> { - // Note: This has been lowered and adapted from: - // %0 = "tf.Conv2D"(%arg0, %arg1) { - // data_format = "NHWC", - // dilations = [1, 2, 2, 1], - // explicit_paddings = [0, 0, 0, 1, 0, 1, 0, 0], - // padding = "EXPLICIT", - // strides = [1, 1, 1, 1]} : - // (tensor<64x8x8x8xf32>, tensor<4x4x8x32xf32>) -> tensor<64x3x3x32xf32> - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [2, 2]} - {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : - (tensor<64x8x8x8xi4>, tensor<4x4x8x32xi4>) -> tensor<64x3x3x32xi8> - func.return %0 : tensor<64x3x3x32xi8> -} - -// ----- - // CHECK-LABEL: func @pad func.func @pad(%arg0: tensor<1x2x3xf16>, %arg1: tensor) -> tensor<2x4x7xf16> { %0 = "stablehlo.pad"(%arg0, %arg1) { diff --git a/stablehlo/tests/verify_conv.mlir b/stablehlo/tests/verify_convolution.mlir similarity index 66% rename from stablehlo/tests/verify_conv.mlir rename to stablehlo/tests/verify_convolution.mlir index 746197f71af..c9197e8aca7 100644 --- a/stablehlo/tests/verify_conv.mlir +++ b/stablehlo/tests/verify_convolution.mlir @@ -1,9 +1,7 @@ // RUN: stablehlo-opt %s -verify-diagnostics -split-input-file | FileCheck %s -// Valid: Generic convolution - -// CHECK-LABEL: func @main -func.func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> +// CHECK-LABEL: func @convolution +func.func @convolution(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { %result = "stablehlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, @@ -30,7 +28,28 @@ func.func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) // ----- -// Valid: Test convolution i8xi8 -> i32. +// CHECK: func @convolution_empty_spatial_dimensions +// CHECK: stablehlo.convolution +// CHECK-SAME: dim_numbers = [b, f]x[i, o]->[b, f] +// CHECK-SAME: window = {stride = [], pad = [], lhs_dilate = [], +// CHECK-SAME: rhs_dilate = [], reverse = []} +func.func @convolution_empty_spatial_dimensions(%arg0: tensor<3x2xf16>, + %arg1: tensor<2x2xf16>) -> tuple> { + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, f]x[i, o]->[b, f], + window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], + reverse = []} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } + : (tensor<3x2xf16>, tensor<2x2xf16>) -> tensor<3x2xf16> + %1 = "stablehlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple> + func.return %1 : tuple> +} + +// ----- // CHECK-LABEL: func @convolution_upcast func.func @convolution_upcast(%arg0 : tensor<100x26x26x32xi8>, @@ -59,71 +78,45 @@ func.func @convolution_upcast(%arg0 : tensor<100x26x26x32xi8>, // ----- -// Valid: Empty spatial dimensions - -// CHECK: func @conv_empty_spatial_dimensions -// CHECK: stablehlo.convolution -// CHECK-SAME: dim_numbers = [b, f]x[i, o]->[b, f] -// CHECK-SAME: window = {stride = [], pad = [], lhs_dilate = [], -// CHECK-SAME: rhs_dilate = [], reverse = []} -func.func @conv_empty_spatial_dimensions(%arg0: tensor<3x2xf16>, - %arg1: tensor<2x2xf16>) -> tuple> { +func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { + // expected-error@+3{{Unexpected keyword stide}} %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, f]x[i, o]->[b, f], - window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], - reverse = []} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } - : (tensor<3x2xf16>, tensor<2x2xf16>) -> tensor<3x2xf16> - %1 = "stablehlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple> - func.return %1 : tuple> + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stide = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> + func.return %0 : tensor<3x5x5x4xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<2x4x5x2xf32>, - %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> { - // expected-error@+1 {{expects input dimension-numbers to be unique, got {0, 0}.}} - %1 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 2 : i64, - someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> - tensor<2x3x4x6xf32> - func.return %1 : tensor<2x3x4x6xf32> +func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { + // expected-error@+3{{expected integer value}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, b], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> + func.return %0 : tensor<3x5x5x4xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects convolution arguments to have same number of dimensions. Got: 'tensor<1x8x8x207xf32>' and 'tensor<3x3x207xf32>'.}} +func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { + // expected-error@+3{{Unexpected keyword stride}} %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2], stride=[2,1]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> + func.return %0 : tensor<3x5x5x4xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) - -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects convolution arguments to have >= 2 dimensions. Got: 'tensor<1xf32>' and 'tensor<3xf32>'.}} +func.func @convolution_c1(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects convolution arguments to have same number of dimensions. Got: 'tensor<1x8x8x207xf32>' and 'tensor<3x3x207xf32>'.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], @@ -133,35 +126,37 @@ func.func @invalid_conv_dimensions(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] } : - (tensor<1xf32>, tensor<3xf32>) -> tensor<1x8x8x16xf32> + (tensor<1x8x8x207xf32>, tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 3, 2, and 2 resp.}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, 2, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> +func.func @convolution_c2(%arg0: tensor<1x4x4x1xi64>, + %arg1: tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi64> { + // expected-error@+1 {{expects lhs and rhs to have compatible element type. Got: 'i64' and 'i32'}} + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<4> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi64> + func.return %0 : tensor<1x2x2x1xi64> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, +func.func @convolution_c3(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 3, and 2 resp.}} + // expected-error@+1 {{expects window-strides to have same dimension-size as size of window dimensions (2), but got: 1.}} %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, 2, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, @@ -174,33 +169,32 @@ func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, +func.func @convolution_c4(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 2, and 3 resp.}} + // expected-error@+1 {{expects window to have positive stride for 1-th window dimension, but got 0.}} %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, 2, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 0], pad = [[1, 1], [1,1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : + precision_config = [#stablehlo, #stablehlo]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, - %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { - // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} +func.func @convolution_c5(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 3.}} %result = "stablehlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<2x2xi64>, + padding = dense<2> : tensor<3x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> @@ -220,71 +214,34 @@ func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, // ----- -func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, - %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { - // expected-error@+1 {{expects kernel dimension-numbers to be unique, got {3, 2, 0, 0}.}} - %result = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<2x2xi64>, +func.func @convolution_c5(%arg0: tensor<1x4x4x1xi64>, + %arg1: tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> { + // expected-error@+1 {{expects the shape of padding-attribute to be {N, 2}, but got {2, 3}.}} + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<4> : tensor<2xi64>, + padding = dense<0> : tensor<2x3xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64> - } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> - tensor<100x28x28x1xf32> - func.return %result : tensor<100x28x28x1xf32> -} - -// ----- - -func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, - %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { - // expected-error@+1 {{expects output dimension-numbers to be unique, got {0, 3, 0, 3}.}} - %result = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, - lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<2x2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64> - } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> - tensor<100x28x28x1xf32> - func.return %result : tensor<100x28x28x1xf32> + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + func.return %0 : tensor<1x2x2x1xi64> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, +func.func @convolution_c5(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects batch_group_count to be a positive number, got 0.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + // expected-error@+1 {{Expected array with 2 elements, got 4 elements instead}} + window = {stride = [1, 1], pad = [[1, 1, 1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { - batch_group_count = 0 : i64, + batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] } : @@ -294,34 +251,33 @@ func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, +func.func @convolution_c6(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects feature_group_count to be a positive number, got 0.}} + // expected-error@+1 {{expects base-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + lhs_dilate = [1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, - feature_group_count = 0 : i64, - precision_config = [#stablehlo, #stablehlo] - } : + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, +func.func @convolution_c7(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects batch_group_count and feature_group_count not to be both greater than 1. Got 2 and 2 resp.}} + // expected-error@+1 {{expects window to have positive base dilation factor for 0-th window dimension, but got 0.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [0, 1], rhs_dilate = [1, 1]} { - batch_group_count = 2 : i64, - feature_group_count = 2 : i64, + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> @@ -330,79 +286,59 @@ func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, +func.func @convolution_c8(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects output feature dimension size (16) to be a multiple of batch_group_count. Got batch_group_count = 3.}} + // expected-error@+1 {{expects window-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + lhs_dilate = [1, 1], rhs_dilate = [1]} { - batch_group_count = 3 : i64, + batch_group_count = 1 : i64, feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : + precision_config = [#stablehlo, #stablehlo]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects input feature dimension (207) to be a multiple of feature_group_count. Got feature_group_count = 2.}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 2 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> -} - -// ----- - -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects input feature dimension (207) / feature_group_count = kernel input feature dimension (20). Got feature_group_count = 1.}} +func.func @convolution_c9(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window to have positive window dilation factor for 0-th window dimension, but got 0.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [0, 1]} { batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] } : - (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects kernel output feature dimension (16) to be divisible by feature_group_count. For feature_group_count = 3.}} +func.func @convolution_c10(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window-reversal to have same dimension-size as size of window dimensions (2), but got: 1.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false]} { batch_group_count = 1 : i64, - feature_group_count = 3 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<5x8x8x207xf32>, +func.func @convolution_c11(%arg0: tensor<5x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+1 {{expects input batch dimension (5) to be divisible by batch_group_count. Got batch_group_count = 2.}} %0 = stablehlo.convolution(%arg0, %arg1) @@ -420,98 +356,212 @@ func.func @invalid_conv_dimensions(%arg0: tensor<5x8x8x207xf32>, // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects window-strides to have same dimension-size as size of window dimensions (2), but got: 1.}} +func.func @convolution_c12(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects input feature dimension (207) to be a multiple of feature_group_count. Got feature_group_count = 2.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1], pad = [[1, 1], [1, 1]], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, - feature_group_count = 1 : i64, + feature_group_count = 2 : i64, precision_config = [#stablehlo, #stablehlo] } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects base-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}} - %0 = stablehlo.convolution(%arg0, %arg1) +// This is an positive test in MLIR-HLO: +// https://github.com/tensorflow/mlir-hlo/blob/master/tests/Dialect/mhlo/ops.mlir#L3829 +// but negative here: stablehlo.convolution does no support unknown dimenstion +// dim_numbers = [b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f] +// window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} +func.func @convolution_c13(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> { + // expected-error@+1{{expects convolution arguments to have 4 dimensions. Got: 5}} + %0 = "stablehlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#stablehlo, #stablehlo], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : + (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> + func.return %0 : tensor<32x1x8x8x16xf32> +} + +// ----- + +func.func @convolution_c14(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) + -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects convolution arguments to have >= 2 dimensions. Got: 'tensor<1xf32>' and 'tensor<3xf32>'.}} + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1], rhs_dilate = [1, 1]} + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1xf32>, tensor<3xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } + // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects window-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}} +func.func @convolution_c14(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects input dimension-numbers to be unique, got {0, 0, 1, 2}.}} + %result = "stablehlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +func.func @convolution_c14(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} + %result = "stablehlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +func.func @convolution_c14(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} + %result = "stablehlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- + +func.func @convolution_c15(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects input feature dimension (207) / feature_group_count = kernel input feature dimension (20). Got feature_group_count = 1.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1]} + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects window-reversal to have same dimension-size as size of window dimensions (2), but got: 1.}} +func.func @convolution_c16(%arg0: tensor<3x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<3x8x8x16xf32> { + // expected-error@+1 {{expects output feature dimension size (16) to be a multiple of batch_group_count. Got batch_group_count = 3.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false]} + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { - batch_group_count = 1 : i64, + batch_group_count = 3 : i64, feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<3x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<3x8x8x16xf32> + func.return %0 : tensor<3x8x8x16xf32> } // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 1.}} +func.func @convolution_c17(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects kernel output feature dimension (16) to be divisible by feature_group_count. For feature_group_count = 3.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1]], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + feature_group_count = 3 : i64, + precision_config = [#stablehlo, #stablehlo] + } : + (tensor<1x8x8x207xf32>, tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- -func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, +func.func @convolution_c18(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 3, and 2 resp.}} %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - // expected-error@+1 {{Expected array with 2 elements, got 4 elements instead}} - window = {stride = [1, 1], pad = [[1, 1, 1, 1]], + dim_numbers = [b, 0, 1, f]x[0, 1, 2, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, @@ -524,9 +574,9 @@ func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, // ----- -func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> - tensor<100x28x28x1xf32> { - // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 3.}} +func.func @convolution_c19(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects kernel dimension-numbers to be unique, got {3, 2, 0, 0}.}} %result = "stablehlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #stablehlo.conv, %arg1 : ten input_spatial_dimensions = [1, 2], kernel_input_feature_dimension = 3, kernel_output_feature_dimension = 2, - kernel_spatial_dimensions = [0, 1], + kernel_spatial_dimensions = [0, 0], output_batch_dimension = 0, output_feature_dimension = 3, output_spatial_dimensions = [1, 2] >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<3x2xi64>, + padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> @@ -552,47 +602,69 @@ func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, %arg1 : ten // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects window to have positive value for 0-th window dimension, but got 0.}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> +func.func @convolution_c19(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} + %result = "stablehlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> } // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects window to have positive stride for 1-th window dimension, but got 0.}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 0], pad = [[1, 1], [1,1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> +func.func @convolution_c19(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} + %result = "stablehlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> } // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, +func.func @convolution_c20(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects window to have positive base dilation factor for 0-th window dimension, but got 0.}} + // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 2, and 3 resp.}} %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], - lhs_dilate = [0, 1], rhs_dilate = [1, 1]} + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, 2, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, feature_group_count = 1 : i64, @@ -604,243 +676,277 @@ func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, // ----- -func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects window to have positive window dilation factor for 0-th window dimension, but got 0.}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], - lhs_dilate = [1, 1], rhs_dilate = [0, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> - func.return %0 : tensor<1x8x8x16xf32> +func.func @convolution_c21(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects output dimension-numbers to be unique, got {0, 3, 0, 3}.}} + %result = "stablehlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> } // ----- -// Invalid rank of output-type. - -func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> { - // expected-error @+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x8x16xf32>'}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> - func.return %0 : tensor<1x8x16xf32> +func.func @convolution_c21(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} + %result = "stablehlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> } // ----- -// Invalid batch dimension in output-type. Should be equal to -// input-batch-dimension / batch_group_count. +func.func @convolution_c21(%arg0 : tensor<100x26x26x32xf32>, + %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}} + %result = "stablehlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #stablehlo.conv, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> + tensor<100x28x28x1xf32> + func.return %result : tensor<100x28x28x1xf32> +} + +// ----- -func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> { - // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<2x8x8x16xf32>'}} +func.func @convolution_c22(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects feature_group_count to be a positive number, got 0.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { batch_group_count = 1 : i64, - feature_group_count = 1 : i64, + feature_group_count = 0 : i64, precision_config = [#stablehlo, #stablehlo] } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> - func.return %0 : tensor<2x8x8x16xf32> + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> } // ----- -// Invalid feature dimension in output-type. Should be equal to -// kernel_output_feature_dimension. - -func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> { - // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x8x8x32xf32>'}} +func.func @convolution_c23(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects batch_group_count to be a positive number, got 0.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { - batch_group_count = 1 : i64, + batch_group_count = 0 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> - func.return %0 : tensor<1x8x8x32xf32> + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> } // ----- -// The following tests checks the inferred output-type of ConvolutionOp. We -// deliberately put an invalid output-type in these tests so that the -// inffered-type can be highlighted in the error message. - -// Dynamic input-batch-dimension -func.func @invalid_conv_dynamic_shapes(%arg0: tensor, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{inferred shape '[?, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} +func.func @convolution_c24(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects batch_group_count and feature_group_count not to be both greater than 1. Got 2 and 2 resp.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, + batch_group_count = 2 : i64, + feature_group_count = 2 : i64, precision_config = [#stablehlo, #stablehlo] } : - (tensor, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> - func.return %0 : tensor<1x1x1x1xf32> + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> } // ----- -// Dynamic input-feature-dimension: No effect on output dimensions. -func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x?xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} +func.func @convolution_c25(%arg0: tensor<3x2xf16>, + %arg1: tensor<2x2xf16>) -> tuple> { + // expected-error@+1{{expects precision config to be empty or have <= 2 elements}} %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + dim_numbers = [b, f]x[i, o]->[b, f], + window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], + reverse = []} { batch_group_count = 1 : i64, feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x?xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> - func.return %0 : tensor<1x1x1x1xf32> + precision_config = [#stablehlo, #stablehlo, #stablehlo] + } + : (tensor<3x2xf16>, tensor<2x2xf16>) -> tensor<3x2xf16> + %1 = "stablehlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple> + func.return %1 : tuple> } // ----- -// Dynamic input-spatial-dimension -func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x?x8x207xf32>, - %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{inferred shape '[1, ?, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x?x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> - func.return %0 : tensor<1x1x1x1xf32> +func.func @convolution_i3(%arg0: tensor<1x4x4x1xi64>, + %arg1: tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> { + // expected-error@+1 {{expects the shape of window_strides attribute to be 1-D, but got {2, 2}.}} + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<4> : tensor<2x2xi64>, + padding = dense<0> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + func.return %0 : tensor<1x2x2x1xi64> } // ----- -// Dynamic kernel-input-feature-dimension: No effect on output dimensions. -func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> - func.return %0 : tensor<1x1x1x1xf32> +func.func @convolution_i4(%arg0: tensor<1x4x4x1xi64>, + %arg1: tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> { + // expected-error@+1 {{expects the shape of padding-attribute to be {N, 2}, but got {2}.}} + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<4> : tensor<2xi64>, + padding = dense<0> : tensor<2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + func.return %0 : tensor<1x2x2x1xi64> } // ----- -// Dynamic kernel-output-feature-dimension -func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{inferred shape '[1, 8, 8, ?]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> - func.return %0 : tensor<1x1x1x1xf32> +func.func @convolution_i5(%arg0: tensor<1x4x4x1xi64>, + %arg1: tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> { + // expected-error@+1 {{expects the shape of lhs_dilation attribute to be 1-D, but got {2, 2}.}} + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<4> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + func.return %0 : tensor<1x2x2x1xi64> } // ----- -// Dynamic kernel-spatial-dimension -func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>, - %arg1: tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{inferred shape '[1, 8, ?, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1], [1,1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : - (tensor<1x8x8x207xf32>, tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> - func.return %0 : tensor<1x1x1x1xf32> +func.func @convolution_i6(%arg0: tensor<1x4x4x1xi64>, + %arg1: tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> { + // expected-error@+1 {{expects the shape of rhs_dilation attribute to be 1-D, but got {2, 2}.}} + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<4> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<1> : tensor<2x2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + func.return %0 : tensor<1x2x2x1xi64> } // ----- -// This is an positive test in MLIR-HLO: -// https://github.com/tensorflow/mlir-hlo/blob/master/tests/Dialect/mhlo/ops.mlir#L3829 -// but negative here: stablehlo.convolution does no support unknown dimenstion -// dim_numbers = [b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f] -// window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} -func.func @conv2d_generic(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> { - // expected-error@+1{{expects convolution arguments to have 4 dimensions. Got: 5}} - %0 = "stablehlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#stablehlo, #stablehlo], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : - (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> - func.return %0 : tensor<32x1x8x8x16xf32> +func.func @convolution_i7(%arg0: tensor<1x4x4x1xi64>, + %arg1: tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> { + // expected-error@+1 {{expects the shape of window_reversal attribute to be 1-D, but got {2, 2}.}} + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<4> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2x2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + func.return %0 : tensor<1x2x2x1xi64> } // ----- -func.func @conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error @+3 {{'stablehlo.convolution' Expected array with 2 elements, got 3 elements instead}} +func.func @convolution_invalid_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window to have positive value for 0-th window dimension, but got 0.}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[1, 1, 1], [1, 1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : - (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + window = {stride = [1, 1], pad = [[1, 1], [1,1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207xf32>, tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> func.return %0 : tensor<1x8x8x16xf32> } // ----- // CHECK: module -// CHECK-SAME: stablehlo.conv = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 1, 0, f]> +// CHECK: stablehlo.conv = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 1, 0, f]> module attributes { stablehlo.conv = #stablehlo.conv[f, b, 0, 1]> } {} - -// ----- - -func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { - // expected-error@+3{{Unexpected keyword stide}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stide = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} - { batch_group_count = 1 : i64, feature_group_count = 1 : i64} - : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> - func.return %0 : tensor<3x5x5x4xf32> -} - -// ----- - -func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { - // expected-error@+3{{expected integer value}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [2, b], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} - { batch_group_count = 1 : i64, feature_group_count = 1 : i64} - : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> - func.return %0 : tensor<3x5x5x4xf32> -} - -// ----- - -func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { - // expected-error@+3{{Unexpected keyword stride}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2], stride=[2,1]} - { batch_group_count = 1 : i64, feature_group_count = 1 : i64} - : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> - func.return %0 : tensor<3x5x5x4xf32> -} - -// ----- - -func.func @conv_invalid_precision_config(%arg0: tensor<3x2xf16>, - %arg1: tensor<2x2xf16>) -> tuple> { - // expected-error@+1{{expects precision config to be empty or have <= 2 elements}} - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, f]x[i, o]->[b, f], - window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], - reverse = []} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo, #stablehlo] - } - : (tensor<3x2xf16>, tensor<2x2xf16>) -> tensor<3x2xf16> - %1 = "stablehlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple> - func.return %1 : tuple> -}