diff --git a/pom.xml b/pom.xml
index 9bee8348..53f15988 100644
--- a/pom.xml
+++ b/pom.xml
@@ -376,7 +376,7 @@
org.apache.arrow
flight-sql-jdbc-driver
- 16.1.0
+ 17.0.0
diff --git a/src/sqlancer/ComparatorHelper.java b/src/sqlancer/ComparatorHelper.java
index 5da635de..d061ecb7 100644
--- a/src/sqlancer/ComparatorHelper.java
+++ b/src/sqlancer/ComparatorHelper.java
@@ -131,7 +131,8 @@ public static void assumeResultSetsAreEqual(List resultSet, List
public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet,
String originalQueryString, List combinedString, SQLGlobalState, ?> state,
UnaryOperator canonicalizationRule) {
- // Overloaded version of assumeResultSetsAreEqual that takes a canonicalization function which is applied to
+ // Overloaded version of assumeResultSetsAreEqual that takes a canonicalization
+ // function which is applied to
// both result sets before their comparison.
List canonicalizedResultSet = resultSet.stream().map(canonicalizationRule).collect(Collectors.toList());
List canonicalizedSecondResultSet = secondResultSet.stream().map(canonicalizationRule)
diff --git a/src/sqlancer/IgnoreMeException.java b/src/sqlancer/IgnoreMeException.java
index bc2e2591..cdae6acd 100644
--- a/src/sqlancer/IgnoreMeException.java
+++ b/src/sqlancer/IgnoreMeException.java
@@ -4,4 +4,11 @@ public class IgnoreMeException extends RuntimeException {
private static final long serialVersionUID = 1L;
+ public IgnoreMeException() {
+ super();
+ }
+
+ public IgnoreMeException(String message) {
+ super(message);
+ }
}
diff --git a/src/sqlancer/Main.java b/src/sqlancer/Main.java
index 1cbc0264..68639093 100644
--- a/src/sqlancer/Main.java
+++ b/src/sqlancer/Main.java
@@ -31,6 +31,7 @@
import sqlancer.common.query.Query;
import sqlancer.common.query.SQLancerResultSet;
import sqlancer.databend.DatabendProvider;
+import sqlancer.datafusion.DataFusionProvider;
import sqlancer.doris.DorisProvider;
import sqlancer.duckdb.DuckDBProvider;
import sqlancer.h2.H2Provider;
@@ -734,6 +735,7 @@ private static void checkForIssue799(List> providers)
providers.add(new CnosDBProvider());
providers.add(new CockroachDBProvider());
providers.add(new DatabendProvider());
+ providers.add(new DataFusionProvider());
providers.add(new DorisProvider());
providers.add(new DuckDBProvider());
providers.add(new H2Provider());
diff --git a/src/sqlancer/common/query/SQLancerResultSet.java b/src/sqlancer/common/query/SQLancerResultSet.java
index d1221a7f..7cc1523e 100644
--- a/src/sqlancer/common/query/SQLancerResultSet.java
+++ b/src/sqlancer/common/query/SQLancerResultSet.java
@@ -6,7 +6,7 @@
public class SQLancerResultSet implements Closeable {
- ResultSet rs;
+ public ResultSet rs;
private Runnable runnableEpilogue;
public SQLancerResultSet(ResultSet rs) {
diff --git a/src/sqlancer/datafusion/DataFusionErrors.java b/src/sqlancer/datafusion/DataFusionErrors.java
index a646f304..9e62bbe9 100644
--- a/src/sqlancer/datafusion/DataFusionErrors.java
+++ b/src/sqlancer/datafusion/DataFusionErrors.java
@@ -44,28 +44,26 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
errors.add("There is only support Literal types for field at idx:");
errors.add("nth_value not supported for n:");
errors.add("Invalid argument error: Nested comparison: List(");
+ errors.add("This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal");
+ errors.add(
+ "This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal");
/*
* Known bugs
*/
- errors.add("to type Int"); // https://github.com/apache/datafusion/issues/11249
errors.add("bitwise"); // https://github.com/apache/datafusion/issues/11260
errors.add("Sort expressions cannot be empty for streaming merge."); // https://github.com/apache/datafusion/issues/11561
- errors.add("compute_utf8_flag_op_scalar failed to cast literal value NULL for operation"); // https://github.com/apache/datafusion/issues/11623
errors.add("Schema error: No field named "); // https://github.com/apache/datafusion/issues/12006
- errors.add("Internal error: PhysicalExpr Column references column"); // https://github.com/apache/datafusion/issues/12012
- errors.add("APPROX_"); // https://github.com/apache/datafusion/issues/12058
- errors.add("External error: task"); // https://github.com/apache/datafusion/issues/12057
- errors.add("NTH_VALUE"); // https://github.com/apache/datafusion/issues/12073
- errors.add("SUBSTR"); // https://github.com/apache/datafusion/issues/12129
+ errors.add("NATURAL JOIN"); // https://github.com/apache/datafusion/issues/14015
/*
* False positives
*/
errors.add("Cannot cast string"); // ifnull() is passed two non-compattable type and caused execution error
- errors.add("Physical plan does not support logical expression AggregateFunction"); // False positive: when aggr
- // is generated in where
- // clause
+ // False positive: when aggr is generated in where clause
+ errors.add("Physical plan does not support logical expression AggregateFunction");
+ errors.add("Unsupported ArrowType Utf8View"); // Maybe bug in arrow flight
+ // jdbc driver
/*
* Not critical, investigate in the future
@@ -73,5 +71,16 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
errors.add("does not match with the projection expression");
errors.add("invalid operator for nested");
errors.add("Arrow error: Cast error: Can't cast value");
+ errors.add("Nth value indices are 1 based");
+ /*
+ * Example query that triggers this error: create table t1(v1 int, v2 bool); select v1, sum(1) over (partition
+ * by v1 order by v2 range between 0 preceding and 0 following) from t1;
+ *
+ * Current error message: Arrow error: Invalid argument error: Invalid arithmetic operation: Boolean - Boolean
+ *
+ * TODO: The error message could be more meaningful to indicate that RANGE frame is not supported for boolean
+ * ORDER BY columns
+ */
+ errors.add("Invalid arithmetic operation");
}
}
diff --git a/src/sqlancer/datafusion/DataFusionOptions.java b/src/sqlancer/datafusion/DataFusionOptions.java
index 5a2d8b69..56db6b76 100644
--- a/src/sqlancer/datafusion/DataFusionOptions.java
+++ b/src/sqlancer/datafusion/DataFusionOptions.java
@@ -15,6 +15,7 @@
import sqlancer.datafusion.test.DataFusionNoCrashAggregate;
import sqlancer.datafusion.test.DataFusionNoCrashWindow;
import sqlancer.datafusion.test.DataFusionNoRECOracle;
+import sqlancer.datafusion.test.DataFusionPQS;
import sqlancer.datafusion.test.DataFusionQueryPartitioningAggrTester;
import sqlancer.datafusion.test.DataFusionQueryPartitioningHavingTester;
import sqlancer.datafusion.test.DataFusionQueryPartitioningWhereTester;
@@ -26,13 +27,11 @@ public class DataFusionOptions implements DBMSSpecificOptions getTestOracleFactory() {
- return Arrays.asList(
- // DataFusionOracleFactory.NO_CRASH_WINDOW,
- // DataFusionOracleFactory.NO_CRASH_AGGREGATE,
- DataFusionOracleFactory.NOREC, DataFusionOracleFactory.QUERY_PARTITIONING_WHERE
+ return Arrays.asList(DataFusionOracleFactory.PQS, DataFusionOracleFactory.NO_CRASH_WINDOW,
+ DataFusionOracleFactory.NO_CRASH_AGGREGATE, DataFusionOracleFactory.NOREC,
+ DataFusionOracleFactory.QUERY_PARTITIONING_WHERE);
// DataFusionOracleFactory.QUERY_PARTITIONING_AGGREGATE
- // ,DataFusionOracleFactory.QUERY_PARTITIONING_HAVING
- );
+ // DataFusionOracleFactory.QUERY_PARTITIONING_HAVING);
}
public enum DataFusionOracleFactory implements OracleFactory {
@@ -42,6 +41,12 @@ public TestOracle create(DataFusionGlobalState globalStat
return new DataFusionNoRECOracle(globalState);
}
},
+ PQS {
+ @Override
+ public TestOracle create(DataFusionGlobalState globalState) throws SQLException {
+ return new DataFusionPQS(globalState);
+ }
+ },
QUERY_PARTITIONING_WHERE {
@Override
public TestOracle create(DataFusionGlobalState globalState) throws SQLException {
diff --git a/src/sqlancer/datafusion/DataFusionProvider.java b/src/sqlancer/datafusion/DataFusionProvider.java
index 37328e4b..5ad60561 100644
--- a/src/sqlancer/datafusion/DataFusionProvider.java
+++ b/src/sqlancer/datafusion/DataFusionProvider.java
@@ -1,6 +1,5 @@
package sqlancer.datafusion;
-import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.DML;
import static sqlancer.datafusion.DataFusionUtil.dfAssert;
import static sqlancer.datafusion.DataFusionUtil.displayTables;
@@ -8,6 +7,7 @@
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.List;
+import java.util.Optional;
import java.util.Properties;
import java.util.stream.Collectors;
@@ -34,29 +34,52 @@ public DataFusionProvider() {
super(DataFusionGlobalState.class, DataFusionOptions.class);
}
+ // Basic tables generated are DataFusion memory tables (named t1, t2, ...)
+ // Equivalent table can be backed by different physical implementation
+ // which will be named like t1_stringview, t2_parquet, etc.
+ //
+ // e.g. t1 and t1_stringview are logically equivalent table, but backed by
+ // different physical representation
+ //
+ // This helps to do more metamorphic testing on tables, for example
+ // `select * from t1` and `select * from t1_stringview` should give same
+ // result
+ //
+ // Supported physical implementation for tables:
+ // 1. Memory table (t1)
+ // 2. Memory table use StringView for TEXT columns (t1_stringview)
+ // Note: It's possible only convert random TEXT columns to StringView
@Override
public void generateDatabase(DataFusionGlobalState globalState) throws Exception {
- int tableCount = Randomly.fromOptions(1, 2, 3, 4, 5, 6, 7);
+ // Create base tables
+ // ============================
+
+ int tableCount = Randomly.fromOptions(1, 2, 3, 4);
for (int i = 0; i < tableCount; i++) {
- SQLQueryAdapter queryCreateRandomTable = new DataFusionTableGenerator().getQuery(globalState);
+ SQLQueryAdapter queryCreateRandomTable = new DataFusionTableGenerator().getCreateStmt(globalState);
queryCreateRandomTable.execute(globalState);
globalState.updateSchema();
- globalState.dfLogger.appendToLog(DML, queryCreateRandomTable.toString() + "\n");
+ globalState.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.DML,
+ queryCreateRandomTable.toString() + "\n");
}
// Now only `INSERT` DML is supported
// If more DMLs are added later, should use`StatementExecutor` instead
// (see DuckDB's implementation for reference)
+ // Generating rows in base tables (t1, t2, ... not include t1_stringview, etc.)
+ // ============================
+
globalState.updateSchema();
- List allTables = globalState.getSchema().getDatabaseTables();
- List allTablesName = allTables.stream().map(t -> t.getName()).collect(Collectors.toList());
- if (allTablesName.isEmpty()) {
+ List allBaseTables = globalState.getSchema().getDatabaseTables();
+ List allBaseTablesName = allBaseTables.stream().map(DataFusionTable::getName)
+ .collect(Collectors.toList());
+ if (allBaseTablesName.isEmpty()) {
dfAssert(false, "Generate Database failed.");
}
// Randomly insert some data into existing tables
- for (DataFusionTable table : allTables) {
+ for (DataFusionTable table : allBaseTables) {
int nInsertQuery = globalState.getRandomly().getInteger(0, globalState.getOptions().getMaxNumberInserts());
for (int i = 0; i < nInsertQuery; i++) {
@@ -69,9 +92,24 @@ public void generateDatabase(DataFusionGlobalState globalState) throws Exception
}
insertQuery.execute(globalState);
- globalState.dfLogger.appendToLog(DML, insertQuery.toString() + "\n");
+ globalState.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.DML, insertQuery.toString() + "\n");
+ }
+ }
+
+ // Construct mutated tables like t1_stringview, etc.
+ // ============================
+ for (DataFusionTable table : allBaseTables) {
+ Optional queryCreateStringViewTable = new DataFusionTableGenerator()
+ .createStringViewTable(globalState, table);
+ if (queryCreateStringViewTable.isPresent()) {
+ queryCreateStringViewTable.get().execute(globalState);
+ globalState.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.DML,
+ queryCreateStringViewTable.get().toString() + "\n");
}
}
+ globalState.updateSchema();
+ List allTables = globalState.getSchema().getDatabaseTables();
+ List allTablesName = allTables.stream().map(DataFusionTable::getName).collect(Collectors.toList());
// TODO(datafusion) add `DataFUsionLogType.STATE` for this whole db state log
if (globalState.getDbmsSpecificOptions().showDebugInfo) {
diff --git a/src/sqlancer/datafusion/DataFusionSchema.java b/src/sqlancer/datafusion/DataFusionSchema.java
index 24c3a4eb..2d08a347 100644
--- a/src/sqlancer/datafusion/DataFusionSchema.java
+++ b/src/sqlancer/datafusion/DataFusionSchema.java
@@ -9,6 +9,8 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.Optional;
+import java.util.regex.Pattern;
import java.util.stream.Collectors;
import sqlancer.Randomly;
@@ -32,6 +34,9 @@ public DataFusionSchema(List databaseTables) {
// update existing tables in DB by query again
// (like `show tables;`)
+ //
+ // This function also setup table<->column reference pointers
+ // and equivalent tables(see `DataFusionTable.equivalentTables)
public static DataFusionSchema fromConnection(SQLConnection con, String databaseName) throws SQLException {
List databaseTables = new ArrayList<>();
List tableNames = getTableNames(con);
@@ -47,6 +52,24 @@ public static DataFusionSchema fromConnection(SQLConnection con, String database
databaseTables.add(t);
}
+ // Setup equivalent tables
+ // For example, now we have t1, t1_csv, t1_parquet, t2_csv, t2_parquet
+ // t1's equivalent tables: t1, t1_csv, t1_parquet
+ // t2_csv's equivalent tables: t2_csv, t2_parquet
+ // ...
+ //
+ // It can be assumed that:
+ // base table names are like t1, t2, ...
+ // equivalent tables are like t1_csv, t1_parquet, ...
+ for (DataFusionTable t : databaseTables) {
+ String baseTableName = t.getName().split("_")[0];
+ String patternString = "^" + baseTableName + "(_.*)?$"; // t1 or t1_*
+ Pattern pattern = Pattern.compile(patternString);
+
+ t.equivalentTables = databaseTables.stream().filter(table -> pattern.matcher(table.getName()).matches())
+ .map(DataFusionTable::getName).collect(Collectors.toList());
+ }
+
return new DataFusionSchema(databaseTables);
}
@@ -120,8 +143,10 @@ public static DataFusionDataType parseFromDataFusionCatalog(String typeString) {
return DataFusionDataType.BOOLEAN;
case "Utf8":
return DataFusionDataType.STRING;
+ case "Utf8View":
+ return DataFusionDataType.STRING;
default:
- dfAssert(false, "Unreachable. All branches should be eovered");
+ dfAssert(false, "Uncovered branch typeString: " + typeString);
}
dfAssert(false, "Unreachable. All branches should be eovered");
@@ -169,25 +194,89 @@ public Node getRandomConstant(DataFusionGlobalState state)
public static class DataFusionColumn extends AbstractTableColumn {
private final boolean isNullable;
+ public Optional alias;
public DataFusionColumn(String name, DataFusionDataType columnType, boolean isNullable) {
super(name, null, columnType);
this.isNullable = isNullable;
+ this.alias = Optional.empty();
}
public boolean isNullable() {
return isNullable;
}
+ public String getOrignalName() {
+ return getTable().getName() + "." + getName();
+ }
+
+ @Override
+ public String getFullQualifiedName() {
+ if (getTable() == null) {
+ return getName();
+ } else {
+ if (alias.isPresent()) {
+ return alias.get();
+ } else {
+ return getTable().getName() + "." + getName();
+ }
+ }
+ }
}
public static class DataFusionTable
extends AbstractRelationalTable {
+ // There might exist multiple logically equivalent tables with
+ // different physical format.
+ // e.g. t1_csv, t1_parquet, ...
+ //
+ // When generating random query, it's possible to randomly pick one
+ // of them for stronger randomization.
+ public List equivalentTables;
+
+ // Pick a random equivalent table name
+ // This can be used when generating differential queries
+ public Optional currentEquivalentTableName;
+
+ // For example in query `select * from t1 as tt1, t1 as tt2`
+ // `tt1` is the alias for the first occurance of `t1`
+ public Optional alias;
public DataFusionTable(String tableName, List columns, boolean isView) {
super(tableName, columns, Collections.emptyList(), isView);
}
+ public String getNotAliasedName() {
+ if (currentEquivalentTableName != null && currentEquivalentTableName.isPresent()) {
+ // In case setup is not done yet
+ return currentEquivalentTableName.get();
+ } else {
+ return super.getName();
+ }
+ }
+
+ // TODO(datafusion) Now implementation is hacky, should send a patch
+ // to core to support this
+ @Override
+ public String getName() {
+ // Before setup equivalent tables, we use the original table name
+ // Setup happens in `fromConnection()`
+ if (equivalentTables == null || currentEquivalentTableName == null) {
+ return super.getName();
+ }
+
+ if (alias.isPresent()) {
+ return alias.get();
+ } else {
+ return currentEquivalentTableName.get();
+ }
+ }
+
+ public void pickAnotherEquivalentTableName() {
+ dfAssert(!equivalentTables.isEmpty(), "equivalentTables should not be empty");
+ currentEquivalentTableName = Optional.of(Randomly.fromList(equivalentTables));
+ }
+
public static List getAllColumns(List tables) {
return tables.stream().map(AbstractTable::getColumns).flatMap(List::stream).collect(Collectors.toList());
}
diff --git a/src/sqlancer/datafusion/DataFusionToStringVisitor.java b/src/sqlancer/datafusion/DataFusionToStringVisitor.java
index f07e4a16..11965841 100644
--- a/src/sqlancer/datafusion/DataFusionToStringVisitor.java
+++ b/src/sqlancer/datafusion/DataFusionToStringVisitor.java
@@ -7,10 +7,14 @@
import sqlancer.Randomly;
import sqlancer.common.ast.newast.NewToStringVisitor;
import sqlancer.common.ast.newast.Node;
+import sqlancer.common.ast.newast.TableReferenceNode;
+import sqlancer.datafusion.DataFusionSchema.DataFusionTable;
import sqlancer.datafusion.ast.DataFusionConstant;
import sqlancer.datafusion.ast.DataFusionExpression;
import sqlancer.datafusion.ast.DataFusionSelect;
+import sqlancer.datafusion.ast.DataFusionSelect.DataFusionAliasedTable;
import sqlancer.datafusion.ast.DataFusionSelect.DataFusionFrom;
+import sqlancer.datafusion.ast.DataFusionSpecialExpr.CastToStringView;
import sqlancer.datafusion.ast.DataFusionWindowExpr;
public class DataFusionToStringVisitor extends NewToStringVisitor {
@@ -37,6 +41,10 @@ public void visitSpecific(Node expr) {
visit((DataFusionFrom) expr);
} else if (expr instanceof DataFusionWindowExpr) {
visit((DataFusionWindowExpr) expr);
+ } else if (expr instanceof CastToStringView) {
+ visit((CastToStringView) expr);
+ } else if (expr instanceof DataFusionAliasedTable) {
+ visit((DataFusionAliasedTable) expr);
} else {
throw new AssertionError(expr.getClass());
}
@@ -49,13 +57,13 @@ private void visit(DataFusionFrom from) {
/* e.g. from t1, t2, t3 */
if (from.joinConditionList.isEmpty()) {
- visit(from.tableList);
+ visit(from.tableExprList);
return;
}
- dfAssert(from.joinConditionList.size() == from.tableList.size() - 1, "Validate from");
+ dfAssert(from.joinConditionList.size() == from.tableExprList.size() - 1, "Validate from");
/* e.g. from t1 join t2 on t1.v1=t2.v1 */
- visit(from.tableList.get(0));
+ visit(from.tableExprList.get(0));
for (int i = 0; i < from.joinConditionList.size(); i++) {
switch (from.joinTypeList.get(i)) {
case INNER:
@@ -80,7 +88,7 @@ private void visit(DataFusionFrom from) {
dfAssert(false, "Unreachable");
}
- visit(from.tableList.get(i + 1)); // ti
+ visit(from.tableExprList.get(i + 1)); // ti
/* ON ... */
Node cond = from.joinConditionList.get(i);
@@ -163,4 +171,31 @@ private void visit(DataFusionWindowExpr window) {
sb.append(")");
}
+ private void visit(CastToStringView castToStringView) {
+ sb.append("ARROW_CAST(");
+ visit(castToStringView.expr);
+ sb.append(", 'Utf8')");
+ }
+
+ private void visit(DataFusionAliasedTable alias) {
+ if (alias.table instanceof TableReferenceNode) {
+ DataFusionTable t = null;
+ if (alias.table instanceof TableReferenceNode) {
+ TableReferenceNode, ?> tableRef = (TableReferenceNode, ?>) alias.table;
+ t = (DataFusionTable) tableRef.getTable();
+ } else {
+ dfAssert(false, "Unreachable");
+ }
+
+ String baseName = t.getNotAliasedName();
+ sb.append(baseName);
+
+ dfAssert(t.alias.isPresent(), "Alias should be present");
+ sb.append(" AS ");
+ sb.append(t.alias.get());
+ } else {
+ dfAssert(false, "Unreachable");
+ }
+ }
+
}
diff --git a/src/sqlancer/datafusion/DataFusionUtil.java b/src/sqlancer/datafusion/DataFusionUtil.java
index ac082afd..4f5b34f6 100644
--- a/src/sqlancer/datafusion/DataFusionUtil.java
+++ b/src/sqlancer/datafusion/DataFusionUtil.java
@@ -34,7 +34,6 @@ public static String displayTables(DataFusionGlobalState state, List fro
ResultSetMetaData metaData = wholeTable.getMetaData();
int columnCount = metaData.getColumnCount();
- resultStringBuilder.append("Table: ").append(tableName).append("\n");
for (int i = 1; i <= columnCount; i++) {
resultStringBuilder.append(metaData.getColumnName(i)).append(" (")
.append(metaData.getColumnTypeName(i)).append(")");
@@ -58,7 +57,8 @@ public static String displayTables(DataFusionGlobalState state, List fro
} catch (SQLException err) {
resultStringBuilder.append("Table: ").append(tableName).append("\n");
resultStringBuilder.append("----------------------------------------\n\n");
- // resultStringBuilder.append("Error retrieving data from table ").append(tableName).append(":
+ // resultStringBuilder.append("Error retrieving data from table
+ // ").append(tableName).append(":
// ").append(err.getMessage()).append("\n");
}
}
@@ -66,7 +66,8 @@ public static String displayTables(DataFusionGlobalState state, List fro
return resultStringBuilder.toString();
}
- // During development, you might want to manually let this function call exit(1) to fail fast
+ // During development, you might want to manually let this function call exit(1)
+ // to fail fast
public static void dfAssert(boolean condition, String message) {
if (!condition) {
// Development mode assertion failure
diff --git a/src/sqlancer/datafusion/ast/DataFusionSelect.java b/src/sqlancer/datafusion/ast/DataFusionSelect.java
index 0b3922b1..d91b59cc 100644
--- a/src/sqlancer/datafusion/ast/DataFusionSelect.java
+++ b/src/sqlancer/datafusion/ast/DataFusionSelect.java
@@ -27,15 +27,21 @@ public class DataFusionSelect extends SelectBase> imp
// `from` is used to represent from table list and join clause
// `fromList` and `joinList` in base class should always be empty
public DataFusionFrom from;
+ // Randomly selected table (equivalent to `from.tableList`)
+ // Can be refactored, it's a hack for now
+ public List tableList;
+
// e.g. let's say all colummns are {c1, c2, c3, c4, c5}
// First randomly pick a subset say {c2, c1, c3, c4}
// `exprGenAll` can generate random expr using above 4 columns
//
- // Next, randomly take two non-overlapping subset from all columns used by `exprGenAll`
+ // Next, randomly take two non-overlapping subset from all columns used by
+ // `exprGenAll`
// exprGenGroupBy: {c1} (randomly generate group by exprs using c1 only)
// exprGenAggregate: {c3, c4}
//
- // Finally, use all `Gen`s to generate different clauses in a query (`exprGenAll` in where clause, `exprGenGroupBy`
+ // Finally, use all `Gen`s to generate different clauses in a query
+ // (`exprGenAll` in where clause, `exprGenGroupBy`
// in group by clause, etc.)
public DataFusionExpressionGenerator exprGenAll;
public DataFusionExpressionGenerator exprGenGroupBy;
@@ -46,7 +52,8 @@ public enum JoinType {
}
// DataFusionFrom can be used to represent from table list or join list
- // 1. When `joinConditionList` is empty, then it's a table list (implicit cross join)
+ // 1. When `joinConditionList` is empty, then it's a table list (implicit cross
+ // join)
// join condition can be generated in `WHERE` clause (outside `FromClause`)
// e.g. select * from [expr], [expr] is t1, t3, t2
// - tableList -> {t1, t3,t2}
@@ -60,18 +67,25 @@ public enum JoinType {
// - joinTypeList -> {INNER, LEFT}
// - joinConditionList -> {[expr_with_t1_t2], [expr_with_t1_t2_t3]}
public static class DataFusionFrom implements Node {
- public List> tableList;
+ public List> tableExprList;
public List joinTypeList;
public List> joinConditionList;
public DataFusionFrom() {
- tableList = new ArrayList<>();
+ tableExprList = new ArrayList<>();
joinTypeList = new ArrayList<>();
joinConditionList = new ArrayList<>();
}
+ public DataFusionFrom(List tables) {
+ this();
+ tableExprList = tables.stream().map(t -> new TableReferenceNode(t))
+ .map(tableExpr -> new DataFusionAliasedTable(tableExpr)).collect(Collectors.toList());
+ }
+
public boolean isExplicitJoin() {
- // if it's explicit join, joinTypeList and joinConditionList should be both length of tableList.len - 1
+ // if it's explicit join, joinTypeList and joinConditionList should be both
+ // length of tableList.len - 1
// Otherwise, both is empty
dfAssert(joinTypeList.size() == joinConditionList.size(), "Validate FromClause");
return !joinTypeList.isEmpty();
@@ -89,17 +103,19 @@ public static DataFusionFrom generateFromClause(DataFusionGlobalState state,
List> randomTableNodes = randomTables.stream()
.map(t -> new TableReferenceNode(t))
.collect(Collectors.toList());
- fromClause.tableList = randomTableNodes;
+ fromClause.tableExprList = randomTableNodes;
/* If JoinConditionList is empty, FromClause will be interpreted as from list */
if (Randomly.getBoolean() && Randomly.getBoolean()) {
+ fromClause.setupAlias();
return fromClause;
}
/* Set fromClause's joinTypeList and joinConditionList */
List possibleColsToGenExpr = new ArrayList<>();
possibleColsToGenExpr.addAll(randomTables.get(0).getColumns()); // first table
- // Generate join conditions (see class-level comment example's joinConditionList)
+ // Generate join conditions (see class-level comment example's
+ // joinConditionList)
//
// Join Type | `ON` Clause Requirement
// INNER JOIN | Required
@@ -130,12 +146,80 @@ public static DataFusionFrom generateFromClause(DataFusionGlobalState state,
.add(exprGen.generateExpression(DataFusionSchema.DataFusionDataType.BOOLEAN));
}
}
- // TODO(datafusion) make join conditions more likely to be 'col1=col2', also some join types don't have
+ // TODO(datafusion) make join conditions more likely to be 'col1=col2', also
+ // some join types don't have
// 'ON' condition
}
+ // TODO(datafusion) add an option to disable this when issue fixed
+ // https://github.com/apache/datafusion/issues/12337
+ fromClause.setupAlias();
+
return fromClause;
}
+
+ public void setupAlias() {
+ for (int i = 0; i < tableExprList.size(); i++) {
+ if (tableExprList.get(i) instanceof TableReferenceNode) {
+ @SuppressWarnings("unchecked") // Suppress the unchecked cast warning
+ TableReferenceNode node = (TableReferenceNode) tableExprList
+ .get(i);
+ node.getTable().alias = Optional.of("tt" + i);
+ } else {
+ dfAssert(false, "Expected all items in tableList to be TableReferenceNode instances");
+ }
+ }
+
+ // wrap table in `DataFusionAlias` for display
+ List> wrappedTables = new ArrayList<>();
+ for (Node table : tableExprList) {
+ wrappedTables.add(new DataFusionAliasedTable(table));
+ }
+ tableExprList = wrappedTables;
+ }
+
+ }
+
+ // If original query is
+ // select * from t1, t2, t3
+ // The randomly mutated query looks like:
+ // select * from t1_csv, t2, t3_parquet
+ public void mutateEquivalentTableName() {
+ for (Node table : from.tableExprList) {
+ if (table instanceof DataFusionAliasedTable) {
+ Node aliasedTable = ((DataFusionAliasedTable) table).table;
+
+ if (aliasedTable instanceof TableReferenceNode) {
+ @SuppressWarnings("unchecked") // Suppress the unchecked cast warning
+ TableReferenceNode tableRef = (TableReferenceNode) aliasedTable;
+ tableRef.getTable().pickAnotherEquivalentTableName();
+ } else {
+ dfAssert(false, "Expected all items in tableList to be TableReferenceNode instances");
+ }
+ } else {
+ dfAssert(false, "Expected all items in tableList to be TableReferenceNode instances");
+ }
+ }
+ }
+
+ // Just a marker for table in `DataFusionFrom`
+ //
+ // For example in query `select * from t1 as tt1, t1 as tt2`
+ // If it's in the from list, we use `DataFusionAlias` wrapper on the table
+ // and print it as 't1 as tt1'
+ // If the same table is in expressions, don't use the wrapper and print it as
+ // 'tt1'
+ public static class DataFusionAliasedTable implements Node {
+ public Node table;
+
+ public DataFusionAliasedTable(Node table) {
+ dfAssert(table instanceof TableReferenceNode, "Expected table reference node");
+ @SuppressWarnings("unchecked") // Suppress the unchecked cast warning
+ DataFusionTable t = ((TableReferenceNode) table).getTable();
+ dfAssert(t.alias.isPresent(), "Expected table to have alias");
+
+ this.table = table;
+ }
}
// Generate SELECT statement according to the dependency of exprs, e.g.:
@@ -149,12 +233,17 @@ public static DataFusionFrom generateFromClause(DataFusionGlobalState state,
//
// The generation order will be:
// 1. [from_clause] - Pick tables like t1, t2, t3 and get a join clause
- // 2. [expr_all_cols] - Generate a non-aggregate expression with all columns in t1, t2, t3. e.g.:
+ // 2. [expr_all_cols] - Generate a non-aggregate expression with all columns in
+ // t1, t2, t3. e.g.:
// - t1.v1 = t2.v1 and t1.v2 > t3.v2
- // 3. [expr_groupby_cols], [expr_aggr_cols] - Randomly pick some cols in t1, t2, t3 as group by columns, and pick
- // some other columns as aggregation columns, and generate non-aggr expression [expr_groupby_cols] on group by
- // columns, finally generate aggregation expressions [expr_aggr_cols] on non-group-by/aggregation columns.
- // For example, group by column is t1.v1, and aggregate columns is t2.v1, t3.v1, generated expressions can be:
+ // 3. [expr_groupby_cols], [expr_aggr_cols] - Randomly pick some cols in t1, t2,
+ // t3 as group by columns, and pick
+ // some other columns as aggregation columns, and generate non-aggr expression
+ // [expr_groupby_cols] on group by
+ // columns, finally generate aggregation expressions [expr_aggr_cols] on
+ // non-group-by/aggregation columns.
+ // For example, group by column is t1.v1, and aggregate columns is t2.v1, t3.v1,
+ // generated expressions can be:
// - [expr_groupby_cols] t1.v1 + 1
// - [expr_aggr_cols] SUM(t3.v1 + t2.v1)
public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) {
@@ -172,6 +261,7 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) {
randomTables = randomTables.subList(0, maxSize);
}
DataFusionFrom randomFrom = DataFusionFrom.generateFromClause(state, randomTables);
+ randomSelect.tableList = randomTables;
/* Setup expression generators (to generate different clauses) */
List randomColumnsAll = DataFusionTable.getRandomColumns(randomTables);
@@ -190,9 +280,15 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) {
.generateExpression(DataFusionSchema.DataFusionDataType.BOOLEAN);
/* Constructing result */
- List> randomColumnNodes = randomColumnsAll.stream()
- .map((c) -> new ColumnReferenceNode(c))
- .collect(Collectors.toList());
+ List> randomColumnNodes = randomColumnsAll.stream().map((c) -> {
+ if (c.getType() == DataFusionSchema.DataFusionDataType.STRING) {
+ Node colRef = new ColumnReferenceNode(c);
+ return new DataFusionSpecialExpr.CastToStringView(colRef);
+
+ } else {
+ return new ColumnReferenceNode(c);
+ }
+ }).collect(Collectors.toList());
randomSelect.setFetchColumns(randomColumnNodes); // TODO(datafusion) make it more random like 'select *'
randomSelect.from = randomFrom;
@@ -218,7 +314,8 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) {
// ...
// group by v1
//
- // This method assume `DataFusionSelect` is propoerly initialized with `getRandomSelect()`
+ // This method assume `DataFusionSelect` is propoerly initialized with
+ // `getRandomSelect()`
public void setAggregates(DataFusionGlobalState state) {
// group by exprs (e.g. group by v1, abs(v2))
List> groupByExprs = this.exprGenGroupBy.generateExpressionsPreferColumns();
diff --git a/src/sqlancer/datafusion/ast/DataFusionSpecialExpr.java b/src/sqlancer/datafusion/ast/DataFusionSpecialExpr.java
new file mode 100644
index 00000000..8ae63861
--- /dev/null
+++ b/src/sqlancer/datafusion/ast/DataFusionSpecialExpr.java
@@ -0,0 +1,13 @@
+package sqlancer.datafusion.ast;
+
+import sqlancer.common.ast.newast.Node;
+
+public class DataFusionSpecialExpr {
+ public static class CastToStringView implements Node {
+ public Node expr;
+
+ public CastToStringView(Node expr) {
+ this.expr = expr;
+ }
+ }
+}
diff --git a/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java b/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java
index 3754a826..c19ad00d 100644
--- a/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java
+++ b/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java
@@ -55,7 +55,8 @@ protected boolean canGenerateColumnOfType(DataFusionDataType type) {
return true;
}
- // If target expr type is numeric, when `supportAggregate`, make it more likely to generate aggregate functions
+ // If target expr type is numeric, when `supportAggregate`, make it more likely
+ // to generate aggregate functions
// Since randomly generated expressions are nested:
// For window/aggregate case, we want only outer layer to be window/aggr
// to make it more likely to generate valid query
@@ -82,7 +83,8 @@ private boolean filterBaseExpr(DataFusionBaseExpr expr, DataFusionDataType type,
}
// By default all possible non-aggregate expressions
- // To generate aggregate functions: set this.supportAggregate to `true`, generate exprs, and reset.
+ // To generate aggregate functions: set this.supportAggregate to `true`,
+ // generate exprs, and reset.
@Override
protected Node generateExpression(DataFusionDataType type, int depth) {
if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) {
@@ -103,7 +105,8 @@ protected Node generateExpression(DataFusionDataType type,
DataFusionBaseExpr randomExpr = Randomly.fromList(possibleBaseExprs);
- // if (randomExpr.exprType == DataFusionBaseExprCategory.AGGREGATE || randomExpr.exprType ==
+ // if (randomExpr.exprType == DataFusionBaseExprCategory.AGGREGATE ||
+ // randomExpr.exprType ==
// DataFusionBaseExprCategory.WINDOW) {
// if (depth == 0) {
// System.out.println("DBG depth 0");
@@ -143,7 +146,8 @@ protected Node generateExpression(DataFusionDataType type,
case BINARY:
dfAssert(randomExpr.argTypes.size() == 2 && randomExpr.nArgs == 2,
"Binrary expression should only have 2 argument" + randomExpr.argTypes);
- List argTypeList = new ArrayList<>(); // types of current expression's input arguments
+ List argTypeList = new ArrayList<>(); // types of current expression's input
+ // arguments
for (ArgumentType argumentType : randomExpr.argTypes) {
if (argumentType instanceof ArgumentType.Fixed) {
ArgumentType.Fixed possibleArgTypes = (ArgumentType.Fixed) randomExpr.argTypes.get(0);
@@ -333,13 +337,15 @@ public Node isNull(Node expr) {
return new NewUnaryPostfixOperatorNode<>(expr, createExpr(DataFusionBaseExprType.IS_NULL));
}
- // TODO(datafusion) refactor: make single generate aware of group by and aggr columns, and it can directly generate
+ // TODO(datafusion) refactor: make single generate aware of group by and aggr
+ // columns, and it can directly generate
// having clause
// Try best to generate a valid having clause
//
// Suppose query "... group by a, b ..."
// and all available columns are "a, b, c, d"
- // then a valid having clause can have expr of {a, b}, and expr of aggregation of {c, d}
+ // then a valid having clause can have expr of {a, b}, and expr of aggregation
+ // of {c, d}
// e.g. "having a=b and avg(c) > avg(d)"
//
// `groupbyGen` can generate expression only with group by cols
diff --git a/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java b/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java
index 28ce5da3..e1acddcf 100644
--- a/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java
+++ b/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java
@@ -1,18 +1,24 @@
package sqlancer.datafusion.gen;
+import java.util.Optional;
+
import sqlancer.Randomly;
import sqlancer.common.query.ExpectedErrors;
import sqlancer.common.query.SQLQueryAdapter;
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
+import sqlancer.datafusion.DataFusionSchema.DataFusionColumn;
import sqlancer.datafusion.DataFusionSchema.DataFusionDataType;
+import sqlancer.datafusion.DataFusionSchema.DataFusionTable;
public class DataFusionTableGenerator {
// Randomly generate a query like 'create table t1 (v1 bigint, v2 boolean)'
- public SQLQueryAdapter getQuery(DataFusionGlobalState globalState) {
+ public SQLQueryAdapter getCreateStmt(DataFusionGlobalState globalState) {
ExpectedErrors errors = new ExpectedErrors();
StringBuilder sb = new StringBuilder();
String tableName = globalState.getSchema().getFreeTableName();
+
+ // Build "create table t1..." using sb
sb.append("CREATE TABLE ");
sb.append(tableName);
sb.append("(");
@@ -30,4 +36,40 @@ public SQLQueryAdapter getQuery(DataFusionGlobalState globalState) {
return new SQLQueryAdapter(sb.toString(), errors, true);
}
+
+ // Given a table t1, return create statement to generate t1_stringview
+ // If t1 has no string column, return empty
+ //
+ // Query looks like (only v2 is TEXT column):
+ // create table t1_stringview as
+ // select v1, arrow_cast(v2, 'Utf8View') as v2 from t1;
+ public Optional createStringViewTable(DataFusionGlobalState globalState, DataFusionTable table) {
+ if (!table.getColumns().stream().anyMatch(c -> c.getType().equals(DataFusionDataType.STRING))) {
+ return Optional.empty();
+ }
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("CREATE TABLE ");
+ sb.append(table.getName());
+ sb.append("_stringview AS SELECT ");
+ for (DataFusionColumn column : table.getColumns()) {
+ String colName = column.getName();
+ if (column.getType().equals(DataFusionDataType.STRING)) {
+ // Found a TEXT column, cast it
+ sb.append("arrow_cast(").append(colName).append(", 'Utf8View') as ").append(colName);
+ } else {
+ sb.append(colName);
+ }
+
+ // Join expressions with ','
+ if (column != table.getColumns().get(table.getColumns().size() - 1)) {
+ sb.append(", ");
+ }
+ }
+
+ sb.append(" FROM ").append(table.getName()).append(";");
+
+ return Optional.of(new SQLQueryAdapter(sb.toString(), new ExpectedErrors(), true));
+ }
+
}
diff --git a/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml b/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml
index dd441ba3..f690cb52 100644
--- a/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml
+++ b/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml
@@ -4,20 +4,26 @@ edition = "2021"
description = "Standalone DataFusion server"
license = "Apache-2.0"
+# TODO(datafusion): Figure out how to automatically manage arrow version
+# arrow version should be the same as the one used by datafusion
[dependencies]
-ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
-arrow = { version = "52.1.0", features = ["prettyprint"] }
-arrow-array = { version = "52.1.0", default-features = false, features = ["chrono-tz"] }
-arrow-buffer = { version = "52.1.0", default-features = false }
-arrow-flight = { version = "52.1.0", features = ["flight-sql-experimental"] }
-arrow-ipc = { version = "52.1.0", default-features = false, features = ["lz4"] }
-arrow-ord = { version = "52.1.0", default-features = false }
-arrow-schema = { version = "52.1.0", default-features = false }
-arrow-string = { version = "52.1.0", default-features = false }
+ahash = { version = "0.8", default-features = false, features = [
+ "runtime-rng",
+] }
+arrow = { version = "53.3.0", features = ["prettyprint"] }
+arrow-array = { version = "53.3.0", default-features = false, features = [
+ "chrono-tz",
+] }
+arrow-buffer = { version = "53.3.0", default-features = false }
+arrow-flight = { version = "53.3.0", features = ["flight-sql-experimental"] }
+arrow-ipc = { version = "53.3.0", default-features = false, features = ["lz4"] }
+arrow-ord = { version = "53.3.0", default-features = false }
+arrow-schema = { version = "53.3.0", default-features = false }
+arrow-string = { version = "53.3.0", default-features = false }
async-trait = "0.1.73"
bytes = "1.4"
-chrono = { version = "0.4.34", default-features = false }
-dashmap = "5.5.0"
+chrono = { version = "0.4.38", default-features = false }
+dashmap = "6.0.1"
# This version is for SQLancer CI run (disabled temporary for multiple newly fixed bugs)
# datafusion = { version = "41.0.0" }
# Use following line if you want to test against the latest main branch of DataFusion
@@ -28,17 +34,21 @@ half = { version = "2.2.1", default-features = false }
hashbrown = { version = "0.14.5", features = ["raw"] }
log = "0.4"
num_cpus = "1.13.0"
-object_store = { version = "0.10.1", default-features = false }
+object_store = { version = "0.11.0", default-features = false }
parking_lot = "0.12"
-parquet = { version = "52.0.0", default-features = false, features = ["arrow", "async", "object_store"] }
+parquet = { version = "53.3.0", default-features = false, features = [
+ "arrow",
+ "async",
+ "object_store",
+] }
rand = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
tokio = { version = "1.36", features = ["macros", "rt", "sync"] }
-tonic = "0.11"
+tonic = "0.12.1"
uuid = "1.0"
-prost = { version = "0.12", default-features = false }
-prost-derive = { version = "0.12", default-features = false }
+prost = { version = "0.13.1" }
+prost-derive = "0.13.1"
mimalloc = { version = "0.1", default-features = false }
[[bin]]
diff --git a/src/sqlancer/datafusion/server/datafusion_server/src/main.rs b/src/sqlancer/datafusion/server/datafusion_server/src/main.rs
index 13ec73e9..41b12e21 100644
--- a/src/sqlancer/datafusion/server/datafusion_server/src/main.rs
+++ b/src/sqlancer/datafusion/server/datafusion_server/src/main.rs
@@ -17,7 +17,7 @@ use arrow_flight::{
use arrow_schema::{DataType, Field, Schema};
use dashmap::DashMap;
use datafusion::logical_expr::LogicalPlan;
-use datafusion::prelude::{DataFrame, ParquetReadOptions, SessionConfig, SessionContext};
+use datafusion::prelude::{DataFrame, SessionConfig, SessionContext};
use futures::{Stream, StreamExt, TryStreamExt};
use log::info;
use mimalloc::MiMalloc;
@@ -292,7 +292,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
let result = df
.collect()
.await
- .map_err(|e| status!("Error executing query", e))?;
+ .map_err(|e| status!("Errorr executing query", e))?;
// if we get an empty result, create an empty schema
let schema = match result.first() {
@@ -400,6 +400,8 @@ impl FlightSqlService for FlightSqlServiceImpl {
.and_then(|df| df.into_optimized_plan())
.map_err(|e| Status::internal(format!("Error building plan: {e}")))?;
+ info!("Plan is {:#?}", plan);
+
// store a copy of the plan, it will be used for execution
let plan_uuid = Uuid::new_v4().hyphenated().to_string();
self.statements.insert(plan_uuid.clone(), plan.clone());
@@ -417,6 +419,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
dataset_schema: schema_bytes,
parameter_schema: Default::default(),
};
+ info!("do_action_create_prepared_statement SUCCEED!");
Ok(res)
}
diff --git a/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java b/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java
index c4a7defb..a7d876e8 100644
--- a/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java
+++ b/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java
@@ -7,6 +7,7 @@
import java.util.List;
import sqlancer.ComparatorHelper;
+import sqlancer.IgnoreMeException;
import sqlancer.common.oracle.NoRECBase;
import sqlancer.common.oracle.TestOracle;
import sqlancer.datafusion.DataFusionErrors;
@@ -43,24 +44,31 @@ public void check() throws SQLException {
// generate a random:
// SELECT [expr1] FROM [expr2] WHERE [expr3]
DataFusionSelect randomSelect = getRandomSelect(state);
+
// Q1: SELECT count(*) FROM [expr2] WHERE [expr3]
+ // Q1 and Q2 is constructed from randomSelect's fields
+ // So for equivalent table mutation, we mutate randomSelect
+ randomSelect.mutateEquivalentTableName();
DataFusionSelect q1 = new DataFusionSelect();
q1.setFetchColumnsString("COUNT(*)");
q1.from = randomSelect.from;
q1.setWhereClause(randomSelect.getWhereClause());
+ String q1String = DataFusionToStringVisitor.asString(q1);
+
// Q2: SELECT count(case when [expr3] then 1 else null end) FROM [expr2]
+ randomSelect.mutateEquivalentTableName();
DataFusionSelect q2 = new DataFusionSelect();
String selectExpr = String.format("COUNT(CASE WHEN %s THEN 1 ELSE NULL END)",
DataFusionToStringVisitor.asString(randomSelect.getWhereClause()));
q2.setFetchColumnsString(selectExpr);
q2.from = randomSelect.from;
q2.setWhereClause(null);
+ String q2String = DataFusionToStringVisitor.asString(q2);
/*
* Execute Q1 and Q2
*/
- String q1String = DataFusionToStringVisitor.asString(q1);
- String q2String = DataFusionToStringVisitor.asString(q2);
+ // System.out.println("DBG: " + q1String + "\n" + q2String);
List q1ResultSet = null;
List q2ResultSet = null;
try {
@@ -81,6 +89,13 @@ public void check() throws SQLException {
int count1 = q1ResultSet != null ? Integer.parseInt(q1ResultSet.get(0)) : -1;
int count2 = q2ResultSet != null ? Integer.parseInt(q2ResultSet.get(0)) : -1;
if (count1 != count2) {
+ // whitelist
+ // ---------
+ // https://github.com/apache/datafusion/issues/12468
+ if (q1String.contains("NATURAL JOIN")) {
+ throw new IgnoreMeException();
+ }
+
StringBuilder errorMessage = new StringBuilder().append("NoREC oracle violated:\n")
.append(" Q1(result size ").append(count1).append("):").append(q1String).append(";\n")
.append(" Q2(result size ").append(count2).append("):").append(q2String).append(";\n")
@@ -94,5 +109,6 @@ public void check() throws SQLException {
throw new AssertionError("\n\n" + indentedErrorLog);
}
+ // System.out.println("NOREC passed: \n" + q1String + "\n" + q2String);
}
}
diff --git a/src/sqlancer/datafusion/test/DataFusionPQS.java b/src/sqlancer/datafusion/test/DataFusionPQS.java
new file mode 100644
index 00000000..00bd9f2a
--- /dev/null
+++ b/src/sqlancer/datafusion/test/DataFusionPQS.java
@@ -0,0 +1,344 @@
+package sqlancer.datafusion.test;
+
+import static sqlancer.datafusion.DataFusionUtil.dfAssert;
+import static sqlancer.datafusion.ast.DataFusionSelect.getRandomSelect;
+
+import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import sqlancer.IgnoreMeException;
+import sqlancer.common.ast.newast.Node;
+import sqlancer.common.oracle.NoRECBase;
+import sqlancer.common.oracle.TestOracle;
+import sqlancer.common.query.SQLQueryAdapter;
+import sqlancer.common.query.SQLancerResultSet;
+import sqlancer.datafusion.DataFusionErrors;
+import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
+import sqlancer.datafusion.DataFusionSchema.DataFusionColumn;
+import sqlancer.datafusion.DataFusionSchema.DataFusionTable;
+import sqlancer.datafusion.DataFusionToStringVisitor;
+import sqlancer.datafusion.DataFusionUtil;
+import sqlancer.datafusion.DataFusionUtil.DataFusionLogger;
+import sqlancer.datafusion.ast.DataFusionExpression;
+import sqlancer.datafusion.ast.DataFusionSelect;
+import sqlancer.datafusion.ast.DataFusionSelect.DataFusionFrom;
+
+public class DataFusionPQS extends NoRECBase implements TestOracle {
+
+ private final DataFusionGlobalState state;
+
+ // Table references in a randomly generated SELECT query
+ // used for current PQS check.
+ // To construct queries ONLY used in PQS
+ // columns will be temporarily aliased
+ // Rember to reset them when PQS check is done
+ private List pqsTables;
+
+ private StringBuilder currentCheckLog; // Each append should end with '\n'
+
+ public DataFusionPQS(DataFusionGlobalState globalState) {
+ super(globalState);
+ this.state = globalState;
+
+ DataFusionErrors.registerExpectedExecutionErrors(errors);
+ }
+
+ private void setColumnAlias(List tables) {
+ List allColumns = tables.stream().flatMap(t -> t.getColumns().stream())
+ .collect(Collectors.toList());
+ for (int i = 0; i < allColumns.size(); i++) {
+ String alias = "cc" + i;
+ allColumns.get(i).alias = Optional.of(alias);
+ }
+ }
+
+ private void resetColumnAlias(List tables) {
+ tables.stream().flatMap(t -> t.getColumns().stream()).forEach(c -> c.alias = Optional.empty());
+ }
+
+ private void pqsCleanUp() throws SQLException {
+ if (pqsTables == null) {
+ return;
+ }
+ resetColumnAlias(pqsTables);
+
+ // Drop temp tables used in PQS check
+ SQLQueryAdapter tttCleanUp = new SQLQueryAdapter("drop table if exists ttt", errors);
+ tttCleanUp.execute(state);
+ SQLQueryAdapter ttCleanUp = new SQLQueryAdapter("drop table if exists tt", errors);
+ ttCleanUp.execute(state);
+ }
+
+ @Override
+ public void check() throws SQLException {
+ this.currentCheckLog = new StringBuilder();
+ String replay = DataFusionUtil.getReplay(state.getDatabaseName());
+ currentCheckLog.append(replay);
+ pqsCleanUp();
+
+ try {
+ checkImpl();
+ } catch (Exception | AssertionError e) {
+ pqsCleanUp();
+ throw e;
+ }
+ pqsCleanUp();
+ }
+
+ public void checkImpl() throws SQLException {
+ // ======================================================
+ // Step 1:
+ // select tt0.c0 as cc0, tt0.c1 as cc1, tt1.c0 as cc2
+ // from t0 as tt0, t1 as tt1
+ // where tt0.c0 = tt1.c0;
+ // ======================================================
+ DataFusionSelect randomSelect = getRandomSelect(state);
+ randomSelect.from.joinConditionList = new ArrayList<>();
+ randomSelect.from.joinTypeList = new ArrayList<>();
+ randomSelect.mutateEquivalentTableName();
+ pqsTables = randomSelect.tableList;
+
+ // Reset fetch columns
+ List allColumns = randomSelect.tableList.stream().flatMap(t -> t.getColumns().stream())
+ .collect(Collectors.toList());
+ setColumnAlias(pqsTables);
+ List aliasedColumns = allColumns.stream().map(c -> c.getOrignalName() + " AS " + c.alias.get())
+ .collect(Collectors.toList());
+ String fetchColumnsString = String.join(", ", aliasedColumns);
+ randomSelect.setFetchColumnsString(fetchColumnsString);
+
+ // ======================================================
+ // Step 2:
+ // create table tt as
+ // with cte0 as (select tt0.c0 as cc0, tt0.c1 as cc1 from t0 as tt0 order by
+ // random() limit 1),
+ // with cte1 as (select tt1.c0 as cc2 from t1 as tt1 order by random() limit 1)
+ // select * from cte0, cte1;
+ // ======================================================
+ List cteSelects = new ArrayList<>();
+ for (DataFusionTable table : randomSelect.tableList) {
+ DataFusionSelect cteSelect = new DataFusionSelect();
+ DataFusionFrom cteFrom = new DataFusionFrom(Arrays.asList(table));
+ cteSelect.from = cteFrom;
+
+ List columns = table.getColumns();
+ List cteAliasedColumns = columns.stream().map(c -> c.getOrignalName() + " AS " + c.alias.get())
+ .collect(Collectors.toList());
+ String cteFetchColumnsString = String.join(", ", cteAliasedColumns);
+ cteSelect.setFetchColumnsString(cteFetchColumnsString);
+ cteSelects.add(cteSelect);
+ }
+
+ // select tt0.c0 as cc0, tt0.c1 as cc1 from tt0 order by random()
+ List cteSelectsPickOneRow = cteSelects.stream()
+ .map(cteSelect -> DataFusionToStringVisitor.asString(cteSelect) + " ORDER BY RANDOM() LIMIT 1")
+ .collect(Collectors.toList());
+ int ncte = cteSelectsPickOneRow.size();
+
+ List ctes = new ArrayList<>();
+ for (int i = 0; i < ncte; i++) {
+ String cte = "cte" + i + " AS (" + cteSelectsPickOneRow.get(i) + ")";
+ ctes.add(cte);
+ }
+
+ // cte0, cte1, cte2
+ List ctesInFrom = IntStream.range(0, ncte).mapToObj(i -> "cte" + i).collect(Collectors.toList());
+ String ctesInFromString = String.join(", ", ctesInFrom);
+
+ // Create tt (Table with one pivot row)
+ String ttCreate = "CREATE TABLE tt AS\n WITH\n " + String.join(",\n ", ctes) + "\n" + "SELECT * FROM "
+ + ctesInFromString;
+
+ currentCheckLog.append("==== Create tt (Table with one pivot row):\n").append(ttCreate).append("\n");
+
+ // ======================================================
+ // Step3:
+ // Find the predicate that can select the predicate row
+ // can be {p, NOT p, p is NULL}
+ // Note 'p' is 'cc0 = cc2' in Step 1
+ // ======================================================
+ Node whereExpr = randomSelect.getWhereClause(); // must be valid
+ Node notWhereExpr = randomSelect.exprGenAll.negatePredicate(whereExpr);
+ Node isNullWhereExpr = randomSelect.exprGenAll.isNull(whereExpr);
+ List> candidatePredicates = Arrays.asList(whereExpr, notWhereExpr, isNullWhereExpr);
+
+ String pivotQ1 = "select * from tt where " + DataFusionToStringVisitor.asString(whereExpr);
+ String pivotQ2 = "select * from tt where " + DataFusionToStringVisitor.asString(notWhereExpr);
+ String pivotQ3 = "select * from tt where " + DataFusionToStringVisitor.asString(isNullWhereExpr);
+
+ List pivotQs = Arrays.asList(pivotQ1, pivotQ2, pivotQ3);
+
+ // Execute "crete tt" (table with one pivot row)
+ SQLQueryAdapter q = new SQLQueryAdapter(ttCreate, errors);
+ q.execute(state);
+
+ SQLancerResultSet ttResult = null;
+ SQLQueryAdapter ttSelect = new SQLQueryAdapter("select * from tt", errors);
+ int nrow = 0;
+ try {
+ ttResult = ttSelect.executeAndGetLogged(state);
+
+ if (ttResult == null) {
+ // Possible bug here, investigate later
+ throw new IgnoreMeException();
+ }
+
+ while (ttResult.next()) {
+ nrow++;
+ }
+ } catch (Exception e) {
+ // Possible bug here, investigate later
+ throw new IgnoreMeException();
+ } finally {
+ if (ttResult != null && !ttResult.isClosed()) {
+ ttResult.close();
+ }
+ }
+
+ if (nrow == 0) {
+ // If empty table is picked, we can't find a pivot row
+ // Give up current check
+ // TODO(datafusion): support empty tables
+ throw new IgnoreMeException("Empty table is picked");
+ }
+
+ Node pivotPredicate = null;
+ String pivotRow = "";
+ for (int i = 0; i < pivotQs.size(); i++) {
+ String pivotQ = pivotQs.get(i);
+ SQLQueryAdapter qSelect = new SQLQueryAdapter(pivotQ, errors);
+ SQLancerResultSet rs = null;
+ try {
+ rs = qSelect.executeAndGetLogged(state);
+ if (rs == null) {
+ // Only one in 3 pivot query will return 1 row
+ continue;
+ }
+
+ int rowCount = 0;
+ while (rs.next()) {
+ rowCount += 1;
+ for (int ii = 1; ii <= rs.rs.getMetaData().getColumnCount(); ii++) {
+ pivotRow += "[" + rs.getString(ii) + "]";
+ }
+ pivotPredicate = candidatePredicates.get(i);
+ }
+
+ dfAssert(rowCount <= 1, "Pivot row should be length of 1, got " + rowCount);
+
+ if (rowCount == 1) {
+ break;
+ }
+ } catch (Exception e) {
+ currentCheckLog.append(pivotQ).append("\n");
+ currentCheckLog.append(e.getMessage()).append("\n").append(e.getCause()).append("\n");
+
+ String fullErrorMessage = currentCheckLog.toString();
+ state.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.ERROR, fullErrorMessage);
+
+ throw new AssertionError(fullErrorMessage);
+ } finally {
+ if (rs != null && !rs.isClosed()) {
+ rs.close();
+ }
+ }
+ }
+
+ if (pivotPredicate == null) {
+ // Sometimes all valid pivot queries failed
+ // Potential bug, investigate later
+ currentCheckLog.append("ALl pivot q failed! ").append(pivotQs).append("\n");
+ throw new IgnoreMeException("All pivot queries failed " + pivotQs);
+ }
+
+ // ======================================================
+ // Step 4:
+ // Let's say in Step 3 we found the predicate is "Not (cc0 = cc2)"
+ // Check if the pivot row can be find in
+ // "select * from tt0, tt1 where Not(cc0 = cc2)"
+ // Then we construct table ttt with above query
+ // Finally join 'tt' and 'ttt' make sure one pivot row will be output
+ // ======================================================
+ DataFusionSelect selectAllRows = new DataFusionSelect();
+ DataFusionFrom selectAllRowsFrom = new DataFusionFrom(randomSelect.tableList);
+ selectAllRows.from = selectAllRowsFrom;
+ selectAllRows.setWhereClause(pivotPredicate);
+
+ List allSelectColumns = randomSelect.tableList.stream().flatMap(t -> t.getColumns().stream())
+ .collect(Collectors.toList());
+ // tt0.v0 as cc0, tt0.v1 as cc1, tt1.v0 as cc2
+ List allSelectExprs = allSelectColumns.stream().map(c -> c.getOrignalName() + " AS " + c.alias.get())
+ .collect(Collectors.toList());
+ resetColumnAlias(randomSelect.tableList);
+ String selectAllRowsFetchColStr = String.join(", ", allSelectExprs);
+ selectAllRows.setFetchColumnsString(selectAllRowsFetchColStr);
+
+ String selectAllRowsString = DataFusionToStringVisitor.asString(selectAllRows);
+ String tttCreate = "CREATE TABLE ttt AS\n" + selectAllRowsString;
+
+ SQLQueryAdapter tttCreateStmt = new SQLQueryAdapter(tttCreate, errors);
+ tttCreateStmt.execute(state);
+ setColumnAlias(randomSelect.tableList);
+
+ // ======================================================
+ // Step 5:
+ // Make sure the following query return 1 pivot row
+ // Otherwise PQS oracle is violated
+ // select * from tt join ttt
+ // on
+ // tt.cc0 is not distinct from ttt.cc0
+ // and tt.cc1 is not distinct from ttt.cc1
+ // and tt.cc2 is not distinct from ttt.cc2
+ // ======================================================
+ List onConditions = allSelectColumns.stream()
+ .map(c -> "(tt." + c.alias.get() + " IS NOT DISTINCT FROM ttt." + c.alias.get() + ")")
+ .collect(Collectors.toList());
+ String onCond = String.join("\nAND ", onConditions);
+ String joinQuery = "SELECT COUNT(*) FROM tt JOIN ttt ON\n" + onCond;
+
+ SQLQueryAdapter qJoin = new SQLQueryAdapter(joinQuery, errors);
+
+ SQLancerResultSet rsFull = null;
+ try {
+ rsFull = qJoin.executeAndGetLogged(state);
+
+ if (rsFull == null) {
+ throw new IgnoreMeException("Join query returned no results: " + joinQuery);
+ }
+
+ String joinCount = "invalid";
+ while (rsFull.next()) {
+ joinCount = rsFull.getString(1);
+ }
+
+ if (joinCount.equals("0")) {
+ String replay = DataFusionUtil.getReplay(state.getDatabaseName());
+ StringBuilder errorLog = new StringBuilder().append("PQS oracle violated:\n").append("Found ")
+ .append(joinCount).append(" pivot rows:\n").append(" Pivot row: ").append(pivotRow).append("\n")
+ .append("Query to select pivot row: ").append(ttCreate).append("\n")
+ .append("Query to select all rows: ").append(tttCreate).append("\n").append("Join: ")
+ .append(joinQuery).append("\n").append(replay).append("\n");
+
+ String errorString = errorLog.toString();
+ String indentedErrorLog = errorString.replaceAll("(?m)^", " ");
+ state.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.ERROR, errorString);
+
+ throw new AssertionError("\n\n" + indentedErrorLog);
+ } else if (!joinCount.matches("\\d+")) {
+ // If joinCount is not a integer > 0, throw exception
+ throw new IgnoreMeException("Join query returned invalid result: " + joinCount);
+ }
+ } catch (Exception e) {
+ throw new IgnoreMeException("Failed to execute join query: " + joinQuery);
+ } finally {
+ if (rsFull != null && !rsFull.isClosed()) {
+ rsFull.close();
+ }
+ }
+ }
+}
diff --git a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java
index 9d6c2efc..31f90d0f 100644
--- a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java
+++ b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java
@@ -8,6 +8,7 @@
import java.util.List;
import sqlancer.ComparatorHelper;
+import sqlancer.IgnoreMeException;
import sqlancer.Randomly;
import sqlancer.common.ast.newast.NewBinaryOperatorNode;
import sqlancer.common.ast.newast.Node;
@@ -43,6 +44,7 @@ public void check() throws SQLException {
// generate a random 'SELECT [expr1] FROM [expr2] WHERE [expr3]
super.check();
DataFusionSelect randomSelect = select;
+ randomSelect.mutateEquivalentTableName();
if (Randomly.getBoolean()) {
randomSelect.distinct = true;
@@ -69,12 +71,15 @@ public void check() throws SQLException {
randomSelect.setWhereClause(null);
qString = DataFusionToStringVisitor.asString(randomSelect);
+ randomSelect.mutateEquivalentTableName();
randomSelect.setWhereClause(predicate);
qp1String = DataFusionToStringVisitor.asString(randomSelect);
+ randomSelect.mutateEquivalentTableName();
randomSelect.setWhereClause(negatedPredicate);
qp2String = DataFusionToStringVisitor.asString(randomSelect);
+ randomSelect.mutateEquivalentTableName();
randomSelect.setWhereClause(isNullPredicate);
qp3String = DataFusionToStringVisitor.asString(randomSelect);
} else {
@@ -111,6 +116,8 @@ public void check() throws SQLException {
/*
* Run all queires
*/
+ // System.out.println("DBG TLP: " + qString + "\n" + qp1String + "\n" +
+ // qp2String + "\n" + qp3String);
List qResultSet = ComparatorHelper.getResultSetFirstColumnAsString(qString, errors, state);
List combinedString = new ArrayList<>();
List qpResultSet = ComparatorHelper.getCombinedResultSet(qp1String, qp2String, qp3String,
@@ -121,6 +128,13 @@ public void check() throws SQLException {
ComparatorHelper.assumeResultSetsAreEqual(qResultSet, qpResultSet, qString, combinedString, state,
ComparatorHelper::canonicalizeResultValue);
} catch (AssertionError e) {
+ // whitelist
+ // ---------
+ // https://github.com/apache/datafusion/issues/12468
+ if (qp1String.contains("NATURAL JOIN")) {
+ throw new IgnoreMeException();
+ }
+
// Append more error message
String replay = DataFusionUtil.getReplay(state.getDatabaseName());
String newMessage = e.getMessage() + "\n" + e.getCause() + "\n" + replay + "\n";