Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-sqlcompat validation in CalciteWindowQueryTest #15086

Merged
merged 39 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
bd55afd
fixes
kgyrtkirk Oct 4, 2023
5cf1e2c
check for latest rewrite place
kgyrtkirk Oct 4, 2023
c4ff274
Revert "check for latest rewrite place"
kgyrtkirk Oct 4, 2023
89844e9
some stuff
kgyrtkirk Oct 4, 2023
ed5d100
update test output
kgyrtkirk Oct 4, 2023
2a4a3ab
updates to test ouptuts
kgyrtkirk Oct 4, 2023
02aac9d
some stuff
kgyrtkirk Oct 4, 2023
a9877c4
move validator
kgyrtkirk Oct 4, 2023
f104ce4
cleanup
kgyrtkirk Oct 4, 2023
145fe82
fix
kgyrtkirk Oct 4, 2023
761cd68
change test slightly
kgyrtkirk Oct 4, 2023
a10fe7a
add apidoc cleanup warnings
kgyrtkirk Oct 4, 2023
6082df5
cleanup/etc
kgyrtkirk Oct 4, 2023
34a6aeb
instead of telling the story; add a fail with some reason whats the i…
kgyrtkirk Oct 4, 2023
9b74ef5
Merge remote-tracking branch 'apache/master' into windowing-fix-test-cmp
kgyrtkirk Oct 4, 2023
e376ae9
lead-lag fix
kgyrtkirk Oct 4, 2023
1acd53b
add test
kgyrtkirk Oct 5, 2023
a708ef2
remove unnecessary throw
kgyrtkirk Oct 5, 2023
8fa0664
druidexception-trial
kgyrtkirk Oct 5, 2023
b484606
Revert "druidexception-trial"
kgyrtkirk Oct 5, 2023
2858ff6
undo changes to no_grouping; add no_grouping2
kgyrtkirk Oct 5, 2023
2e91d7c
Merge remote-tracking branch 'apache/master' into windowing-fix-test-cmp
kgyrtkirk Oct 6, 2023
9a80c89
add missing assert on resultcount
kgyrtkirk Oct 6, 2023
ee2b35d
rename method; update
kgyrtkirk Oct 9, 2023
4e216b7
introduce enum/etc
kgyrtkirk Oct 9, 2023
5d0fcc0
make resultmatchmode accessible from TestBuilder#expectedResults
kgyrtkirk Oct 9, 2023
0ddd3be
fix dump results to use log
kgyrtkirk Oct 9, 2023
4073f5e
Merge remote-tracking branch 'apache/master' into windowing-fix-test-cmp
kgyrtkirk Oct 10, 2023
47e81d8
Merge remote-tracking branch 'apache/master' into windowing-fix-test-cmp
kgyrtkirk Oct 10, 2023
edea152
fix
kgyrtkirk Oct 10, 2023
de099c6
handle null correctly
kgyrtkirk Oct 10, 2023
e018b2c
disable feature type based things for MSQ
kgyrtkirk Oct 10, 2023
a74a9fd
fix varianssqlaggtest
kgyrtkirk Oct 10, 2023
185b8e7
use eps in other test
kgyrtkirk Oct 10, 2023
ed1bb89
fix intellij error
kgyrtkirk Oct 10, 2023
91b1be9
add final
kgyrtkirk Oct 11, 2023
df73774
addrss review
kgyrtkirk Oct 11, 2023
7714e2f
update test/string/etc
kgyrtkirk Oct 11, 2023
78d1d31
write concat in 3 lines :D
kgyrtkirk Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.QueryTestRunner.QueryResults;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.sql.calcite.util.TestDataBuilder;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
Expand Down Expand Up @@ -506,23 +506,9 @@ public void testGroupByAggregatorDefaultValues()
}

@Override
public void assertResultsEquals(String sql, List<Object[]> expectedResults, List<Object[]> results)
public void assertResultsValid(ResultMatchMode matchMode, List<Object[]> expected, QueryResults queryResults)
{
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Object[] expectedResult = expectedResults.get(i);
Object[] result = results.get(i);
Assert.assertEquals(expectedResult.length, result.length);
for (int j = 0; j < expectedResult.length; j++) {
if (expectedResult[j] instanceof Float) {
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 0.000001);
} else if (expectedResult[j] instanceof Double) {
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 0.000001);
} else {
Assert.assertEquals(expectedResult[j], result[j]);
}
}
}
super.assertResultsValid(ResultMatchMode.EQUALS_EPS, expected, queryResults);
}

private static PostAggregator makeFieldAccessPostAgg(String name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.QueryTestRunner.QueryResults;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.sql.calcite.util.TestDataBuilder;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
Expand Down Expand Up @@ -679,22 +679,8 @@ public void testVarianceAggAsInput()
}

