Skip to content

Commit

Permalink
Refactor MySQLComStmtPrepareParameterMarkerExtractor and PostgreSQLCo…
Browse files Browse the repository at this point in the history
…mDescribeExecutor (#33970)

* Refactor MySQLComStmtPrepareParameterMarkerExtractor

* Refactor MySQLComStmtPrepareParameterMarkerExtractor and PostgreSQLComDescribeExecutor
  • Loading branch information
terrymanu authored Dec 8, 2024
1 parent d7a4bd9 commit 934ff92
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.postgresql.exception.metadata.ColumnNotFoundException;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
Expand All @@ -32,8 +34,8 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.stream.Collectors;

/**
Expand All @@ -51,42 +53,54 @@ public final class MySQLComStmtPrepareParameterMarkerExtractor {
* @return corresponding columns of parameter markers
*/
public static List<ShardingSphereColumn> findColumnsOfParameterMarkers(final SQLStatement sqlStatement, final ShardingSphereSchema schema) {
return sqlStatement instanceof InsertStatement ? findColumnsOfParameterMarkersForInsert((InsertStatement) sqlStatement, schema) : Collections.emptyList();
return sqlStatement instanceof InsertStatement && ((InsertStatement) sqlStatement).getTable().isPresent()
? findColumnsOfParameterMarkersForInsert((InsertStatement) sqlStatement, schema)
: Collections.emptyList();
}

private static List<ShardingSphereColumn> findColumnsOfParameterMarkersForInsert(final InsertStatement insertStatement, final ShardingSphereSchema schema) {
ShardingSphereTable table = schema.getTable(insertStatement.getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse(""));
List<String> columnNamesOfInsert = getColumnNamesOfInsertStatement(insertStatement, table);
List<ShardingSphereColumn> result = getParameterMarkerColumns(insertStatement, table, columnNamesOfInsert);
insertStatement.getOnDuplicateKeyColumns().ifPresent(optional -> result.addAll(getOnDuplicateKeyParameterMarkerColumns(optional.getColumns(), table)));
return result;
}

private static List<String> getColumnNamesOfInsertStatement(final InsertStatement insertStatement, final ShardingSphereTable table) {
return insertStatement.getColumns().isEmpty() ? table.getColumnNames() : insertStatement.getColumns().stream().map(each -> each.getIdentifier().getValue()).collect(Collectors.toList());
}

private static List<ShardingSphereColumn> getParameterMarkerColumns(final InsertStatement insertStatement, final ShardingSphereTable table, final List<String> columnNamesOfInsert) {
List<ShardingSphereColumn> result = new ArrayList<>(insertStatement.getParameterMarkerSegments().size());
for (InsertValuesSegment each : insertStatement.getValues()) {
ListIterator<ExpressionSegment> listIterator = each.getValues().listIterator();
for (int columnIndex = listIterator.nextIndex(); listIterator.hasNext(); columnIndex = listIterator.nextIndex()) {
ExpressionSegment value = listIterator.next();
if (!(value instanceof ParameterMarkerExpressionSegment)) {
continue;
}
String columnName = columnNamesOfInsert.get(columnIndex);
ShardingSphereColumn column = table.getColumn(columnName);
result.add(column);
}
result.addAll(getParameterMarkerColumns(table, columnNamesOfInsert, each));
}
insertStatement.getOnDuplicateKeyColumns().ifPresent(optional -> appendOnDuplicateKeyParameterMarkers(optional.getColumns(), table, result));
return result;
}

private static List<String> getColumnNamesOfInsertStatement(final InsertStatement insertStatement, final ShardingSphereTable table) {
return insertStatement.getColumns().isEmpty() ? table.getColumnNames() : insertStatement.getColumns().stream().map(each -> each.getIdentifier().getValue()).collect(Collectors.toList());
private static List<ShardingSphereColumn> getParameterMarkerColumns(final ShardingSphereTable table, final List<String> columnNamesOfInsert, final InsertValuesSegment segment) {
List<ShardingSphereColumn> result = new LinkedList<>();
int index = 0;
for (ExpressionSegment each : segment.getValues()) {
if (each instanceof ParameterMarkerExpressionSegment) {
String columnName = columnNamesOfInsert.get(index);
ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(table.getName(), columnName));
result.add(table.getColumn(columnName));
}
index++;
}
return result;
}

private static void appendOnDuplicateKeyParameterMarkers(final Collection<ColumnAssignmentSegment> onDuplicateKeyColumns,
final ShardingSphereTable table, final List<ShardingSphereColumn> result) {
private static List<ShardingSphereColumn> getOnDuplicateKeyParameterMarkerColumns(final Collection<ColumnAssignmentSegment> onDuplicateKeyColumns, final ShardingSphereTable table) {
List<ShardingSphereColumn> result = new LinkedList<>();
for (ColumnAssignmentSegment each : onDuplicateKeyColumns) {
if (!(each.getValue() instanceof ParameterMarkerExpressionSegment)) {
continue;
if (each.getValue() instanceof ParameterMarkerExpressionSegment) {
String columnName = each.getColumns().iterator().next().getIdentifier().getValue();
ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(table.getName(), columnName));
result.add(table.getColumn(columnName));
}
String columnName = each.getColumns().iterator().next().getIdentifier().getValue();
ShardingSphereColumn column = table.getColumn(columnName);
result.add(column);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@

class MySQLComStmtPrepareParameterMarkerExtractorTest {

private final DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "MySQL");

@Test
void assertFindColumnsOfParameterMarkersForInsertStatement() {
String sql = "insert into user (id, name, age) values (1, ?, ?), (?, 'bar', ?)";
SQLStatement sqlStatement = new ShardingSphereSQLParserEngine(TypedSPILoader.getService(DatabaseType.class, "MySQL"), new CacheOption(0, 0L), new CacheOption(0, 0L)).parse(sql, false);
ShardingSphereSchema schema = prepareSchema();
String sql = "INSERT INTO user (id, name, age) VALUES (1, ?, ?), (?, 'bar', ?)";
SQLStatement sqlStatement = new ShardingSphereSQLParserEngine(databaseType, new CacheOption(0, 0L), new CacheOption(0, 0L)).parse(sql, false);
ShardingSphereSchema schema = createSchema();
List<ShardingSphereColumn> actual = MySQLComStmtPrepareParameterMarkerExtractor.findColumnsOfParameterMarkers(sqlStatement, schema);
assertThat(actual.get(0), is(schema.getTable("user").getColumn("name")));
assertThat(actual.get(1), is(schema.getTable("user").getColumn("age")));
assertThat(actual.get(2), is(schema.getTable("user").getColumn("id")));
assertThat(actual.get(3), is(schema.getTable("user").getColumn("age")));
}

private ShardingSphereSchema prepareSchema() {
private ShardingSphereSchema createSchema() {
ShardingSphereTable table = new ShardingSphereTable("user", Arrays.asList(
new ShardingSphereColumn("id", Types.BIGINT, true, false, false, false, true, false),
new ShardingSphereColumn("name", Types.VARCHAR, false, false, false, false, false, false),
new ShardingSphereColumn("age", Types.SMALLINT, false, false, false, false, true, false)), Collections.emptyList(), Collections.emptyList());
ShardingSphereSchema result = new ShardingSphereSchema("foo_db");
result.putTable(table);
return result;
return new ShardingSphereSchema("foo_db", Collections.singleton(table), Collections.emptyList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,18 @@ private void describeInsertStatementByDatabaseMetaData(final PostgreSQLServerPre
preparedStatement.setRowDescription(returningSegment.<PostgreSQLPacket>map(returning -> describeReturning(returning, table)).orElseGet(PostgreSQLNoDataPacket::getInstance));
int parameterMarkerIndex = 0;
for (InsertValuesSegment each : insertStatement.getValues()) {
ListIterator<ExpressionSegment> listIterator = each.getValues().listIterator();
for (int columnIndex = listIterator.nextIndex(); listIterator.hasNext(); columnIndex = listIterator.nextIndex()) {
ExpressionSegment value = listIterator.next();
for (int i = 0; i < each.getValues().size(); i++) {
ExpressionSegment value = each.getValues().get(i);
if (!(value instanceof ParameterMarkerExpressionSegment)) {
continue;
}
if (!unspecifiedTypeParameterIndexes.contains(parameterMarkerIndex)) {
parameterMarkerIndex++;
continue;
}
String columnName = columnNamesOfInsert.get(columnIndex);
ShardingSphereColumn column = table.getColumn(columnName);
ShardingSpherePreconditions.checkNotNull(column, () -> new ColumnNotFoundException(logicTableName, columnName));
preparedStatement.getParameterTypes().set(parameterMarkerIndex++, PostgreSQLColumnType.valueOfJDBCType(column.getDataType()));
String columnName = columnNamesOfInsert.get(i);
ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(logicTableName, columnName));
preparedStatement.getParameterTypes().set(parameterMarkerIndex++, PostgreSQLColumnType.valueOfJDBCType(table.getColumn(columnName).getDataType()));
}
}
}
Expand Down Expand Up @@ -179,8 +177,8 @@ private PostgreSQLRowDescriptionPacket describeReturning(final ReturningSegment
Collection<PostgreSQLColumnDescription> result = new LinkedList<>();
for (ProjectionSegment each : returningSegment.getProjections().getProjections()) {
if (each instanceof ShorthandProjectionSegment) {
table.getAllColumns().stream().map(column -> new PostgreSQLColumnDescription(column.getName(), 0, column.getDataType(), estimateColumnLength(column.getDataType()), ""))
.forEach(result::add);
table.getAllColumns().stream()
.map(column -> new PostgreSQLColumnDescription(column.getName(), 0, column.getDataType(), estimateColumnLength(column.getDataType()), "")).forEach(result::add);
}
if (each instanceof ColumnProjectionSegment) {
ColumnProjectionSegment segment = (ColumnProjectionSegment) each;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ class PostgreSQLComDescribeExecutorTest {

private static final String TABLE_NAME = "t_order";

private static final SQLParserEngine SQL_PARSER_ENGINE = new ShardingSphereSQLParserEngine(
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), new CacheOption(2000, 65535L), new CacheOption(128, 1024L));
private static final DatabaseType DATABASE_TYPE = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL");

private static final SQLParserEngine SQL_PARSER_ENGINE = new ShardingSphereSQLParserEngine(DATABASE_TYPE, new CacheOption(2000, 65535L), new CacheOption(128, 1024L));

@Mock
private PortalContext portalContext;
Expand Down Expand Up @@ -395,9 +396,9 @@ private ContextManager mockContextManager() {
new ShardingSphereColumn("c", Types.CHAR, true, false, false, true, false, false),
new ShardingSphereColumn("pad", Types.CHAR, true, false, false, true, false, false));
when(schema.getTable(TABLE_NAME)).thenReturn(new ShardingSphereTable(TABLE_NAME, columns, Collections.emptyList(), Collections.emptyList()));
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getProtocolType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"));
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getProtocolType()).thenReturn(DATABASE_TYPE);
StorageUnit storageUnit = mock(StorageUnit.class, RETURNS_DEEP_STUBS);
when(storageUnit.getStorageType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"));
when(storageUnit.getStorageType()).thenReturn(DATABASE_TYPE);
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getResourceMetaData().getStorageUnits()).thenReturn(Collections.singletonMap("ds_0", storageUnit));
when(result.getMetaDataContexts().getMetaData().containsDatabase(DATABASE_NAME)).thenReturn(true);
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).containsSchema("public")).thenReturn(true);
Expand Down

0 comments on commit 934ff92

Please sign in to comment.