Skip to content

Commit

Permalink
Merge pull request #20185 from matthewhall2/fmin_fmax_dmin_dmax
Browse files Browse the repository at this point in the history
Support Java Behaviour w.r.t Math.max and Math.min for Floating Points
  • Loading branch information
r30shah authored Oct 30, 2024
2 parents 7a366ac + bc2c3cf commit c0eabdd
Show file tree
Hide file tree
Showing 8 changed files with 495 additions and 102 deletions.
11 changes: 11 additions & 0 deletions runtime/compiler/codegen/J9CodeGenerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,16 @@ void addMonClass(TR::Node* monNode, TR_OpaqueClassBlock* clazz);
*/
void setSupportsInlineVectorizedHashCode() { _j9Flags.set(SupportsInlineVectorizedHashCode); }

/** \brief
* Determines whether the code generator supports inlining of java_lang_Math_max/min_F/D
*/
bool getSupportsInlineMath_MaxMin_FD() { return _j9Flags.testAny(SupportsInlineMath_MaxMin_FD); }

/** \brief
* The code generator supports inlining of java_lang_Math_max/min_F/D
*/
void setSupportsInlineMath_MaxMin_FD() { _j9Flags.set(SupportsInlineMath_MaxMin_FD); }

/**
* \brief
* The number of nodes between a monext and the next monent before
Expand Down Expand Up @@ -677,6 +687,7 @@ void addMonClass(TR::Node* monNode, TR_OpaqueClassBlock* clazz);
SavesNonVolatileGPRsForGC = 0x00000800,
SupportsInlineVectorizedMismatch = 0x00001000,
SupportsInlineVectorizedHashCode = 0x00002000,
SupportsInlineMath_MaxMin_FD = 0x00004000,
};

flags32_t _j9Flags;
Expand Down
4 changes: 4 additions & 0 deletions runtime/compiler/env/j9method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5029,6 +5029,10 @@ TR_ResolvedJ9Method::setRecognizedMethodInfo(TR::RecognizedMethod rm)
case TR::java_lang_Math_min_I:
case TR::java_lang_Math_max_L:
case TR::java_lang_Math_min_L:
case TR::java_lang_Math_max_F:
case TR::java_lang_Math_min_F:
case TR::java_lang_Math_max_D:
case TR::java_lang_Math_min_D:
case TR::java_lang_Math_abs_I:
case TR::java_lang_Math_abs_L:
case TR::java_lang_Math_abs_F:
Expand Down
17 changes: 17 additions & 0 deletions runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,11 @@ bool J9::RecognizedCallTransformer::isInlineable(TR::TreeTop* treetop)
case TR::java_lang_Math_max_L:
case TR::java_lang_Math_min_L:
return !comp()->getOption(TR_DisableMaxMinOptimization);
case TR::java_lang_Math_max_F:
case TR::java_lang_Math_min_F:
case TR::java_lang_Math_max_D:
case TR::java_lang_Math_min_D:
return !comp()->getOption(TR_DisableMaxMinOptimization) && cg()->getSupportsInlineMath_MaxMin_FD();
case TR::java_lang_Math_multiplyHigh:
return cg()->getSupportsLMulHigh();
case TR::java_lang_StringUTF16_toBytes:
Expand Down Expand Up @@ -1495,6 +1500,18 @@ void J9::RecognizedCallTransformer::transform(TR::TreeTop* treetop)
case TR::java_lang_Math_min_L:
processIntrinsicFunction(treetop, node, TR::lmin);
break;
case TR::java_lang_Math_max_F:
processIntrinsicFunction(treetop, node, TR::fmax);
break;
case TR::java_lang_Math_min_F:
processIntrinsicFunction(treetop, node, TR::fmin);
break;
case TR::java_lang_Math_max_D:
processIntrinsicFunction(treetop, node, TR::dmax);
break;
case TR::java_lang_Math_min_D:
processIntrinsicFunction(treetop, node, TR::dmin);
break;
case TR::java_lang_Math_multiplyHigh:
processIntrinsicFunction(treetop, node, TR::lmulh);
break;
Expand Down
24 changes: 17 additions & 7 deletions runtime/compiler/z/codegen/J9CodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ J9::Z::CodeGenerator::initialize()
cg->setSupportsInlineEncodeASCII();
}

static bool disableInlineMath_MaxMin_FD = feGetEnv("TR_disableInlineMaxMin") != NULL;
if (!disableInlineMath_MaxMin_FD)
{
cg->setSupportsInlineMath_MaxMin_FD();
}

static bool disableInlineVectorizedMismatch = feGetEnv("TR_disableInlineVectorizedMismatch") != NULL;
if (cg->getSupportsArrayCmpLen() &&
#if defined(J9VM_GC_ENABLE_SPARSE_HEAP_ALLOCATION)
Expand Down Expand Up @@ -4118,20 +4124,24 @@ J9::Z::CodeGenerator::inlineDirectCall(
}
}

if (!comp->getOption(TR_DisableSIMDDoubleMaxMin) && cg->getSupportsVectorRegisters())
{
switch (methodSymbol->getRecognizedMethod())
{
if (!self()->comp()->getOption(TR_DisableMaxMinOptimization) && cg->getSupportsInlineMath_MaxMin_FD()) {
switch (methodSymbol->getRecognizedMethod()) {
case TR::java_lang_Math_max_D:
resultReg = TR::TreeEvaluator::inlineDoubleMax(node, cg);
resultReg = J9::Z::TreeEvaluator::dmaxEvaluator(node, cg);
return true;
case TR::java_lang_Math_min_D:
resultReg = TR::TreeEvaluator::inlineDoubleMin(node, cg);
resultReg = J9::Z::TreeEvaluator::dminEvaluator(node, cg);
return true;
case TR::java_lang_Math_max_F:
resultReg = J9::Z::TreeEvaluator::fmaxEvaluator(node, cg);
return true;
case TR::java_lang_Math_min_F:
resultReg = J9::Z::TreeEvaluator::fminEvaluator(node, cg);
return true;
default:
break;
}
}
}

switch (methodSymbol->getRecognizedMethod())
{
Expand Down
110 changes: 35 additions & 75 deletions runtime/compiler/z/codegen/J9TreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -906,76 +906,48 @@ allocateWriteBarrierInternalPointerRegister(TR::CodeGenerator * cg, TR::Node * s
}


extern TR::Register *
doubleMaxMinHelper(TR::Node *node, TR::CodeGenerator *cg, bool isMaxOp)
TR::Register *
J9::Z::TreeEvaluator::dmaxEvaluator(TR::Node * node, TR::CodeGenerator * cg)
{
TR_ASSERT(node->getNumChildren() >= 1 || node->getNumChildren() <= 2, "node has incorrect number of children");

/* ===================== Allocating Registers ===================== */

