Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update timezone test framework to support both GPU and CPU POC #9739

Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ object TimeZoneDB {
assert(inputVector.getType == DType.TIMESTAMP_DAYS)
val rowCount = inputVector.getRowCount.toInt
withResource(inputVector.copyToHost()) { input =>
withResource(HostColumnVector.builder(DType.INT64, rowCount)) { builder =>
withResource(HostColumnVector.builder(DType.TIMESTAMP_MICROSECONDS, rowCount)) { builder =>
var currRow = 0
while (currRow < rowCount) {
if (input.isNull(currRow)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,22 @@ import scala.collection.mutable
import ai.rapids.cudf.{ColumnVector, DType, HostColumnVector}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.SparkQueryCompareTestSuite
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToInstant
import org.apache.spark.sql.rapids.TimeZoneDB
import org.apache.spark.sql.types._

class TimeZoneSuite extends SparkQueryCompareTestSuite {
class TimeZoneSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAll {
private val useGPU = true
private val testAllTimezones = false
private val testAllYears = false

private var zones = Seq.empty[String]

/**
* create timestamp column vector
*/
Expand Down Expand Up @@ -92,13 +101,24 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
/**
* assert timestamp result with Spark result
*/
def assertTimestampRet(actualRet: ColumnVector, sparkRet: Seq[Row]): Unit = {
def assertTimestampRet(actualRet: ColumnVector, sparkRet: Seq[Row], input: ColumnVector): Unit = {
withResource(actualRet.copyToHost()) { host =>
assert(actualRet.getRowCount == sparkRet.length)
for (i <- 0 until actualRet.getRowCount.toInt) {
val sparkInstant = sparkRet(i).getInstant(0)
val sparkMicro = sparkInstant.getEpochSecond * 1000000L + sparkInstant.getNano / 1000L
assert(host.getLong(i) == sparkMicro)
withResource(input.copyToHost()) { hostInput =>
assert(actualRet.getRowCount == sparkRet.length)
for (i <- 0 until actualRet.getRowCount.toInt) {
val sparkInstant = sparkRet(i).getInstant(0)
val sparkMicro = sparkInstant.getEpochSecond * 1000000L + sparkInstant.getNano / 1000L
if (hostInput.getType == DType.TIMESTAMP_DAYS) {
assert(host.getLong(i) == sparkMicro,
s"for ${hostInput.getInt(i)} " +
s"${microsToInstant(host.getLong(i))} != ${microsToInstant(sparkMicro)}")

} else {
assert(host.getLong(i) == sparkMicro,
s"for ${hostInput.getLong(i)} (${microsToInstant(hostInput.getLong(i))}) " +
s"${microsToInstant(host.getLong(i))} != ${microsToInstant(sparkMicro)}")
}
}
}
}
}
Expand Down Expand Up @@ -161,18 +181,24 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
TimeZoneDB.fromUtcTimestampToTimestamp(
inputCv,
ZoneId.of(zoneStr))
withResource(createColumnVector(epochSeconds)) { inputCv =>
val actualRet = if (useGPU) {
GpuTimeZoneDB.fromUtcTimestampToTimestamp(
inputCv,
ZoneId.of(zoneStr))
} else {
TimeZoneDB.fromUtcTimestampToTimestamp(
inputCv,
ZoneId.of(zoneStr))
}
withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet, inputCv)
}
}

withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet)
}
}

def testFromTimestampToUTCTimestamp(epochSeconds: Array[Long], zoneStr: String): Unit = {
def testFromTimestampToUtcTimestamp(epochSeconds: Array[Long], zoneStr: String): Unit = {
// get result from Spark
val sparkRet = withCpuSparkSession(
spark => {
Expand All @@ -190,15 +216,21 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
TimeZoneDB.fromTimestampToUtcTimestamp(
inputCv,
ZoneId.of(zoneStr))
withResource(createColumnVector(epochSeconds)) { inputCv =>
val actualRet = if (useGPU) {
GpuTimeZoneDB.fromTimestampToUtcTimestamp(
inputCv,
ZoneId.of(zoneStr))
} else {
TimeZoneDB.fromTimestampToUtcTimestamp(
inputCv,
ZoneId.of(zoneStr))
}
withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet, inputCv)
}
}

withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet)
}
}

def testFromTimestampToDate(epochSeconds: Array[Long], zoneStr: String): Unit = {
Expand Down Expand Up @@ -246,15 +278,15 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createDateColumnVector(epochDays)) { inputCv =>
TimeZoneDB.fromDateToTimestamp(
withResource(createDateColumnVector(epochDays)) { inputCv =>
val actualRet = TimeZoneDB.fromDateToTimestamp(
inputCv,
ZoneId.of(zoneStr))
withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet, inputCv)
}
}

withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet)
}
}

def selectWithRepeatZones: Seq[String] = {
Expand All @@ -267,36 +299,78 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
repeatZones.slice(0, 2) ++ mustZones
}

def selectNonRepeatZones: Seq[String] = {
def selectTimeZones: Seq[String] = {
val mustZones = Array[String]("Asia/Shanghai", "America/Sao_Paulo")
val nonRepeatZones = ZoneId.getAvailableZoneIds.asScala.toList.filter { z =>
val rules = ZoneId.of(z).getRules
// remove this line when we support repeat rules
(rules.isFixedOffset || rules.getTransitionRules.isEmpty) && !mustZones.contains(z)
if (testAllTimezones) {
val nonRepeatZones = ZoneId.getAvailableZoneIds.asScala.toList.filter { z =>
val rules = ZoneId.of(z).getRules
// remove this line when we support repeat rules
(rules.isFixedOffset || rules.getTransitionRules.isEmpty) && !mustZones.contains(z)
}
scala.util.Random.shuffle(nonRepeatZones)
nonRepeatZones.slice(0, 2) ++ mustZones
} else {
mustZones
}
}

override def beforeAll(): Unit = {
zones = selectTimeZones
if (useGPU) {
GpuTimeZoneDB.cacheDatabase()
}
scala.util.Random.shuffle(nonRepeatZones)
nonRepeatZones.slice(0, 2) ++ mustZones
}

test("test all time zones") {
assume(false,
"It's time consuming for test all time zones, by default it's disabled")
override def afterAll(): Unit = {
if (useGPU) {
GpuTimeZoneDB.shutdown()
}
}

val zones = selectNonRepeatZones
// iterate zones
test("test timestamp to utc timestamp") {
for (zoneStr <- zones) {
// iterate years
val startYear = 1
val endYear = 9999
val startYear = if (testAllYears) 1 else 1899
val endYear = if (testAllYears) 9999 else 2030
for (year <- startYear until endYear by 7) {
val epochSeconds = getEpochSeconds(year, year + 1)
testFromTimestampToUtcTimestamp(epochSeconds, zoneStr)
}
}
}

test("test utc timestamp to timestamp") {
for (zoneStr <- zones) {
// iterate years
val startYear = if (testAllYears) 1 else 1899
val endYear = if (testAllYears) 9999 else 2030
for (year <- startYear until endYear by 7) {
val epochSeconds = getEpochSeconds(year, year + 1)
testFromUtcTimeStampToTimestamp(epochSeconds, zoneStr)
testFromTimestampToUTCTimestamp(epochSeconds, zoneStr)
testFromTimestampToDate(epochSeconds, zoneStr)
}
}
}

test("test timestamp to date") {
for (zoneStr <- zones) {
// iterate years
val startYear = if (testAllYears) 1 else 1899
val endYear = if (testAllYears) 9999 else 2030
for (year <- startYear until endYear by 7) {
val epochSeconds = getEpochSeconds(year, year + 1)
testFromTimestampToDate(epochSeconds, zoneStr)
}
}
}

test("test date to timestamp") {
for (zoneStr <- zones) {
// iterate years
val startYear = if (testAllYears) 1 else 1899
val endYear = if (testAllYears) 9999 else 2030
val epochDays = getEpochDays(startYear, endYear)
testFromDateToTimestamp(epochDays, zoneStr)
}
}

}