Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/branch-24.02' into gerashegalov/…
Browse files Browse the repository at this point in the history
…issue9992
  • Loading branch information
gerashegalov committed Dec 21, 2023
2 parents 6050f4b + bb235c9 commit 7119ec4
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 8 deletions.
8 changes: 2 additions & 6 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,21 +456,17 @@ def test_date_format_for_date(data_gen, date_format):
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
def test_date_format_for_time(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
conf)
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)))

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.skipif(is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
@allow_non_gpu('ProjectExec')
def test_date_format_for_time_fall_back(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
'ProjectExec',
conf)
'ProjectExec')

@pytest.mark.parametrize('date_format', supported_date_formats + ['yyyyMMdd'], ids=idfn)
# from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
val (args, rest) = stack.splitAt(n + 1)
(args.reverse, rest)
})
case Opcode.INVOKEDYNAMIC =>
invokedynamic(lambdaReflection, state,
(stack, n) => {
val (args, rest) = stack.splitAt(n)
(args.reverse, rest)
})
case _ => throw new SparkException("Unsupported instruction: " + instructionStr)
}
logDebug(s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}")
Expand Down Expand Up @@ -563,6 +569,31 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
}
}

private def invokedynamic(lambdaReflection: LambdaReflection, state: State,
getArgs: (List[Expression], Int) =>
(List[Expression], List[Expression])): State = {
val State(locals, stack, cond, expr) = state
val (bootstrapMethod, bootstrapArgs) = lambdaReflection.lookupBootstrapMethod(operand)
val declaringClass = bootstrapMethod.getDeclaringClass
val declaringClassName = declaringClass.getName
val newstack = {
if (declaringClassName.equals("java.lang.invoke.StringConcatFactory") &&
bootstrapMethod.getName.equals("makeConcatWithConstants") &&
bootstrapArgs.length == 1) {
val recipe = bootstrapArgs.head.toString
if (recipe.contains('\u0002')) {
throw new SparkException("Unsupported instruction: " + instructionStr)
}
val (args, rest) = getArgs(stack, recipe.count{x => x == '\u0001'})
Concat(recipe.split('\u0001').zipAll(args, "", Literal(""))
.map{ case(x, y) => Concat(Seq(Literal(x), y))}.toSeq) :: rest
} else {
throw new SparkException("Unsupported instruction: " + instructionStr)
}
}
State(locals, newstack, cond, expr)
}

private def checkArgs(methodName: String,
expectedTypes: List[DataType],
args: List[Expression]): Unit = {
Expand Down Expand Up @@ -958,7 +989,7 @@ object Instruction {
codeIterator.byteAt(offset + 1)
case Opcode.BIPUSH =>
codeIterator.signedByteAt(offset + 1)
case Opcode.LDC_W | Opcode.LDC2_W | Opcode.NEW | Opcode.CHECKCAST |
case Opcode.LDC_W | Opcode.LDC2_W | Opcode.NEW | Opcode.CHECKCAST | Opcode.INVOKEDYNAMIC |
Opcode.INVOKESTATIC | Opcode.INVOKEVIRTUAL | Opcode.INVOKEINTERFACE |
Opcode.INVOKESPECIAL | Opcode.GETSTATIC =>
codeIterator.u16bitAt(offset + 1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -78,6 +78,32 @@ class LambdaReflection private(private val classPool: ClassPool,
}
}

def lookupBootstrapMethod(constPoolIndex: Int): (CtMethod, Seq[Any]) = {
if (constPool.getTag(constPoolIndex) != ConstPool.CONST_InvokeDynamic) {
throw new SparkException(s"Unexpected index ${constPoolIndex} for bootstrap")
}
val bootstrapMethodIndex = constPool.getInvokeDynamicBootstrap(constPoolIndex)
val bootstrapMethodsAttribute = ctMethod.getDeclaringClass.getClassFile.getAttributes
.toArray.filter(_.isInstanceOf[javassist.bytecode.BootstrapMethodsAttribute])
if (bootstrapMethodsAttribute.length != 1) {
throw new SparkException(s"Multiple bootstrap methods attributes aren't supported")
}
val bootstrapMethods = bootstrapMethodsAttribute.head
.asInstanceOf[javassist.bytecode.BootstrapMethodsAttribute].getMethods
val bootstrapMethod = bootstrapMethods(bootstrapMethodIndex)
val bootstrapMethodArguments = try {
bootstrapMethod.arguments.map(lookupConstant)
} catch {
case _: Throwable =>
throw new SparkException(s"only constants are supported as bootstrap method arguments")
}
val constPoolIndexMethodref = constPool.getMethodHandleIndex(bootstrapMethod.methodRef)
val methodName = constPool.getMethodrefName(constPoolIndexMethodref)
val descriptor = constPool.getMethodrefType(constPoolIndexMethodref)
val className = constPool.getMethodrefClassName(constPoolIndexMethodref)
(classPool.getCtClass(className).getMethod(methodName, descriptor), bootstrapMethodArguments)
}

def lookupClassName(constPoolIndex: Int): String = {
if (constPool.getTag(constPoolIndex) != ConstPool.CONST_Class) {
throw new SparkException("Unexpected index for class")
Expand Down

0 comments on commit 7119ec4

Please sign in to comment.