TR::Register * v16 = cg->allocateRegister(TR_VRF);
TR::Register * v17 = cg->allocateRegister(TR_VRF);
TR::Register * v18 = cg->allocateRegister(TR_VRF);

/* ===================== Generating instructions ===================== */

/* ====== LD FPR0,16(GPR5) Load a ====== */
TR::Register * v0 = cg->fprClobberEvaluate(node->getFirstChild());

/* ====== LD FPR2, 0(GPR5) Load b ====== */
TR::Register * v2 = cg->evaluate(node->getSecondChild());

/* ====== WFTCIDB V16,V0,X'F' a == NaN ====== */
generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v16, v0, 0xF, 8, 3);

/* ====== For Max: WFCHE V17,V0,V2 Compare a >= b ====== */
if(isMaxOp)
if (cg->getSupportsVectorRegisters())
{
generateVRRcInstruction(cg, TR::InstOpCode::VFCH, node, v17, v0, v2, 0, 8, 3);
cg->generateDebugCounter("z13/simd/doubleMax", 1, TR::DebugCounter::Free);
return OMR::Z::TreeEvaluator::fpMinMaxVectorHelper(node, cg);
}
/* ====== For Min: WFCHE V17,V0,V2 Compare a <= b ====== */
else
return OMR::Z::TreeEvaluator::xmaxxminHelper(node, cg);
}

TR::Register *
J9::Z::TreeEvaluator::dminEvaluator(TR::Node * node, TR::CodeGenerator * cg)
{
if (cg->getSupportsVectorRegisters())
{
generateVRRcInstruction(cg, TR::InstOpCode::VFCH, node, v17, v2, v0, 0, 8, 3);
cg->generateDebugCounter("z13/simd/doubleMin", 1, TR::DebugCounter::Free);
return OMR::Z::TreeEvaluator::fpMinMaxVectorHelper(node, cg);
}
return OMR::Z::TreeEvaluator::xmaxxminHelper(node, cg);
}