@Override
public void assertResultsEquals(String sql, List<Object[]> expectedResults, List<Object[]> results)
public void assertResultsValid(ResultMatchMode matchMode, List<Object[]> expected, QueryResults queryResults)
{
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Object[] expectedResult = expectedResults.get(i);
Object[] result = results.get(i);
Assert.assertEquals(expectedResult.length, result.length);
for (int j = 0; j < expectedResult.length; j++) {
if (expectedResult[j] instanceof Float) {
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-5);
} else if (expectedResult[j] instanceof Double) {
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-5);
} else {
Assert.assertEquals(expectedResult[j], result[j]);
}
}
}
super.assertResultsValid(ResultMatchMode.EQUALS_EPS, expected, queryResults);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.inject.Inject;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.data.Indexed;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -161,6 +162,28 @@ public static <T> T defaultValueForClass(final Class<T> clazz)
}
}

/**
* Returns the default value for the given {@link ValueType}.
*
* May be null or non-null based on the current SQL-compatible null handling mode.
*/
@Nullable
@SuppressWarnings("unchecked")
public static Object defaultValueForType(ValueType type)
{
if (type == ValueType.FLOAT) {
return defaultFloatValue();
} else if (type == ValueType.DOUBLE) {
return defaultDoubleValue();
} else if (type == ValueType.LONG) {
return defaultLongValue();
} else if (type == ValueType.STRING) {
return defaultStringValue();
} else {
return null;
}
}

public static boolean isNullOrEquivalent(@Nullable String value)
{
return replaceWithDefault() ? Strings.isNullOrEmpty(value) : value == null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
package org.apache.druid.common.config;

import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.data.ListIndexed;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;

import java.util.Collections;

import static org.apache.druid.common.config.NullHandling.defaultValueForClass;
import static org.apache.druid.common.config.NullHandling.defaultValueForType;
import static org.apache.druid.common.config.NullHandling.replaceWithDefault;
import static org.junit.Assert.assertEquals;

Expand Down Expand Up @@ -89,6 +92,17 @@ public void test_defaultValueForClass_object()
Assert.assertNull(NullHandling.defaultValueForClass(Object.class));
}

@Test
public void test_defaultValueForType()
{
assertEquals(defaultValueForClass(Float.class), defaultValueForType(ValueType.FLOAT));
assertEquals(defaultValueForClass(Double.class), defaultValueForType(ValueType.DOUBLE));
assertEquals(defaultValueForClass(Long.class), defaultValueForType(ValueType.LONG));
assertEquals(defaultValueForClass(String.class), defaultValueForType(ValueType.STRING));
assertEquals(defaultValueForClass(Object.class), defaultValueForType(ValueType.COMPLEX));
assertEquals(defaultValueForClass(Object.class), defaultValueForType(ValueType.ARRAY));
}

@Test
public void test_ignoreNullsStrings()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
Expand All @@ -86,6 +87,7 @@
import org.apache.druid.server.security.ForbiddenException;
import org.apache.druid.server.security.ResourceAction;
import org.apache.druid.sql.SqlStatementFactory;
import org.apache.druid.sql.calcite.QueryTestRunner.QueryResults;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
Expand Down Expand Up @@ -120,19 +122,24 @@
import org.junit.rules.TemporaryFolder;

import javax.annotation.Nullable;

import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;

