From 28ce56c054d122ae68f605f88391009996e19d0c Mon Sep 17 00:00:00 2001 From: Zoltan Herczeg Date: Fri, 18 Oct 2024 08:01:02 +0000 Subject: [PATCH] Implement relaxed simd operations on ARM-64 Signed-off-by: Zoltan Herczeg zherczeg.u-szeged@partner.samsung.com --- .github/workflows/actions.yml | 4 +- src/jit/ByteCodeParser.cpp | 4 +- src/jit/SimdArm64Inl.h | 162 ++++++++++++++++++++++++++++------ 3 files changed, 137 insertions(+), 33 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 472b1763f..6502222db 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -182,7 +182,7 @@ jobs: with: submodules: true - name: Build in arm32 container - uses: uraimo/run-on-arch-action@v2.7.2 + uses: uraimo/run-on-arch-action@v2.8.1 with: arch: armv7 distro: ubuntu_latest @@ -214,7 +214,7 @@ jobs: with: submodules: true - name: Build in arm64 container - uses: uraimo/run-on-arch-action@v2.7.2 + uses: uraimo/run-on-arch-action@v2.8.1 with: arch: aarch64 distro: ubuntu22.04 diff --git a/src/jit/ByteCodeParser.cpp b/src/jit/ByteCodeParser.cpp index a1c0c9abc..d78e87ff1 100644 --- a/src/jit/ByteCodeParser.cpp +++ b/src/jit/ByteCodeParser.cpp @@ -295,7 +295,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module) #define OTPopcntV128 OTOp1V128 #define OTSwizzleV128 OTOp2V128 #define OTShiftV128Tmp OTShiftV128 -#define OTOp3DotAddV128 OTOp2V128 +#define OTOp3DotAddV128 OTOp3V128 #elif (defined SLJIT_CONFIG_ARM_32 && SLJIT_CONFIG_ARM_32) @@ -314,7 +314,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module) #define OTPMinMaxV128 OTOp2V128 #define OTPopcntV128 OTOp1V128 #define OTShiftV128Tmp OTShiftV128 -#define OTOp3DotAddV128 OTOp2V128 +#define OTOp3DotAddV128 OTOp3V128 #endif /* SLJIT_CONFIG_ARM */ diff --git a/src/jit/SimdArm64Inl.h b/src/jit/SimdArm64Inl.h index 304290f52..282c604c1 100644 --- a/src/jit/SimdArm64Inl.h +++ b/src/jit/SimdArm64Inl.h @@ -43,6 +43,8 @@ enum Type : uint32_t { fdiv = 0x6e20fc00, fmax = 0x4e20f400, fmin = 0x4ea0f400, + fmla = 0x4e20cc00, + fmls = 0x4ea0cc00, fmul = 0x6e20dc00, fneg = 0x6ea0f800, frintm = 0x4e219800, // floor @@ -185,11 +187,15 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr) break; case ByteCode::I32X4TruncSatF32X4SOpcode: case ByteCode::I32X4TruncSatF32X4UOpcode: + case ByteCode::I32X4RelaxedTruncF32X4SOpcode: + case ByteCode::I32X4RelaxedTruncF32X4UOpcode: srcType = SLJIT_SIMD_ELEM_32 | SLJIT_SIMD_FLOAT; dstType = SLJIT_SIMD_ELEM_32; break; case ByteCode::I32X4TruncSatF64X2SZeroOpcode: case ByteCode::I32X4TruncSatF64X2UZeroOpcode: + case ByteCode::I32X4RelaxedTruncF64X2SZeroOpcode: + case ByteCode::I32X4RelaxedTruncF64X2UZeroOpcode: srcType = SLJIT_SIMD_ELEM_64 | SLJIT_SIMD_FLOAT; dstType = SLJIT_SIMD_ELEM_32; break; @@ -336,16 +342,20 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr) simdEmitOp(compiler, SimdOp::uxtl | (0x1 << 20) | (0x1 << 30), dst, args[0].arg, 0); break; case ByteCode::I32X4TruncSatF32X4SOpcode: + case ByteCode::I32X4RelaxedTruncF32X4SOpcode: simdEmitOp(compiler, SimdOp::fcvtzs | SimdOp::FS4, dst, args[0].arg, 0); break; case ByteCode::I32X4TruncSatF32X4UOpcode: + case ByteCode::I32X4RelaxedTruncF32X4UOpcode: simdEmitOp(compiler, SimdOp::fcvtzu | SimdOp::FS4, dst, args[0].arg, 0); break; case ByteCode::I32X4TruncSatF64X2SZeroOpcode: + case ByteCode::I32X4RelaxedTruncF64X2SZeroOpcode: simdEmitOp(compiler, SimdOp::fcvtzs | SimdOp::FD2, dst, args[0].arg, 0); simdEmitOp(compiler, SimdOp::sqxtn | SimdOp::S4, dst, dst, 0); break; case ByteCode::I32X4TruncSatF64X2UZeroOpcode: + case ByteCode::I32X4RelaxedTruncF64X2UZeroOpcode: simdEmitOp(compiler, SimdOp::fcvtzu | SimdOp::FD2, dst, args[0].arg, 0); simdEmitOp(compiler, SimdOp::uqxtn | SimdOp::S4, dst, dst, 0); break; @@ -587,28 +597,26 @@ static void simdEmitNarrowUnsigned(sljit_compiler* compiler, sljit_s32 rd, sljit simdEmitOp(compiler, SimdOp::sqxtun | size | (0x1 << 30), rd, rm, 0); } -static void simdEmitDot(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm) +static void simdEmitDot(sljit_compiler* compiler, uint32_t type, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm) { + // The rd can be tmpReg1 #ifdef __ARM_FEATURE_DOTPROD - simdEmitOp(compiler, SimdOp::sdot | SimdOp::S4, rd, rn, rm); + simdEmitOp(compiler, SimdOp::sdot | type, rd, rn, rm); #else - auto tmpReg1 = SLJIT_TMP_FR0; - auto tmpReg2 = SLJIT_TMP_FR1; + sljit_s32 tmpReg1 = SLJIT_TMP_FR0; + sljit_s32 tmpReg2 = SLJIT_TMP_FR1; + uint32_t lowType = type - (0x1 << SimdOp::sizeOffset); - // tmpReg1 = rn * rm lower - simdEmitOp(compiler, SimdOp::smull | SimdOp::H8, tmpReg1, rn, rm); // tmpReg2 = rn * rm upper - simdEmitOp(compiler, SimdOp::smull | SimdOp::H8 | (0x1 << 30), tmpReg2, rn, rm); - // rd = tmpReg1[1], tmpReg2[1], tmpReg1[0], tmpReg2[0] - simdEmitOp(compiler, SimdOp::zip1 | SimdOp::S4, rd, tmpReg1, tmpReg2); - // tmpReg1 = tmpReg1[3], tmpReg2[3], tmpReg1[2], tmpReg2[2] - simdEmitOp(compiler, SimdOp::zip2 | SimdOp::S4, tmpReg1, tmpReg1, tmpReg2); - // rd = rd[3] + rd[2], rd[1] + rd[0] - simdEmitOp(compiler, SimdOp::saddlp | SimdOp::S4, rd, rd, 0); - // tmpReg1 = tmpReg1[3] + tmpReg1[2], tmpReg1[1] + tmpReg1[0] - simdEmitOp(compiler, SimdOp::saddlp | SimdOp::S4, tmpReg1, tmpReg1, 0); - // rd = rd[1], tmpReg1[1], rd[0], tmpReg1[0] - simdEmitOp(compiler, SimdOp::uzp1 | SimdOp::S4, rd, rd, tmpReg1); + simdEmitOp(compiler, SimdOp::smull | lowType | (0x1 << 30), tmpReg2, rn, rm); + // tmpReg1 = rn * rm lower + simdEmitOp(compiler, SimdOp::smull | lowType, tmpReg1, rn, rm); + // Widening result + simdEmitOp(compiler, SimdOp::saddlp | type, tmpReg2, tmpReg2, 0); + simdEmitOp(compiler, SimdOp::saddlp | type, tmpReg1, tmpReg1, 0); + // Combine + narrow + simdEmitOp(compiler, SimdOp::xtn | type, rd, tmpReg1, 0); + simdEmitOp(compiler, SimdOp::xtn | type | (0x1 << 30), rd, tmpReg2, 0); #endif } @@ -643,6 +651,7 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) case ByteCode::I8X16MaxUOpcode: case ByteCode::I8X16AvgrUOpcode: case ByteCode::I8X16SwizzleOpcode: + case ByteCode::I8X16RelaxedSwizzleOpcode: srcType = SLJIT_SIMD_ELEM_8; dstType = SLJIT_SIMD_ELEM_8; break; @@ -674,9 +683,11 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) case ByteCode::I16X8MaxUOpcode: case ByteCode::I16X8AvgrUOpcode: case ByteCode::I16X8Q15mulrSatSOpcode: + case ByteCode::I16X8RelaxedQ15mulrSOpcode: srcType = SLJIT_SIMD_ELEM_16; dstType = SLJIT_SIMD_ELEM_16; break; + case ByteCode::I16X8DotI8X16I7X16SOpcode: case ByteCode::I16X8ExtmulLowI8X16SOpcode: case ByteCode::I16X8ExtmulHighI8X16SOpcode: case ByteCode::I16X8ExtmulLowI8X16UOpcode: @@ -750,6 +761,8 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) case ByteCode::F32X4PMaxOpcode: case ByteCode::F32X4MaxOpcode: case ByteCode::F32X4MinOpcode: + case ByteCode::F32X4RelaxedMaxOpcode: + case ByteCode::F32X4RelaxedMinOpcode: srcType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_32; dstType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_32; break; @@ -768,6 +781,8 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) case ByteCode::F64X2PMaxOpcode: case ByteCode::F64X2MaxOpcode: case ByteCode::F64X2MinOpcode: + case ByteCode::F64X2RelaxedMaxOpcode: + case ByteCode::F64X2RelaxedMinOpcode: srcType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_64; dstType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_64; break; @@ -965,8 +980,12 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) simdEmitNarrowUnsigned(compiler, dst, args[0].arg, args[1].arg, SimdOp::H8); break; case ByteCode::I16X8Q15mulrSatSOpcode: + case ByteCode::I16X8RelaxedQ15mulrSOpcode: simdEmitOp(compiler, SimdOp::sqrdmulh | SimdOp::H8, dst, args[0].arg, args[1].arg); break; + case ByteCode::I16X8DotI8X16I7X16SOpcode: + simdEmitDot(compiler, SimdOp::H8, dst, args[0].arg, args[1].arg); + break; case ByteCode::I32X4AddOpcode: simdEmitOp(compiler, SimdOp::add | SimdOp::S4, dst, args[0].arg, args[1].arg); break; @@ -1032,7 +1051,7 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) simdEmitOp(compiler, SimdOp::umull | SimdOp::H8 | (0x1 << 30), dst, args[0].arg, args[1].arg); break; case ByteCode::I32X4DotI16X8SOpcode: - simdEmitDot(compiler, dst, args[0].arg, args[1].arg); + simdEmitDot(compiler, SimdOp::S4, dst, args[0].arg, args[1].arg); break; case ByteCode::I64X2AddOpcode: simdEmitOp(compiler, SimdOp::add | SimdOp::D2, dst, args[0].arg, args[1].arg); @@ -1069,9 +1088,11 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) simdEmitOp(compiler, SimdOp::fdiv | SimdOp::FS4, dst, args[0].arg, args[1].arg); break; case ByteCode::F32X4MaxOpcode: + case ByteCode::F32X4RelaxedMaxOpcode: simdEmitOp(compiler, SimdOp::fmax | SimdOp::FS4, dst, args[0].arg, args[1].arg); break; case ByteCode::F32X4MinOpcode: + case ByteCode::F32X4RelaxedMinOpcode: simdEmitOp(compiler, SimdOp::fmin | SimdOp::FS4, dst, args[0].arg, args[1].arg); break; case ByteCode::F32X4MulOpcode: @@ -1093,9 +1114,11 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) simdEmitOp(compiler, SimdOp::fdiv | SimdOp::FD2, dst, args[0].arg, args[1].arg); break; case ByteCode::F64X2MaxOpcode: + case ByteCode::F64X2RelaxedMaxOpcode: simdEmitOp(compiler, SimdOp::fmax | SimdOp::FD2, dst, args[0].arg, args[1].arg); break; case ByteCode::F64X2MinOpcode: + case ByteCode::F64X2RelaxedMinOpcode: simdEmitOp(compiler, SimdOp::fmin | SimdOp::FD2, dst, args[0].arg, args[1].arg); break; case ByteCode::F64X2MulOpcode: @@ -1154,6 +1177,7 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) simdEmitOp(compiler, SimdOp::bic, dst, args[0].arg, args[1].arg); break; case ByteCode::I8X16SwizzleOpcode: + case ByteCode::I8X16RelaxedSwizzleOpcode: simdEmitOp(compiler, SimdOp::tbl, dst, args[0].arg, args[1].arg); break; default: @@ -1166,26 +1190,106 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr) } } +static void simdEmitDotAdd(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm, sljit_s32 ro) +{ + sljit_s32 tmpReg1 = SLJIT_TMP_FR0; + + simdEmitDot(compiler, SimdOp::H8, tmpReg1, rn, rm); + simdEmitOp(compiler, SimdOp::saddlp | SimdOp::H8, tmpReg1, tmpReg1, 0); + simdEmitOp(compiler, SimdOp::add | SimdOp::S4, rd, ro, tmpReg1); +} + static void emitTernarySIMD(sljit_compiler* compiler, Instruction* instr) { Operand* operands = instr->operands(); - JITArg args[3]; + JITArg args[4]; + + sljit_s32 srcType = SLJIT_SIMD_ELEM_128; + sljit_s32 dstType = SLJIT_SIMD_ELEM_128; + bool moveToDst = true; + + switch (instr->opcode()) { + case ByteCode::V128BitSelectOpcode: + srcType = SLJIT_SIMD_ELEM_128; + dstType = SLJIT_SIMD_ELEM_128; + break; + case ByteCode::I8X16RelaxedLaneSelectOpcode: + srcType = SLJIT_SIMD_ELEM_8; + dstType = SLJIT_SIMD_ELEM_8; + break; + case ByteCode::I16X8RelaxedLaneSelectOpcode: + srcType = SLJIT_SIMD_ELEM_16; + dstType = SLJIT_SIMD_ELEM_16; + break; + case ByteCode::I32X4RelaxedLaneSelectOpcode: + srcType = SLJIT_SIMD_ELEM_32; + dstType = SLJIT_SIMD_ELEM_32; + break; + case ByteCode::I64X2RelaxedLaneSelectOpcode: + srcType = SLJIT_SIMD_ELEM_64; + dstType = SLJIT_SIMD_ELEM_64; + break; + case ByteCode::I32X4DotI8X16I7X16AddSOpcode: + srcType = SLJIT_SIMD_ELEM_8; + dstType = SLJIT_SIMD_ELEM_32; + moveToDst = false; + break; + case ByteCode::F32X4RelaxedMaddOpcode: + case ByteCode::F32X4RelaxedNmaddOpcode: + srcType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_32; + dstType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_32; + break; + case ByteCode::F64X2RelaxedMaddOpcode: + case ByteCode::F64X2RelaxedNmaddOpcode: + srcType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_64; + dstType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_64; + break; + default: + ASSERT_NOT_REACHED(); + break; + } - simdOperandToArg(compiler, operands, args[0], SLJIT_SIMD_ELEM_128, instr->requiredReg(0)); - simdOperandToArg(compiler, operands + 1, args[1], SLJIT_SIMD_ELEM_128, instr->requiredReg(1)); - simdOperandToArg(compiler, operands + 2, args[2], SLJIT_SIMD_ELEM_128, instr->requiredReg(2)); + simdOperandToArg(compiler, operands, args[0], srcType, instr->requiredReg(0)); + simdOperandToArg(compiler, operands + 1, args[1], srcType, instr->requiredReg(1)); + simdOperandToArg(compiler, operands + 2, args[2], dstType, instr->requiredReg(2)); - sljit_s32 dst = instr->requiredReg(2); + args[3].set(operands + 3); + sljit_s32 dst = GET_TARGET_REG(args[3].arg, instr->requiredReg(2)); - if (dst != args[2].arg) { - sljit_emit_simd_mov(compiler, SLJIT_SIMD_LOAD | SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_128, dst, args[2].arg, args[2].argw); + if (moveToDst && dst != args[2].arg) { + sljit_emit_simd_mov(compiler, SLJIT_SIMD_REG_128 | srcType, dst, args[2].arg, 0); } - simdEmitOp(compiler, SimdOp::bsl, dst, args[0].arg, args[1].arg); + switch (instr->opcode()) { + case ByteCode::V128BitSelectOpcode: + case ByteCode::I8X16RelaxedLaneSelectOpcode: + case ByteCode::I16X8RelaxedLaneSelectOpcode: + case ByteCode::I32X4RelaxedLaneSelectOpcode: + case ByteCode::I64X2RelaxedLaneSelectOpcode: + simdEmitOp(compiler, SimdOp::bsl, dst, args[0].arg, args[1].arg); + break; + case ByteCode::I32X4DotI8X16I7X16AddSOpcode: + simdEmitDotAdd(compiler, dst, args[0].arg, args[1].arg, args[2].arg); + break; + case ByteCode::F32X4RelaxedMaddOpcode: + simdEmitOp(compiler, SimdOp::fmla | SimdOp::FS4, dst, args[0].arg, args[1].arg); + break; + case ByteCode::F32X4RelaxedNmaddOpcode: + simdEmitOp(compiler, SimdOp::fmls | SimdOp::FS4, dst, args[0].arg, args[1].arg); + break; + case ByteCode::F64X2RelaxedMaddOpcode: + simdEmitOp(compiler, SimdOp::fmla | SimdOp::FD2, dst, args[0].arg, args[1].arg); + break; + case ByteCode::F64X2RelaxedNmaddOpcode: + simdEmitOp(compiler, SimdOp::fmls | SimdOp::FD2, dst, args[0].arg, args[1].arg); + break; + default: + ASSERT_NOT_REACHED(); + break; + } - args[2].set(operands + 3); - if (SLJIT_IS_MEM(args[2].arg)) { - sljit_emit_simd_mov(compiler, SLJIT_SIMD_STORE | SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_128, dst, args[2].arg, args[2].argw); + if (SLJIT_IS_MEM(args[3].arg)) { + sljit_emit_simd_mov(compiler, SLJIT_SIMD_STORE | SLJIT_SIMD_REG_128 | dstType, dst, args[3].arg, args[3].argw); } }