From 925d44a2d4af54ea11cbe80cfd293fb06cbd740d Mon Sep 17 00:00:00 2001 From: Matthew Hall Date: Fri, 25 Oct 2024 17:34:28 -0500 Subject: [PATCH 1/3] Add new build test for Math.max and Math.min Add functional tests for `Math.max` and `Math.min`: - to test specific float and double corner cases - test with NaN, +0, & -0, values to confirm respective omr changes - test all possible execution paths Signed-off-by: Matthew Hall --- .../recognizedMethod/TestJavaLangMath.java | 125 ++++++-- .../test/recognizedMethod/TestMathUtils.java | 300 ++++++++++++++++++ 2 files changed, 407 insertions(+), 18 deletions(-) create mode 100644 test/functional/JIT_Test/src/jit/test/recognizedMethod/TestMathUtils.java diff --git a/test/functional/JIT_Test/src/jit/test/recognizedMethod/TestJavaLangMath.java b/test/functional/JIT_Test/src/jit/test/recognizedMethod/TestJavaLangMath.java index adea8099ce5..ce1a5b5e4cb 100644 --- a/test/functional/JIT_Test/src/jit/test/recognizedMethod/TestJavaLangMath.java +++ b/test/functional/JIT_Test/src/jit/test/recognizedMethod/TestJavaLangMath.java @@ -23,28 +23,35 @@ package jit.test.recognizedMethod; import org.testng.AssertJUnit; import org.testng.annotations.Test; +import java.util.Random; +import org.testng.asserts.SoftAssert; +import static jit.test.recognizedMethod.TestMathUtils.*; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Listeners; + +@Test(singleThreaded=true) public class TestJavaLangMath { /** - * Tests the constant corner cases defined by the {@link Math.sqrt} method. - *

- * The JIT compiler will transform calls to {@link Math.sqrt} within this test - * into the following tree sequence: - * - * - * dsqrt - * dconst - * - * - * Subsequent tree simplification passes will attempt to reduce this constant - * operation to a dsqrt IL by performing the square root at compile - * time. The transformation will be performed when the function get executed - * twice, therefore, the "invocationCount=2" is needed. However we must ensure the - * result of the square root done by the compiler at compile time will be exactly - * the same as the result had it been done by the Java runtime at runtime. This - * test validates the results are the same. - */ + * Tests the constant corner cases defined by the {@link Math.sqrt} method. + *

+ * The JIT compiler will transform calls to {@link Math.sqrt} within this test + * into the following tree sequence: + * + * + * dsqrt + * dconst + * + * + * Subsequent tree simplification passes will attempt to reduce this constant + * operation to a dsqrt IL by performing the square root at compile + * time. The transformation will be performed when the function get executed + * twice, therefore, the "invocationCount=2" is needed. However we must ensure the + * result of the square root done by the compiler at compile time will be exactly + * the same as the result had it been done by the Java runtime at runtime. This + * test validates the results are the same. + */ @Test(groups = {"level.sanity"}, invocationCount=2) public void test_java_lang_Math_sqrt() { AssertJUnit.assertTrue(Double.isNaN(Math.sqrt(Double.NEGATIVE_INFINITY))); @@ -55,4 +62,86 @@ public void test_java_lang_Math_sqrt() { AssertJUnit.assertEquals(Double.POSITIVE_INFINITY, Math.sqrt(Double.POSITIVE_INFINITY)); AssertJUnit.assertTrue(Double.isNaN(Math.sqrt(Double.NaN))); } + + @Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="zeroProviderFD", dataProviderClass=TestMathUtils.class) + public void test_java_lang_Math_min_zeros_FD(Number a, Number b, boolean isFirstArg) { + if (a instanceof Float) { + float f1 = a.floatValue(); + float f2 = b.floatValue(); + assertEquals(Math.min(f1, f2), isFirstArg ? f1 : f2); + } else { + double f1 = a.doubleValue(); + double f2 = b.doubleValue(); + assertEquals(Math.min(f1, f2), isFirstArg ? f1 : f2); + } + } + + @Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="zeroProviderFD", dataProviderClass=TestMathUtils.class) + public void test_java_lang_Math_max_zeros_FD(Number a, Number b, boolean isFirstArg) { + if (a instanceof Float) { + float f1 = a.floatValue(); + float f2 = b.floatValue(); + assertEquals(Math.max(f1, f2), isFirstArg ? f2 : f1); + } else { + double f1 = a.doubleValue(); + double f2 = b.doubleValue(); + assertEquals(Math.max(f1, f2), isFirstArg ? f2 : f1); + } + } + + @Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="nanProviderFD", dataProviderClass=TestMathUtils.class) + public void test_java_lang_Math_min_nan_FD(Number a, Number b, Number expected) { + if (a instanceof Float) { + float f1 = a.floatValue(); + float f2 = b.floatValue(); + int exp = expected.intValue(); + assertEquals(Math.min(f1, f2), exp); + } else { + double f1 = a.doubleValue(); + double f2 = b.doubleValue(); + long exp = expected.longValue(); + assertEquals(Math.min(f1, f2), exp); + } + } + + @Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="nanProviderFD", dataProviderClass=TestMathUtils.class) + public void test_java_lang_Math_max_nan_FD(Number a, Number b, Number expected) { + if (a instanceof Float) { + float f1 = a.floatValue(); + float f2 = b.floatValue(); + int exp = expected.intValue(); + assertEquals(Math.max(f1, f2), exp); + } else { + double f1 = a.doubleValue(); + double f2 = b.doubleValue(); + long exp = expected.longValue(); + assertEquals(Math.max(f1, f2), exp); + } + } + + @Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="normalNumberProviderFD", dataProviderClass=TestMathUtils.class) + public void test_java_lang_Math_min_normal_FD(Number a, Number b){ + if (a instanceof Float) { + float f1 = a.floatValue(); + float f2 = b.floatValue(); + assertEquals(Math.min(f1, f2), f1 <= f2 ? f1 : f2); + } else { + double f1 = a.doubleValue(); + double f2 = b.doubleValue(); + assertEquals(Math.min(f1, f2), f1 <= f2 ? f1 : f2); + } + } + + @Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="normalNumberProviderFD", dataProviderClass=TestMathUtils.class) + public void test_java_lang_Math_max_normal_FD(Number a, Number b){ + if (a instanceof Float) { + float f1 = a.floatValue(); + float f2 = b.floatValue(); + assertEquals(Math.max(f1, f2), f1 >= f2 ? f1 : f2); + } else { + double f1 = a.doubleValue(); + double f2 = b.doubleValue(); + assertEquals(Math.max(f1, f2), f1 >= f2 ? f1 : f2); + } + } } diff --git a/test/functional/JIT_Test/src/jit/test/recognizedMethod/TestMathUtils.java b/test/functional/JIT_Test/src/jit/test/recognizedMethod/TestMathUtils.java new file mode 100644 index 00000000000..fab79f5d601 --- /dev/null +++ b/test/functional/JIT_Test/src/jit/test/recognizedMethod/TestMathUtils.java @@ -0,0 +1,300 @@ +/* + * Copyright IBM Corp. and others 2024 + * + * This program and the accompanying materials are made available under + * the terms of the Eclipse Public License 2.0 which accompanies this + * distribution and is available at https://www.eclipse.org/legal/epl-2.0/ + * or the Apache License, Version 2.0 which accompanies this distribution and + * is available at https://www.apache.org/licenses/LICENSE-2.0. + * + * This Source Code may also be made available under the following + * Secondary Licenses when the conditions for such availability set + * forth in the Eclipse Public License, v. 2.0 are satisfied: GNU + * General Public License, version 2 with the GNU Classpath + * Exception [1] and GNU General Public License, version 2 with the + * OpenJDK Assembly Exception [2]. + * + * [1] https://www.gnu.org/software/classpath/license.html + * [2] https://openjdk.org/legal/assembly-exception.html + * + * SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 OR GPL-2.0-only WITH OpenJDK-assembly-exception-1.0 + */ + +package jit.test.recognizedMethod; +import java.util.Random; +import org.testng.asserts.SoftAssert; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.testng.annotations.DataProvider; +import org.testng.AssertJUnit; + +public class TestMathUtils { + // constants used to various min/max tests + private static final int fqNaNBits = 0x7fcabdef; + private static final float fqNaN = Float.intBitsToFloat(fqNaNBits); + private static final int fsNaNBits = 0x7faedbaf; + private static final float fsNaN = Float.intBitsToFloat(fsNaNBits); + private static final int fnZeroBits = 0x80000000; + private static final float fpInf = Float.POSITIVE_INFINITY; + private static final float fnInf = Float.NEGATIVE_INFINITY; + private static final int fpInfBits = Float.floatToRawIntBits(fpInf); + private static final int fnInfBits = Float.floatToRawIntBits(fnInf); + private static final int fquietBit = 0x00400000; + private static final int fNaNExpStart = 0x7f800000; + private static final int fNaNMantisaMax = 0x00400000; + + private static final long dqNaNBits = 0x7ff800000000000fL; + private static final double dqNaN = Double.longBitsToDouble(dqNaNBits); + private static final long dsNaNBits = 0x7ff400000000000fL; + private static final double dsNaN = Double.longBitsToDouble(dsNaNBits); + private static final long dnZeroBits = 0x8000000000000000L; + private static final double dpInf = Double.POSITIVE_INFINITY; + private static final double dnInf = Double.NEGATIVE_INFINITY; + private static final long dpInfBits = Double.doubleToRawLongBits(dpInf); + private static final long dnInfBits = Double.doubleToRawLongBits(dnInf); + private static final long dNaNMantisaMax = 0x0008000000000000L; + private static final long dNaNExpStart = 0x7ff0000000000000L; + private static final long dquietBit = 0x0008000000000000L; + private static final Random r = new Random(); + + public static class NaNTestPair { + T first; + T second; + BitType bFirst; + BitType bSecond; + BitType expected; + } + + @DataProvider(name="zeroProviderFD") + public static Object[][] zeroProviderFD(){ + return new Object[][]{ + // arg1, arg2, expected + {+0.0f, +0.0f, true}, + {-0.0f, -0.0f, true}, + {+0.0f, -0.0f, false}, + {-0.0f, +0.0f, true}, + + {+0.0f, fpInf, true}, + {fpInf, +0.0f, false}, + {+0.0f, fnInf, false}, + {fnInf, +0.0f, true}, + {-0.0f, fpInf, true}, + {fpInf, -0.0f, false}, + {-0.0f, fnInf, false}, + {fnInf, -0.0f, true}, + + {+0.0d, +0.0d, true}, + {-0.0d, -0.0d, true}, + {+0.0d, -0.0d, false}, + {-0.0d, +0.0d, true}, + + {+0.0d, dpInf, true}, + {dpInf, +0.0d, false}, + {+0.0d, dnInf, false}, + {dnInf, +0.0d, true}, + {-0.0d, dpInf, true}, + {dpInf, -0.0d, false}, + {-0.0d, dnInf, false}, + {dnInf, -0.0d, true} + }; + } + + @DataProvider(name="nanProviderFD") + public static Object[][] nanProviderFD(){ + Object[][] constNanPairs = new Object[][] { + {fqNaN, fqNaN, fqNaNBits}, + {fsNaN, fsNaN, fsNaNBits}, + {fqNaN, fsNaN, fqNaNBits}, + {fsNaN, fqNaN, fsNaNBits}, + + {+0.0f, fqNaN, fqNaNBits}, + {fqNaN, +0.0f, fqNaNBits}, + {+0.0f, fsNaN, fsNaNBits}, + {fsNaN, +0.0f, fsNaNBits}, + {-0.0f, fqNaN, fqNaNBits}, + {fqNaN, -0.0f, fqNaNBits}, + {-0.0f, fsNaN, fsNaNBits}, + {fsNaN, -0.0f, fsNaNBits}, + + {dqNaN, dqNaN, dqNaNBits}, + {dsNaN, dsNaN, dsNaNBits}, + {dqNaN, dsNaN, dqNaNBits}, + {dsNaN, dqNaN, dsNaNBits}, + + {+0.0d, dqNaN, dqNaNBits}, + {dqNaN, +0.0d, dqNaNBits}, + {+0.0d, dsNaN, dsNaNBits}, + {dsNaN, +0.0d, dsNaNBits}, + {-0.0d, dqNaN, dqNaNBits}, + {dqNaN, -0.0d, dqNaNBits}, + {-0.0d, dsNaN, dsNaNBits}, + {dsNaN, -0.0d, dsNaNBits} + }; + + Object[][] nanPairs = new Object[60][3]; + int i = 0; + while (i < constNanPairs.length) { + nanPairs[i] = constNanPairs[i]; + i++; + } + while (i < nanPairs.length) { + NaNTestPair fpair = getFloatNaNPair(); + NaNTestPair dpair = getDoubleNaNPair(); + if (fpair == null) { + continue; + } + nanPairs[i][0] = fpair.first; + nanPairs[i][1] = fpair.second; + nanPairs[i][2] = fpair.expected; + i++; + if (dpair == null | i >= nanPairs.length){ + continue; + } + nanPairs[i][0] = dpair.first; + nanPairs[i][1] = dpair.second; + nanPairs[i][2] = dpair.expected; + i++; + } + return nanPairs; + } + + @DataProvider(name="normalNumberProviderFD") + public static Object[][] normalNumberProviderFD() { + Object [][] numList = new Object[20][2]; + for (int i = 0; i < numList.length / 2; i++) { + float a = r.nextFloat() * 1000000.0f; + float b = r.nextFloat() * 1000000.0f; + int sign1 = r.nextInt(2) == 0 ? -1 : 1; + int sign2 = r.nextInt(2) == 0 ? -1 : 1; + numList[i][0] = a * sign1; + numList[i][1] = b * sign2; + } + for (int i = numList.length / 2; i < 20; i++) { + double a = r.nextDouble() * 1000000000.0d; + double b = r.nextDouble() * 1000000000.0d; + int sign1 = r.nextInt(2) == 0 ? -1 : 1; + int sign2 = r.nextInt(2) == 0 ? -1 : 1; + numList[i][0] = a * sign1; + numList[i][1] = b * sign2; + } + return numList; + } + + private static NaNTestPair getDoubleNaNPair(){ + NaNTestPair p = new NaNTestPair(); + double farg1, farg2; + farg1 = farg2 = 0.0f; + + long iarg1; + long min = 1L << 13; + do { + iarg1 = r.nextLong(); + } while (iarg1 < min); + // shift right so that mantisa is non-zero + // shift of 13 instead of 12 ensures that the quiet bit is not set + iarg1 = iarg1 >> 13; + + int arg1Type = r.nextInt(3); // 0 -> quiet, 1 -> signalling + if (arg1Type == 0) { + iarg1 = iarg1 | dNaNExpStart | dquietBit; + farg1 = Double.longBitsToDouble(iarg1); + } else if (arg1Type == 1) { + iarg1 = iarg1 | dNaNExpStart; + farg1 = Double.longBitsToDouble(iarg1); + } else { + // Normal number (non NaN) + farg1 = r.nextDouble() * 1000000.0 + 1.0; + } + + // arg 2 + long iarg2; + do { + iarg2 = r.nextLong(); + iarg2 = iarg2 >> 13; + } while (iarg2 < min); + + int arg2Type = r.nextInt(3); + if (arg2Type == 0) { + iarg2 = iarg2 | dNaNExpStart | dquietBit; + farg2 = Double.longBitsToDouble(iarg2); + } + else if (arg2Type == 1) { + iarg2 = iarg2 | dNaNExpStart; + farg2 = Double.longBitsToDouble(iarg2); + } else { + if (arg1Type != 2) { + farg2 = r.nextDouble() * 1000000.0 + 1; + iarg2 = Double.doubleToRawLongBits(farg2); + } else { + return null; + } + } + + p.first = farg1; + p.bFirst = iarg1; + p.second = farg2; + p.bSecond = iarg2; + p.expected = (arg1Type != 2 ? iarg1 : iarg2); + return p; + } + + private static NaNTestPair getFloatNaNPair(){ // 0: max, 1: min + NaNTestPair p = new NaNTestPair(); + float farg1, farg2; + farg1 = farg2 = 0.0f; + + int iarg1 = r.nextInt(fNaNMantisaMax - 1) + 1; + int arg1Type = r.nextInt(3); // 0 -> quiet, 1 -> signalling + if (arg1Type == 0) { + iarg1 = iarg1 | fNaNExpStart | fquietBit; + farg1 = Float.intBitsToFloat(iarg1); + } else if (arg1Type == 1) { + iarg1 = iarg1 | fNaNExpStart; + farg1 = Float.intBitsToFloat(iarg1); + } else { + // Normal number (non NaN) + farg1 = r.nextFloat() * 1000000.0f + 1.0f; + } + + // arg 2 + int arg2Type = r.nextInt(3); + int iarg2 = r.nextInt(fNaNMantisaMax - 1) + 1; + if (arg2Type == 0) { + iarg2 = (r.nextInt(fNaNMantisaMax - 1) + 1) | fNaNExpStart | fquietBit; + farg2 = Float.intBitsToFloat(iarg2); + } else if (arg2Type == 1){ + iarg2 = (r.nextInt(fNaNMantisaMax - 1) + 1) | fNaNExpStart; + farg2 = Float.intBitsToFloat(iarg2); + } else { + if (arg1Type != 2) { + farg2 = r.nextFloat() * 1000000.0f + 1f; + iarg2 = Float.floatToRawIntBits(farg2); + } else { + return null; + } + } + + p.first = farg1; + p.second = farg2; + p.bFirst = iarg1; + p.bSecond = iarg2; + p.expected = (arg1Type != 2 ? iarg1 : iarg2); + return p; + } + + public static void assertEquals(float actual, float expected){ + AssertJUnit.assertEquals(Float.floatToRawIntBits(expected), Float.floatToRawIntBits(actual)); + } + + public static void assertEquals(double actual, double expected){ + AssertJUnit.assertEquals(Double.doubleToRawLongBits(expected), Double.doubleToRawLongBits(actual)); + } + + public static void assertEquals(float actual, int expected) { + AssertJUnit.assertEquals(expected, Float.floatToRawIntBits(actual)); + } + + public static void assertEquals(double actual, long expected){ + AssertJUnit.assertEquals(expected, Double.doubleToRawLongBits(actual)); + } +} From aefcc0774dca554fac4148cc0a1b6bb2262313fd Mon Sep 17 00:00:00 2001 From: Sarwat Shaheen Date: Mon, 22 Jan 2024 12:39:24 -0500 Subject: [PATCH 2/3] Enable inlining of fmax/fmin/dmax/dmin on Z - Adds java_lang_Math_max/min_float/double as a recognized method - Adds a SupportsInlineMath_MaxMin_FD flag to the Z code generator - Flag is only set in Z if the TR_disableInlineMath_MaxMin_FD environment variable is not set - If the flag is set, call nodes are transformed to a functionally equivalent tree that uses fmin/fmax/dmin/dmax nodes Signed-off-by: Sarwat Shaheen --- runtime/compiler/codegen/J9CodeGenerator.hpp | 11 +++++++++++ runtime/compiler/env/j9method.cpp | 4 ++++ .../optimizer/J9RecognizedCallTransformer.cpp | 16 ++++++++++++++++ runtime/compiler/z/codegen/J9CodeGenerator.cpp | 6 ++++++ 4 files changed, 37 insertions(+) diff --git a/runtime/compiler/codegen/J9CodeGenerator.hpp b/runtime/compiler/codegen/J9CodeGenerator.hpp index b5d425b2fe2..e286569e533 100644 --- a/runtime/compiler/codegen/J9CodeGenerator.hpp +++ b/runtime/compiler/codegen/J9CodeGenerator.hpp @@ -512,6 +512,16 @@ void addMonClass(TR::Node* monNode, TR_OpaqueClassBlock* clazz); */ void setSupportsInlineVectorizedHashCode() { _j9Flags.set(SupportsInlineVectorizedHashCode); } + /** \brief + * Determines whether the code generator supports inlining of java_lang_Math_max/min_F/D + */ + bool getSupportsInlineMath_MaxMin_FD() { return _j9Flags.testAny(SupportsInlineMath_MaxMin_FD); } + + /** \brief + * The code generator supports inlining of java_lang_Math_max/min_F/D + */ + void setSupportsInlineMath_MaxMin_FD() { _j9Flags.set(SupportsInlineMath_MaxMin_FD); } + /** * \brief * The number of nodes between a monext and the next monent before @@ -677,6 +687,7 @@ void addMonClass(TR::Node* monNode, TR_OpaqueClassBlock* clazz); SavesNonVolatileGPRsForGC = 0x00000800, SupportsInlineVectorizedMismatch = 0x00001000, SupportsInlineVectorizedHashCode = 0x00002000, + SupportsInlineMath_MaxMin_FD = 0x00002000, }; flags32_t _j9Flags; diff --git a/runtime/compiler/env/j9method.cpp b/runtime/compiler/env/j9method.cpp index 41c7b767625..69a742f18f9 100644 --- a/runtime/compiler/env/j9method.cpp +++ b/runtime/compiler/env/j9method.cpp @@ -5014,6 +5014,10 @@ TR_ResolvedJ9Method::setRecognizedMethodInfo(TR::RecognizedMethod rm) case TR::java_lang_Math_min_I: case TR::java_lang_Math_max_L: case TR::java_lang_Math_min_L: + case TR::java_lang_Math_max_F: + case TR::java_lang_Math_min_F: + case TR::java_lang_Math_max_D: + case TR::java_lang_Math_min_D: case TR::java_lang_Math_abs_I: case TR::java_lang_Math_abs_L: case TR::java_lang_Math_abs_F: diff --git a/runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp b/runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp index ca4274de589..615a91b12fb 100644 --- a/runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp +++ b/runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp @@ -1366,6 +1366,10 @@ bool J9::RecognizedCallTransformer::isInlineable(TR::TreeTop* treetop) case TR::java_lang_Math_min_I: case TR::java_lang_Math_max_L: case TR::java_lang_Math_min_L: + case TR::java_lang_Math_max_F: + case TR::java_lang_Math_min_F: + case TR::java_lang_Math_max_D: + case TR::java_lang_Math_min_D: return !comp()->getOption(TR_DisableMaxMinOptimization); case TR::java_lang_Math_multiplyHigh: return cg()->getSupportsLMulHigh(); @@ -1495,6 +1499,18 @@ void J9::RecognizedCallTransformer::transform(TR::TreeTop* treetop) case TR::java_lang_Math_min_L: processIntrinsicFunction(treetop, node, TR::lmin); break; + case TR::java_lang_Math_max_F: + processIntrinsicFunction(treetop, node, TR::fmax); + break; + case TR::java_lang_Math_min_F: + processIntrinsicFunction(treetop, node, TR::fmin); + break; + case TR::java_lang_Math_max_D: + processIntrinsicFunction(treetop, node, TR::dmax); + break; + case TR::java_lang_Math_min_D: + processIntrinsicFunction(treetop, node, TR::dmin); + break; case TR::java_lang_Math_multiplyHigh: processIntrinsicFunction(treetop, node, TR::lmulh); break; diff --git a/runtime/compiler/z/codegen/J9CodeGenerator.cpp b/runtime/compiler/z/codegen/J9CodeGenerator.cpp index ca13349d986..8bc9478abca 100644 --- a/runtime/compiler/z/codegen/J9CodeGenerator.cpp +++ b/runtime/compiler/z/codegen/J9CodeGenerator.cpp @@ -125,6 +125,12 @@ J9::Z::CodeGenerator::initialize() cg->setSupportsInlineEncodeASCII(); } + static bool disableInlineMath_MaxMin_FD = feGetEnv("TR_disableInlineMath_MaxMin_FD") != NULL; + if (!disableInlineMath_MaxMin_FD) + { + cg->setSupportsInlineMath_MaxMin_FD(); + } + static bool disableInlineVectorizedMismatch = feGetEnv("TR_disableInlineVectorizedMismatch") != NULL; if (cg->getSupportsArrayCmpLen() && #if defined(J9VM_GC_ENABLE_SPARSE_HEAP_ALLOCATION) From bc2c3cf7222cf7f650369593a1a18f9fe9d4bcef Mon Sep 17 00:00:00 2001 From: Matthew Hall Date: Wed, 2 Oct 2024 16:12:00 -0400 Subject: [PATCH 3/3] Support Math.max/min for floating points w.r.t java spec - spearate evaluators for J9 vs OMR to support differing behaviour (OMR complies with IEEE_754, while J9 returns the first NaN (if present) - +0.0 compares as strictly greater than -0.0 Signed-off-by: Matthew Hall --- runtime/compiler/codegen/J9CodeGenerator.hpp | 2 +- .../optimizer/J9RecognizedCallTransformer.cpp | 3 +- .../compiler/z/codegen/J9CodeGenerator.cpp | 20 ++-- .../compiler/z/codegen/J9TreeEvaluator.cpp | 110 ++++++------------ .../compiler/z/codegen/J9TreeEvaluator.hpp | 6 +- 5 files changed, 54 insertions(+), 87 deletions(-) diff --git a/runtime/compiler/codegen/J9CodeGenerator.hpp b/runtime/compiler/codegen/J9CodeGenerator.hpp index e286569e533..a1a2e545ce1 100644 --- a/runtime/compiler/codegen/J9CodeGenerator.hpp +++ b/runtime/compiler/codegen/J9CodeGenerator.hpp @@ -687,7 +687,7 @@ void addMonClass(TR::Node* monNode, TR_OpaqueClassBlock* clazz); SavesNonVolatileGPRsForGC = 0x00000800, SupportsInlineVectorizedMismatch = 0x00001000, SupportsInlineVectorizedHashCode = 0x00002000, - SupportsInlineMath_MaxMin_FD = 0x00002000, + SupportsInlineMath_MaxMin_FD = 0x00004000, }; flags32_t _j9Flags; diff --git a/runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp b/runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp index 615a91b12fb..3341837ba7d 100644 --- a/runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp +++ b/runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp @@ -1366,11 +1366,12 @@ bool J9::RecognizedCallTransformer::isInlineable(TR::TreeTop* treetop) case TR::java_lang_Math_min_I: case TR::java_lang_Math_max_L: case TR::java_lang_Math_min_L: + return !comp()->getOption(TR_DisableMaxMinOptimization); case TR::java_lang_Math_max_F: case TR::java_lang_Math_min_F: case TR::java_lang_Math_max_D: case TR::java_lang_Math_min_D: - return !comp()->getOption(TR_DisableMaxMinOptimization); + return !comp()->getOption(TR_DisableMaxMinOptimization) && cg()->getSupportsInlineMath_MaxMin_FD(); case TR::java_lang_Math_multiplyHigh: return cg()->getSupportsLMulHigh(); case TR::java_lang_StringUTF16_toBytes: diff --git a/runtime/compiler/z/codegen/J9CodeGenerator.cpp b/runtime/compiler/z/codegen/J9CodeGenerator.cpp index 8bc9478abca..73f81a8ac87 100644 --- a/runtime/compiler/z/codegen/J9CodeGenerator.cpp +++ b/runtime/compiler/z/codegen/J9CodeGenerator.cpp @@ -125,7 +125,7 @@ J9::Z::CodeGenerator::initialize() cg->setSupportsInlineEncodeASCII(); } - static bool disableInlineMath_MaxMin_FD = feGetEnv("TR_disableInlineMath_MaxMin_FD") != NULL; + static bool disableInlineMath_MaxMin_FD = feGetEnv("TR_disableInlineMaxMin") != NULL; if (!disableInlineMath_MaxMin_FD) { cg->setSupportsInlineMath_MaxMin_FD(); @@ -4079,20 +4079,24 @@ J9::Z::CodeGenerator::inlineDirectCall( } } - if (!comp->getOption(TR_DisableSIMDDoubleMaxMin) && cg->getSupportsVectorRegisters()) - { - switch (methodSymbol->getRecognizedMethod()) - { + if (!self()->comp()->getOption(TR_DisableMaxMinOptimization) && cg->getSupportsInlineMath_MaxMin_FD()) { + switch (methodSymbol->getRecognizedMethod()) { case TR::java_lang_Math_max_D: - resultReg = TR::TreeEvaluator::inlineDoubleMax(node, cg); + resultReg = J9::Z::TreeEvaluator::dmaxEvaluator(node, cg); return true; case TR::java_lang_Math_min_D: - resultReg = TR::TreeEvaluator::inlineDoubleMin(node, cg); + resultReg = J9::Z::TreeEvaluator::dminEvaluator(node, cg); + return true; + case TR::java_lang_Math_max_F: + resultReg = J9::Z::TreeEvaluator::fmaxEvaluator(node, cg); + return true; + case TR::java_lang_Math_min_F: + resultReg = J9::Z::TreeEvaluator::fminEvaluator(node, cg); return true; default: break; - } } + } switch (methodSymbol->getRecognizedMethod()) { diff --git a/runtime/compiler/z/codegen/J9TreeEvaluator.cpp b/runtime/compiler/z/codegen/J9TreeEvaluator.cpp index 2c7dbe4b29b..1c9d4eecf91 100644 --- a/runtime/compiler/z/codegen/J9TreeEvaluator.cpp +++ b/runtime/compiler/z/codegen/J9TreeEvaluator.cpp @@ -906,76 +906,48 @@ allocateWriteBarrierInternalPointerRegister(TR::CodeGenerator * cg, TR::Node * s } -extern TR::Register * -doubleMaxMinHelper(TR::Node *node, TR::CodeGenerator *cg, bool isMaxOp) +TR::Register * +J9::Z::TreeEvaluator::dmaxEvaluator(TR::Node * node, TR::CodeGenerator * cg) { - TR_ASSERT(node->getNumChildren() >= 1 || node->getNumChildren() <= 2, "node has incorrect number of children"); - - /* ===================== Allocating Registers ===================== */ - - TR::Register * v16 = cg->allocateRegister(TR_VRF); - TR::Register * v17 = cg->allocateRegister(TR_VRF); - TR::Register * v18 = cg->allocateRegister(TR_VRF); - - /* ===================== Generating instructions ===================== */ - - /* ====== LD FPR0,16(GPR5) Load a ====== */ - TR::Register * v0 = cg->fprClobberEvaluate(node->getFirstChild()); - - /* ====== LD FPR2, 0(GPR5) Load b ====== */ - TR::Register * v2 = cg->evaluate(node->getSecondChild()); - - /* ====== WFTCIDB V16,V0,X'F' a == NaN ====== */ - generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v16, v0, 0xF, 8, 3); - - /* ====== For Max: WFCHE V17,V0,V2 Compare a >= b ====== */ - if(isMaxOp) + if (cg->getSupportsVectorRegisters()) { - generateVRRcInstruction(cg, TR::InstOpCode::VFCH, node, v17, v0, v2, 0, 8, 3); + cg->generateDebugCounter("z13/simd/doubleMax", 1, TR::DebugCounter::Free); + return OMR::Z::TreeEvaluator::fpMinMaxVectorHelper(node, cg); } - /* ====== For Min: WFCHE V17,V0,V2 Compare a <= b ====== */ - else + return OMR::Z::TreeEvaluator::xmaxxminHelper(node, cg); + } + +TR::Register * +J9::Z::TreeEvaluator::dminEvaluator(TR::Node * node, TR::CodeGenerator * cg) + { + if (cg->getSupportsVectorRegisters()) { - generateVRRcInstruction(cg, TR::InstOpCode::VFCH, node, v17, v2, v0, 0, 8, 3); + cg->generateDebugCounter("z13/simd/doubleMin", 1, TR::DebugCounter::Free); + return OMR::Z::TreeEvaluator::fpMinMaxVectorHelper(node, cg); } + return OMR::Z::TreeEvaluator::xmaxxminHelper(node, cg); + } - /* ====== VO V16,V16,V17 (a >= b) || (a == NaN) ====== */ - generateVRRcInstruction(cg, TR::InstOpCode::VO, node, v16, v16, v17, 0, 0, 0); - - /* ====== For Max: WFTCIDB V17,V0,X'800' a == +0 ====== */ - if(isMaxOp) - { - generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v17, v0, 0x800, 8, 3); - } - /* ====== For Min: WFTCIDB V17,V0,X'400' a == -0 ====== */ - else - { - generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v17, v0, 0x400, 8, 3); - } - /* ====== WFTCIDB V18,V2,X'C00' b == 0 ====== */ - generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v18, v2, 0xC00, 8, 3); - - /* ====== VN V17,V17,V18 (a == -0) && (b == 0) ====== */ - generateVRRcInstruction(cg, TR::InstOpCode::VN, node, v17, v17, v18, 0, 0, 0); - - /* ====== VO V16,V16,V17 (a >= b) || (a == NaN) || ((a == -0) && (b == 0)) ====== */ - generateVRRcInstruction(cg, TR::InstOpCode::VO, node, v16, v16, v17, 0, 0, 0); - - /* ====== VSEL V0,V0,V2,V16 ====== */ - generateVRReInstruction(cg, TR::InstOpCode::VSEL, node, v0, v0, v2, v16); - - /* ===================== Deallocating Registers ===================== */ - cg->stopUsingRegister(v2); - cg->stopUsingRegister(v16); - cg->stopUsingRegister(v17); - cg->stopUsingRegister(v18); - - node->setRegister(v0); - - cg->decReferenceCount(node->getFirstChild()); - cg->decReferenceCount(node->getSecondChild()); +TR::Register * +J9::Z::TreeEvaluator::fmaxEvaluator(TR::Node * node, TR::CodeGenerator * cg) + { + if (cg->getSupportsVectorRegisters()) + { + cg->generateDebugCounter("z13/simd/floatMax", 1, TR::DebugCounter::Free); + return OMR::Z::TreeEvaluator::fpMinMaxVectorHelper(node, cg); + } + return OMR::Z::TreeEvaluator::xmaxxminHelper(node, cg); + } - return node->getRegister(); +TR::Register * +J9::Z::TreeEvaluator::fminEvaluator(TR::Node * node, TR::CodeGenerator * cg) + { + if (cg->getSupportsVectorRegisters()) + { + cg->generateDebugCounter("z13/simd/floatMin", 1, TR::DebugCounter::Free); + return OMR::Z::TreeEvaluator::fpMinMaxVectorHelper(node, cg); + } + return OMR::Z::TreeEvaluator::xmaxxminHelper(node, cg); } TR::Register* @@ -2945,19 +2917,7 @@ J9::Z::TreeEvaluator::toLowerIntrinsic(TR::Node *node, TR::CodeGenerator *cg, bo return caseConversionHelper(node, cg, false, isCompressedString); } -TR::Register* -J9::Z::TreeEvaluator::inlineDoubleMax(TR::Node *node, TR::CodeGenerator *cg) - { - cg->generateDebugCounter("z13/simd/doubleMax", 1, TR::DebugCounter::Free); - return doubleMaxMinHelper(node, cg, true); - } -TR::Register* -J9::Z::TreeEvaluator::inlineDoubleMin(TR::Node *node, TR::CodeGenerator *cg) - { - cg->generateDebugCounter("z13/simd/doubleMin", 1, TR::DebugCounter::Free); - return doubleMaxMinHelper(node, cg, false); - } TR::Register * J9::Z::TreeEvaluator::inlineMathFma(TR::Node *node, TR::CodeGenerator *cg) diff --git a/runtime/compiler/z/codegen/J9TreeEvaluator.hpp b/runtime/compiler/z/codegen/J9TreeEvaluator.hpp index da2286d3b73..c21264611ae 100644 --- a/runtime/compiler/z/codegen/J9TreeEvaluator.hpp +++ b/runtime/compiler/z/codegen/J9TreeEvaluator.hpp @@ -126,8 +126,10 @@ class OMR_EXTENSIBLE TreeEvaluator: public J9::TreeEvaluator */ static TR::Register *inlineVectorizedStringIndexOf(TR::Node *node, TR::CodeGenerator *cg, bool isCompressed); static TR::Register *inlineIntrinsicIndexOf(TR::Node *node, TR::CodeGenerator *cg, bool isLatin1); - static TR::Register *inlineDoubleMax(TR::Node *node, TR::CodeGenerator *cg); - static TR::Register *inlineDoubleMin(TR::Node *node, TR::CodeGenerator *cg); + static TR::Register *fminEvaluator(TR::Node *node, TR::CodeGenerator *cg); + static TR::Register *dminEvaluator(TR::Node *node, TR::CodeGenerator *cg); + static TR::Register *fmaxEvaluator(TR::Node *node, TR::CodeGenerator *cg); + static TR::Register *dmaxEvaluator(TR::Node *node, TR::CodeGenerator *cg); static TR::Register *inlineMathFma(TR::Node *node, TR::CodeGenerator *cg); /* This Evaluator generates the SIMD routine for methods