Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix class cast exception when select columns system table #30545

Merged
merged 3 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ private void registerTableScanExecutor(final Schema sqlFederationSchema, final D
TableScanExecutorContext executorContext = new TableScanExecutorContext(databaseName, schemaName, metaData.getProps(), federationContext);
EnumerableScanExecutor scanExecutor = new EnumerableScanExecutor(prepareEngine, jdbcExecutor, callback, optimizerContext, metaData.getGlobalRuleMetaData(), executorContext, statistics);
// TODO register only the required tables
for (String each : metaData.getDatabase(databaseName).getSchema(schemaName).getAllTableNames()) {
Table table = sqlFederationSchema.getTable(each);
for (ShardingSphereTable each : metaData.getDatabase(databaseName).getSchema(schemaName).getTables().values()) {
Table table = sqlFederationSchema.getTable(each.getName());
if (table instanceof SQLFederationTable) {
((SQLFederationTable) table).setScanExecutor(scanExecutor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ private void computeConnectionOffsets(final ExecutionContext context) {
private Enumerable<Object> executeByShardingSphereData(final String databaseName, final String schemaName, final ShardingSphereTable table, final DatabaseType databaseType) {
// TODO move this logic to ShardingSphere statistics
if (databaseType instanceof OpenGaussDatabaseType && SYSTEM_CATALOG_TABLES.contains(table.getName())) {
return createMemoryEnumerator(createSystemCatalogTableData(table));
return createMemoryEnumerator(createSystemCatalogTableData(table), table, databaseType);
}
Optional<ShardingSphereTableData> tableData = Optional.ofNullable(statistics.getDatabaseData().get(databaseName)).map(optional -> optional.getSchemaData().get(schemaName))
.map(ShardingSphereSchemaData::getTableData).map(shardingSphereData -> shardingSphereData.get(table.getName()));
return tableData.map(this::createMemoryEnumerator).orElseGet(this::createEmptyEnumerable);
return tableData.map(optional -> createMemoryEnumerator(optional, table, databaseType)).orElseGet(this::createEmptyEnumerable);
}

private ShardingSphereTableData createSystemCatalogTableData(final ShardingSphereTable table) {
Expand Down Expand Up @@ -231,12 +231,12 @@ private void appendOpenGaussRoleData(final ShardingSphereTableData tableData, fi
}
}

private Enumerable<Object> createMemoryEnumerator(final ShardingSphereTableData tableData) {
private Enumerable<Object> createMemoryEnumerator(final ShardingSphereTableData tableData, final ShardingSphereTable table, final DatabaseType databaseType) {
return new AbstractEnumerable<Object>() {

@Override
public Enumerator<Object> enumerator() {
return new MemoryEnumerator(tableData.getRows());
return new MemoryEnumerator(tableData.getRows(), table.getColumns().values(), databaseType);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,22 @@

package org.apache.shardingsphere.sqlfederation.executor.row;

import lombok.SneakyThrows;
import org.apache.calcite.linq4j.Enumerator;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.impl.driver.jdbc.type.util.ResultSetUtils;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.metadata.statistics.ShardingSphereRowData;
import org.apache.shardingsphere.sqlfederation.optimizer.metadata.util.SQLFederationDataTypeUtils;

import java.sql.SQLFeatureNotSupportedException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* Memory enumerator.
Expand All @@ -30,13 +41,36 @@ public final class MemoryEnumerator implements Enumerator<Object> {

private final Collection<ShardingSphereRowData> rows;

private Iterator<ShardingSphereRowData> rowDataIterator;
private final DatabaseType databaseType;

private final Map<Integer, Class<?>> columnTypes;

private Iterator<ShardingSphereRowData> iterator;

private Object current;

public MemoryEnumerator(final Collection<ShardingSphereRowData> rows) {
public MemoryEnumerator(final Collection<ShardingSphereRowData> rows, final Collection<ShardingSphereColumn> columns, final DatabaseType databaseType) {
this.rows = rows;
rowDataIterator = rows.iterator();
this.databaseType = databaseType;
columnTypes = createColumnTypes(new ArrayList<>(columns));
iterator = rows.iterator();
}

private Map<Integer, Class<?>> createColumnTypes(final List<ShardingSphereColumn> columns) {
Map<Integer, Class<?>> result = new HashMap<>(columns.size(), 1F);
for (int index = 0; index < columns.size(); index++) {
int finalIndex = index;
getSqlTypeClass(columns, index).ifPresent(optional -> result.put(finalIndex, optional));
}
return result;
}

private Optional<Class<?>> getSqlTypeClass(final List<ShardingSphereColumn> columns, final int index) {
try {
return Optional.of(SQLFederationDataTypeUtils.getSqlTypeClass(databaseType, columns.get(index)));
} catch (final IllegalArgumentException ex) {
return Optional.empty();
}
}

@Override
Expand All @@ -46,22 +80,41 @@ public Object current() {

@Override
public boolean moveNext() {
if (rowDataIterator.hasNext()) {
current = rowDataIterator.next().getRows().toArray();
if (iterator.hasNext()) {
current = convertToTargetType(iterator.next().getRows().toArray());
return true;
}
current = null;
rowDataIterator = rows.iterator();
iterator = rows.iterator();
return false;
}

@SneakyThrows
private Object[] convertToTargetType(final Object[] rows) {
Object[] result = new Object[rows.length];
for (int index = 0; index < rows.length; index++) {
if (columnTypes.containsKey(index)) {
result[index] = convertValue(rows, index);
}
}
return result;
}

private Object convertValue(final Object[] rows, final int index) {
try {
return ResultSetUtils.convertValue(rows[index], columnTypes.get(index));
} catch (final SQLFeatureNotSupportedException ex) {
return rows[index];
}
}

@Override
public void reset() {
}

@Override
public void close() {
rowDataIterator = rows.iterator();
iterator = rows.iterator();
current = null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.calcite.linq4j.Enumerable;
import org.apache.calcite.linq4j.Enumerator;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
import org.apache.shardingsphere.infra.metadata.statistics.ShardingSphereDatabaseData;
import org.apache.shardingsphere.infra.metadata.statistics.ShardingSphereRowData;
Expand All @@ -32,6 +33,7 @@
import org.apache.shardingsphere.sqlfederation.optimizer.metadata.schema.table.ScanExecutorContext;
import org.junit.jupiter.api.Test;

import java.sql.Types;
import java.util.Collections;

import static org.hamcrest.CoreMatchers.instanceOf;
Expand Down Expand Up @@ -59,10 +61,11 @@ void assertExecuteWithStatistics() {
ShardingSphereTableData tableData = mock(ShardingSphereTableData.class);
when(tableData.getRows()).thenReturn(Collections.singletonList(new ShardingSphereRowData(Collections.singletonList(1))));
when(schemaData.getTableData().get("test")).thenReturn(tableData);
ShardingSphereTable shardingSphereTable = mock(ShardingSphereTable.class);
when(shardingSphereTable.getName()).thenReturn("test");
ShardingSphereTable table = mock(ShardingSphereTable.class, RETURNS_DEEP_STUBS);
when(table.getName()).thenReturn("test");
when(table.getColumns().values()).thenReturn(Collections.singleton(new ShardingSphereColumn("id", Types.INTEGER, true, false, false, false, true, false)));
Enumerable<Object> enumerable = new EnumerableScanExecutor(null, null, null, optimizerContext, null, executorContext, statistics)
.execute(shardingSphereTable, mock(ScanExecutorContext.class));
.execute(table, mock(ScanExecutorContext.class));
try (Enumerator<Object> actual = enumerable.enumerator()) {
actual.moveNext();
Object row = actual.current();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ private static RelDataType getRelDataType(final DatabaseType protocolType, final
return typeFactory.createTypeWithNullability(javaType, true);
}

private static Class<?> getSqlTypeClass(final DatabaseType protocolType, final ShardingSphereColumn column) {
/**
* Get SQL type class.
*
* @param protocolType protocol type
* @param column ShardingSphere column
* @return SQL type class
*/
public static Class<?> getSqlTypeClass(final DatabaseType protocolType, final ShardingSphereColumn column) {
Optional<Class<?>> typeClazz = Optional.empty();
if (protocolType instanceof MySQLDatabaseType) {
typeClazz = findMySQLTypeClass(column);
Expand Down