Skip to content

Commit

Permalink
fix flaky array_item test failures (NVIDIA#11054)
Browse files Browse the repository at this point in the history
* fix flaky array_item test failures

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>

* fix indent

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>

* fix whitespace

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>

---------

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
  • Loading branch information
binmahone authored Jun 14, 2024
1 parent 356d5a1 commit 599ae17
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
23 changes: 15 additions & 8 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def __repr__(self):
return super().__repr__() + '(' + str(self._child_gen) + ')'

def _cache_repr(self):
return super()._cache_repr() + '(' + self._child_gen._cache_repr() + ')'
return (super()._cache_repr() + '(' + self._child_gen._cache_repr() +
',' + str(self._func.__code__) + ')' )

def start(self, rand):
self._child_gen.start(rand)
Expand Down Expand Up @@ -667,7 +668,10 @@ def __repr__(self):
return super().__repr__() + '(' + str(self._child_gen) + ')'

def _cache_repr(self):
return super()._cache_repr() + '(' + self._child_gen._cache_repr() + ')'
return (super()._cache_repr() + '(' + self._child_gen._cache_repr() +
',' + str(self._min_length) + ',' + str(self._max_length) + ',' +
str(self.all_null) + ',' + str(self.convert_to_tuple) + ')')


def start(self, rand):
self._child_gen.start(rand)
Expand Down Expand Up @@ -701,7 +705,8 @@ def __repr__(self):
return super().__repr__() + '(' + str(self._key_gen) + ',' + str(self._value_gen) + ')'

def _cache_repr(self):
return super()._cache_repr() + '(' + self._key_gen._cache_repr() + ',' + self._value_gen._cache_repr() + ')'
return (super()._cache_repr() + '(' + self._key_gen._cache_repr() + ',' + self._value_gen._cache_repr() +
',' + str(self._min_length) + ',' + str(self._max_length) + ')')

def start(self, rand):
self._key_gen.start(rand)
Expand Down Expand Up @@ -769,12 +774,13 @@ def __init__(self, min_value=MIN_DAY_TIME_INTERVAL, max_value=MAX_DAY_TIME_INTER
self._min_micros = (math.floor(min_value.total_seconds()) * 1000000) + min_value.microseconds
self._max_micros = (math.floor(max_value.total_seconds()) * 1000000) + max_value.microseconds
fields = ["day", "hour", "minute", "second"]
start_index = fields.index(start_field)
end_index = fields.index(end_field)
if start_index > end_index:
self._start_index = fields.index(start_field)
self._end_index = fields.index(end_field)
if self._start_index > self._end_index:
raise RuntimeError('Start field {}, end field {}, valid fields is {}, start field index should <= end '
'field index'.format(start_field, end_field, fields))
super().__init__(DayTimeIntervalType(start_index, end_index), nullable=nullable, special_cases=special_cases)
super().__init__(DayTimeIntervalType(self._start_index, self._end_index), nullable=nullable,
special_cases=special_cases)

def _gen_random(self, rand):
micros = rand.randint(self._min_micros, self._max_micros)
Expand All @@ -784,7 +790,8 @@ def _gen_random(self, rand):
return timedelta(microseconds=micros)

def _cache_repr(self):
return super()._cache_repr() + '(' + str(self._min_micros) + ',' + str(self._max_micros) + ')'
return (super()._cache_repr() + '(' + str(self._min_micros) + ',' + str(self._max_micros) +
',' + str(self._start_index) + ',' + str(self._end_index) + ')')

def start(self, rand):
self._start(rand, lambda: self._gen_random(rand))
Expand Down
6 changes: 5 additions & 1 deletion integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -224,6 +224,10 @@ def test_all_null_int96(spark_tmp_path):
class AllNullTimestampGen(TimestampGen):
def start(self, rand):
self._start(rand, lambda : None)

def _cache_repr(self):
return super()._cache_repr() + '(all_nulls)'

data_path = spark_tmp_path + '/PARQUET_DATA'
confs = copy_and_update(writer_confs, {'spark.sql.parquet.outputTimestampType': 'INT96'})
assert_gpu_and_cpu_writes_are_equal_collect(
Expand Down

0 comments on commit 599ae17

Please sign in to comment.