/**
* A base class for SQL query testing. It sets up query execution environment, provides useful helper methods,
* and populates data using {@link CalciteTests#createMockWalker}.
Expand Down Expand Up @@ -1033,11 +1040,13 @@ public ObjectMapper jsonMapper()
@Override
public ResultsVerifier defaultResultsVerifier(
List<Object[]> expectedResults,
ResultMatchMode expectedResultMatchMode,
RowSignature expectedResultSignature
)
{
return BaseCalciteQueryTest.this.defaultResultsVerifier(
expectedResults,
expectedResultMatchMode,
expectedResultSignature
);
}
Expand All @@ -1055,6 +1064,115 @@ public Map<String, Object> baseQueryContext()
}
}

public enum ResultMatchMode
{
EQUALS {
@Override
void validate(int row, int column, ValueType type, Object expectedCell, Object resultCell)
{
assertEquals(
mismatchMessage(row, column),
expectedCell,
resultCell);
}
},
RELAX_NULLS {
@Override
void validate(int row, int column, ValueType type, Object expectedCell, Object resultCell)
{
if (expectedCell == null) {
if (resultCell == null) {
return;
}
expectedCell = NullHandling.defaultValueForType(type);
}
EQUALS.validate(row, column, type, expectedCell, resultCell);
}
},
EQUALS_EPS {
@Override
void validate(int row, int column, ValueType type, Object expectedCell, Object resultCell)
{
if (expectedCell instanceof Float) {
assertEquals(
mismatchMessage(row, column),
(Float) expectedCell,
(Float) resultCell,
1e-5);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make this a static variable and set it to 1e-5. Also is there any rationale behind chosing this value ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two tests were overriding the assertResultsEquals method - and they had this constant in them; to set the absolute error bound to 1e-5

I think they were making assertions on approximation results - which may be a little bit different based on the jvm or other external reasons.

I've extracted it into a constant

} else if (expectedCell instanceof Double) {
assertEquals(
mismatchMessage(row, column),
(Double) expectedCell,
(Double) resultCell,
1e-5);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

} else {
EQUALS.validate(row, column, type, expectedCell, resultCell);
}
}
};

abstract void validate(int row, int column, ValueType type, Object expectedCell, Object resultCell);

private static String mismatchMessage(int row, int column)
{
return String.format(Locale.ENGLISH, "column content mismatch at %d,%d", row, column);
kgyrtkirk marked this conversation as resolved.
Show resolved Hide resolved
}

}

/**
* Validates the results with slight loosening in case {@link NullHandling} is not sql compatible.
*
* In case {@link NullHandling#replaceWithDefault()} is true, if the expected result is <code>null</code> it accepts
* both <code>null</code> and the default value for that column as actual result.
*/
public void assertResultsValid(ResultMatchMode matchMode, List<Object[]> expected, QueryResults queryResults)
{
List<Object[]> results = queryResults.results;
Assert.assertEquals("Result count mismatch", expected.size(), results.size());

final List<ValueType> types = new ArrayList<>();

boolean isMSQ = isMSQRowType(queryResults.signature);

if (!isMSQ) {
for (int i = 0; i < queryResults.signature.getColumnNames().size(); i++) {
Optional<ColumnType> columnType = queryResults.signature.getColumnType(i);
if (columnType.isPresent()) {
types.add(columnType.get().getType());
} else {
types.add(null);
}
}
}

int numRows = results.size();
for (int row = 0; row < numRows; row++) {
Object[] expectedRow = expected.get(row);
kgyrtkirk marked this conversation as resolved.
Show resolved Hide resolved
Object[] resultRow = results.get(row);
assertEquals("column count mismatch; at row#" + row, expectedRow.length, resultRow.length);

for (int i = 0; i < resultRow.length; i++) {
Object resultCell = resultRow[i];
Object expectedCell = expectedRow[i];

ResultMatchMode cellValidator = matchMode;
cellValidator.validate(
row,
i,
isMSQ ? null : types.get(i),
expectedCell,
resultCell);
}
}
}

private boolean isMSQRowType(RowSignature signature)
{
List<String> colNames = signature.getColumnNames();
return colNames.size() == 1 && "TASK".equals(colNames.get(0));
}

public void assertResultsEquals(String sql, List<Object[]> expectedResults, List<Object[]> results)
{
int minSize = Math.min(results.size(), expectedResults.size());
Expand Down Expand Up @@ -1331,29 +1449,37 @@ default void verifyRowSignature(RowSignature rowSignature)
// do nothing
}

void verify(String sql, List<Object[]> results);
void verify(String sql, QueryResults queryResults);
}

private ResultsVerifier defaultResultsVerifier(
final List<Object[]> expectedResults,
ResultMatchMode expectedResultMatchMode,
final RowSignature expectedSignature
)
{
return new DefaultResultsVerifier(expectedResults, expectedSignature);
return new DefaultResultsVerifier(expectedResults, expectedResultMatchMode, expectedSignature);
}

public class DefaultResultsVerifier implements ResultsVerifier
{
protected final List<Object[]> expectedResults;
@Nullable
protected final RowSignature expectedResultRowSignature;
protected final ResultMatchMode expectedResultMatchMode;

public DefaultResultsVerifier(List<Object[]> expectedResults, RowSignature expectedSignature)
public DefaultResultsVerifier(List<Object[]> expectedResults, ResultMatchMode expectedResultMatchMode, RowSignature expectedSignature)
{
this.expectedResults = expectedResults;
this.expectedResultMatchMode = expectedResultMatchMode;
this.expectedResultRowSignature = expectedSignature;
}

public DefaultResultsVerifier(List<Object[]> expectedResults, RowSignature expectedSignature)
{
this(expectedResults, ResultMatchMode.EQUALS, expectedSignature);
}

@Override
public void verifyRowSignature(RowSignature rowSignature)
{
Expand All @@ -1363,17 +1489,18 @@ public void verifyRowSignature(RowSignature rowSignature)
}

@Override
public void verify(String sql, List<Object[]> results)
public void verify(String sql, QueryResults queryResults)
{
try {
Assert.assertEquals(StringUtils.format("result count: %s", sql), expectedResults.size(), results.size());
assertResultsEquals(sql, expectedResults, results);
assertResultsValid(expectedResultMatchMode, expectedResults, queryResults);
}
catch (AssertionError e) {
displayResults("Actual", results);
System.out.println("sql: " + sql);
displayResults("Actual", queryResults.results);
throw e;
}
}

}

/**
Expand Down
Loading