/* ====== VO V16,V16,V17 (a >= b) || (a == NaN) ====== */
generateVRRcInstruction(cg, TR::InstOpCode::VO, node, v16, v16, v17, 0, 0, 0);

/* ====== For Max: WFTCIDB V17,V0,X'800' a == +0 ====== */
if(isMaxOp)
{
generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v17, v0, 0x800, 8, 3);
}
/* ====== For Min: WFTCIDB V17,V0,X'400' a == -0 ====== */
else
{
generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v17, v0, 0x400, 8, 3);
}
/* ====== WFTCIDB V18,V2,X'C00' b == 0 ====== */
generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v18, v2, 0xC00, 8, 3);

/* ====== VN V17,V17,V18 (a == -0) && (b == 0) ====== */
generateVRRcInstruction(cg, TR::InstOpCode::VN, node, v17, v17, v18, 0, 0, 0);

/* ====== VO V16,V16,V17 (a >= b) || (a == NaN) || ((a == -0) && (b == 0)) ====== */
generateVRRcInstruction(cg, TR::InstOpCode::VO, node, v16, v16, v17, 0, 0, 0);

/* ====== VSEL V0,V0,V2,V16 ====== */
generateVRReInstruction(cg, TR::InstOpCode::VSEL, node, v0, v0, v2, v16);

/* ===================== Deallocating Registers ===================== */
cg->stopUsingRegister(v2);
cg->stopUsingRegister(v16);
cg->stopUsingRegister(v17);
cg->stopUsingRegister(v18);

node->setRegister(v0);

cg->decReferenceCount(node->getFirstChild());
cg->decReferenceCount(node->getSecondChild());
TR::Register *
J9::Z::TreeEvaluator::fmaxEvaluator(TR::Node * node, TR::CodeGenerator * cg)
{
if (cg->getSupportsVectorRegisters())
{
cg->generateDebugCounter("z13/simd/floatMax", 1, TR::DebugCounter::Free);
return OMR::Z::TreeEvaluator::fpMinMaxVectorHelper(node, cg);
}
return OMR::Z::TreeEvaluator::xmaxxminHelper(node, cg);
}

return node->getRegister();
TR::Register *
J9::Z::TreeEvaluator::fminEvaluator(TR::Node * node, TR::CodeGenerator * cg)
{
if (cg->getSupportsVectorRegisters())
{
cg->generateDebugCounter("z13/simd/floatMin", 1, TR::DebugCounter::Free);
return OMR::Z::TreeEvaluator::fpMinMaxVectorHelper(node, cg);
}
return OMR::Z::TreeEvaluator::xmaxxminHelper(node, cg);
}

TR::Register*
Expand Down Expand Up @@ -2945,19 +2917,7 @@ J9::Z::TreeEvaluator::toLowerIntrinsic(TR::Node *node, TR::CodeGenerator *cg, bo
return caseConversionHelper(node, cg, false, isCompressedString);
}

TR::Register*
J9::Z::TreeEvaluator::inlineDoubleMax(TR::Node *node, TR::CodeGenerator *cg)
{
cg->generateDebugCounter("z13/simd/doubleMax", 1, TR::DebugCounter::Free);
return doubleMaxMinHelper(node, cg, true);
}

TR::Register*
J9::Z::TreeEvaluator::inlineDoubleMin(TR::Node *node, TR::CodeGenerator *cg)
{
cg->generateDebugCounter("z13/simd/doubleMin", 1, TR::DebugCounter::Free);
return doubleMaxMinHelper(node, cg, false);
}

