diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 706bcfa30f..cfbfb025be 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -153,6 +153,7 @@ add_library( src/CastStringJni.cpp src/DateTimeRebaseJni.cpp src/DecimalUtilsJni.cpp + src/GpuTimeZoneDBJni.cpp src/HashJni.cpp src/HistogramJni.cpp src/MapUtilsJni.cpp @@ -172,6 +173,7 @@ add_library( src/murmur_hash.cu src/parse_uri.cu src/row_conversion.cu + src/timezones.cu src/utilities.cu src/xxhash64.cu src/zorder.cu diff --git a/src/main/cpp/src/GpuTimeZoneDBJni.cpp b/src/main/cpp/src/GpuTimeZoneDBJni.cpp new file mode 100644 index 0000000000..55639853de --- /dev/null +++ b/src/main/cpp/src/GpuTimeZoneDBJni.cpp @@ -0,0 +1,53 @@ +/* Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cudf_jni_apis.hpp" +#include "timezones.hpp" + +extern "C" { + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_GpuTimeZoneDB_convertTimestampColumnToUTC( + JNIEnv* env, jclass, jlong input_handle, jlong transitions_handle, jint tz_index) +{ + JNI_NULL_CHECK(env, input_handle, "column is null", 0); + JNI_NULL_CHECK(env, transitions_handle, "column is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const input = reinterpret_cast(input_handle); + auto const transitions = reinterpret_cast(transitions_handle); + auto const index = static_cast(tz_index); + return cudf::jni::ptr_as_jlong( + spark_rapids_jni::convert_timestamp_to_utc(*input, *transitions, index).release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL +Java_com_nvidia_spark_rapids_jni_GpuTimeZoneDB_convertUTCTimestampColumnToTimeZone( + JNIEnv* env, jclass, jlong input_handle, jlong transitions_handle, jint tz_index) +{ + JNI_NULL_CHECK(env, input_handle, "column is null", 0); + JNI_NULL_CHECK(env, transitions_handle, "column is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const input = reinterpret_cast(input_handle); + auto const transitions = reinterpret_cast(transitions_handle); + auto const index = static_cast(tz_index); + return cudf::jni::ptr_as_jlong( + spark_rapids_jni::convert_utc_timestamp_to_timezone(*input, *transitions, index).release()); + } + CATCH_STD(env, 0); +} +} diff --git a/src/main/cpp/src/timezones.cu b/src/main/cpp/src/timezones.cu new file mode 100644 index 0000000000..43fff55b3b --- /dev/null +++ b/src/main/cpp/src/timezones.cu @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "timezones.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +using column = cudf::column; +using column_device_view = cudf::column_device_view; +using column_view = cudf::column_view; +using lists_column_device_view = cudf::detail::lists_column_device_view; +using size_type = cudf::size_type; +using struct_view = cudf::struct_view; +using table_view = cudf::table_view; + +namespace { + +// This device functor uses a binary search to find the instant of the transition +// to find the right offset to do the transition. +// To transition to UTC: do a binary search on the tzInstant child column and subtract +// the offset. +// To transition from UTC: do a binary search on the utcInstant child column and add +// the offset. +template +struct convert_timestamp_tz_functor { + using duration_type = typename timestamp_type::duration; + + // The list column of transitions to figure out the correct offset + // to adjust the timestamp. The type of the values in this column is + // LIST>. + lists_column_device_view const transitions; + // the index of the specified zone id in the transitions table + size_type const tz_index; + // whether we are converting to UTC or converting to the timezone + bool const to_utc; + + /** + * @brief Convert the timestamp value to either UTC or a specified timezone + * @param timestamp input timestamp + * + */ + __device__ timestamp_type operator()(timestamp_type const& timestamp) const + { + auto const utc_instants = transitions.child().child(0); + auto const tz_instants = transitions.child().child(1); + auto const utc_offsets = transitions.child().child(2); + + auto const epoch_seconds = static_cast( + cuda::std::chrono::duration_cast(timestamp.time_since_epoch()).count()); + auto const tz_transitions = cudf::list_device_view{transitions, tz_index}; + auto const list_size = tz_transitions.size(); + + auto const transition_times = cudf::device_span( + (to_utc ? tz_instants : utc_instants).data() + tz_transitions.element_offset(0), + static_cast(list_size)); + + auto const it = thrust::upper_bound( + thrust::seq, transition_times.begin(), transition_times.end(), epoch_seconds); + auto const idx = static_cast(thrust::distance(transition_times.begin(), it)); + auto const list_offset = tz_transitions.element_offset(idx - 1); + auto const utc_offset = cuda::std::chrono::duration_cast( + cudf::duration_s{static_cast(utc_offsets.element(list_offset))}); + return to_utc ? timestamp - utc_offset : timestamp + utc_offset; + } +}; + +template +auto convert_timestamp_tz(column_view const& input, + table_view const& transitions, + size_type tz_index, + bool to_utc, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // get the fixed transitions + auto const ft_cdv_ptr = column_device_view::create(transitions.column(0), stream); + auto const fixed_transitions = lists_column_device_view{*ft_cdv_ptr}; + + auto results = cudf::make_timestamp_column(input.type(), + input.size(), + cudf::detail::copy_bitmask(input, stream, mr), + input.null_count(), + stream, + mr); + + thrust::transform( + rmm::exec_policy(stream), + input.begin(), + input.end(), + results->mutable_view().begin(), + convert_timestamp_tz_functor{fixed_transitions, tz_index, to_utc}); + + return results; +} + +} // namespace + +namespace spark_rapids_jni { + +std::unique_ptr convert_timestamp(column_view const& input, + table_view const& transitions, + size_type tz_index, + bool to_utc, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto const type = input.type().id(); + + switch (type) { + case cudf::type_id::TIMESTAMP_SECONDS: + return convert_timestamp_tz( + input, transitions, tz_index, to_utc, stream, mr); + case cudf::type_id::TIMESTAMP_MILLISECONDS: + return convert_timestamp_tz( + input, transitions, tz_index, to_utc, stream, mr); + case cudf::type_id::TIMESTAMP_MICROSECONDS: + return convert_timestamp_tz( + input, transitions, tz_index, to_utc, stream, mr); + default: CUDF_FAIL("Unsupported timestamp unit for timezone conversion"); + } +} + +std::unique_ptr convert_timestamp_to_utc(column_view const& input, + table_view const& transitions, + size_type tz_index, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + return convert_timestamp(input, transitions, tz_index, true, stream, mr); +} + +std::unique_ptr convert_utc_timestamp_to_timezone(column_view const& input, + table_view const& transitions, + size_type tz_index, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + return convert_timestamp(input, transitions, tz_index, false, stream, mr); +} + +} // namespace spark_rapids_jni diff --git a/src/main/cpp/src/timezones.hpp b/src/main/cpp/src/timezones.hpp new file mode 100644 index 0000000000..c7ab3c0cc8 --- /dev/null +++ b/src/main/cpp/src/timezones.hpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include + +#include + +namespace spark_rapids_jni { + +/** + * @brief Convert input column timestamps in current timezone to UTC + * + * The transition rules are in enclosed in a table, and the index corresponding to the + * current timezone is given. + * + * This method is the inverse of convert_utc_timestamp_to_timezone. + * + * @param input the column of input timestamps in the current timezone + * @param transitions the table of transitions for all timezones + * @param tz_index the index of the row in `transitions` corresponding to the current timezone + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned timestamp column's memory + */ +std::unique_ptr convert_timestamp_to_utc( + cudf::column_view const& input, + cudf::table_view const& transitions, + cudf::size_type tz_index, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @brief Convert input column timestamps in UTC to specified timezone + * + * The transition rules are in enclosed in a table, and the index corresponding to the + * specific timezone is given. + * + * This method is the inverse of convert_timestamp_to_utc. + * + * @param input the column of input timestamps in UTC + * @param transitions the table of transitions for all timezones + * @param tz_index the index of the row in `transitions` corresponding to the specific timezone + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned timestamp column's memory + */ +std::unique_ptr convert_utc_timestamp_to_timezone( + cudf::column_view const& input, + cudf::table_view const& transitions, + cudf::size_type tz_index, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +} // namespace spark_rapids_jni \ No newline at end of file diff --git a/src/main/cpp/tests/CMakeLists.txt b/src/main/cpp/tests/CMakeLists.txt index fcc956bc34..5e16398145 100644 --- a/src/main/cpp/tests/CMakeLists.txt +++ b/src/main/cpp/tests/CMakeLists.txt @@ -63,6 +63,9 @@ ConfigureTest(HASH ConfigureTest(BLOOM_FILTER bloom_filter.cu) +ConfigureTest(TIMEZONES + timezones.cpp) + ConfigureTest(UTILITIES utilities.cpp) diff --git a/src/main/cpp/tests/timezones.cpp b/src/main/cpp/tests/timezones.cpp new file mode 100644 index 0000000000..9801a3c0a4 --- /dev/null +++ b/src/main/cpp/tests/timezones.cpp @@ -0,0 +1,262 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "timezones.hpp" + +#include +#include +#include +#include +#include + +#include + +#include + +auto constexpr int64_min = std::numeric_limits::min(); + +using int32_col = cudf::test::fixed_width_column_wrapper; +using int64_col = cudf::test::fixed_width_column_wrapper; + +using seconds_col = + cudf::test::fixed_width_column_wrapper; + +using millis_col = + cudf::test::fixed_width_column_wrapper; + +using micros_col = + cudf::test::fixed_width_column_wrapper; + +class TimeZoneTest : public cudf::test::BaseFixture { + protected: + void SetUp() override { transitions = make_transitions_table(); } + std::unique_ptr transitions; + + private: + std::unique_ptr make_transitions_table() + { + auto instants_from_utc_col = int64_col({int64_min, + int64_min, + -1585904400L, + -933667200L, + -922093200L, + -908870400L, + -888829200L, + -650019600L, + 515527200L, + 558464400L, + 684867600L}); + auto instants_to_utc_col = int64_col({int64_min, + int64_min, + -1585904400L, + -933634800L, + -922064400L, + -908838000L, + -888796800L, + -649990800L, + 515559600L, + 558493200L, + 684896400L}); + auto utc_offsets_col = + int32_col({18000, 29143, 28800, 32400, 28800, 32400, 28800, 28800, 32400, 28800, 28800}); + auto struct_column = cudf::test::structs_column_wrapper{ + {instants_from_utc_col, instants_to_utc_col, utc_offsets_col}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}; + auto offsets = cudf::test::fixed_width_column_wrapper{0, 1, 11}; + auto list_nullmask = std::vector(1, 1); + auto [null_mask, null_count] = + cudf::test::detail::make_null_mask(list_nullmask.begin(), list_nullmask.end()); + auto list_column = cudf::make_lists_column( + 2, offsets.release(), struct_column.release(), null_count, std::move(null_mask)); + auto columns = std::vector>{}; + columns.push_back(std::move(list_column)); + return std::make_unique(std::move(columns)); + } +}; + +TEST_F(TimeZoneTest, ConvertToUTCSeconds) +{ + auto const ts_col = seconds_col{ + -1262260800L, + -908838000L, + -908840700L, + -888800400L, + -888799500L, + -888796800L, + 0L, + 1699566167L, + 568036800L, + }; + // check the converted to utc version + auto const expected = seconds_col{-1262289600L, + -908870400L, + -908869500L, + -888832800L, + -888831900L, + -888825600L, + -28800L, + 1699537367L, + 568008000L}; + auto const actual = spark_rapids_jni::convert_timestamp_to_utc( + ts_col, *transitions, 1, cudf::get_default_stream(), rmm::mr::get_current_device_resource()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); +} + +TEST_F(TimeZoneTest, ConvertToUTCMilliseconds) +{ + auto const ts_col = millis_col{ + -1262260800000L, + -908838000000L, + -908840700000L, + -888800400000L, + -888799500000L, + -888796800000L, + 0L, + 1699571634312L, + 568036800000L, + }; + // check the converted to utc version + auto const expected = millis_col{-1262289600000L, + -908870400000L, + -908869500000L, + -888832800000L, + -888831900000L, + -888825600000L, + -28800000L, + 1699542834312L, + 568008000000L}; + auto const actual = spark_rapids_jni::convert_timestamp_to_utc( + ts_col, *transitions, 1, cudf::get_default_stream(), rmm::mr::get_current_device_resource()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); +} + +TEST_F(TimeZoneTest, ConvertToUTCMicroseconds) +{ + auto const ts_col = micros_col{ + -1262260800000000L, + -908838000000000L, + -908840700000000L, + -888800400000000L, + -888799500000000L, + -888796800000000L, + 0L, + 1699571634312000L, + 568036800000000L, + }; + // check the converted to utc version + auto const expected = micros_col{-1262289600000000L, + -908870400000000L, + -908869500000000L, + -888832800000000L, + -888831900000000L, + -888825600000000L, + -28800000000L, + 1699542834312000L, + 568008000000000L}; + auto const actual = spark_rapids_jni::convert_timestamp_to_utc( + ts_col, *transitions, 1, cudf::get_default_stream(), rmm::mr::get_current_device_resource()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); +} + +TEST_F(TimeZoneTest, ConvertFromUTCSeconds) +{ + auto const ts_col = seconds_col{-1262289600L, + -908870400L, + -908869500L, + -888832800L, + -888831900L, + -888825600L, + 0L, + 1699537367L, + 568008000L}; + // check the converted to utc version + auto const expected = seconds_col{ + -1262260800L, + -908838000L, + -908837100L, + -888800400L, + -888799500L, + -888796800L, + 28800L, + 1699566167L, + 568036800L, + }; + auto const actual = spark_rapids_jni::convert_utc_timestamp_to_timezone( + ts_col, *transitions, 1, cudf::get_default_stream(), rmm::mr::get_current_device_resource()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); +} + +TEST_F(TimeZoneTest, ConvertFromUTCMilliseconds) +{ + auto const ts_col = millis_col{-1262289600000L, + -908870400000L, + -908869500000L, + -888832800000L, + -888831900000L, + -888825600000L, + 0L, + 1699542834312L, + 568008000000L}; + // check the converted to timezone version + auto const expected = millis_col{ + -1262260800000L, + -908838000000L, + -908837100000L, + -888800400000L, + -888799500000L, + -888796800000L, + 28800000L, + 1699571634312L, + 568036800000L, + }; + auto const actual = spark_rapids_jni::convert_utc_timestamp_to_timezone( + ts_col, *transitions, 1, cudf::get_default_stream(), rmm::mr::get_current_device_resource()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); +} + +TEST_F(TimeZoneTest, ConvertFromUTCMicroseconds) +{ + auto const ts_col = micros_col{-1262289600000000L, + -908870400000000L, + -908869500000000L, + -888832800000000L, + -888831900000000L, + -888825600000000L, + 0L, + 1699542834312000L, + 568008000000000L}; + // check the converted to timezone version + auto const expected = micros_col{ + -1262260800000000L, + -908838000000000L, + -908837100000000L, + -888800400000000L, + -888799500000000L, + -888796800000000L, + 28800000000L, + 1699571634312000L, + 568036800000000L, + }; + auto const actual = spark_rapids_jni::convert_utc_timestamp_to_timezone( + ts_col, *transitions, 1, cudf::get_default_stream(), rmm::mr::get_current_device_resource()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); +} \ No newline at end of file diff --git a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java new file mode 100644 index 0000000000..0eb56100e4 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java @@ -0,0 +1,317 @@ +/* +* Copyright (c) 2023, NVIDIA CORPORATION. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package com.nvidia.spark.rapids.jni; + +import java.time.Instant; +import java.time.ZoneId; +import java.time.zone.ZoneOffsetTransition; +import java.time.zone.ZoneRules; +import java.time.zone.ZoneRulesException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.*; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.Table; + +public class GpuTimeZoneDB { + + public static final int TIMEOUT_SECS = 300; + + + // For the timezone database, we store the transitions in a ColumnVector that is a list of + // structs. The type of this column vector is: + // LIST> + private CompletableFuture> zoneIdToTableFuture; + private CompletableFuture fixedTransitionsFuture; + + private boolean closed = false; + + GpuTimeZoneDB() { + zoneIdToTableFuture = new CompletableFuture<>(); + fixedTransitionsFuture = new CompletableFuture<>(); + } + + private static GpuTimeZoneDB instance = new GpuTimeZoneDB(); + // This method is default visibility for testing purposes only. The instance will be never be exposed publicly + // for this class. + static GpuTimeZoneDB getInstance() { + return instance; + } + + /** + * Start to cache the database. This should be called on startup of an executor. It should start + * to cache the data on the CPU in a background thread. It should return immediately and allow the + * other APIs to be called. Depending on what we want to do we can have the other APIs block + * until this is done caching, or we can have private APIs that would let us load and use specific + * parts of the database. I prefer the former solution at least until we see a performance hit + * where we are waiting on the database to finish loading. + */ + public static void cacheDatabase() { + synchronized (instance) { + if (!instance.isLoaded()) { + Executor executor = Executors.newSingleThreadExecutor( + new ThreadFactory() { + private ThreadFactory defaultFactory = Executors.defaultThreadFactory(); + + @Override + public Thread newThread(Runnable r) { + Thread thread = defaultFactory.newThread(r); + thread.setName("gpu-timezone-database-0"); + thread.setDaemon(true); + return thread; + } + }); + instance.loadData(executor); + } + } + } + + + public static void shutdown() { + if (instance.isLoaded()) { + instance.close(); + } + } + + public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneId currentTimeZone) { + // TODO: Remove this check when all timezones are supported + // (See https://github.com/NVIDIA/spark-rapids/issues/6840) + if (!isSupportedTimeZone(currentTimeZone)) { + throw new IllegalArgumentException(String.format("Unsupported timezone: %s", + currentTimeZone.toString())); + } + if (!instance.isLoaded()) { + cacheDatabase(); // lazy load the database + } + Integer tzIndex = instance.getZoneIDMap().get(currentTimeZone.normalized().toString()); + Table transitions = instance.getTransitions(); + ColumnVector result = new ColumnVector(convertTimestampColumnToUTC(input.getNativeView(), + transitions.getNativeView(), tzIndex)); + transitions.close(); + return result; + } + + public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneId desiredTimeZone) { + // TODO: Remove this check when all timezones are supported + // (See https://github.com/NVIDIA/spark-rapids/issues/6840) + if (!isSupportedTimeZone(desiredTimeZone)) { + throw new IllegalArgumentException(String.format("Unsupported timezone: %s", + desiredTimeZone.toString())); + } + if (!instance.isLoaded()) { + cacheDatabase(); // lazy load the database + } + Integer tzIndex = instance.getZoneIDMap().get(desiredTimeZone.normalized().toString()); + Table transitions = instance.getTransitions(); + ColumnVector result = new ColumnVector(convertUTCTimestampColumnToTimeZone(input.getNativeView(), + transitions.getNativeView(), tzIndex)); + transitions.close(); + return result; + } + + // TODO: Deprecate this API when we support all timezones + // (See https://github.com/NVIDIA/spark-rapids/issues/6840) + public static boolean isSupportedTimeZone(ZoneId desiredTimeZone) { + return desiredTimeZone != null && + (desiredTimeZone.getRules().isFixedOffset() || + desiredTimeZone.getRules().getTransitionRules().isEmpty()); + } + + public static boolean isSupportedTimeZone(String zoneId) { + try { + return isSupportedTimeZone(getZoneId(zoneId)); + } catch (ZoneRulesException e) { + return false; + } + } + + // Ported from Spark. Used to format time zone ID string with (+|-)h:mm and (+|-)hh:m + public static ZoneId getZoneId(String timeZoneId) { + String formattedZoneId = timeZoneId + // To support the (+|-)h:mm format because it was supported before Spark 3.0. + .replaceFirst("(\\+|\\-)(\\d):", "$10$2:") + // To support the (+|-)hh:m format because it was supported before Spark 3.0. + .replaceFirst("(\\+|\\-)(\\d\\d):(\\d)$", "$1$2:0$3"); + return ZoneId.of(formattedZoneId, ZoneId.SHORT_IDS); + } + + private boolean isLoaded() { + return zoneIdToTableFuture.isDone(); + } + + private void loadData(Executor executor) throws IllegalStateException { + // Start loading the data in separate thread and return + try { + executor.execute(this::doLoadData); + } catch (RejectedExecutionException e) { + throw new IllegalStateException(e); + } + } + + @SuppressWarnings("unchecked") + private void doLoadData() { + synchronized (this) { + try { + Map zoneIdToTable = new HashMap<>(); + List> masterTransitions = new ArrayList<>(); + for (String tzId : TimeZone.getAvailableIDs()) { + ZoneId zoneId; + try { + zoneId = ZoneId.of(tzId).normalized(); // we use the normalized form to dedupe + } catch (ZoneRulesException e) { + // Sometimes the list of getAvailableIDs() is one of the 3-letter abbreviations, however, + // this use is deprecated due to ambiguity reasons (same abbrevation can be used for + // multiple time zones). These are not supported by ZoneId.of(...) directly here. + continue; + } + ZoneRules zoneRules = zoneId.getRules(); + // Filter by non-repeating rules + if (!zoneRules.isFixedOffset() && !zoneRules.getTransitionRules().isEmpty()) { + continue; + } + if (!zoneIdToTable.containsKey(zoneId.getId())) { + List transitions = zoneRules.getTransitions(); + int idx = masterTransitions.size(); + List data = new ArrayList<>(); + if (zoneRules.isFixedOffset()) { + data.add( + new HostColumnVector.StructData(Long.MIN_VALUE, Long.MIN_VALUE, + zoneRules.getOffset(Instant.now()).getTotalSeconds()) + ); + } else { + // Capture the first official offset (before any transition) using Long min + ZoneOffsetTransition first = transitions.get(0); + data.add( + new HostColumnVector.StructData(Long.MIN_VALUE, Long.MIN_VALUE, + first.getOffsetBefore().getTotalSeconds()) + ); + transitions.forEach(t -> { + // Whether transition is an overlap vs gap. + // In Spark: + // if it's a gap, then we use the offset after *on* the instant + // If it's an overlap, then there are 2 sets of valid timestamps in that are overlapping + // So, for the transition to UTC, you need to compare to instant + {offset before} + // The time math still uses {offset after} + if (t.isGap()) { + data.add( + new HostColumnVector.StructData( + t.getInstant().getEpochSecond(), + t.getInstant().getEpochSecond() + t.getOffsetAfter().getTotalSeconds(), + t.getOffsetAfter().getTotalSeconds()) + ); + } else { + data.add( + new HostColumnVector.StructData( + t.getInstant().getEpochSecond(), + t.getInstant().getEpochSecond() + t.getOffsetBefore().getTotalSeconds(), + t.getOffsetAfter().getTotalSeconds()) + ); + } + }); + } + masterTransitions.add(data); + zoneIdToTable.put(zoneId.getId(), idx); + } + } + HostColumnVector.DataType childType = new HostColumnVector.StructType(false, + new HostColumnVector.BasicType(false, DType.INT64), + new HostColumnVector.BasicType(false, DType.INT64), + new HostColumnVector.BasicType(false, DType.INT32)); + HostColumnVector.DataType resultType = + new HostColumnVector.ListType(false, childType); + HostColumnVector fixedTransitions = HostColumnVector.fromLists(resultType, + masterTransitions.toArray(new List[0])); + fixedTransitionsFuture.complete(fixedTransitions); + zoneIdToTableFuture.complete(zoneIdToTable); + } catch (Exception e) { + fixedTransitionsFuture.completeExceptionally(e); + zoneIdToTableFuture.completeExceptionally(e); + throw e; + } + } + } + + private void close() { + synchronized (this) { + if (closed) { + return; + } + try (HostColumnVector hcv = getHostFixedTransitions()) { + // automatically closed + closed = true; + } + } + } + + private HostColumnVector getHostFixedTransitions() { + try { + return fixedTransitionsFuture.get(TIMEOUT_SECS, TimeUnit.SECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new RuntimeException(e); + } + } + + private Map getZoneIDMap() { + try { + return zoneIdToTableFuture.get(TIMEOUT_SECS, TimeUnit.SECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new RuntimeException(e); + } + } + + private Table getTransitions() { + try (ColumnVector fixedTransitions = getFixedTransitions()) { + return new Table(fixedTransitions); + } + } + + private ColumnVector getFixedTransitions() { + HostColumnVector hostTransitions = getHostFixedTransitions(); + return hostTransitions.copyToDevice(); + } + + /** + * FOR TESTING PURPOSES ONLY, DO NOT USE IN PRODUCTION + * + * This method retrieves the raw list of struct data that forms the list of + * fixed transitions for a particular zoneId. + * + * It has default visibility so the test can access it. + * @param zoneId + * @return list of fixed transitions + */ + List getHostFixedTransitions(String zoneId) { + zoneId = ZoneId.of(zoneId).normalized().toString(); // we use the normalized form to dedupe + Integer idx = getZoneIDMap().get(zoneId); + if (idx == null) { + return null; + } + HostColumnVector transitions = getHostFixedTransitions(); + return transitions.getList(idx); + } + + + private static native long convertTimestampColumnToUTC(long input, long transitions, int tzIndex); + + private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, int tzIndex); +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java b/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java new file mode 100644 index 0000000000..7aaec496de --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java @@ -0,0 +1,231 @@ +/* +* Copyright (c) 2023, NVIDIA CORPORATION. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package com.nvidia.spark.rapids.jni; + +import java.time.ZoneId; +import java.util.List; + +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import ai.rapids.cudf.ColumnVector; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + + +public class TimeZoneTest { + @BeforeAll + static void cacheTimezoneDatabase() { + GpuTimeZoneDB.cacheDatabase(); + } + + @AfterAll + static void cleanup() { + GpuTimeZoneDB.shutdown(); + } + + @Test + void databaseLoadedTest() { + // Check for a few timezones + GpuTimeZoneDB instance = GpuTimeZoneDB.getInstance(); + List transitions = instance.getHostFixedTransitions("UTC+8"); + assertNotNull(transitions); + assertEquals(1, transitions.size()); + transitions = instance.getHostFixedTransitions("Asia/Shanghai"); + assertNotNull(transitions); + ZoneId shanghai = ZoneId.of("Asia/Shanghai").normalized(); + assertEquals(shanghai.getRules().getTransitions().size() + 1, transitions.size()); + } + + @Test + void convertToUtcSecondsTest() { + try (ColumnVector input = ColumnVector.timestampSecondsFromBoxedLongs( + -1262260800L, + -908838000L, + -908840700L, + -888800400L, + -888799500L, + -888796800L, + 0L, + 1699571634L, + 568036800L + ); + ColumnVector expected = ColumnVector.timestampSecondsFromBoxedLongs( + -1262289600L, + -908870400L, + -908869500L, + -888832800L, + -888831900L, + -888825600L, + -28800L, + 1699542834L, + 568008000L + ); + ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void convertToUtcMilliSecondsTest() { + try (ColumnVector input = ColumnVector.timestampMilliSecondsFromBoxedLongs( + -1262260800000L, + -908838000000L, + -908840700000L, + -888800400000L, + -888799500000L, + -888796800000L, + 0L, + 1699571634312L, + 568036800000L + ); + ColumnVector expected = ColumnVector.timestampMilliSecondsFromBoxedLongs( + -1262289600000L, + -908870400000L, + -908869500000L, + -888832800000L, + -888831900000L, + -888825600000L, + -28800000L, + 1699542834312L, + 568008000000L + ); + ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void convertToUtcMicroSecondsTest() { + try (ColumnVector input = ColumnVector.timestampMicroSecondsFromBoxedLongs( + -1262260800000000L, + -908838000000000L, + -908840700000000L, + -888800400000000L, + -888799500000000L, + -888796800000000L, + 0L, + 1699571634312000L, + 568036800000000L + ); + ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs( + -1262289600000000L, + -908870400000000L, + -908869500000000L, + -888832800000000L, + -888831900000000L, + -888825600000000L, + -28800000000L, + 1699542834312000L, + 568008000000000L + ); + ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void convertFromUtcSecondsTest() { + try (ColumnVector input = ColumnVector.timestampSecondsFromBoxedLongs( + -1262289600L, + -908870400L, + -908869500L, + -888832800L, + -888831900L, + -888825600L, + 0L, + 1699542834L, + 568008000L); + ColumnVector expected = ColumnVector.timestampSecondsFromBoxedLongs( + -1262260800L, + -908838000L, + -908837100L, + -888800400L, + -888799500L, + -888796800L, + 28800L, + 1699571634L, + 568036800L); + ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void convertFromUtcMilliSecondsTest() { + try (ColumnVector input = ColumnVector.timestampMilliSecondsFromBoxedLongs( + -1262289600000L, + -908870400000L, + -908869500000L, + -888832800000L, + -888831900000L, + -888825600000L, + 0L, + 1699542834312L, + 568008000000L); + ColumnVector expected = ColumnVector.timestampMilliSecondsFromBoxedLongs( + -1262260800000L, + -908838000000L, + -908837100000L, + -888800400000L, + -888799500000L, + -888796800000L, + 28800000L, + 1699571634312L, + 568036800000L); + ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void convertFromUtcMicroSecondsTest() { + try (ColumnVector input = ColumnVector.timestampMicroSecondsFromBoxedLongs( + -1262289600000000L, + -908870400000000L, + -908869500000000L, + -888832800000000L, + -888831900000000L, + -888825600000000L, + 0L, + 1699542834312000L, + 568008000000000L); + ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs( + -1262260800000000L, + -908838000000000L, + -908837100000000L, + -888800400000000L, + -888799500000000L, + -888796800000000L, + 28800000000L, + 1699571634312000L, + 568036800000000L); + ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { + assertColumnsAreEqual(expected, actual); + } + } + +}