From 23116b3c6aa0754e12c6f155f012dba2d1384f70 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 28 Nov 2023 14:36:50 -0800 Subject: [PATCH] Fix typo in all_reduce constraint comment (#1860) --- stablehlo/dialect/TypeInference.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 7f023191114..77fae7b64aa 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -536,20 +536,20 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, ArrayRef allowedDimensions) { int64_t numInputs = inputTypes.size(); - // all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13, + // all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13, // scatter_c15, select_and_scatter_c10 if (static_cast(block.getArguments().size()) != numInputs * 2) return emitOptionalError(loc, "Reduction-region must take ", numInputs * 2, " parameters, but takes ", block.getArguments().size(), " parameter(s)"); - // all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13, + // all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13, // scatter_c15, select_and_scatter_c10 if (block.getTerminator()->getOperands().empty()) return emitOptionalError( loc, "The reduction-region expected to return some value(s)"); - // all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13, + // all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13, // scatter_c15, select_and_scatter_c10 if (static_cast(block.getTerminator()->getOperands().size()) != numInputs) @@ -558,7 +558,7 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, block.getTerminator()->getOperands().size(), " instead"); - // all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13, + // all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13, // scatter_c15, select_and_scatter_c10 SmallVector accumulatorSubShapes; for (Value retOperand : block.getTerminator()->getOperands()) { @@ -573,7 +573,7 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, } for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { - // all_reduce_c6, reduce_c2, reduce_scatter_c7, reduce_window_c13, + // all_reduce_c5, reduce_c2, reduce_scatter_c7, reduce_window_c13, // scatter_c15, select_and_scatter_c10 if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx], block.getArgument(inputIdx).getType())) @@ -583,7 +583,7 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, block.getArgument(inputIdx).getType(), " vs ", accumulatorSubShapes[inputIdx]); - // all_reduce_c6, reduce_c2, reduce_scatter_c7, reduce_window_c13, + // all_reduce_c5, reduce_c2, reduce_scatter_c7, reduce_window_c13, // scatter_c15, select_and_scatter_c3, select_and_scatter_c10 if (!compatibleShapeAndElementType( accumulatorSubShapes[inputIdx], @@ -596,7 +596,7 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, block.getArgument(numInputs + inputIdx).getType(), " vs ", accumulatorSubShapes[inputIdx]); - // all_reduce_c6, reduce_c6, reduce_scatter_c7, reduce_window_c13, + // all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13, // reduce_window_i2, scatter_c6, scatter_c15, select_and_scatter_c10 if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx], initValueTypes[inputIdx],