Skip to content

Commit

Permalink
Fix Java OOM in non-UTC case with lots of xfail (#9944) (#10007)
Browse files Browse the repository at this point in the history
Signed-off-by: Ferdinand Xu <[email protected]>
Co-authored-by: Ferdinand Xu <[email protected]>
  • Loading branch information
res-life and winningsix authored Dec 11, 2023
1 parent a5c37fb commit d6bc300
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
41 changes: 24 additions & 17 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,24 +351,31 @@ def assert_gpu_fallback_write(write_func,
jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.startCapture()
gpu_start = time.time()
gpu_path = base_path + '/GPU'
with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf)
gpu_end = time.time()
jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name_list, 10000)
print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format(
gpu_end - gpu_start, cpu_end - cpu_start))

(cpu_bring_back, cpu_collect_type) = _prep_func_for_compare(
lambda spark: read_func(spark, cpu_path), 'COLLECT')
(gpu_bring_back, gpu_collect_type) = _prep_func_for_compare(
lambda spark: read_func(spark, gpu_path), 'COLLECT')

from_cpu = with_cpu_session(cpu_bring_back, conf=conf)
from_gpu = with_cpu_session(gpu_bring_back, conf=conf)
if should_sort_locally():
from_cpu.sort(key=_RowCmp)
from_gpu.sort(key=_RowCmp)
try:
with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf)
gpu_end = time.time()
jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name_list, 10000)
print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format(
gpu_end - gpu_start, cpu_end - cpu_start))

(cpu_bring_back, cpu_collect_type) = _prep_func_for_compare(
lambda spark: read_func(spark, cpu_path), 'COLLECT')
(gpu_bring_back, gpu_collect_type) = _prep_func_for_compare(
lambda spark: read_func(spark, gpu_path), 'COLLECT')

from_cpu = with_cpu_session(cpu_bring_back, conf=conf)
from_gpu = with_cpu_session(gpu_bring_back, conf=conf)
if should_sort_locally():
from_cpu.sort(key=_RowCmp)
from_gpu.sort(key=_RowCmp)

assert_equal(from_cpu, from_gpu)
finally:
# Ensure `shouldCapture` state is restored. This may happen when GpuPlan is failed to be executed,
# then `shouldCapture` state is failed to restore in `assertCapturedAndGpuFellBack` method.
# This mostly happen within a xfail case where error may be ignored.
jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.endCapture()

assert_equal(from_cpu, from_gpu)

def assert_cpu_and_gpu_are_equal_collect_with_capture(func,
exist_classes='',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ trait ExecutionPlanCaptureCallbackBase {
def captureIfNeeded(qe: QueryExecution): Unit
def startCapture(): Unit
def startCapture(timeoutMillis: Long): Unit
def endCapture(): Unit
def endCapture(timeoutMillis: Long): Unit
def getResultsWithTimeout(timeoutMs: Long = 10000): Array[SparkPlan]
def extractExecutedPlan(plan: SparkPlan): SparkPlan
def assertContains(gpuPlan: SparkPlan, className: String): Unit
Expand Down Expand Up @@ -57,6 +59,10 @@ object ExecutionPlanCaptureCallback extends ExecutionPlanCaptureCallbackBase {
override def startCapture(timeoutMillis: Long): Unit =
impl.startCapture(timeoutMillis)

override def endCapture(): Unit = impl.endCapture()

override def endCapture(timeoutMillis: Long): Unit = impl.endCapture(timeoutMillis)

override def getResultsWithTimeout(timeoutMs: Long = 10000): Array[SparkPlan] =
impl.getResultsWithTimeout(timeoutMs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ class ShimmedExecutionPlanCaptureCallbackImpl extends ExecutionPlanCaptureCallba
}
}

override def endCapture(): Unit = endCapture(10000)

override def endCapture(timeoutMillis: Long): Unit = synchronized {
if (shouldCapture) {
shouldCapture = false
execPlans.clear()
}
}

override def getResultsWithTimeout(timeoutMs: Long = 10000): Array[SparkPlan] = {
try {
val spark = SparkSession.active
Expand Down

0 comments on commit d6bc300

Please sign in to comment.