TR::Register *
J9::Z::TreeEvaluator::inlineMathFma(TR::Node *node, TR::CodeGenerator *cg)
Expand Down
6 changes: 4 additions & 2 deletions runtime/compiler/z/codegen/J9TreeEvaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ class OMR_EXTENSIBLE TreeEvaluator: public J9::TreeEvaluator
*/
static TR::Register *inlineVectorizedStringIndexOf(TR::Node *node, TR::CodeGenerator *cg, bool isCompressed);
static TR::Register *inlineIntrinsicIndexOf(TR::Node *node, TR::CodeGenerator *cg, bool isLatin1);
static TR::Register *inlineDoubleMax(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *inlineDoubleMin(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *fminEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *dminEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *fmaxEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *dmaxEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *inlineMathFma(TR::Node *node, TR::CodeGenerator *cg);

/* This Evaluator generates the SIMD routine for methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,35 @@
package jit.test.recognizedMethod;
import org.testng.AssertJUnit;
import org.testng.annotations.Test;
import java.util.Random;
import org.testng.asserts.SoftAssert;
import static jit.test.recognizedMethod.TestMathUtils.*;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Listeners;


@Test(singleThreaded=true)
public class TestJavaLangMath {

/**
* Tests the constant corner cases defined by the {@link Math.sqrt} method.
* <p>
* The JIT compiler will transform calls to {@link Math.sqrt} within this test
* into the following tree sequence:
*
* <code>
* dsqrt
* dconst <x>
* </code>
*
* Subsequent tree simplification passes will attempt to reduce this constant
* operation to a <code>dsqrt</code> IL by performing the square root at compile
* time. The transformation will be performed when the function get executed
* twice, therefore, the "invocationCount=2" is needed. However we must ensure the
* result of the square root done by the compiler at compile time will be exactly
* the same as the result had it been done by the Java runtime at runtime. This
* test validates the results are the same.
*/
* Tests the constant corner cases defined by the {@link Math.sqrt} method.
* <p>
* The JIT compiler will transform calls to {@link Math.sqrt} within this test
* into the following tree sequence:
*
* <code>
* dsqrt
* dconst <x>
* </code>
*
* Subsequent tree simplification passes will attempt to reduce this constant
* operation to a <code>dsqrt</code> IL by performing the square root at compile
* time. The transformation will be performed when the function get executed
* twice, therefore, the "invocationCount=2" is needed. However we must ensure the
* result of the square root done by the compiler at compile time will be exactly
* the same as the result had it been done by the Java runtime at runtime. This
* test validates the results are the same.
*/
@Test(groups = {"level.sanity"}, invocationCount=2)
public void test_java_lang_Math_sqrt() {
AssertJUnit.assertTrue(Double.isNaN(Math.sqrt(Double.NEGATIVE_INFINITY)));
Expand All @@ -55,4 +62,86 @@ public void test_java_lang_Math_sqrt() {
AssertJUnit.assertEquals(Double.POSITIVE_INFINITY, Math.sqrt(Double.POSITIVE_INFINITY));
AssertJUnit.assertTrue(Double.isNaN(Math.sqrt(Double.NaN)));
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="zeroProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_min_zeros_FD(Number a, Number b, boolean isFirstArg) {
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
assertEquals(Math.min(f1, f2), isFirstArg ? f1 : f2);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
assertEquals(Math.min(f1, f2), isFirstArg ? f1 : f2);
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="zeroProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_max_zeros_FD(Number a, Number b, boolean isFirstArg) {
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
assertEquals(Math.max(f1, f2), isFirstArg ? f2 : f1);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
assertEquals(Math.max(f1, f2), isFirstArg ? f2 : f1);
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="nanProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_min_nan_FD(Number a, Number b, Number expected) {
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
int exp = expected.intValue();
assertEquals(Math.min(f1, f2), exp);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
long exp = expected.longValue();
assertEquals(Math.min(f1, f2), exp);
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="nanProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_max_nan_FD(Number a, Number b, Number expected) {
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
int exp = expected.intValue();
assertEquals(Math.max(f1, f2), exp);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
long exp = expected.longValue();
assertEquals(Math.max(f1, f2), exp);
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="normalNumberProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_min_normal_FD(Number a, Number b){
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
assertEquals(Math.min(f1, f2), f1 <= f2 ? f1 : f2);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
assertEquals(Math.min(f1, f2), f1 <= f2 ? f1 : f2);
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="normalNumberProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_max_normal_FD(Number a, Number b){
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
assertEquals(Math.max(f1, f2), f1 >= f2 ? f1 : f2);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
assertEquals(Math.max(f1, f2), f1 >= f2 ? f1 : f2);
}
}
}
Loading

0 comments on commit c0eabdd

Please sign in to comment.