diff --git a/pom.xml b/pom.xml index 9bee8348..53f15988 100644 --- a/pom.xml +++ b/pom.xml @@ -376,7 +376,7 @@ org.apache.arrow flight-sql-jdbc-driver - 16.1.0 + 17.0.0 diff --git a/src/sqlancer/ComparatorHelper.java b/src/sqlancer/ComparatorHelper.java index 5da635de..d061ecb7 100644 --- a/src/sqlancer/ComparatorHelper.java +++ b/src/sqlancer/ComparatorHelper.java @@ -131,7 +131,8 @@ public static void assumeResultSetsAreEqual(List resultSet, List public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, String originalQueryString, List combinedString, SQLGlobalState state, UnaryOperator canonicalizationRule) { - // Overloaded version of assumeResultSetsAreEqual that takes a canonicalization function which is applied to + // Overloaded version of assumeResultSetsAreEqual that takes a canonicalization + // function which is applied to // both result sets before their comparison. List canonicalizedResultSet = resultSet.stream().map(canonicalizationRule).collect(Collectors.toList()); List canonicalizedSecondResultSet = secondResultSet.stream().map(canonicalizationRule) diff --git a/src/sqlancer/IgnoreMeException.java b/src/sqlancer/IgnoreMeException.java index bc2e2591..cdae6acd 100644 --- a/src/sqlancer/IgnoreMeException.java +++ b/src/sqlancer/IgnoreMeException.java @@ -4,4 +4,11 @@ public class IgnoreMeException extends RuntimeException { private static final long serialVersionUID = 1L; + public IgnoreMeException() { + super(); + } + + public IgnoreMeException(String message) { + super(message); + } } diff --git a/src/sqlancer/Main.java b/src/sqlancer/Main.java index 1cbc0264..68639093 100644 --- a/src/sqlancer/Main.java +++ b/src/sqlancer/Main.java @@ -31,6 +31,7 @@ import sqlancer.common.query.Query; import sqlancer.common.query.SQLancerResultSet; import sqlancer.databend.DatabendProvider; +import sqlancer.datafusion.DataFusionProvider; import sqlancer.doris.DorisProvider; import sqlancer.duckdb.DuckDBProvider; import sqlancer.h2.H2Provider; @@ -734,6 +735,7 @@ private static void checkForIssue799(List> providers) providers.add(new CnosDBProvider()); providers.add(new CockroachDBProvider()); providers.add(new DatabendProvider()); + providers.add(new DataFusionProvider()); providers.add(new DorisProvider()); providers.add(new DuckDBProvider()); providers.add(new H2Provider()); diff --git a/src/sqlancer/common/query/SQLancerResultSet.java b/src/sqlancer/common/query/SQLancerResultSet.java index d1221a7f..7cc1523e 100644 --- a/src/sqlancer/common/query/SQLancerResultSet.java +++ b/src/sqlancer/common/query/SQLancerResultSet.java @@ -6,7 +6,7 @@ public class SQLancerResultSet implements Closeable { - ResultSet rs; + public ResultSet rs; private Runnable runnableEpilogue; public SQLancerResultSet(ResultSet rs) { diff --git a/src/sqlancer/datafusion/DataFusionErrors.java b/src/sqlancer/datafusion/DataFusionErrors.java index a646f304..9e62bbe9 100644 --- a/src/sqlancer/datafusion/DataFusionErrors.java +++ b/src/sqlancer/datafusion/DataFusionErrors.java @@ -44,28 +44,26 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) { errors.add("There is only support Literal types for field at idx:"); errors.add("nth_value not supported for n:"); errors.add("Invalid argument error: Nested comparison: List("); + errors.add("This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal"); + errors.add( + "This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal"); /* * Known bugs */ - errors.add("to type Int"); // https://github.com/apache/datafusion/issues/11249 errors.add("bitwise"); // https://github.com/apache/datafusion/issues/11260 errors.add("Sort expressions cannot be empty for streaming merge."); // https://github.com/apache/datafusion/issues/11561 - errors.add("compute_utf8_flag_op_scalar failed to cast literal value NULL for operation"); // https://github.com/apache/datafusion/issues/11623 errors.add("Schema error: No field named "); // https://github.com/apache/datafusion/issues/12006 - errors.add("Internal error: PhysicalExpr Column references column"); // https://github.com/apache/datafusion/issues/12012 - errors.add("APPROX_"); // https://github.com/apache/datafusion/issues/12058 - errors.add("External error: task"); // https://github.com/apache/datafusion/issues/12057 - errors.add("NTH_VALUE"); // https://github.com/apache/datafusion/issues/12073 - errors.add("SUBSTR"); // https://github.com/apache/datafusion/issues/12129 + errors.add("NATURAL JOIN"); // https://github.com/apache/datafusion/issues/14015 /* * False positives */ errors.add("Cannot cast string"); // ifnull() is passed two non-compattable type and caused execution error - errors.add("Physical plan does not support logical expression AggregateFunction"); // False positive: when aggr - // is generated in where - // clause + // False positive: when aggr is generated in where clause + errors.add("Physical plan does not support logical expression AggregateFunction"); + errors.add("Unsupported ArrowType Utf8View"); // Maybe bug in arrow flight + // jdbc driver /* * Not critical, investigate in the future @@ -73,5 +71,16 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) { errors.add("does not match with the projection expression"); errors.add("invalid operator for nested"); errors.add("Arrow error: Cast error: Can't cast value"); + errors.add("Nth value indices are 1 based"); + /* + * Example query that triggers this error: create table t1(v1 int, v2 bool); select v1, sum(1) over (partition + * by v1 order by v2 range between 0 preceding and 0 following) from t1; + * + * Current error message: Arrow error: Invalid argument error: Invalid arithmetic operation: Boolean - Boolean + * + * TODO: The error message could be more meaningful to indicate that RANGE frame is not supported for boolean + * ORDER BY columns + */ + errors.add("Invalid arithmetic operation"); } } diff --git a/src/sqlancer/datafusion/DataFusionOptions.java b/src/sqlancer/datafusion/DataFusionOptions.java index 5a2d8b69..56db6b76 100644 --- a/src/sqlancer/datafusion/DataFusionOptions.java +++ b/src/sqlancer/datafusion/DataFusionOptions.java @@ -15,6 +15,7 @@ import sqlancer.datafusion.test.DataFusionNoCrashAggregate; import sqlancer.datafusion.test.DataFusionNoCrashWindow; import sqlancer.datafusion.test.DataFusionNoRECOracle; +import sqlancer.datafusion.test.DataFusionPQS; import sqlancer.datafusion.test.DataFusionQueryPartitioningAggrTester; import sqlancer.datafusion.test.DataFusionQueryPartitioningHavingTester; import sqlancer.datafusion.test.DataFusionQueryPartitioningWhereTester; @@ -26,13 +27,11 @@ public class DataFusionOptions implements DBMSSpecificOptions getTestOracleFactory() { - return Arrays.asList( - // DataFusionOracleFactory.NO_CRASH_WINDOW, - // DataFusionOracleFactory.NO_CRASH_AGGREGATE, - DataFusionOracleFactory.NOREC, DataFusionOracleFactory.QUERY_PARTITIONING_WHERE + return Arrays.asList(DataFusionOracleFactory.PQS, DataFusionOracleFactory.NO_CRASH_WINDOW, + DataFusionOracleFactory.NO_CRASH_AGGREGATE, DataFusionOracleFactory.NOREC, + DataFusionOracleFactory.QUERY_PARTITIONING_WHERE); // DataFusionOracleFactory.QUERY_PARTITIONING_AGGREGATE - // ,DataFusionOracleFactory.QUERY_PARTITIONING_HAVING - ); + // DataFusionOracleFactory.QUERY_PARTITIONING_HAVING); } public enum DataFusionOracleFactory implements OracleFactory { @@ -42,6 +41,12 @@ public TestOracle create(DataFusionGlobalState globalStat return new DataFusionNoRECOracle(globalState); } }, + PQS { + @Override + public TestOracle create(DataFusionGlobalState globalState) throws SQLException { + return new DataFusionPQS(globalState); + } + }, QUERY_PARTITIONING_WHERE { @Override public TestOracle create(DataFusionGlobalState globalState) throws SQLException { diff --git a/src/sqlancer/datafusion/DataFusionProvider.java b/src/sqlancer/datafusion/DataFusionProvider.java index 37328e4b..5ad60561 100644 --- a/src/sqlancer/datafusion/DataFusionProvider.java +++ b/src/sqlancer/datafusion/DataFusionProvider.java @@ -1,6 +1,5 @@ package sqlancer.datafusion; -import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.DML; import static sqlancer.datafusion.DataFusionUtil.dfAssert; import static sqlancer.datafusion.DataFusionUtil.displayTables; @@ -8,6 +7,7 @@ import java.sql.DriverManager; import java.sql.SQLException; import java.util.List; +import java.util.Optional; import java.util.Properties; import java.util.stream.Collectors; @@ -34,29 +34,52 @@ public DataFusionProvider() { super(DataFusionGlobalState.class, DataFusionOptions.class); } + // Basic tables generated are DataFusion memory tables (named t1, t2, ...) + // Equivalent table can be backed by different physical implementation + // which will be named like t1_stringview, t2_parquet, etc. + // + // e.g. t1 and t1_stringview are logically equivalent table, but backed by + // different physical representation + // + // This helps to do more metamorphic testing on tables, for example + // `select * from t1` and `select * from t1_stringview` should give same + // result + // + // Supported physical implementation for tables: + // 1. Memory table (t1) + // 2. Memory table use StringView for TEXT columns (t1_stringview) + // Note: It's possible only convert random TEXT columns to StringView @Override public void generateDatabase(DataFusionGlobalState globalState) throws Exception { - int tableCount = Randomly.fromOptions(1, 2, 3, 4, 5, 6, 7); + // Create base tables + // ============================ + + int tableCount = Randomly.fromOptions(1, 2, 3, 4); for (int i = 0; i < tableCount; i++) { - SQLQueryAdapter queryCreateRandomTable = new DataFusionTableGenerator().getQuery(globalState); + SQLQueryAdapter queryCreateRandomTable = new DataFusionTableGenerator().getCreateStmt(globalState); queryCreateRandomTable.execute(globalState); globalState.updateSchema(); - globalState.dfLogger.appendToLog(DML, queryCreateRandomTable.toString() + "\n"); + globalState.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.DML, + queryCreateRandomTable.toString() + "\n"); } // Now only `INSERT` DML is supported // If more DMLs are added later, should use`StatementExecutor` instead // (see DuckDB's implementation for reference) + // Generating rows in base tables (t1, t2, ... not include t1_stringview, etc.) + // ============================ + globalState.updateSchema(); - List allTables = globalState.getSchema().getDatabaseTables(); - List allTablesName = allTables.stream().map(t -> t.getName()).collect(Collectors.toList()); - if (allTablesName.isEmpty()) { + List allBaseTables = globalState.getSchema().getDatabaseTables(); + List allBaseTablesName = allBaseTables.stream().map(DataFusionTable::getName) + .collect(Collectors.toList()); + if (allBaseTablesName.isEmpty()) { dfAssert(false, "Generate Database failed."); } // Randomly insert some data into existing tables - for (DataFusionTable table : allTables) { + for (DataFusionTable table : allBaseTables) { int nInsertQuery = globalState.getRandomly().getInteger(0, globalState.getOptions().getMaxNumberInserts()); for (int i = 0; i < nInsertQuery; i++) { @@ -69,9 +92,24 @@ public void generateDatabase(DataFusionGlobalState globalState) throws Exception } insertQuery.execute(globalState); - globalState.dfLogger.appendToLog(DML, insertQuery.toString() + "\n"); + globalState.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.DML, insertQuery.toString() + "\n"); + } + } + + // Construct mutated tables like t1_stringview, etc. + // ============================ + for (DataFusionTable table : allBaseTables) { + Optional queryCreateStringViewTable = new DataFusionTableGenerator() + .createStringViewTable(globalState, table); + if (queryCreateStringViewTable.isPresent()) { + queryCreateStringViewTable.get().execute(globalState); + globalState.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.DML, + queryCreateStringViewTable.get().toString() + "\n"); } } + globalState.updateSchema(); + List allTables = globalState.getSchema().getDatabaseTables(); + List allTablesName = allTables.stream().map(DataFusionTable::getName).collect(Collectors.toList()); // TODO(datafusion) add `DataFUsionLogType.STATE` for this whole db state log if (globalState.getDbmsSpecificOptions().showDebugInfo) { diff --git a/src/sqlancer/datafusion/DataFusionSchema.java b/src/sqlancer/datafusion/DataFusionSchema.java index 24c3a4eb..2d08a347 100644 --- a/src/sqlancer/datafusion/DataFusionSchema.java +++ b/src/sqlancer/datafusion/DataFusionSchema.java @@ -9,6 +9,8 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.regex.Pattern; import java.util.stream.Collectors; import sqlancer.Randomly; @@ -32,6 +34,9 @@ public DataFusionSchema(List databaseTables) { // update existing tables in DB by query again // (like `show tables;`) + // + // This function also setup table<->column reference pointers + // and equivalent tables(see `DataFusionTable.equivalentTables) public static DataFusionSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { List databaseTables = new ArrayList<>(); List tableNames = getTableNames(con); @@ -47,6 +52,24 @@ public static DataFusionSchema fromConnection(SQLConnection con, String database databaseTables.add(t); } + // Setup equivalent tables + // For example, now we have t1, t1_csv, t1_parquet, t2_csv, t2_parquet + // t1's equivalent tables: t1, t1_csv, t1_parquet + // t2_csv's equivalent tables: t2_csv, t2_parquet + // ... + // + // It can be assumed that: + // base table names are like t1, t2, ... + // equivalent tables are like t1_csv, t1_parquet, ... + for (DataFusionTable t : databaseTables) { + String baseTableName = t.getName().split("_")[0]; + String patternString = "^" + baseTableName + "(_.*)?$"; // t1 or t1_* + Pattern pattern = Pattern.compile(patternString); + + t.equivalentTables = databaseTables.stream().filter(table -> pattern.matcher(table.getName()).matches()) + .map(DataFusionTable::getName).collect(Collectors.toList()); + } + return new DataFusionSchema(databaseTables); } @@ -120,8 +143,10 @@ public static DataFusionDataType parseFromDataFusionCatalog(String typeString) { return DataFusionDataType.BOOLEAN; case "Utf8": return DataFusionDataType.STRING; + case "Utf8View": + return DataFusionDataType.STRING; default: - dfAssert(false, "Unreachable. All branches should be eovered"); + dfAssert(false, "Uncovered branch typeString: " + typeString); } dfAssert(false, "Unreachable. All branches should be eovered"); @@ -169,25 +194,89 @@ public Node getRandomConstant(DataFusionGlobalState state) public static class DataFusionColumn extends AbstractTableColumn { private final boolean isNullable; + public Optional alias; public DataFusionColumn(String name, DataFusionDataType columnType, boolean isNullable) { super(name, null, columnType); this.isNullable = isNullable; + this.alias = Optional.empty(); } public boolean isNullable() { return isNullable; } + public String getOrignalName() { + return getTable().getName() + "." + getName(); + } + + @Override + public String getFullQualifiedName() { + if (getTable() == null) { + return getName(); + } else { + if (alias.isPresent()) { + return alias.get(); + } else { + return getTable().getName() + "." + getName(); + } + } + } } public static class DataFusionTable extends AbstractRelationalTable { + // There might exist multiple logically equivalent tables with + // different physical format. + // e.g. t1_csv, t1_parquet, ... + // + // When generating random query, it's possible to randomly pick one + // of them for stronger randomization. + public List equivalentTables; + + // Pick a random equivalent table name + // This can be used when generating differential queries + public Optional currentEquivalentTableName; + + // For example in query `select * from t1 as tt1, t1 as tt2` + // `tt1` is the alias for the first occurance of `t1` + public Optional alias; public DataFusionTable(String tableName, List columns, boolean isView) { super(tableName, columns, Collections.emptyList(), isView); } + public String getNotAliasedName() { + if (currentEquivalentTableName != null && currentEquivalentTableName.isPresent()) { + // In case setup is not done yet + return currentEquivalentTableName.get(); + } else { + return super.getName(); + } + } + + // TODO(datafusion) Now implementation is hacky, should send a patch + // to core to support this + @Override + public String getName() { + // Before setup equivalent tables, we use the original table name + // Setup happens in `fromConnection()` + if (equivalentTables == null || currentEquivalentTableName == null) { + return super.getName(); + } + + if (alias.isPresent()) { + return alias.get(); + } else { + return currentEquivalentTableName.get(); + } + } + + public void pickAnotherEquivalentTableName() { + dfAssert(!equivalentTables.isEmpty(), "equivalentTables should not be empty"); + currentEquivalentTableName = Optional.of(Randomly.fromList(equivalentTables)); + } + public static List getAllColumns(List tables) { return tables.stream().map(AbstractTable::getColumns).flatMap(List::stream).collect(Collectors.toList()); } diff --git a/src/sqlancer/datafusion/DataFusionToStringVisitor.java b/src/sqlancer/datafusion/DataFusionToStringVisitor.java index f07e4a16..11965841 100644 --- a/src/sqlancer/datafusion/DataFusionToStringVisitor.java +++ b/src/sqlancer/datafusion/DataFusionToStringVisitor.java @@ -7,10 +7,14 @@ import sqlancer.Randomly; import sqlancer.common.ast.newast.NewToStringVisitor; import sqlancer.common.ast.newast.Node; +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; import sqlancer.datafusion.ast.DataFusionConstant; import sqlancer.datafusion.ast.DataFusionExpression; import sqlancer.datafusion.ast.DataFusionSelect; +import sqlancer.datafusion.ast.DataFusionSelect.DataFusionAliasedTable; import sqlancer.datafusion.ast.DataFusionSelect.DataFusionFrom; +import sqlancer.datafusion.ast.DataFusionSpecialExpr.CastToStringView; import sqlancer.datafusion.ast.DataFusionWindowExpr; public class DataFusionToStringVisitor extends NewToStringVisitor { @@ -37,6 +41,10 @@ public void visitSpecific(Node expr) { visit((DataFusionFrom) expr); } else if (expr instanceof DataFusionWindowExpr) { visit((DataFusionWindowExpr) expr); + } else if (expr instanceof CastToStringView) { + visit((CastToStringView) expr); + } else if (expr instanceof DataFusionAliasedTable) { + visit((DataFusionAliasedTable) expr); } else { throw new AssertionError(expr.getClass()); } @@ -49,13 +57,13 @@ private void visit(DataFusionFrom from) { /* e.g. from t1, t2, t3 */ if (from.joinConditionList.isEmpty()) { - visit(from.tableList); + visit(from.tableExprList); return; } - dfAssert(from.joinConditionList.size() == from.tableList.size() - 1, "Validate from"); + dfAssert(from.joinConditionList.size() == from.tableExprList.size() - 1, "Validate from"); /* e.g. from t1 join t2 on t1.v1=t2.v1 */ - visit(from.tableList.get(0)); + visit(from.tableExprList.get(0)); for (int i = 0; i < from.joinConditionList.size(); i++) { switch (from.joinTypeList.get(i)) { case INNER: @@ -80,7 +88,7 @@ private void visit(DataFusionFrom from) { dfAssert(false, "Unreachable"); } - visit(from.tableList.get(i + 1)); // ti + visit(from.tableExprList.get(i + 1)); // ti /* ON ... */ Node cond = from.joinConditionList.get(i); @@ -163,4 +171,31 @@ private void visit(DataFusionWindowExpr window) { sb.append(")"); } + private void visit(CastToStringView castToStringView) { + sb.append("ARROW_CAST("); + visit(castToStringView.expr); + sb.append(", 'Utf8')"); + } + + private void visit(DataFusionAliasedTable alias) { + if (alias.table instanceof TableReferenceNode) { + DataFusionTable t = null; + if (alias.table instanceof TableReferenceNode) { + TableReferenceNode tableRef = (TableReferenceNode) alias.table; + t = (DataFusionTable) tableRef.getTable(); + } else { + dfAssert(false, "Unreachable"); + } + + String baseName = t.getNotAliasedName(); + sb.append(baseName); + + dfAssert(t.alias.isPresent(), "Alias should be present"); + sb.append(" AS "); + sb.append(t.alias.get()); + } else { + dfAssert(false, "Unreachable"); + } + } + } diff --git a/src/sqlancer/datafusion/DataFusionUtil.java b/src/sqlancer/datafusion/DataFusionUtil.java index ac082afd..4f5b34f6 100644 --- a/src/sqlancer/datafusion/DataFusionUtil.java +++ b/src/sqlancer/datafusion/DataFusionUtil.java @@ -34,7 +34,6 @@ public static String displayTables(DataFusionGlobalState state, List fro ResultSetMetaData metaData = wholeTable.getMetaData(); int columnCount = metaData.getColumnCount(); - resultStringBuilder.append("Table: ").append(tableName).append("\n"); for (int i = 1; i <= columnCount; i++) { resultStringBuilder.append(metaData.getColumnName(i)).append(" (") .append(metaData.getColumnTypeName(i)).append(")"); @@ -58,7 +57,8 @@ public static String displayTables(DataFusionGlobalState state, List fro } catch (SQLException err) { resultStringBuilder.append("Table: ").append(tableName).append("\n"); resultStringBuilder.append("----------------------------------------\n\n"); - // resultStringBuilder.append("Error retrieving data from table ").append(tableName).append(": + // resultStringBuilder.append("Error retrieving data from table + // ").append(tableName).append(": // ").append(err.getMessage()).append("\n"); } } @@ -66,7 +66,8 @@ public static String displayTables(DataFusionGlobalState state, List fro return resultStringBuilder.toString(); } - // During development, you might want to manually let this function call exit(1) to fail fast + // During development, you might want to manually let this function call exit(1) + // to fail fast public static void dfAssert(boolean condition, String message) { if (!condition) { // Development mode assertion failure diff --git a/src/sqlancer/datafusion/ast/DataFusionSelect.java b/src/sqlancer/datafusion/ast/DataFusionSelect.java index 0b3922b1..d91b59cc 100644 --- a/src/sqlancer/datafusion/ast/DataFusionSelect.java +++ b/src/sqlancer/datafusion/ast/DataFusionSelect.java @@ -27,15 +27,21 @@ public class DataFusionSelect extends SelectBase> imp // `from` is used to represent from table list and join clause // `fromList` and `joinList` in base class should always be empty public DataFusionFrom from; + // Randomly selected table (equivalent to `from.tableList`) + // Can be refactored, it's a hack for now + public List tableList; + // e.g. let's say all colummns are {c1, c2, c3, c4, c5} // First randomly pick a subset say {c2, c1, c3, c4} // `exprGenAll` can generate random expr using above 4 columns // - // Next, randomly take two non-overlapping subset from all columns used by `exprGenAll` + // Next, randomly take two non-overlapping subset from all columns used by + // `exprGenAll` // exprGenGroupBy: {c1} (randomly generate group by exprs using c1 only) // exprGenAggregate: {c3, c4} // - // Finally, use all `Gen`s to generate different clauses in a query (`exprGenAll` in where clause, `exprGenGroupBy` + // Finally, use all `Gen`s to generate different clauses in a query + // (`exprGenAll` in where clause, `exprGenGroupBy` // in group by clause, etc.) public DataFusionExpressionGenerator exprGenAll; public DataFusionExpressionGenerator exprGenGroupBy; @@ -46,7 +52,8 @@ public enum JoinType { } // DataFusionFrom can be used to represent from table list or join list - // 1. When `joinConditionList` is empty, then it's a table list (implicit cross join) + // 1. When `joinConditionList` is empty, then it's a table list (implicit cross + // join) // join condition can be generated in `WHERE` clause (outside `FromClause`) // e.g. select * from [expr], [expr] is t1, t3, t2 // - tableList -> {t1, t3,t2} @@ -60,18 +67,25 @@ public enum JoinType { // - joinTypeList -> {INNER, LEFT} // - joinConditionList -> {[expr_with_t1_t2], [expr_with_t1_t2_t3]} public static class DataFusionFrom implements Node { - public List> tableList; + public List> tableExprList; public List joinTypeList; public List> joinConditionList; public DataFusionFrom() { - tableList = new ArrayList<>(); + tableExprList = new ArrayList<>(); joinTypeList = new ArrayList<>(); joinConditionList = new ArrayList<>(); } + public DataFusionFrom(List tables) { + this(); + tableExprList = tables.stream().map(t -> new TableReferenceNode(t)) + .map(tableExpr -> new DataFusionAliasedTable(tableExpr)).collect(Collectors.toList()); + } + public boolean isExplicitJoin() { - // if it's explicit join, joinTypeList and joinConditionList should be both length of tableList.len - 1 + // if it's explicit join, joinTypeList and joinConditionList should be both + // length of tableList.len - 1 // Otherwise, both is empty dfAssert(joinTypeList.size() == joinConditionList.size(), "Validate FromClause"); return !joinTypeList.isEmpty(); @@ -89,17 +103,19 @@ public static DataFusionFrom generateFromClause(DataFusionGlobalState state, List> randomTableNodes = randomTables.stream() .map(t -> new TableReferenceNode(t)) .collect(Collectors.toList()); - fromClause.tableList = randomTableNodes; + fromClause.tableExprList = randomTableNodes; /* If JoinConditionList is empty, FromClause will be interpreted as from list */ if (Randomly.getBoolean() && Randomly.getBoolean()) { + fromClause.setupAlias(); return fromClause; } /* Set fromClause's joinTypeList and joinConditionList */ List possibleColsToGenExpr = new ArrayList<>(); possibleColsToGenExpr.addAll(randomTables.get(0).getColumns()); // first table - // Generate join conditions (see class-level comment example's joinConditionList) + // Generate join conditions (see class-level comment example's + // joinConditionList) // // Join Type | `ON` Clause Requirement // INNER JOIN | Required @@ -130,12 +146,80 @@ public static DataFusionFrom generateFromClause(DataFusionGlobalState state, .add(exprGen.generateExpression(DataFusionSchema.DataFusionDataType.BOOLEAN)); } } - // TODO(datafusion) make join conditions more likely to be 'col1=col2', also some join types don't have + // TODO(datafusion) make join conditions more likely to be 'col1=col2', also + // some join types don't have // 'ON' condition } + // TODO(datafusion) add an option to disable this when issue fixed + // https://github.com/apache/datafusion/issues/12337 + fromClause.setupAlias(); + return fromClause; } + + public void setupAlias() { + for (int i = 0; i < tableExprList.size(); i++) { + if (tableExprList.get(i) instanceof TableReferenceNode) { + @SuppressWarnings("unchecked") // Suppress the unchecked cast warning + TableReferenceNode node = (TableReferenceNode) tableExprList + .get(i); + node.getTable().alias = Optional.of("tt" + i); + } else { + dfAssert(false, "Expected all items in tableList to be TableReferenceNode instances"); + } + } + + // wrap table in `DataFusionAlias` for display + List> wrappedTables = new ArrayList<>(); + for (Node table : tableExprList) { + wrappedTables.add(new DataFusionAliasedTable(table)); + } + tableExprList = wrappedTables; + } + + } + + // If original query is + // select * from t1, t2, t3 + // The randomly mutated query looks like: + // select * from t1_csv, t2, t3_parquet + public void mutateEquivalentTableName() { + for (Node table : from.tableExprList) { + if (table instanceof DataFusionAliasedTable) { + Node aliasedTable = ((DataFusionAliasedTable) table).table; + + if (aliasedTable instanceof TableReferenceNode) { + @SuppressWarnings("unchecked") // Suppress the unchecked cast warning + TableReferenceNode tableRef = (TableReferenceNode) aliasedTable; + tableRef.getTable().pickAnotherEquivalentTableName(); + } else { + dfAssert(false, "Expected all items in tableList to be TableReferenceNode instances"); + } + } else { + dfAssert(false, "Expected all items in tableList to be TableReferenceNode instances"); + } + } + } + + // Just a marker for table in `DataFusionFrom` + // + // For example in query `select * from t1 as tt1, t1 as tt2` + // If it's in the from list, we use `DataFusionAlias` wrapper on the table + // and print it as 't1 as tt1' + // If the same table is in expressions, don't use the wrapper and print it as + // 'tt1' + public static class DataFusionAliasedTable implements Node { + public Node table; + + public DataFusionAliasedTable(Node table) { + dfAssert(table instanceof TableReferenceNode, "Expected table reference node"); + @SuppressWarnings("unchecked") // Suppress the unchecked cast warning + DataFusionTable t = ((TableReferenceNode) table).getTable(); + dfAssert(t.alias.isPresent(), "Expected table to have alias"); + + this.table = table; + } } // Generate SELECT statement according to the dependency of exprs, e.g.: @@ -149,12 +233,17 @@ public static DataFusionFrom generateFromClause(DataFusionGlobalState state, // // The generation order will be: // 1. [from_clause] - Pick tables like t1, t2, t3 and get a join clause - // 2. [expr_all_cols] - Generate a non-aggregate expression with all columns in t1, t2, t3. e.g.: + // 2. [expr_all_cols] - Generate a non-aggregate expression with all columns in + // t1, t2, t3. e.g.: // - t1.v1 = t2.v1 and t1.v2 > t3.v2 - // 3. [expr_groupby_cols], [expr_aggr_cols] - Randomly pick some cols in t1, t2, t3 as group by columns, and pick - // some other columns as aggregation columns, and generate non-aggr expression [expr_groupby_cols] on group by - // columns, finally generate aggregation expressions [expr_aggr_cols] on non-group-by/aggregation columns. - // For example, group by column is t1.v1, and aggregate columns is t2.v1, t3.v1, generated expressions can be: + // 3. [expr_groupby_cols], [expr_aggr_cols] - Randomly pick some cols in t1, t2, + // t3 as group by columns, and pick + // some other columns as aggregation columns, and generate non-aggr expression + // [expr_groupby_cols] on group by + // columns, finally generate aggregation expressions [expr_aggr_cols] on + // non-group-by/aggregation columns. + // For example, group by column is t1.v1, and aggregate columns is t2.v1, t3.v1, + // generated expressions can be: // - [expr_groupby_cols] t1.v1 + 1 // - [expr_aggr_cols] SUM(t3.v1 + t2.v1) public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) { @@ -172,6 +261,7 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) { randomTables = randomTables.subList(0, maxSize); } DataFusionFrom randomFrom = DataFusionFrom.generateFromClause(state, randomTables); + randomSelect.tableList = randomTables; /* Setup expression generators (to generate different clauses) */ List randomColumnsAll = DataFusionTable.getRandomColumns(randomTables); @@ -190,9 +280,15 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) { .generateExpression(DataFusionSchema.DataFusionDataType.BOOLEAN); /* Constructing result */ - List> randomColumnNodes = randomColumnsAll.stream() - .map((c) -> new ColumnReferenceNode(c)) - .collect(Collectors.toList()); + List> randomColumnNodes = randomColumnsAll.stream().map((c) -> { + if (c.getType() == DataFusionSchema.DataFusionDataType.STRING) { + Node colRef = new ColumnReferenceNode(c); + return new DataFusionSpecialExpr.CastToStringView(colRef); + + } else { + return new ColumnReferenceNode(c); + } + }).collect(Collectors.toList()); randomSelect.setFetchColumns(randomColumnNodes); // TODO(datafusion) make it more random like 'select *' randomSelect.from = randomFrom; @@ -218,7 +314,8 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) { // ... // group by v1 // - // This method assume `DataFusionSelect` is propoerly initialized with `getRandomSelect()` + // This method assume `DataFusionSelect` is propoerly initialized with + // `getRandomSelect()` public void setAggregates(DataFusionGlobalState state) { // group by exprs (e.g. group by v1, abs(v2)) List> groupByExprs = this.exprGenGroupBy.generateExpressionsPreferColumns(); diff --git a/src/sqlancer/datafusion/ast/DataFusionSpecialExpr.java b/src/sqlancer/datafusion/ast/DataFusionSpecialExpr.java new file mode 100644 index 00000000..8ae63861 --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionSpecialExpr.java @@ -0,0 +1,13 @@ +package sqlancer.datafusion.ast; + +import sqlancer.common.ast.newast.Node; + +public class DataFusionSpecialExpr { + public static class CastToStringView implements Node { + public Node expr; + + public CastToStringView(Node expr) { + this.expr = expr; + } + } +} diff --git a/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java b/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java index 3754a826..c19ad00d 100644 --- a/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java +++ b/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java @@ -55,7 +55,8 @@ protected boolean canGenerateColumnOfType(DataFusionDataType type) { return true; } - // If target expr type is numeric, when `supportAggregate`, make it more likely to generate aggregate functions + // If target expr type is numeric, when `supportAggregate`, make it more likely + // to generate aggregate functions // Since randomly generated expressions are nested: // For window/aggregate case, we want only outer layer to be window/aggr // to make it more likely to generate valid query @@ -82,7 +83,8 @@ private boolean filterBaseExpr(DataFusionBaseExpr expr, DataFusionDataType type, } // By default all possible non-aggregate expressions - // To generate aggregate functions: set this.supportAggregate to `true`, generate exprs, and reset. + // To generate aggregate functions: set this.supportAggregate to `true`, + // generate exprs, and reset. @Override protected Node generateExpression(DataFusionDataType type, int depth) { if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { @@ -103,7 +105,8 @@ protected Node generateExpression(DataFusionDataType type, DataFusionBaseExpr randomExpr = Randomly.fromList(possibleBaseExprs); - // if (randomExpr.exprType == DataFusionBaseExprCategory.AGGREGATE || randomExpr.exprType == + // if (randomExpr.exprType == DataFusionBaseExprCategory.AGGREGATE || + // randomExpr.exprType == // DataFusionBaseExprCategory.WINDOW) { // if (depth == 0) { // System.out.println("DBG depth 0"); @@ -143,7 +146,8 @@ protected Node generateExpression(DataFusionDataType type, case BINARY: dfAssert(randomExpr.argTypes.size() == 2 && randomExpr.nArgs == 2, "Binrary expression should only have 2 argument" + randomExpr.argTypes); - List argTypeList = new ArrayList<>(); // types of current expression's input arguments + List argTypeList = new ArrayList<>(); // types of current expression's input + // arguments for (ArgumentType argumentType : randomExpr.argTypes) { if (argumentType instanceof ArgumentType.Fixed) { ArgumentType.Fixed possibleArgTypes = (ArgumentType.Fixed) randomExpr.argTypes.get(0); @@ -333,13 +337,15 @@ public Node isNull(Node expr) { return new NewUnaryPostfixOperatorNode<>(expr, createExpr(DataFusionBaseExprType.IS_NULL)); } - // TODO(datafusion) refactor: make single generate aware of group by and aggr columns, and it can directly generate + // TODO(datafusion) refactor: make single generate aware of group by and aggr + // columns, and it can directly generate // having clause // Try best to generate a valid having clause // // Suppose query "... group by a, b ..." // and all available columns are "a, b, c, d" - // then a valid having clause can have expr of {a, b}, and expr of aggregation of {c, d} + // then a valid having clause can have expr of {a, b}, and expr of aggregation + // of {c, d} // e.g. "having a=b and avg(c) > avg(d)" // // `groupbyGen` can generate expression only with group by cols diff --git a/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java b/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java index 28ce5da3..e1acddcf 100644 --- a/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java +++ b/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java @@ -1,18 +1,24 @@ package sqlancer.datafusion.gen; +import java.util.Optional; + import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema.DataFusionColumn; import sqlancer.datafusion.DataFusionSchema.DataFusionDataType; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; public class DataFusionTableGenerator { // Randomly generate a query like 'create table t1 (v1 bigint, v2 boolean)' - public SQLQueryAdapter getQuery(DataFusionGlobalState globalState) { + public SQLQueryAdapter getCreateStmt(DataFusionGlobalState globalState) { ExpectedErrors errors = new ExpectedErrors(); StringBuilder sb = new StringBuilder(); String tableName = globalState.getSchema().getFreeTableName(); + + // Build "create table t1..." using sb sb.append("CREATE TABLE "); sb.append(tableName); sb.append("("); @@ -30,4 +36,40 @@ public SQLQueryAdapter getQuery(DataFusionGlobalState globalState) { return new SQLQueryAdapter(sb.toString(), errors, true); } + + // Given a table t1, return create statement to generate t1_stringview + // If t1 has no string column, return empty + // + // Query looks like (only v2 is TEXT column): + // create table t1_stringview as + // select v1, arrow_cast(v2, 'Utf8View') as v2 from t1; + public Optional createStringViewTable(DataFusionGlobalState globalState, DataFusionTable table) { + if (!table.getColumns().stream().anyMatch(c -> c.getType().equals(DataFusionDataType.STRING))) { + return Optional.empty(); + } + + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE "); + sb.append(table.getName()); + sb.append("_stringview AS SELECT "); + for (DataFusionColumn column : table.getColumns()) { + String colName = column.getName(); + if (column.getType().equals(DataFusionDataType.STRING)) { + // Found a TEXT column, cast it + sb.append("arrow_cast(").append(colName).append(", 'Utf8View') as ").append(colName); + } else { + sb.append(colName); + } + + // Join expressions with ',' + if (column != table.getColumns().get(table.getColumns().size() - 1)) { + sb.append(", "); + } + } + + sb.append(" FROM ").append(table.getName()).append(";"); + + return Optional.of(new SQLQueryAdapter(sb.toString(), new ExpectedErrors(), true)); + } + } diff --git a/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml b/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml index dd441ba3..f690cb52 100644 --- a/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml +++ b/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml @@ -4,20 +4,26 @@ edition = "2021" description = "Standalone DataFusion server" license = "Apache-2.0" +# TODO(datafusion): Figure out how to automatically manage arrow version +# arrow version should be the same as the one used by datafusion [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } -arrow = { version = "52.1.0", features = ["prettyprint"] } -arrow-array = { version = "52.1.0", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "52.1.0", default-features = false } -arrow-flight = { version = "52.1.0", features = ["flight-sql-experimental"] } -arrow-ipc = { version = "52.1.0", default-features = false, features = ["lz4"] } -arrow-ord = { version = "52.1.0", default-features = false } -arrow-schema = { version = "52.1.0", default-features = false } -arrow-string = { version = "52.1.0", default-features = false } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } +arrow = { version = "53.3.0", features = ["prettyprint"] } +arrow-array = { version = "53.3.0", default-features = false, features = [ + "chrono-tz", +] } +arrow-buffer = { version = "53.3.0", default-features = false } +arrow-flight = { version = "53.3.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "53.3.0", default-features = false, features = ["lz4"] } +arrow-ord = { version = "53.3.0", default-features = false } +arrow-schema = { version = "53.3.0", default-features = false } +arrow-string = { version = "53.3.0", default-features = false } async-trait = "0.1.73" bytes = "1.4" -chrono = { version = "0.4.34", default-features = false } -dashmap = "5.5.0" +chrono = { version = "0.4.38", default-features = false } +dashmap = "6.0.1" # This version is for SQLancer CI run (disabled temporary for multiple newly fixed bugs) # datafusion = { version = "41.0.0" } # Use following line if you want to test against the latest main branch of DataFusion @@ -28,17 +34,21 @@ half = { version = "2.2.1", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } log = "0.4" num_cpus = "1.13.0" -object_store = { version = "0.10.1", default-features = false } +object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "52.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +parquet = { version = "53.3.0", default-features = false, features = [ + "arrow", + "async", + "object_store", +] } rand = "0.8" serde = { version = "1.0", features = ["derive"] } serde_json = "1" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } -tonic = "0.11" +tonic = "0.12.1" uuid = "1.0" -prost = { version = "0.12", default-features = false } -prost-derive = { version = "0.12", default-features = false } +prost = { version = "0.13.1" } +prost-derive = "0.13.1" mimalloc = { version = "0.1", default-features = false } [[bin]] diff --git a/src/sqlancer/datafusion/server/datafusion_server/src/main.rs b/src/sqlancer/datafusion/server/datafusion_server/src/main.rs index 13ec73e9..41b12e21 100644 --- a/src/sqlancer/datafusion/server/datafusion_server/src/main.rs +++ b/src/sqlancer/datafusion/server/datafusion_server/src/main.rs @@ -17,7 +17,7 @@ use arrow_flight::{ use arrow_schema::{DataType, Field, Schema}; use dashmap::DashMap; use datafusion::logical_expr::LogicalPlan; -use datafusion::prelude::{DataFrame, ParquetReadOptions, SessionConfig, SessionContext}; +use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use futures::{Stream, StreamExt, TryStreamExt}; use log::info; use mimalloc::MiMalloc; @@ -292,7 +292,7 @@ impl FlightSqlService for FlightSqlServiceImpl { let result = df .collect() .await - .map_err(|e| status!("Error executing query", e))?; + .map_err(|e| status!("Errorr executing query", e))?; // if we get an empty result, create an empty schema let schema = match result.first() { @@ -400,6 +400,8 @@ impl FlightSqlService for FlightSqlServiceImpl { .and_then(|df| df.into_optimized_plan()) .map_err(|e| Status::internal(format!("Error building plan: {e}")))?; + info!("Plan is {:#?}", plan); + // store a copy of the plan, it will be used for execution let plan_uuid = Uuid::new_v4().hyphenated().to_string(); self.statements.insert(plan_uuid.clone(), plan.clone()); @@ -417,6 +419,7 @@ impl FlightSqlService for FlightSqlServiceImpl { dataset_schema: schema_bytes, parameter_schema: Default::default(), }; + info!("do_action_create_prepared_statement SUCCEED!"); Ok(res) } diff --git a/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java b/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java index c4a7defb..a7d876e8 100644 --- a/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java +++ b/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java @@ -7,6 +7,7 @@ import java.util.List; import sqlancer.ComparatorHelper; +import sqlancer.IgnoreMeException; import sqlancer.common.oracle.NoRECBase; import sqlancer.common.oracle.TestOracle; import sqlancer.datafusion.DataFusionErrors; @@ -43,24 +44,31 @@ public void check() throws SQLException { // generate a random: // SELECT [expr1] FROM [expr2] WHERE [expr3] DataFusionSelect randomSelect = getRandomSelect(state); + // Q1: SELECT count(*) FROM [expr2] WHERE [expr3] + // Q1 and Q2 is constructed from randomSelect's fields + // So for equivalent table mutation, we mutate randomSelect + randomSelect.mutateEquivalentTableName(); DataFusionSelect q1 = new DataFusionSelect(); q1.setFetchColumnsString("COUNT(*)"); q1.from = randomSelect.from; q1.setWhereClause(randomSelect.getWhereClause()); + String q1String = DataFusionToStringVisitor.asString(q1); + // Q2: SELECT count(case when [expr3] then 1 else null end) FROM [expr2] + randomSelect.mutateEquivalentTableName(); DataFusionSelect q2 = new DataFusionSelect(); String selectExpr = String.format("COUNT(CASE WHEN %s THEN 1 ELSE NULL END)", DataFusionToStringVisitor.asString(randomSelect.getWhereClause())); q2.setFetchColumnsString(selectExpr); q2.from = randomSelect.from; q2.setWhereClause(null); + String q2String = DataFusionToStringVisitor.asString(q2); /* * Execute Q1 and Q2 */ - String q1String = DataFusionToStringVisitor.asString(q1); - String q2String = DataFusionToStringVisitor.asString(q2); + // System.out.println("DBG: " + q1String + "\n" + q2String); List q1ResultSet = null; List q2ResultSet = null; try { @@ -81,6 +89,13 @@ public void check() throws SQLException { int count1 = q1ResultSet != null ? Integer.parseInt(q1ResultSet.get(0)) : -1; int count2 = q2ResultSet != null ? Integer.parseInt(q2ResultSet.get(0)) : -1; if (count1 != count2) { + // whitelist + // --------- + // https://github.com/apache/datafusion/issues/12468 + if (q1String.contains("NATURAL JOIN")) { + throw new IgnoreMeException(); + } + StringBuilder errorMessage = new StringBuilder().append("NoREC oracle violated:\n") .append(" Q1(result size ").append(count1).append("):").append(q1String).append(";\n") .append(" Q2(result size ").append(count2).append("):").append(q2String).append(";\n") @@ -94,5 +109,6 @@ public void check() throws SQLException { throw new AssertionError("\n\n" + indentedErrorLog); } + // System.out.println("NOREC passed: \n" + q1String + "\n" + q2String); } } diff --git a/src/sqlancer/datafusion/test/DataFusionPQS.java b/src/sqlancer/datafusion/test/DataFusionPQS.java new file mode 100644 index 00000000..00bd9f2a --- /dev/null +++ b/src/sqlancer/datafusion/test/DataFusionPQS.java @@ -0,0 +1,344 @@ +package sqlancer.datafusion.test; + +import static sqlancer.datafusion.DataFusionUtil.dfAssert; +import static sqlancer.datafusion.ast.DataFusionSelect.getRandomSelect; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import sqlancer.IgnoreMeException; +import sqlancer.common.ast.newast.Node; +import sqlancer.common.oracle.NoRECBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.datafusion.DataFusionErrors; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema.DataFusionColumn; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; +import sqlancer.datafusion.DataFusionToStringVisitor; +import sqlancer.datafusion.DataFusionUtil; +import sqlancer.datafusion.DataFusionUtil.DataFusionLogger; +import sqlancer.datafusion.ast.DataFusionExpression; +import sqlancer.datafusion.ast.DataFusionSelect; +import sqlancer.datafusion.ast.DataFusionSelect.DataFusionFrom; + +public class DataFusionPQS extends NoRECBase implements TestOracle { + + private final DataFusionGlobalState state; + + // Table references in a randomly generated SELECT query + // used for current PQS check. + // To construct queries ONLY used in PQS + // columns will be temporarily aliased + // Rember to reset them when PQS check is done + private List pqsTables; + + private StringBuilder currentCheckLog; // Each append should end with '\n' + + public DataFusionPQS(DataFusionGlobalState globalState) { + super(globalState); + this.state = globalState; + + DataFusionErrors.registerExpectedExecutionErrors(errors); + } + + private void setColumnAlias(List tables) { + List allColumns = tables.stream().flatMap(t -> t.getColumns().stream()) + .collect(Collectors.toList()); + for (int i = 0; i < allColumns.size(); i++) { + String alias = "cc" + i; + allColumns.get(i).alias = Optional.of(alias); + } + } + + private void resetColumnAlias(List tables) { + tables.stream().flatMap(t -> t.getColumns().stream()).forEach(c -> c.alias = Optional.empty()); + } + + private void pqsCleanUp() throws SQLException { + if (pqsTables == null) { + return; + } + resetColumnAlias(pqsTables); + + // Drop temp tables used in PQS check + SQLQueryAdapter tttCleanUp = new SQLQueryAdapter("drop table if exists ttt", errors); + tttCleanUp.execute(state); + SQLQueryAdapter ttCleanUp = new SQLQueryAdapter("drop table if exists tt", errors); + ttCleanUp.execute(state); + } + + @Override + public void check() throws SQLException { + this.currentCheckLog = new StringBuilder(); + String replay = DataFusionUtil.getReplay(state.getDatabaseName()); + currentCheckLog.append(replay); + pqsCleanUp(); + + try { + checkImpl(); + } catch (Exception | AssertionError e) { + pqsCleanUp(); + throw e; + } + pqsCleanUp(); + } + + public void checkImpl() throws SQLException { + // ====================================================== + // Step 1: + // select tt0.c0 as cc0, tt0.c1 as cc1, tt1.c0 as cc2 + // from t0 as tt0, t1 as tt1 + // where tt0.c0 = tt1.c0; + // ====================================================== + DataFusionSelect randomSelect = getRandomSelect(state); + randomSelect.from.joinConditionList = new ArrayList<>(); + randomSelect.from.joinTypeList = new ArrayList<>(); + randomSelect.mutateEquivalentTableName(); + pqsTables = randomSelect.tableList; + + // Reset fetch columns + List allColumns = randomSelect.tableList.stream().flatMap(t -> t.getColumns().stream()) + .collect(Collectors.toList()); + setColumnAlias(pqsTables); + List aliasedColumns = allColumns.stream().map(c -> c.getOrignalName() + " AS " + c.alias.get()) + .collect(Collectors.toList()); + String fetchColumnsString = String.join(", ", aliasedColumns); + randomSelect.setFetchColumnsString(fetchColumnsString); + + // ====================================================== + // Step 2: + // create table tt as + // with cte0 as (select tt0.c0 as cc0, tt0.c1 as cc1 from t0 as tt0 order by + // random() limit 1), + // with cte1 as (select tt1.c0 as cc2 from t1 as tt1 order by random() limit 1) + // select * from cte0, cte1; + // ====================================================== + List cteSelects = new ArrayList<>(); + for (DataFusionTable table : randomSelect.tableList) { + DataFusionSelect cteSelect = new DataFusionSelect(); + DataFusionFrom cteFrom = new DataFusionFrom(Arrays.asList(table)); + cteSelect.from = cteFrom; + + List columns = table.getColumns(); + List cteAliasedColumns = columns.stream().map(c -> c.getOrignalName() + " AS " + c.alias.get()) + .collect(Collectors.toList()); + String cteFetchColumnsString = String.join(", ", cteAliasedColumns); + cteSelect.setFetchColumnsString(cteFetchColumnsString); + cteSelects.add(cteSelect); + } + + // select tt0.c0 as cc0, tt0.c1 as cc1 from tt0 order by random() + List cteSelectsPickOneRow = cteSelects.stream() + .map(cteSelect -> DataFusionToStringVisitor.asString(cteSelect) + " ORDER BY RANDOM() LIMIT 1") + .collect(Collectors.toList()); + int ncte = cteSelectsPickOneRow.size(); + + List ctes = new ArrayList<>(); + for (int i = 0; i < ncte; i++) { + String cte = "cte" + i + " AS (" + cteSelectsPickOneRow.get(i) + ")"; + ctes.add(cte); + } + + // cte0, cte1, cte2 + List ctesInFrom = IntStream.range(0, ncte).mapToObj(i -> "cte" + i).collect(Collectors.toList()); + String ctesInFromString = String.join(", ", ctesInFrom); + + // Create tt (Table with one pivot row) + String ttCreate = "CREATE TABLE tt AS\n WITH\n " + String.join(",\n ", ctes) + "\n" + "SELECT * FROM " + + ctesInFromString; + + currentCheckLog.append("==== Create tt (Table with one pivot row):\n").append(ttCreate).append("\n"); + + // ====================================================== + // Step3: + // Find the predicate that can select the predicate row + // can be {p, NOT p, p is NULL} + // Note 'p' is 'cc0 = cc2' in Step 1 + // ====================================================== + Node whereExpr = randomSelect.getWhereClause(); // must be valid + Node notWhereExpr = randomSelect.exprGenAll.negatePredicate(whereExpr); + Node isNullWhereExpr = randomSelect.exprGenAll.isNull(whereExpr); + List> candidatePredicates = Arrays.asList(whereExpr, notWhereExpr, isNullWhereExpr); + + String pivotQ1 = "select * from tt where " + DataFusionToStringVisitor.asString(whereExpr); + String pivotQ2 = "select * from tt where " + DataFusionToStringVisitor.asString(notWhereExpr); + String pivotQ3 = "select * from tt where " + DataFusionToStringVisitor.asString(isNullWhereExpr); + + List pivotQs = Arrays.asList(pivotQ1, pivotQ2, pivotQ3); + + // Execute "crete tt" (table with one pivot row) + SQLQueryAdapter q = new SQLQueryAdapter(ttCreate, errors); + q.execute(state); + + SQLancerResultSet ttResult = null; + SQLQueryAdapter ttSelect = new SQLQueryAdapter("select * from tt", errors); + int nrow = 0; + try { + ttResult = ttSelect.executeAndGetLogged(state); + + if (ttResult == null) { + // Possible bug here, investigate later + throw new IgnoreMeException(); + } + + while (ttResult.next()) { + nrow++; + } + } catch (Exception e) { + // Possible bug here, investigate later + throw new IgnoreMeException(); + } finally { + if (ttResult != null && !ttResult.isClosed()) { + ttResult.close(); + } + } + + if (nrow == 0) { + // If empty table is picked, we can't find a pivot row + // Give up current check + // TODO(datafusion): support empty tables + throw new IgnoreMeException("Empty table is picked"); + } + + Node pivotPredicate = null; + String pivotRow = ""; + for (int i = 0; i < pivotQs.size(); i++) { + String pivotQ = pivotQs.get(i); + SQLQueryAdapter qSelect = new SQLQueryAdapter(pivotQ, errors); + SQLancerResultSet rs = null; + try { + rs = qSelect.executeAndGetLogged(state); + if (rs == null) { + // Only one in 3 pivot query will return 1 row + continue; + } + + int rowCount = 0; + while (rs.next()) { + rowCount += 1; + for (int ii = 1; ii <= rs.rs.getMetaData().getColumnCount(); ii++) { + pivotRow += "[" + rs.getString(ii) + "]"; + } + pivotPredicate = candidatePredicates.get(i); + } + + dfAssert(rowCount <= 1, "Pivot row should be length of 1, got " + rowCount); + + if (rowCount == 1) { + break; + } + } catch (Exception e) { + currentCheckLog.append(pivotQ).append("\n"); + currentCheckLog.append(e.getMessage()).append("\n").append(e.getCause()).append("\n"); + + String fullErrorMessage = currentCheckLog.toString(); + state.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.ERROR, fullErrorMessage); + + throw new AssertionError(fullErrorMessage); + } finally { + if (rs != null && !rs.isClosed()) { + rs.close(); + } + } + } + + if (pivotPredicate == null) { + // Sometimes all valid pivot queries failed + // Potential bug, investigate later + currentCheckLog.append("ALl pivot q failed! ").append(pivotQs).append("\n"); + throw new IgnoreMeException("All pivot queries failed " + pivotQs); + } + + // ====================================================== + // Step 4: + // Let's say in Step 3 we found the predicate is "Not (cc0 = cc2)" + // Check if the pivot row can be find in + // "select * from tt0, tt1 where Not(cc0 = cc2)" + // Then we construct table ttt with above query + // Finally join 'tt' and 'ttt' make sure one pivot row will be output + // ====================================================== + DataFusionSelect selectAllRows = new DataFusionSelect(); + DataFusionFrom selectAllRowsFrom = new DataFusionFrom(randomSelect.tableList); + selectAllRows.from = selectAllRowsFrom; + selectAllRows.setWhereClause(pivotPredicate); + + List allSelectColumns = randomSelect.tableList.stream().flatMap(t -> t.getColumns().stream()) + .collect(Collectors.toList()); + // tt0.v0 as cc0, tt0.v1 as cc1, tt1.v0 as cc2 + List allSelectExprs = allSelectColumns.stream().map(c -> c.getOrignalName() + " AS " + c.alias.get()) + .collect(Collectors.toList()); + resetColumnAlias(randomSelect.tableList); + String selectAllRowsFetchColStr = String.join(", ", allSelectExprs); + selectAllRows.setFetchColumnsString(selectAllRowsFetchColStr); + + String selectAllRowsString = DataFusionToStringVisitor.asString(selectAllRows); + String tttCreate = "CREATE TABLE ttt AS\n" + selectAllRowsString; + + SQLQueryAdapter tttCreateStmt = new SQLQueryAdapter(tttCreate, errors); + tttCreateStmt.execute(state); + setColumnAlias(randomSelect.tableList); + + // ====================================================== + // Step 5: + // Make sure the following query return 1 pivot row + // Otherwise PQS oracle is violated + // select * from tt join ttt + // on + // tt.cc0 is not distinct from ttt.cc0 + // and tt.cc1 is not distinct from ttt.cc1 + // and tt.cc2 is not distinct from ttt.cc2 + // ====================================================== + List onConditions = allSelectColumns.stream() + .map(c -> "(tt." + c.alias.get() + " IS NOT DISTINCT FROM ttt." + c.alias.get() + ")") + .collect(Collectors.toList()); + String onCond = String.join("\nAND ", onConditions); + String joinQuery = "SELECT COUNT(*) FROM tt JOIN ttt ON\n" + onCond; + + SQLQueryAdapter qJoin = new SQLQueryAdapter(joinQuery, errors); + + SQLancerResultSet rsFull = null; + try { + rsFull = qJoin.executeAndGetLogged(state); + + if (rsFull == null) { + throw new IgnoreMeException("Join query returned no results: " + joinQuery); + } + + String joinCount = "invalid"; + while (rsFull.next()) { + joinCount = rsFull.getString(1); + } + + if (joinCount.equals("0")) { + String replay = DataFusionUtil.getReplay(state.getDatabaseName()); + StringBuilder errorLog = new StringBuilder().append("PQS oracle violated:\n").append("Found ") + .append(joinCount).append(" pivot rows:\n").append(" Pivot row: ").append(pivotRow).append("\n") + .append("Query to select pivot row: ").append(ttCreate).append("\n") + .append("Query to select all rows: ").append(tttCreate).append("\n").append("Join: ") + .append(joinQuery).append("\n").append(replay).append("\n"); + + String errorString = errorLog.toString(); + String indentedErrorLog = errorString.replaceAll("(?m)^", " "); + state.dfLogger.appendToLog(DataFusionLogger.DataFusionLogType.ERROR, errorString); + + throw new AssertionError("\n\n" + indentedErrorLog); + } else if (!joinCount.matches("\\d+")) { + // If joinCount is not a integer > 0, throw exception + throw new IgnoreMeException("Join query returned invalid result: " + joinCount); + } + } catch (Exception e) { + throw new IgnoreMeException("Failed to execute join query: " + joinQuery); + } finally { + if (rsFull != null && !rsFull.isClosed()) { + rsFull.close(); + } + } + } +} diff --git a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java index 9d6c2efc..31f90d0f 100644 --- a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java +++ b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java @@ -8,6 +8,7 @@ import java.util.List; import sqlancer.ComparatorHelper; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.ast.newast.NewBinaryOperatorNode; import sqlancer.common.ast.newast.Node; @@ -43,6 +44,7 @@ public void check() throws SQLException { // generate a random 'SELECT [expr1] FROM [expr2] WHERE [expr3] super.check(); DataFusionSelect randomSelect = select; + randomSelect.mutateEquivalentTableName(); if (Randomly.getBoolean()) { randomSelect.distinct = true; @@ -69,12 +71,15 @@ public void check() throws SQLException { randomSelect.setWhereClause(null); qString = DataFusionToStringVisitor.asString(randomSelect); + randomSelect.mutateEquivalentTableName(); randomSelect.setWhereClause(predicate); qp1String = DataFusionToStringVisitor.asString(randomSelect); + randomSelect.mutateEquivalentTableName(); randomSelect.setWhereClause(negatedPredicate); qp2String = DataFusionToStringVisitor.asString(randomSelect); + randomSelect.mutateEquivalentTableName(); randomSelect.setWhereClause(isNullPredicate); qp3String = DataFusionToStringVisitor.asString(randomSelect); } else { @@ -111,6 +116,8 @@ public void check() throws SQLException { /* * Run all queires */ + // System.out.println("DBG TLP: " + qString + "\n" + qp1String + "\n" + + // qp2String + "\n" + qp3String); List qResultSet = ComparatorHelper.getResultSetFirstColumnAsString(qString, errors, state); List combinedString = new ArrayList<>(); List qpResultSet = ComparatorHelper.getCombinedResultSet(qp1String, qp2String, qp3String, @@ -121,6 +128,13 @@ public void check() throws SQLException { ComparatorHelper.assumeResultSetsAreEqual(qResultSet, qpResultSet, qString, combinedString, state, ComparatorHelper::canonicalizeResultValue); } catch (AssertionError e) { + // whitelist + // --------- + // https://github.com/apache/datafusion/issues/12468 + if (qp1String.contains("NATURAL JOIN")) { + throw new IgnoreMeException(); + } + // Append more error message String replay = DataFusionUtil.getReplay(state.getDatabaseName()); String newMessage = e.getMessage() + "\n" + e.getCause() + "\n" + replay + "\n";