From 4045a5039f8292ecbb2a1e95351db86f3f3bf8a8 Mon Sep 17 00:00:00 2001 From: Gita Koblents Date: Fri, 25 Nov 2022 19:56:24 -0500 Subject: [PATCH 1/3] Recognize convert() Vector API intrinsic in VectorAPIExpansion --- .../codegen/J9RecognizedMethodsEnum.hpp | 1 + runtime/compiler/env/j9method.cpp | 1 + .../compiler/optimizer/J9ValuePropagation.cpp | 5 +++ .../compiler/optimizer/VectorAPIExpansion.cpp | 8 +++++ .../compiler/optimizer/VectorAPIExpansion.hpp | 32 +++++++++++++++++++ 5 files changed, 47 insertions(+) diff --git a/runtime/compiler/codegen/J9RecognizedMethodsEnum.hpp b/runtime/compiler/codegen/J9RecognizedMethodsEnum.hpp index 39dfa9b5a7d..268b1995a99 100644 --- a/runtime/compiler/codegen/J9RecognizedMethodsEnum.hpp +++ b/runtime/compiler/codegen/J9RecognizedMethodsEnum.hpp @@ -457,6 +457,7 @@ jdk_internal_vm_vector_VectorSupport_binaryOp, jdk_internal_vm_vector_VectorSupport_blend, jdk_internal_vm_vector_VectorSupport_compare, + jdk_internal_vm_vector_VectorSupport_convert, jdk_internal_vm_vector_VectorSupport_fromBitsCoerced, jdk_internal_vm_vector_VectorSupport_maskReductionCoerced, jdk_internal_vm_vector_VectorSupport_reductionCoerced, diff --git a/runtime/compiler/env/j9method.cpp b/runtime/compiler/env/j9method.cpp index a9ab11572ba..1a57351296d 100644 --- a/runtime/compiler/env/j9method.cpp +++ b/runtime/compiler/env/j9method.cpp @@ -3013,6 +3013,7 @@ void TR_ResolvedJ9Method::construct() {x(TR::jdk_internal_vm_vector_VectorSupport_binaryOp, "binaryOp", "(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$VectorPayload;Ljdk/internal/vm/vector/VectorSupport$VectorPayload;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$BinaryOperation;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;" )}, {x(TR::jdk_internal_vm_vector_VectorSupport_blend, "blend", "(Ljava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorBlendOp;)Ljdk/internal/vm/vector/VectorSupport$Vector;")}, {x(TR::jdk_internal_vm_vector_VectorSupport_compare, "compare", "(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorCompareOp;)Ljdk/internal/vm/vector/VectorSupport$VectorMask;")}, + {x(TR::jdk_internal_vm_vector_VectorSupport_convert, "convert", "(ILjava/lang/Class;Ljava/lang/Class;ILjava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$VectorPayload;Ljdk/internal/vm/vector/VectorSupport$VectorSpecies;Ljdk/internal/vm/vector/VectorSupport$VectorConvertOp;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;")}, {x(TR::jdk_internal_vm_vector_VectorSupport_fromBitsCoerced, "fromBitsCoerced", "(Ljava/lang/Class;Ljava/lang/Class;IJILjdk/internal/vm/vector/VectorSupport$VectorSpecies;Ljdk/internal/vm/vector/VectorSupport$FromBitsCoercedOperation;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;")}, {x(TR::jdk_internal_vm_vector_VectorSupport_maskReductionCoerced, "maskReductionCoerced", "(ILjava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorMaskOp;)J")}, {x(TR::jdk_internal_vm_vector_VectorSupport_reductionCoerced, "reductionCoerced", "(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$ReductionOperation;)J")}, diff --git a/runtime/compiler/optimizer/J9ValuePropagation.cpp b/runtime/compiler/optimizer/J9ValuePropagation.cpp index 5e97d8b3ee9..52f28cdecb4 100644 --- a/runtime/compiler/optimizer/J9ValuePropagation.cpp +++ b/runtime/compiler/optimizer/J9ValuePropagation.cpp @@ -3553,6 +3553,8 @@ J9::ValuePropagation::innerConstrainAcall(TR::Node *node) method->getRecognizedMethod() == TR::jdk_internal_vm_vector_VectorSupport_ternaryOp; bool isVectorSupportCompare = method->getRecognizedMethod() == TR::jdk_internal_vm_vector_VectorSupport_compare; + bool isVectorSupportConvert = + method->getRecognizedMethod() == TR::jdk_internal_vm_vector_VectorSupport_convert; bool isVectorSupportBlend = method->getRecognizedMethod() == TR::jdk_internal_vm_vector_VectorSupport_blend; @@ -3562,6 +3564,7 @@ J9::ValuePropagation::innerConstrainAcall(TR::Node *node) isVectorSupportUnaryOp || isVectorSupportTernaryOp || isVectorSupportCompare || + isVectorSupportConvert || isVectorSupportBlend) { bool isGlobal; // dummy @@ -3573,6 +3576,8 @@ J9::ValuePropagation::innerConstrainAcall(TR::Node *node) typeChildIndex = 0; else if (isVectorSupportCompare) typeChildIndex = 2; + else if (isVectorSupportConvert) + typeChildIndex = 4; else typeChildIndex = 1; diff --git a/runtime/compiler/optimizer/VectorAPIExpansion.cpp b/runtime/compiler/optimizer/VectorAPIExpansion.cpp index 83138437b1d..6ad57751002 100644 --- a/runtime/compiler/optimizer/VectorAPIExpansion.cpp +++ b/runtime/compiler/optimizer/VectorAPIExpansion.cpp @@ -1889,6 +1889,13 @@ TR::Node *TR_VectorAPIExpansion::compareIntrinsicHandler(TR_VectorAPIExpansion * return naryIntrinsicHandler(opt, treeTop, node, elementType, vectorLength, numLanes, mode, 2, Compare); } +TR::Node *TR_VectorAPIExpansion::convertIntrinsicHandler(TR_VectorAPIExpansion *opt, TR::TreeTop *treeTop, TR::Node *node, + TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes, + handlerMode mode) + { + return NULL; + } + TR::ILOpCodes TR_VectorAPIExpansion::ILOpcodeFromVectorAPIOpcode(int32_t vectorAPIOpCode, TR::DataType elementType, TR::VectorLength vectorLength, vapiOpCodeType opCodeType, bool withMask) { @@ -2173,6 +2180,7 @@ TR_VectorAPIExpansion::methodTable[] = {binaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_binaryOp {blendIntrinsicHandler, Vector, {Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Unknown}}, // jdk_internal_vm_vector_VectorSupport_blend {compareIntrinsicHandler, Mask, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_compare + {convertIntrinsicHandler, Mask, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_convert {fromBitsCoercedIntrinsicHandler, Unknown, {Unknown, ElementType, NumLanes, Unknown, Unknown, Unknown}}, // jdk_internal_vm_vector_VectorSupport_fromBitsCoerced {maskReductionCoercedIntrinsicHandler, Scalar, {Unknown, Unknown, ElementType, NumLanes, Mask}}, // jdk_internal_vm_vector_VectorSupport_maskReductionCoerced {reductionCoercedIntrinsicHandler, Scalar, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_reductionCoerced diff --git a/runtime/compiler/optimizer/VectorAPIExpansion.hpp b/runtime/compiler/optimizer/VectorAPIExpansion.hpp index b343e063b3a..68a02c943bd 100644 --- a/runtime/compiler/optimizer/VectorAPIExpansion.hpp +++ b/runtime/compiler/optimizer/VectorAPIExpansion.hpp @@ -1076,6 +1076,38 @@ class TR_VectorAPIExpansion : public TR::Optimization static TR::Node *compareIntrinsicHandler(TR_VectorAPIExpansion *opt, TR::TreeTop *treeTop, TR::Node *node, TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes, handlerMode mode); + /** \brief + * Scalarizes or vectorizes a node that is a call to \c VectorSupport.convert() intrinsic. + * In both cases, the node is modified in place. + * In the case of scalarization, extra nodes are created(number of lanes minus one) + * + * \param opt + * This optimization object + * + * \param treeTop + * Tree top of the \c node + * + * \param node + * Node to transform + * + * \param elementType + * Element type + * + * \param vectorLength + * Vector length + * + * \param numLanes + * Number of elements + * + * \param mode + * Handler mode + * + * \return + * Transformed node + */ + static TR::Node *convertIntrinsicHandler(TR_VectorAPIExpansion *opt, TR::TreeTop *treeTop, TR::Node *node, TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes, handlerMode mode); + + /** \brief * Helper method to transform a load from array node * From c9dd3a716792478df95881925c075318d08283d7 Mon Sep 17 00:00:00 2001 From: Gita Koblents Date: Fri, 17 Feb 2023 18:02:04 -0500 Subject: [PATCH 2/3] Allow multi-type webs in VectorAPIExpansion - VectorAPIExpansion will create one web that consists of temps and intrinsic calls. In this web, multiple types can coexist. - It will further create webs of temps assigned to each other. In these webs, only one type is allowed. - During actual transformation, the type of intrinsic call will be determined from its parameters. The type of a temp will be determined by the type of the temp web it belongs to. --- runtime/compiler/il/J9DataTypes.cpp | 2 +- .../compiler/optimizer/VectorAPIExpansion.cpp | 468 +++++++++++++----- .../compiler/optimizer/VectorAPIExpansion.hpp | 122 ++++- 3 files changed, 443 insertions(+), 149 deletions(-) diff --git a/runtime/compiler/il/J9DataTypes.cpp b/runtime/compiler/il/J9DataTypes.cpp index 8f0b0dee33d..36171b59af5 100644 --- a/runtime/compiler/il/J9DataTypes.cpp +++ b/runtime/compiler/il/J9DataTypes.cpp @@ -180,7 +180,7 @@ J9::ILOpCode::getDataTypeConversion(TR::DataType t1, TR::DataType t2) if (t1.isMask() || t2.isMask()) return TR::BadILOp; - if (t1.isVector() && t2.isVector()) return TR::ILOpCode::createVectorOpCode(TR::vcast, t1, t2); + if (t1.isVector() && t2.isVector()) return TR::ILOpCode::createVectorOpCode(TR::vconv, t1, t2); if (t1.isVector() || t2.isVector()) return TR::BadILOp; diff --git a/runtime/compiler/optimizer/VectorAPIExpansion.cpp b/runtime/compiler/optimizer/VectorAPIExpansion.cpp index 6ad57751002..34c30494fb5 100644 --- a/runtime/compiler/optimizer/VectorAPIExpansion.cpp +++ b/runtime/compiler/optimizer/VectorAPIExpansion.cpp @@ -74,15 +74,70 @@ TR_VectorAPIExpansion::getReturnType(TR::MethodSymbol * methodSymbol) return methodTable[index - _firstMethod]._returnType; } -bool -TR_VectorAPIExpansion::isArgType(TR::MethodSymbol *methodSymbol, int32_t i, vapiObjType type) +int32_t +TR_VectorAPIExpansion::getElementTypeIndex(TR::MethodSymbol *methodSymbol) + { + TR_ASSERT_FATAL(isVectorAPIMethod(methodSymbol), "getElementTypeIndex should be called on VectorAPI method"); + + TR::RecognizedMethod index = methodSymbol->getRecognizedMethod(); + + return methodTable[index - _firstMethod]._elementTypeIndex; + } + +int32_t +TR_VectorAPIExpansion::getNumLanesIndex(TR::MethodSymbol *methodSymbol) { - if (!isVectorAPIMethod(methodSymbol) || i < 0 ) return false; + TR_ASSERT_FATAL(isVectorAPIMethod(methodSymbol), "getNumLanesIndex should be called on VectorAPI method"); TR::RecognizedMethod index = methodSymbol->getRecognizedMethod(); - TR_ASSERT_FATAL(i < _maxNumberArguments, "Argument index %d is too big", i); - return (methodTable[index - _firstMethod]._argumentTypes[i] == type); + return methodTable[index - _firstMethod]._numLanesIndex; + } + +int32_t +TR_VectorAPIExpansion::getFirstOperandIndex(TR::MethodSymbol *methodSymbol) + { + TR_ASSERT_FATAL(isVectorAPIMethod(methodSymbol), "getFirstOperandIndex should be called on VectorAPI method"); + + TR::RecognizedMethod index = methodSymbol->getRecognizedMethod(); + + return methodTable[index - _firstMethod]._firstOperandIndex; + } + +int32_t +TR_VectorAPIExpansion::getNumOperands(TR::MethodSymbol *methodSymbol) + { + TR_ASSERT_FATAL(isVectorAPIMethod(methodSymbol), "getNumOperands should be called on VectorAPI method"); + + TR::RecognizedMethod index = methodSymbol->getRecognizedMethod(); + + return methodTable[index - _firstMethod]._numOperands; + } + +int32_t +TR_VectorAPIExpansion::getMaskIndex(TR::MethodSymbol *methodSymbol) + { + TR_ASSERT_FATAL(isVectorAPIMethod(methodSymbol), "getMaskIndex should be called on VectorAPI method"); + + TR::RecognizedMethod index = methodSymbol->getRecognizedMethod(); + + return methodTable[index - _firstMethod]._maskIndex; + } + +void +TR_VectorAPIExpansion::getElementTypeAndNumLanes(TR::Node *node, TR::DataType &elementType, int32_t &numLanes) + { + TR_ASSERT_FATAL(node->getOpCode().isFunctionCall(), "getElementTypeAndVectorLength can only be called on a call node"); + + TR::MethodSymbol *methodSymbol = node->getSymbolReference()->getSymbol()->castToMethodSymbol(); + + int32_t i = getElementTypeIndex(methodSymbol); + TR::Node *elementTypeNode = node->getChild(i); + elementType = getDataTypeFromClassNode(comp(), elementTypeNode); + + i = getNumLanesIndex(methodSymbol); + TR::Node *numLanesNode = node->getChild(i); + numLanes = numLanesNode->get32bitIntegralValue(); } void @@ -90,10 +145,11 @@ TR_VectorAPIExpansion::invalidateSymRef(TR::SymbolReference *symRef) { int32_t id = symRef->getReferenceNumber(); _aliasTable[id]._classId = -1; + _aliasTable[id]._tempClassId = -1; } void -TR_VectorAPIExpansion::alias(TR::Node *node1, TR::Node *node2) +TR_VectorAPIExpansion::alias(TR::Node *node1, TR::Node *node2, bool aliasTemps) { TR_ASSERT_FATAL(node1->getOpCode().hasSymbolReference() && node2->getOpCode().hasSymbolReference(), "%s nodes should have symbol references %p %p", OPT_DETAILS_VECTOR, node1, node2); @@ -114,10 +170,26 @@ TR_VectorAPIExpansion::alias(TR::Node *node1, TR::Node *node2) _aliasTable[id2]._aliases = new (comp()->trStackMemory()) TR_BitVector(symRefCount, comp()->trMemory(), stackAlloc); if (_trace) - traceMsg(comp(), "%s aliasing symref #%d to symref #%d (nodes %p %p)\n", OPT_DETAILS_VECTOR, id1, id2, node1, node2); + traceMsg(comp(), "%s aliasing symref #%d to symref #%d (nodes %p %p) for the whole class\n", OPT_DETAILS_VECTOR, id1, id2, node1, node2); _aliasTable[id1]._aliases->set(id2); _aliasTable[id2]._aliases->set(id1); + + if (aliasTemps) + { + if (_aliasTable[id1]._tempAliases == NULL) + _aliasTable[id1]._tempAliases = new (comp()->trStackMemory()) TR_BitVector(symRefCount, comp()->trMemory(), stackAlloc); + + if (_aliasTable[id2]._tempAliases == NULL) + _aliasTable[id2]._tempAliases = new (comp()->trStackMemory()) TR_BitVector(symRefCount, comp()->trMemory(), stackAlloc); + + if (_trace) + traceMsg(comp(), "%s aliasing symref #%d to symref #%d (nodes %p %p) as temps\n", OPT_DETAILS_VECTOR, id1, id2, node1, node2); + + _aliasTable[id1]._tempAliases->set(id2); + _aliasTable[id2]._tempAliases->set(id1); + } + } @@ -159,14 +231,48 @@ TR_VectorAPIExpansion::visitNodeToBuildVectorAliases(TR::Node *node) { if (!node->chkStoredValueIsIrrelevant()) { + int32_t id1 = node->getSymbolReference()->getReferenceNumber(); TR::Node *rhs = (opCodeValue == TR::astore) ? node->getFirstChild() : node->getSecondChild(); + if (rhs->getOpCode().hasSymbolReference()) { - alias(node, rhs); - - int32_t id1 = node->getSymbolReference()->getReferenceNumber(); int32_t id2 = rhs->getSymbolReference()->getReferenceNumber(); + bool aliasTemps = false; + + if (opCodeValue == TR::astore && + rhs->getOpCode().isFunctionCall() && + isVectorAPIMethod(rhs->getSymbolReference()->getSymbol()->castToMethodSymbol())) + { + // propagate vector info from VectorAPI call to temp + TR::DataType elementType; + int32_t numLanes; + + getElementTypeAndNumLanes(rhs, elementType, numLanes); + + int32_t elementSize = OMR::DataType::getSize(elementType); + int32_t bitsLength = numLanes*elementSize*8; + + + if ((_aliasTable[id1]._elementType != TR::NoType && _aliasTable[id1]._elementType != elementType) || + (_aliasTable[id1]._vecLen != vec_len_default && _aliasTable[id1]._vecLen != bitsLength)) + { + if (_trace) + traceMsg(comp(), "Invalidating #%d due to rhs %p in node %p\n", id1, rhs, node); + invalidateSymRef(node->getSymbolReference()); + } + else + { + _aliasTable[id1]._elementType = elementType; + _aliasTable[id1]._vecLen = bitsLength; + } + } + + if (opCodeValue == TR::astore && rhs->getOpCodeValue() == TR::aload) + aliasTemps = true; + + alias(node, rhs, aliasTemps); + if (_aliasTable[id1]._objectType == Unknown && _aliasTable[id2]._objectType == Unknown) { @@ -190,7 +296,7 @@ TR_VectorAPIExpansion::visitNodeToBuildVectorAliases(TR::Node *node) else { if (_trace) - traceMsg(comp(), "Invalidating #%p due to rhs %p in node %p\n", node->getSymbolReference()->getReferenceNumber(), rhs, node); + traceMsg(comp(), "Invalidating #%d due to rhs %p in node %p\n", id1, rhs, node); invalidateSymRef(node->getSymbolReference()); } } @@ -235,13 +341,14 @@ TR_VectorAPIExpansion::visitNodeToBuildVectorAliases(TR::Node *node) for (int32_t i = 0; i < numChildren; i++) { + bool isMask = false; + if (!isVectorAPICall || - isArgType(methodSymbol, i, Vector) || - isArgType(methodSymbol, i, Mask)) + (i >= getFirstOperandIndex(methodSymbol) && i < (getFirstOperandIndex(methodSymbol) + getNumOperands(methodSymbol))) || + (isMask = (i == getMaskIndex(methodSymbol)))) { TR::Node *child = node->getChild(i); bool hasSymbolReference = child->getOpCode().hasSymbolReference(); - bool isMask = isVectorAPICall && isArgType(methodSymbol, i, Mask); bool isNullMask = isMask && child->isConstZeroValue(); if (hasSymbolReference) @@ -267,42 +374,8 @@ TR_VectorAPIExpansion::visitNodeToBuildVectorAliases(TR::Node *node) continue; } - // Update type and length - if (isArgType(methodSymbol, i, Species)) - { - vec_sz_t methodLen = _aliasTable[methodRefNum]._vecLen; - - TR::Node *speciesNode = node->getChild(i); - int32_t speciesRefNum = speciesNode->getSymbolReference()->getReferenceNumber(); - vec_sz_t speciesLen; - - if (_aliasTable[speciesRefNum]._vecLen == vec_len_default) - { - speciesLen = getVectorSizeFromVectorSpecies(speciesNode); - if (_trace) - traceMsg(comp(), "%snode n%dn (#%d) was updated with vecLen : %d\n", - OPT_DETAILS_VECTOR, speciesNode->getGlobalIndex(), speciesRefNum, speciesLen); - } - else - { - speciesLen = _aliasTable[speciesRefNum]._vecLen; - } - - if (methodLen != vec_len_default && speciesLen != methodLen) - { - if (_trace) - traceMsg(comp(), "%snode n%dn (#%d) species are %d but method is : %d\n", - OPT_DETAILS_VECTOR, node->getGlobalIndex(), methodRefNum, speciesLen, methodLen); - speciesLen = vec_len_unknown; - } - - _aliasTable[methodRefNum]._vecLen = speciesLen; - - if (_trace) - traceMsg(comp(), "%snode n%dn (#%d) was updated with vecLen : %d\n", - OPT_DETAILS_VECTOR, node->getGlobalIndex(), methodRefNum, speciesLen); - } - else if (isArgType(methodSymbol, i, ElementType)) + // Update method element type and vector length + if (i == getElementTypeIndex(methodSymbol)) { TR::Node *elementTypeNode = node->getChild(i); methodElementType = getDataTypeFromClassNode(comp(), elementTypeNode); @@ -313,7 +386,7 @@ TR_VectorAPIExpansion::visitNodeToBuildVectorAliases(TR::Node *node) if (methodSymbol->getRecognizedMethod() != TR::jdk_internal_vm_vector_VectorSupport_maskReductionCoerced) _aliasTable[methodRefNum]._elementType = methodElementType; } - else if (isArgType(methodSymbol, i, NumLanes)) + else if (i == getNumLanesIndex(methodSymbol)) { TR::Node *numLanesNode = node->getChild(i); _aliasTable[methodRefNum]._vecLen = vec_len_unknown; @@ -440,27 +513,32 @@ TR_VectorAPIExpansion::visitNodeToBuildVectorAliases(TR::Node *node) void -TR_VectorAPIExpansion::findAllAliases(int32_t classId, int32_t id) +TR_VectorAPIExpansion::findAllAliases(int32_t classId, int32_t id, + TR_BitVector * vectorAliasTableElement::* aliasesField, + int32_t vectorAliasTableElement::* classField) { - if (_aliasTable[id]._aliases == NULL) + bool tempAliases = &vectorAliasTableElement::_tempAliases == aliasesField; + + if (_aliasTable[id].*aliasesField == NULL) { - TR_ASSERT_FATAL(_aliasTable[id]._classId <= 0 , "#%d should have class -1 or 0, but it's %d\n", - id, _aliasTable[id]._classId); + TR_ASSERT_FATAL(_aliasTable[id].*classField <= 0 , "#%d should have class -1 or 0, but it's %d\n", + id, _aliasTable[id].*classField); - if (_aliasTable[id]._classId == 0) - _aliasTable[id]._classId = id; // in their own empty class + if (_aliasTable[id].*classField == 0) + _aliasTable[id].*classField = id; // in their own empty class return; } + if (_trace) { - traceMsg(comp(), "Iterating through aliases for #%d:\n", id); - _aliasTable[id]._aliases->print(comp()); + traceMsg(comp(), "Iterating through %saliases for #%d:\n", tempAliases ? "temp " : "", id); + (_aliasTable[id].*aliasesField)->print(comp()); traceMsg(comp(), "\n"); } // we need to create a new bit vector so that we don't iterate and modify at the same time - TR_BitVector *aliasesToIterate = (classId == id) ? new (comp()->trStackMemory()) TR_BitVector(*_aliasTable[id]._aliases) - : _aliasTable[id]._aliases; + TR_BitVector *aliasesToIterate = (classId == id) ? new (comp()->trStackMemory()) TR_BitVector(*(_aliasTable[id].*aliasesField)) + : _aliasTable[id].*aliasesField; TR_BitVectorIterator bvi(*aliasesToIterate); @@ -468,33 +546,31 @@ TR_VectorAPIExpansion::findAllAliases(int32_t classId, int32_t id) { int32_t i = bvi.getNextElement(); - if (_aliasTable[i]._classId > 0) + if (_aliasTable[i].*classField > 0) { - TR_ASSERT_FATAL(_aliasTable[i]._classId == classId, "#%d should belong to class %d but it belongs to class %d\n", - i, classId, _aliasTable[i]._classId ); + TR_ASSERT_FATAL(_aliasTable[i].*classField == classId, "#%d should belong to class %d but it belongs to class %d\n", + i, classId, _aliasTable[i].*classField ); continue; } - _aliasTable[classId]._aliases->set(i); + (_aliasTable[classId].*aliasesField)->set(i); - if (_aliasTable[i]._classId == -1) + if (_aliasTable[i].*classField == -1) { if (_trace) - traceMsg(comp(), "Invalidating the whole class #%d due to #%d\n", classId, i); - _aliasTable[classId]._classId = -1; // invalidate the whole class + traceMsg(comp(), "Invalidating %sclass #%d since #%d is already invalid\n", tempAliases ? "temp " : "", classId, i); + _aliasTable[classId].*classField = -1; // invalidate the whole class } - if (_aliasTable[i]._classId != -1 || i != classId) + if (_aliasTable[i].*classField != -1 || i != classId) { if (_trace) - traceMsg(comp(), "Set class #%d for symref #%d\n", classId, i); - _aliasTable[i]._classId = classId; + traceMsg(comp(), "Set %sclass #%d for symref #%d\n", tempAliases ? "temp " : "", classId, i); + _aliasTable[i].*classField = classId; } if (i != classId) - findAllAliases(classId, i); + findAllAliases(classId, i, aliasesField, classField); } - - } @@ -506,10 +582,25 @@ TR_VectorAPIExpansion::buildAliasClasses() int32_t symRefCount = comp()->getSymRefTab()->getNumSymRefs(); + TR_BitVector * vectorAliasTableElement::* aliasesField = &vectorAliasTableElement::_aliases; + int32_t vectorAliasTableElement::* classField = &vectorAliasTableElement::_classId; + for (int32_t i = 0; i < symRefCount; i++) { - if (_aliasTable[i]._classId <= 0) - findAllAliases(i, i); + if (_aliasTable[i].*classField <= 0) + findAllAliases(i, i, aliasesField, classField); + } + + if (_trace) + traceMsg(comp(), "%s Building temp alias classes\n", OPT_DETAILS_VECTOR); + + aliasesField = &vectorAliasTableElement::_tempAliases; + classField = &vectorAliasTableElement::_tempClassId; + + for (int32_t i = 0; i < symRefCount; i++) + { + if (_aliasTable[i].*classField <= 0) + findAllAliases(i, i, aliasesField, classField); } } @@ -650,14 +741,15 @@ TR_VectorAPIExpansion::findVectorMethods(TR::Compilation *comp) } bool -TR_VectorAPIExpansion::validateSymRef(int32_t id, int32_t i, vec_sz_t &classLength, TR::DataType &classType) +TR_VectorAPIExpansion::validateSymRef(int32_t id, int32_t i, vec_sz_t &classLength, TR::DataType &classType, + int32_t vectorAliasTableElement::* classField) { TR::SymbolReference *symRef = comp()->getSymRefTab()->getSymRef(i); if (!symRef || !symRef->getSymbol()) return false; - if (_aliasTable[i]._classId == -1) + if (_aliasTable[i].*classField == -1) { if (_trace) traceMsg(comp(), "%s invalidating1 class #%d due to symref #%d\n", OPT_DETAILS_VECTOR, id, i); @@ -681,57 +773,71 @@ TR_VectorAPIExpansion::validateSymRef(int32_t id, int32_t i, vec_sz_t &classLeng traceMsg(comp(), "%s invalidating3 class #%d due to non-API method #%d\n", OPT_DETAILS_VECTOR, id, i); return false; } - - vec_sz_t methodLength = _aliasTable[i]._vecLen; - TR::DataType methodType = _aliasTable[i]._elementType; + } + else + { + vec_sz_t tempLength = _aliasTable[i]._vecLen; + TR::DataType tempType = _aliasTable[i]._elementType; if (classLength == vec_len_default) { - classLength = methodLength; + classLength = tempLength; } - else if (methodLength != vec_len_default && - methodLength != classLength) + else if (tempLength != vec_len_default && + tempLength != classLength) { if (_trace) - traceMsg(comp(), "%s invalidating5 class #%d due to symref #%d method length %d, seen length %d\n", - OPT_DETAILS_VECTOR, id, i, methodLength, classLength); + traceMsg(comp(), "%s invalidating5 class #%d due to symref #%d temp length %d, seen length %d\n", + OPT_DETAILS_VECTOR, id, i, tempLength, classLength); return false; } if (classType == TR::NoType) { - classType = methodType; + classType = tempType; } - else if (methodType != TR::NoType && - methodType != classType) + else if (tempType != TR::NoType && + tempType != classType) { if (_trace) - traceMsg(comp(), "%s invalidating6 class #%d due to symref #%d method type %s, seen type %s\n", - OPT_DETAILS_VECTOR, id, i, TR::DataType::getName(methodType), TR::DataType::getName(classType)); + traceMsg(comp(), "%s invalidating6 class #%d due to symref #%d temp type %s, seen type %s\n", + OPT_DETAILS_VECTOR, id, i, TR::DataType::getName(tempType), TR::DataType::getName(classType)); return false; } } + return true; } void -TR_VectorAPIExpansion::validateVectorAliasClasses() +TR_VectorAPIExpansion::validateVectorAliasClasses(TR_BitVector * vectorAliasTableElement::* aliasesField, + int32_t vectorAliasTableElement::* classField) { + bool tempClasses = &vectorAliasTableElement::_tempAliases == aliasesField; + if (_trace) - traceMsg(comp(), "%s Validating alias classes\n", OPT_DETAILS_VECTOR); + traceMsg(comp(), "%s Verifying all %salias classes\n", OPT_DETAILS_VECTOR, tempClasses ? "temp " : ""); int32_t symRefCount = comp()->getSymRefTab()->getNumSymRefs(); for (int32_t id = 1; id < symRefCount; id++) { - if (_aliasTable[id]._classId != id) + TR::SymbolReference *symRef = comp()->getSymRefTab()->getSymRef(id); + + if (tempClasses && + symRef && + symRef->getSymbol() && + symRef->getSymbol()->isMethod()) + continue; // classes of temps should not include methods + + if ((_aliasTable[id].*classField) != id) continue; // not an alias class or is already invalid - if (_aliasTable[id]._aliases && _trace) + if (_aliasTable[id].*aliasesField && _trace) { - traceMsg(comp(), "Verifying class: %d\n", id); - _aliasTable[id]._aliases->print(comp()); + traceMsg(comp(), "Verifying %sclass: %d\n", tempClasses ? "temp " : "", id); + (_aliasTable[id].*aliasesField)->print(comp()); traceMsg(comp(), "\n"); } @@ -739,25 +845,30 @@ TR_VectorAPIExpansion::validateVectorAliasClasses() vec_sz_t classLength = vec_len_default; TR::DataType classType = TR::NoType; - if (!_aliasTable[id]._aliases) + if (!(_aliasTable[id].*aliasesField)) { // class might consist of just the symref itself - vectorClass = validateSymRef(id, id, classLength, classType); + vectorClass = validateSymRef(id, id, classLength, classType, classField); } else { - TR_BitVectorIterator bvi(*_aliasTable[id]._aliases); + TR_BitVectorIterator bvi(*(_aliasTable[id].*aliasesField)); while (bvi.hasMoreElements()) { int32_t i = bvi.getNextElement(); - vectorClass = validateSymRef(id, i, classLength, classType); + vectorClass = validateSymRef(id, i, classLength, classType, classField); + if (!vectorClass) + { + if (_trace) + traceMsg(comp(), "Class #%d can't be vectorized or scalarized due to invalid symRef #%d\n", id, i); break; + } if (_aliasTable[i]._objectType == Invalid) { if (_trace) - traceMsg(comp(), "Class #%d can't be vectorized or scalarized due to invalid object type of #%d\n", id, i); + traceMsg(comp(), "Class #%d can't be vectorized or scalarized due to invalid object type of #%d\n", id, i); _aliasTable[id]._cantVectorize = true; _aliasTable[id]._cantScalarize = true; @@ -799,6 +910,8 @@ TR_VectorAPIExpansion::validateVectorAliasClasses() _aliasTable[id]._vecLen = classLength; _aliasTable[id]._elementType = classType; + if (vectorClass && !tempClasses) + continue; if (vectorClass && classLength != vec_len_unknown && @@ -806,9 +919,22 @@ TR_VectorAPIExpansion::validateVectorAliasClasses() continue; // invalidate the whole class - if (_trace && _aliasTable[id]._aliases) // to reduce number of messages - traceMsg(comp(), "Invalidating class #%d\n", id); - _aliasTable[id]._classId = -1; + if (_trace && _aliasTable[id].*aliasesField) // to reduce number of messages + traceMsg(comp(), "Invalidating %sclass #%d\n", tempClasses ? "temp " : "", id); + + _aliasTable[id].*classField = -1; + + int32_t &wholeClass = _aliasTable[id]._classId; + + if (tempClasses && wholeClass >= 0) + { + // invalidate the whole class that temp class belongs to + if (_trace) + traceMsg(comp(), "Invalidating class #%d due to temp class #%d\n", wholeClass, id); + + _aliasTable[wholeClass]._classId = -1; + wholeClass = -1; + } } } @@ -823,7 +949,8 @@ TR_VectorAPIExpansion::expandVectorAPI() buildVectorAliases(); buildAliasClasses(); - validateVectorAliasClasses(); + validateVectorAliasClasses(&vectorAliasTableElement::_aliases, &vectorAliasTableElement::_classId); + validateVectorAliasClasses(&vectorAliasTableElement::_tempAliases, &vectorAliasTableElement::_tempClassId); if (_trace) traceMsg(comp(), "%s Starting Expansion\n", OPT_DETAILS_VECTOR); @@ -863,6 +990,7 @@ TR_VectorAPIExpansion::expandVectorAPI() TR_ASSERT_FATAL(node->getOpCode().hasSymbolReference(), "Node %p should have symbol reference\n", node); int32_t classId = _aliasTable[node->getSymbolReference()->getReferenceNumber()]._classId; + int32_t tempClassId = _aliasTable[node->getSymbolReference()->getReferenceNumber()]._tempClassId; if (_trace) traceMsg(comp(), "#%d classId = %d\n", node->getSymbolReference()->getReferenceNumber(), classId); @@ -906,16 +1034,20 @@ TR_VectorAPIExpansion::expandVectorAPI() if (_trace) traceMsg(comp(), "Transforming node %p of class #%d\n", node, classId); - TR::DataType elementType = _aliasTable[classId]._elementType; - int32_t bitsLength = _aliasTable[classId]._vecLen; - TR::VectorLength vectorLength = OMR::DataType::bitsToVectorLength(bitsLength); - int32_t elementSize = OMR::DataType::getSize(elementType); - int32_t numLanes = bitsLength/8/elementSize; + + int32_t numLanes; if (opCodeValue == TR::astore) { if (_trace) traceMsg(comp(), "handling astore %p\n", node); + + TR::DataType elementType = _aliasTable[tempClassId]._elementType; + int32_t bitsLength = _aliasTable[tempClassId]._vecLen; + TR::VectorLength vectorLength = OMR::DataType::bitsToVectorLength(bitsLength); + int32_t elementSize = OMR::DataType::getSize(elementType); + numLanes = bitsLength/8/elementSize; + astoreHandler(this, treeTop, node, elementType, vectorLength, numLanes, doMode); } else if (opCode.isFunctionCall()) @@ -924,6 +1056,13 @@ TR_VectorAPIExpansion::expandVectorAPI() TR::RecognizedMethod index = methodSymbol->getRecognizedMethod(); int32_t handlerIndex = index - _firstMethod; + TR::DataType elementType; + + getElementTypeAndNumLanes(node, elementType, numLanes); + + int32_t elementSize = OMR::DataType::getSize(elementType); + int32_t bitsLength = numLanes*elementSize*8; + TR::VectorLength vectorLength = OMR::DataType::bitsToVectorLength(bitsLength); TR_ASSERT_FATAL(methodTable[handlerIndex]._methodHandler(this, treeTop, node, elementType, vectorLength, numLanes, checkMode), "Analysis should've proved that method is supported"); @@ -1668,9 +1807,12 @@ TR::Node *TR_VectorAPIExpansion::naryIntrinsicHandler(TR_VectorAPIExpansion *opt if (opCodeType == Test || opCodeType == MaskReduction || opCodeType == Blend) firstOperand = 4; + if (opCodeType == Convert) + firstOperand = 7; + bool withMask = false; - if (opCodeType != MaskReduction) + if (opCodeType != MaskReduction && opCodeType != Convert) { TR::Node *maskNode = node->getChild(firstOperand + numChildren); // each intrinsic has a mask argument withMask = !maskNode->isConstZeroValue(); @@ -1703,7 +1845,7 @@ TR::Node *TR_VectorAPIExpansion::naryIntrinsicHandler(TR_VectorAPIExpansion *opt // and all operations should be done in Int in the case of scalarization if (elementType == TR::Int8 || elementType == TR::Int16) opType = TR::Int32; - scalarOpCode = ILOpcodeFromVectorAPIOpcode(vectorAPIOpcode, opType, TR::NoVectorLength, opCodeType, withMask); + scalarOpCode = ILOpcodeFromVectorAPIOpcode(comp, vectorAPIOpcode, opType, TR::NoVectorLength, opCodeType, withMask); if (mode == checkScalarization) { @@ -1731,13 +1873,41 @@ TR::Node *TR_VectorAPIExpansion::naryIntrinsicHandler(TR_VectorAPIExpansion *opt } else { + TR::DataType resultElementType = TR::NoType; + TR::VectorLength resultVectorLength = TR::NoVectorLength; + + if (opCodeType == Convert) + { + // result vector type info is in children 5 and 6 + TR::Node *resultElementTypeNode = node->getChild(5); + resultElementType = getDataTypeFromClassNode(comp, resultElementTypeNode); + + TR::Node *resultNumLanesNode = node->getChild(6); + + if (resultNumLanesNode->getOpCode().isLoadConst()) + { + int32_t elementSize = OMR::DataType::getSize(resultElementType); + vec_sz_t bitsLength = resultNumLanesNode->get32bitIntegralValue()*8*elementSize; + + if (supportedOnPlatform(comp, bitsLength) == TR::NoVectorLength) + return NULL; + + resultVectorLength = OMR::DataType::bitsToVectorLength(bitsLength); + } + + if (resultElementType == TR::NoType || resultVectorLength == TR::NoVectorLength) + return NULL; + } + if (mode == checkVectorization) { - vectorOpCode = ILOpcodeFromVectorAPIOpcode(vectorAPIOpcode, opType, vectorLength, opCodeType, withMask); + vectorOpCode = ILOpcodeFromVectorAPIOpcode(comp, vectorAPIOpcode, opType, vectorLength, opCodeType, withMask, + resultElementType, resultVectorLength); if (vectorOpCode == TR::BadILOp || !comp->cg()->getSupportsOpCodeForAutoSIMD(vectorOpCode)) { - if (opt->_trace) traceMsg(comp, "Unsupported vector opcode in node %p\n", node); + if (opt->_trace) traceMsg(comp, "Unsupported vector opcode in node %p %s\n", node, + vectorOpCode == TR::BadILOp ? "(no IL)" : "(no codegen)"); return NULL; } else @@ -1747,7 +1917,8 @@ TR::Node *TR_VectorAPIExpansion::naryIntrinsicHandler(TR_VectorAPIExpansion *opt } else { - vectorOpCode = ILOpcodeFromVectorAPIOpcode(vectorAPIOpcode, opType, vectorLength, opCodeType, withMask); + vectorOpCode = ILOpcodeFromVectorAPIOpcode(comp, vectorAPIOpcode, opType, vectorLength, opCodeType, withMask, + resultElementType, resultVectorLength); TR_ASSERT_FATAL(vectorOpCode != TR::BadILOp, "Vector opcode should exist for node %p\n", node); @@ -1893,18 +2064,49 @@ TR::Node *TR_VectorAPIExpansion::convertIntrinsicHandler(TR_VectorAPIExpansion * TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes, handlerMode mode) { - return NULL; + return naryIntrinsicHandler(opt, treeTop, node, elementType, vectorLength, numLanes, mode, 1, Convert); } -TR::ILOpCodes TR_VectorAPIExpansion::ILOpcodeFromVectorAPIOpcode(int32_t vectorAPIOpCode, TR::DataType elementType, - TR::VectorLength vectorLength, vapiOpCodeType opCodeType, bool withMask) +TR::ILOpCodes TR_VectorAPIExpansion::ILOpcodeFromVectorAPIOpcode(TR::Compilation *comp, int32_t vectorAPIOpCode, TR::DataType elementType, + TR::VectorLength vectorLength, vapiOpCodeType opCodeType, + bool withMask, + TR::DataType resultElementType, + TR::VectorLength resultVectorLength) { // TODO: support more scalarization bool scalar = (vectorLength == TR::NoVectorLength); TR::DataType vectorType = scalar ? TR::NoType : TR::DataType::createVectorType(elementType, vectorLength); + TR::DataType resultVectorType = TR::NoType; + + if (resultElementType != TR::NoType) + resultVectorType = scalar ? TR::NoType : TR::DataType::createVectorType(resultElementType, resultVectorLength); + + + if (opCodeType == Convert) + { + switch (vectorAPIOpCode) + { + case VECTOR_OP_CAST: return TR::BadILOp; + case VECTOR_OP_UCAST: return TR::BadILOp; + case VECTOR_OP_REINTERPRET: + if (scalar) return TR::BadILOp; + + if (OMR::DataType::getSize(resultElementType) != OMR::DataType::getSize(elementType) || + resultVectorLength != vectorLength) + { + traceMsg(comp, "\nCalling VECTOR_OP_REINTERPRET on %s to %s in %s\n", TR::DataType::getName(vectorType), + TR::DataType::getName(resultVectorType), + comp->signature()); + return TR::BadILOp; + } - if (opCodeType == Blend) + return TR::ILOpCode::createVectorOpCode(TR::vcast, vectorType, resultVectorType); + default: + return TR::BadILOp; + } + } + else if (opCodeType == Blend) { if (scalar) return TR::BadILOp; @@ -2175,18 +2377,20 @@ TR::Node *TR_VectorAPIExpansion::transformNary(TR_VectorAPIExpansion *opt, TR::T TR_VectorAPIExpansion::methodTableEntry TR_VectorAPIExpansion::methodTable[] = { - {loadIntrinsicHandler, Unknown, {Unknown, ElementType, NumLanes}}, // jdk_internal_vm_vector_VectorSupport_load - {storeIntrinsicHandler, Unknown, {Unknown, ElementType, NumLanes, Unknown, Unknown, Vector}}, // jdk_internal_vm_vector_VectorSupport_store - {binaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_binaryOp - {blendIntrinsicHandler, Vector, {Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Unknown}}, // jdk_internal_vm_vector_VectorSupport_blend - {compareIntrinsicHandler, Mask, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_compare - {convertIntrinsicHandler, Mask, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_convert - {fromBitsCoercedIntrinsicHandler, Unknown, {Unknown, ElementType, NumLanes, Unknown, Unknown, Unknown}}, // jdk_internal_vm_vector_VectorSupport_fromBitsCoerced - {maskReductionCoercedIntrinsicHandler, Scalar, {Unknown, Unknown, ElementType, NumLanes, Mask}}, // jdk_internal_vm_vector_VectorSupport_maskReductionCoerced - {reductionCoercedIntrinsicHandler, Scalar, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_reductionCoerced - {ternaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_ternaryOp - {testIntrinsicHandler, Scalar, {Unknown, Unknown, ElementType, NumLanes, Mask, Mask, Unknown}}, // jdk_internal_vm_vector_VectorSupport_test - {unaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_unaryOp + {loadIntrinsicHandler, Unknown, 1, 2, -1, 0, -1, {Unknown, ElementType, NumLanes}}, // jdk_internal_vm_vector_VectorSupport_load + {storeIntrinsicHandler, Unknown, 1, 2, 5, 1, -1, {Unknown, ElementType, NumLanes, Unknown, Unknown, Vector}}, // jdk_internal_vm_vector_VectorSupport_store + {binaryIntrinsicHandler, Vector, 3, 4, 5, 2, 7, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_binaryOp + {blendIntrinsicHandler, Vector, 2, 3, 4, 3, -1, {Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Unknown}}, // jdk_internal_vm_vector_VectorSupport_blend + {compareIntrinsicHandler, Mask, 3, 4, 5, 2, 7, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_compare + + {convertIntrinsicHandler, Vector, 2, 3, 7, 1, -1, {Unknown, Unknown, ElementType, NumLanes, Unknown, Unknown, Unknown, Vector}}, // jdk_internal_vm_vector_VectorSupport_convert + + {fromBitsCoercedIntrinsicHandler, Unknown, 1, 2, -1, 0, -1, {Unknown, ElementType, NumLanes, Unknown, Unknown, Unknown}}, // jdk_internal_vm_vector_VectorSupport_fromBitsCoerced + {maskReductionCoercedIntrinsicHandler, Scalar, 2, 3, 4, 1, -1, {Unknown, Unknown, ElementType, NumLanes, Mask}}, // jdk_internal_vm_vector_VectorSupport_maskReductionCoerced + {reductionCoercedIntrinsicHandler, Scalar, 3, 4, 5, 1, 6, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_reductionCoerced + {ternaryIntrinsicHandler, Vector, 3, 4, 5, 3, 8, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_ternaryOp + {testIntrinsicHandler, Scalar, 2, 3, 4, 1, 5, {Unknown, Unknown, ElementType, NumLanes, Mask, Mask, Unknown}}, // jdk_internal_vm_vector_VectorSupport_test + {unaryIntrinsicHandler, Vector, 3, 4, 5, 1, 6, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_unaryOp }; diff --git a/runtime/compiler/optimizer/VectorAPIExpansion.hpp b/runtime/compiler/optimizer/VectorAPIExpansion.hpp index 68a02c943bd..706427de7af 100644 --- a/runtime/compiler/optimizer/VectorAPIExpansion.hpp +++ b/runtime/compiler/optimizer/VectorAPIExpansion.hpp @@ -193,6 +193,7 @@ class TR_VectorAPIExpansion : public TR::Optimization Reduction, Test, Blend, + Convert, Other }; @@ -204,6 +205,11 @@ class TR_VectorAPIExpansion : public TR::Optimization { TR::Node * (* _methodHandler)(TR_VectorAPIExpansion *, TR::TreeTop *, TR::Node *, TR::DataType, TR::VectorLength, int32_t, handlerMode); vapiObjType _returnType; + int32_t _elementTypeIndex; + int32_t _numLanesIndex; + int32_t _firstOperandIndex; + int32_t _numOperands; + int32_t _maskIndex; vapiObjType _argumentTypes[_maxNumberArguments]; }; @@ -233,7 +239,8 @@ class TR_VectorAPIExpansion : public TR::Optimization vectorAliasTableElement() : _symRef(NULL), _vecSymRef(NULL), _vecLen(vec_len_default), _elementType(TR::NoType), _aliases(NULL), _classId(0), - _cantVectorize(false), _cantScalarize(false), _objectType(Unknown) {} + _cantVectorize(false), _cantScalarize(false), _objectType(Unknown), + _tempAliases(NULL), _tempClassId(0) {} TR::SymbolReference *_symRef; union @@ -252,6 +259,9 @@ class TR_VectorAPIExpansion : public TR::Optimization bool _cantVectorize; bool _cantScalarize; vapiObjType _objectType; + + TR_BitVector *_tempAliases; + int32_t _tempClassId; }; @@ -354,22 +364,73 @@ class TR_VectorAPIExpansion : public TR::Optimization /** \brief - * Checks if method's argument is one the \c vapiObjType types + * Returns index of a child node that contains element type * * \param methodSymbol * Method symbol * - * \param i - * argument's number + * \return + * Index of a child node that contains element type + */ + int32_t getElementTypeIndex(TR::MethodSymbol *methodSymbol); + + /** \brief + * Returns index of a child node that contains number of lanes * - * \param type - * \c vapiObjType + * \param methodSymbol + * Method symbol * * \return - * \c true if the argument is the same as \c type, - * \c false otherwise + * Index of a child node that contains number of lanes + */ + int32_t getNumLanesIndex(TR::MethodSymbol *methodSymbol); + + /** \brief + * Returns index of a child node that contains first operand + * + * \param methodSymbol + * Method symbol + * + * \return + * Index of a child node that contains first operand + */ + int32_t getFirstOperandIndex(TR::MethodSymbol *methodSymbol); + + /** \brief + * Returns number of operands + * + * \param methodSymbol + * Method symbol + * + * \return + * Number of operands + */ + int32_t getNumOperands(TR::MethodSymbol *methodSymbol); + + /** \brief + * Returns index of a child node that contains mask + * + * \param methodSymbol + * Method symbol + * + * \return + * Index of a child node that contains mask + */ + int32_t getMaskIndex(TR::MethodSymbol *methodSymbol); + + /** \brief + * Determines element type and number of lanes of a node + * + * \param node + * Call node + * + * \param elementType + * Element type + * + * \param numLanes + * Number of lanes */ - bool isArgType(TR::MethodSymbol *methodSymbol, int32_t i, vapiObjType type); + void getElementTypeAndNumLanes(TR::Node *node, TR::DataType &elementType, int32_t &numLanes); /** \brief * Aliases symbol references with each other as described above @@ -395,13 +456,25 @@ class TR_VectorAPIExpansion : public TR::Optimization * * \param id * Symbol reference for which all transitive aliases need to be found + * + * \param aliasesField + * Pointer to the struct member that contains aliases + * + * \param classField + * Pointer to the struct member that contains class */ - void findAllAliases(int32_t classId, int32_t id); + void findAllAliases(int32_t classId, int32_t id, TR_BitVector * vectorAliasTableElement::* aliasesField, int32_t vectorAliasTableElement::* classField); /** \brief * Validates classes found by \c buildAliasClasses() + * + * \param aliasesField + * Pointer to the struct member that contains aliases + * + * \param classField + * Pointer to the struct member that contains class */ - void validateVectorAliasClasses(); + void validateVectorAliasClasses(TR_BitVector * vectorAliasTableElement::* aliasesField, int32_t vectorAliasTableElement::* classField); /** \brief * Used by \c validateVectorAliasClasses() to check individual symbol reference @@ -418,8 +491,11 @@ class TR_VectorAPIExpansion : public TR::Optimization * \param classType * Element type of the class * + * \param classField + * Pointer to the struct member that contains class + * */ - bool validateSymRef(int32_t classId, int32_t i, vec_sz_t &classLength, TR::DataType &classType); + bool validateSymRef(int32_t classId, int32_t i, vec_sz_t &classLength, TR::DataType &classType, int32_t vectorAliasTableElement::* classField); /** \brief * Sets \c _classId of a symbol reference to -1 to indicate it's invalid @@ -438,8 +514,11 @@ class TR_VectorAPIExpansion : public TR::Optimization * * \param node2 * Second node + * + * \param aliasTemps + * true if aliasing is caused by storing one temp into another */ - void alias(TR::Node *node1, TR::Node *node2); + void alias(TR::Node *node1, TR::Node *node2, bool aliasTemps = false); /** \brief * Finds vector length from SPECIES node if it's a known object @@ -521,6 +600,9 @@ class TR_VectorAPIExpansion : public TR::Optimization /** \brief * Maps Vector API opcode enum into scalar or vector TR::ILOpCodes * + * \param comp + * Compilation + * * \param vectorOpCode * Vector API opcode enum * @@ -536,11 +618,19 @@ class TR_VectorAPIExpansion : public TR::Optimization * \param withMask * true if mask is present, false otherwise * + * \param resultElementType + * Result element type + * + * \param resultVectorLength + * Result vector length + * * \return * scalar TR::IL opcode if scalar is true, otherwise vector opcode */ - static TR::ILOpCodes ILOpcodeFromVectorAPIOpcode(int32_t vectorOpCode, TR::DataType elementType, - TR::VectorLength vectorLength, vapiOpCodeType opCodeType, bool withMask); + static TR::ILOpCodes ILOpcodeFromVectorAPIOpcode(TR::Compilation *comp, int32_t vectorOpCode, TR::DataType elementType, + TR::VectorLength vectorLength, vapiOpCodeType opCodeType, bool withMask, + TR::DataType resultElementType = TR::NoType, + TR::VectorLength resultVectorLength = TR::NoVectorLength); /** \brief * For the node's symbol reference, creates and records(if it does not exist yet) @@ -1106,7 +1196,7 @@ class TR_VectorAPIExpansion : public TR::Optimization * Transformed node */ static TR::Node *convertIntrinsicHandler(TR_VectorAPIExpansion *opt, TR::TreeTop *treeTop, TR::Node *node, TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes, handlerMode mode); - + /** \brief * Helper method to transform a load from array node From bbe4e851cfc5169dab3c5e3db554f76a6f931811 Mon Sep 17 00:00:00 2001 From: Gita Koblents Date: Mon, 27 Mar 2023 20:26:42 -0400 Subject: [PATCH 3/3] Handle vector compare opcodes as two-type opcodes - since vector compare opcodes take vectors and return a mask (with possibly different element type) they should be two-type opcodes - mask element type is Int32 for Float vectors and Int64 for Double vectors - fix temp class invalidation during VectorAPIExpansion --- .../compiler/optimizer/VectorAPIExpansion.cpp | 69 +++++++++++++------ 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/runtime/compiler/optimizer/VectorAPIExpansion.cpp b/runtime/compiler/optimizer/VectorAPIExpansion.cpp index 34c30494fb5..d10ed4f96ae 100644 --- a/runtime/compiler/optimizer/VectorAPIExpansion.cpp +++ b/runtime/compiler/optimizer/VectorAPIExpansion.cpp @@ -250,6 +250,10 @@ TR_VectorAPIExpansion::visitNodeToBuildVectorAliases(TR::Node *node) getElementTypeAndNumLanes(rhs, elementType, numLanes); + if (getReturnType(rhs->getSymbolReference()->getSymbol()->castToMethodSymbol()) == Mask && + (elementType == TR::Float || elementType == TR::Double)) + elementType = (elementType == TR::Float) ? TR::Int32 : TR::Int64; + int32_t elementSize = OMR::DataType::getSize(elementType); int32_t bitsLength = numLanes*elementSize*8; @@ -379,12 +383,7 @@ TR_VectorAPIExpansion::visitNodeToBuildVectorAliases(TR::Node *node) { TR::Node *elementTypeNode = node->getChild(i); methodElementType = getDataTypeFromClassNode(comp(), elementTypeNode); - - // maskReductionCoerced intrinsic has element type Int for Float vectors and Long for Double vectors. - // For the sake of class verification we can leave _elementType field unset - // and it will be automatically derived from the child type (since they will be in the same class) - if (methodSymbol->getRecognizedMethod() != TR::jdk_internal_vm_vector_VectorSupport_maskReductionCoerced) - _aliasTable[methodRefNum]._elementType = methodElementType; + _aliasTable[methodRefNum]._elementType = methodElementType; } else if (i == getNumLanesIndex(methodSymbol)) { @@ -744,6 +743,8 @@ bool TR_VectorAPIExpansion::validateSymRef(int32_t id, int32_t i, vec_sz_t &classLength, TR::DataType &classType, int32_t vectorAliasTableElement::* classField) { + bool tempClasses = &vectorAliasTableElement::_tempClassId == classField; + TR::SymbolReference *symRef = comp()->getSymRefTab()->getSymRef(i); if (!symRef || !symRef->getSymbol()) @@ -774,33 +775,41 @@ TR_VectorAPIExpansion::validateSymRef(int32_t id, int32_t i, vec_sz_t &classLeng return false; } } - else + else if (tempClasses) { vec_sz_t tempLength = _aliasTable[i]._vecLen; TR::DataType tempType = _aliasTable[i]._elementType; if (classLength == vec_len_default) { + if (_trace) + traceMsg(comp(), "%s assigning length to class #%d from symref #%d temp length %d\n", + OPT_DETAILS_VECTOR, id, i, tempLength); + classLength = tempLength; } else if (tempLength != vec_len_default && tempLength != classLength) { if (_trace) - traceMsg(comp(), "%s invalidating5 class #%d due to symref #%d temp length %d, seen length %d\n", + traceMsg(comp(), "%s invalidating5 class #%d due to symref #%d temp length %d, class length %d\n", OPT_DETAILS_VECTOR, id, i, tempLength, classLength); return false; } if (classType == TR::NoType) { + if (_trace) + traceMsg(comp(), "%s assigning element type to class #%d from symref #%d temp type %s\n", + OPT_DETAILS_VECTOR, id, i, TR::DataType::getName(tempType)); + classType = tempType; } else if (tempType != TR::NoType && tempType != classType) { if (_trace) - traceMsg(comp(), "%s invalidating6 class #%d due to symref #%d temp type %s, seen type %s\n", + traceMsg(comp(), "%s invalidating6 class #%d due to symref #%d temp type %s, class type %s\n", OPT_DETAILS_VECTOR, id, i, TR::DataType::getName(tempType), TR::DataType::getName(classType)); return false; } @@ -1899,6 +1908,19 @@ TR::Node *TR_VectorAPIExpansion::naryIntrinsicHandler(TR_VectorAPIExpansion *opt return NULL; } + + if (opCodeType == Compare) + { + resultElementType = elementType; + resultVectorLength = vectorLength; + + if (elementType == TR::Float) + resultElementType = TR::Int32; + + if (elementType == TR::Double) + resultElementType = TR::Int64; + } + if (mode == checkVectorization) { vectorOpCode = ILOpcodeFromVectorAPIOpcode(comp, vectorAPIOpcode, opType, vectorLength, opCodeType, withMask, @@ -2125,28 +2147,33 @@ TR::ILOpCodes TR_VectorAPIExpansion::ILOpcodeFromVectorAPIOpcode(TR::Compilation } else if ((opCodeType == Compare) && withMask) { + TR::DataType resultMaskType = scalar ? TR::NoType : TR::DataType::createMaskType(resultElementType, resultVectorLength); + switch (vectorAPIOpCode) { - case BT_eq: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmpeq, vectorType); - case BT_ne: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmpne, vectorType); - case BT_le: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmple, vectorType); - case BT_ge: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmpge, vectorType); - case BT_lt: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmplt, vectorType); - case BT_gt: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmpgt, vectorType); + case BT_eq: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmpeq, vectorType, resultMaskType); + case BT_ne: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmpne, vectorType, resultMaskType); + case BT_le: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmple, vectorType, resultMaskType); + case BT_ge: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmpge, vectorType, resultMaskType); + case BT_lt: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmplt, vectorType, resultMaskType); + case BT_gt: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vmcmpgt, vectorType, resultMaskType); default: return TR::BadILOp; } } else if (opCodeType == Compare) { + TR::DataType resultMaskType = scalar ? TR::NoType : TR::DataType::createMaskType(resultElementType, resultVectorLength); + switch (vectorAPIOpCode) { - case BT_eq: return scalar ? TR::ILOpCode::cmpeqOpCode(elementType) : TR::ILOpCode::createVectorOpCode(TR::vcmpeq, vectorType); - case BT_ne: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmpne, vectorType); - case BT_le: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmple, vectorType); - case BT_ge: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmpge, vectorType); - case BT_lt: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmplt, vectorType); - case BT_gt: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmpgt, vectorType); + case BT_eq: return scalar ? TR::ILOpCode::cmpeqOpCode(elementType) + : TR::ILOpCode::createVectorOpCode(TR::vcmpeq, vectorType, resultMaskType); + case BT_ne: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmpne, vectorType, resultMaskType); + case BT_le: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmple, vectorType, resultMaskType); + case BT_ge: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmpge, vectorType, resultMaskType); + case BT_lt: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmplt, vectorType, resultMaskType); + case BT_gt: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vcmpgt, vectorType, resultMaskType); default: return TR::BadILOp; }