Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update timezone test framework to support both GPU and CPU POC #9739

Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@ package com.nvidia.spark.rapids
import java.lang.reflect.InvocationTargetException
import java.time.ZoneId
import java.util.Properties

import scala.collection.JavaConverters._
import scala.sys.process._
import scala.util.Try

import ai.rapids.cudf.{Cuda, CudaException, CudaFatalException, CudfException, MemoryCleaner}
import com.nvidia.spark.rapids.filecache.{FileCache, FileCacheLocalityManager, FileCacheLocalityMsg}
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import org.apache.commons.lang3.exception.ExceptionUtils

import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskFailedReason}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -381,6 +379,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
s"Driver timezone is $driverTimezone and executor timezone is " +
s"$executorTimezone. Set executor timezone to $driverTimezone.")
}
GpuTimeZoneDB.cacheDatabase()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: it seems initTimezoneDB() is a more proper name comparing to cacheDatabase.

}

GpuCoreDumpHandler.executorInit(conf, pluginContext)
Expand Down Expand Up @@ -503,6 +502,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
}

override def shutdown(): Unit = {
GpuTimeZoneDB.shutdown()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the more of these shutdown calls we have the higher is the risk that one of them will throw and leave the remaining ones up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. FWIW, GpuTimeZoneDB.shutdown() implementation is actually within a try block. So it should fail silently and pass to the other calls.

See NVIDIA/spark-rapids-jni#1553 and GpuTimeZoneDB.java there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GpuSemaphore.shutdown()
PythonWorkerSemaphore.shutdown()
GpuDeviceManager.shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ object TimeZoneDB {
assert(inputVector.getType == DType.TIMESTAMP_DAYS)
val rowCount = inputVector.getRowCount.toInt
withResource(inputVector.copyToHost()) { input =>
withResource(HostColumnVector.builder(DType.INT64, rowCount)) { builder =>
withResource(HostColumnVector.builder(DType.TIMESTAMP_MICROSECONDS, rowCount)) { builder =>
var currRow = 0
while (currRow < rowCount) {
val origin = input.getInt(currRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,21 @@ import scala.collection.mutable
import ai.rapids.cudf.{ColumnVector, DType, HostColumnVector}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.SparkQueryCompareTestSuite
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToInstant
import org.apache.spark.sql.types._

class TimeZoneSuite extends SparkQueryCompareTestSuite {
class TimeZoneSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAll {
private val useGPU = true
private val testAllTimezones = false
private val testAllYears = false

private var zones = Seq.empty[String]

/**
* create timestamp column vector
*/
Expand Down Expand Up @@ -91,13 +100,24 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
/**
* assert timestamp result with Spark result
*/
def assertTimestampRet(actualRet: ColumnVector, sparkRet: Seq[Row]): Unit = {
def assertTimestampRet(actualRet: ColumnVector, sparkRet: Seq[Row], input: ColumnVector): Unit = {
withResource(actualRet.copyToHost()) { host =>
assert(actualRet.getRowCount == sparkRet.length)
for (i <- 0 until actualRet.getRowCount.toInt) {
val sparkInstant = sparkRet(i).getInstant(0)
val sparkMicro = sparkInstant.getEpochSecond * 1000000L + sparkInstant.getNano / 1000L
assert(host.getLong(i) == sparkMicro)
withResource(input.copyToHost()) { hostInput =>
assert(actualRet.getRowCount == sparkRet.length)
for (i <- 0 until actualRet.getRowCount.toInt) {
val sparkInstant = sparkRet(i).getInstant(0)
val sparkMicro = sparkInstant.getEpochSecond * 1000000L + sparkInstant.getNano / 1000L
if (hostInput.getType == DType.TIMESTAMP_DAYS) {
assert(host.getLong(i) == sparkMicro,
s"for ${hostInput.getInt(i)} " +
s"${microsToInstant(host.getLong(i))} != ${microsToInstant(sparkMicro)}")

} else {
assert(host.getLong(i) == sparkMicro,
s"for ${hostInput.getLong(i)} (${microsToInstant(hostInput.getLong(i))}) " +
s"${microsToInstant(host.getLong(i))} != ${microsToInstant(sparkMicro)}")
}
}
}
}
}
Expand Down Expand Up @@ -160,18 +180,24 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
TimeZoneDB.fromUtcTimestampToTimestamp(
inputCv,
ZoneId.of(zoneStr))
withResource(createColumnVector(epochSeconds)) { inputCv =>
val actualRet = if (useGPU) {
GpuTimeZoneDB.fromUtcTimestampToTimestamp(
inputCv,
ZoneId.of(zoneStr))
} else {
TimeZoneDB.fromUtcTimestampToTimestamp(
inputCv,
ZoneId.of(zoneStr))
}
withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet, inputCv)
}
}

withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet)
}
}

def testFromTimestampToUTCTimestamp(epochSeconds: Array[Long], zoneStr: String): Unit = {
def testFromTimestampToUtcTimestamp(epochSeconds: Array[Long], zoneStr: String): Unit = {
// get result from Spark
val sparkRet = withCpuSparkSession(
spark => {
Expand All @@ -189,15 +215,21 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
TimeZoneDB.fromTimestampToUtcTimestamp(
inputCv,
ZoneId.of(zoneStr))
withResource(createColumnVector(epochSeconds)) { inputCv =>
val actualRet = if (useGPU) {
GpuTimeZoneDB.fromTimestampToUtcTimestamp(
inputCv,
ZoneId.of(zoneStr))
} else {
TimeZoneDB.fromTimestampToUtcTimestamp(
inputCv,
ZoneId.of(zoneStr))
}
withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet, inputCv)
}
}

withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet)
}
}

def testFromTimestampToDate(epochSeconds: Array[Long], zoneStr: String): Unit = {
Expand Down Expand Up @@ -245,15 +277,15 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createDateColumnVector(epochDays)) { inputCv =>
TimeZoneDB.fromDateToTimestamp(
withResource(createDateColumnVector(epochDays)) { inputCv =>
val actualRet = TimeZoneDB.fromDateToTimestamp(
inputCv,
ZoneId.of(zoneStr))
withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet, inputCv)
}
}

withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet)
}
}

def selectWithRepeatZones: Seq[String] = {
Expand All @@ -266,36 +298,72 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
repeatZones.slice(0, 2) ++ mustZones
}

def selectNonRepeatZones: Seq[String] = {
def selectTimeZones: Seq[String] = {
val mustZones = Array[String]("Asia/Shanghai", "America/Sao_Paulo")
val nonRepeatZones = ZoneId.getAvailableZoneIds.asScala.toList.filter { z =>
val rules = ZoneId.of(z).getRules
// remove this line when we support repeat rules
(rules.isFixedOffset || rules.getTransitionRules.isEmpty) && !mustZones.contains(z)
if (testAllTimezones) {
val nonRepeatZones = ZoneId.getAvailableZoneIds.asScala.toList.filter { z =>
val rules = ZoneId.of(z).getRules
// remove this line when we support repeat rules
(rules.isFixedOffset || rules.getTransitionRules.isEmpty) && !mustZones.contains(z)
}
scala.util.Random.shuffle(nonRepeatZones)
nonRepeatZones.slice(0, 2) ++ mustZones
} else {
mustZones
}
scala.util.Random.shuffle(nonRepeatZones)
nonRepeatZones.slice(0, 2) ++ mustZones
}

test("test all time zones") {
assume(false,
"It's time consuming for test all time zones, by default it's disabled")
override def beforeAll(): Unit = {
zones = selectTimeZones
if (useGPU) {
withGpuSparkSession(_ => { })
}
}

test("test timestamp to utc timestamp") {
for (zoneStr <- zones) {
// iterate years
val startYear = if (testAllYears) 1 else 1899
val endYear = if (testAllYears) 9999 else 2030
for (year <- startYear until endYear by 7) {
val epochSeconds = getEpochSeconds(year, year + 1)
testFromTimestampToUtcTimestamp(epochSeconds, zoneStr)
}
}
}

val zones = selectNonRepeatZones
// iterate zones
test("test utc timestamp to timestamp") {
for (zoneStr <- zones) {
// iterate years
val startYear = 1
val endYear = 9999
val startYear = if (testAllYears) 1 else 1899
val endYear = if (testAllYears) 9999 else 2030
for (year <- startYear until endYear by 7) {
val epochSeconds = getEpochSeconds(year, year + 1)
testFromUtcTimeStampToTimestamp(epochSeconds, zoneStr)
testFromTimestampToUTCTimestamp(epochSeconds, zoneStr)
testFromTimestampToDate(epochSeconds, zoneStr)
}
}
}

test("test timestamp to date") {
for (zoneStr <- zones) {
// iterate years
val startYear = if (testAllYears) 1 else 1899
val endYear = if (testAllYears) 9999 else 2030
for (year <- startYear until endYear by 7) {
val epochSeconds = getEpochSeconds(year, year + 1)
testFromTimestampToDate(epochSeconds, zoneStr)
}
}
}

test("test date to timestamp") {
for (zoneStr <- zones) {
// iterate years
val startYear = if (testAllYears) 1 else 1899
val endYear = if (testAllYears) 9999 else 2030
val epochDays = getEpochDays(startYear, endYear)
testFromDateToTimestamp(epochDays, zoneStr)
}
}

}
Loading