Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SQLTranslator interface to reduce parameter passing #28902

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private RouteContext route(final QueryContext queryContext, final ShardingSphere
private SQLRewriteResult rewrite(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData,
final ConfigurationProperties props, final RouteContext routeContext, final ConnectionContext connectionContext) {
SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, globalRuleMetaData, props);
return sqlRewriteEntry.rewrite(queryContext.getSql(), queryContext.getParameters(), queryContext.getSqlStatementContext(), routeContext, connectionContext, queryContext.getHintValueContext());
return sqlRewriteEntry.rewrite(queryContext, routeContext, connectionContext);
}

private ExecutionContext createExecutionContext(final QueryContext queryContext, final ShardingSphereDatabase database, final RouteContext routeContext, final SQLRewriteResult rewriteResult) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.shardingsphere.infra.rewrite;

import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
Expand All @@ -30,10 +29,10 @@
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPILoader;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

Expand Down Expand Up @@ -61,27 +60,23 @@ public SQLRewriteEntry(final ShardingSphereDatabase database, final RuleMetaData
/**
* Rewrite.
*
* @param sql SQL
* @param params SQL parameters
* @param sqlStatementContext SQL statement context
* @param queryContext query context
* @param routeContext route context
* @param connectionContext connection context
* @param hintValueContext hint value context
*
* @return route unit and SQL rewrite result map
*/
public SQLRewriteResult rewrite(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext);
public SQLRewriteResult rewrite(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(queryContext, routeContext, connectionContext);
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext);
? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, queryContext)
: new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext, queryContext);
}

private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
SQLRewriteContext result = new SQLRewriteContext(database, sqlStatementContext, sql, params, connectionContext, hintValueContext);
private SQLRewriteContext createSQLRewriteContext(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) {
HintValueContext hintValueContext = queryContext.getHintValueContext();
SQLRewriteContext result = new SQLRewriteContext(database, queryContext.getSqlStatementContext(), queryContext.getSql(), queryContext.getParameters(), connectionContext, hintValueContext);
decorate(decorators, result, routeContext, hintValueContext);
result.generateSQLTokens();
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
import org.apache.shardingsphere.infra.rewrite.sql.impl.DefaultSQLBuilder;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;

import java.util.Map;
Expand All @@ -46,14 +47,14 @@ public final class GenericSQLRewriteEngine {
* Rewrite SQL and parameters.
*
* @param sqlRewriteContext SQL rewrite context
* @param queryContext query context
* @return SQL rewrite result
*/
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) {
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final QueryContext queryContext) {
DatabaseType protocolType = database.getProtocolType();
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
DatabaseType storageType = storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType();
String sql = translatorRule.translate(
new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext(), storageType, database, globalRuleMetaData);
String sql = translatorRule.translate(new DefaultSQLBuilder(sqlRewriteContext).toSQL(), queryContext, storageType, database, globalRuleMetaData);
return new GenericSQLRewriteResult(new SQLRewriteUnit(sql, sqlRewriteContext.getParameterBuilder().getParameters()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
Expand Down Expand Up @@ -62,9 +63,10 @@ public final class RouteSQLRewriteEngine {
*
* @param sqlRewriteContext SQL rewrite context
* @param routeContext route context
* @param queryContext query context
* @return SQL rewrite result
*/
public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final QueryContext queryContext) {
Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits = new LinkedHashMap<>(routeContext.getRouteUnits().size(), 1F);
for (Entry<String, Collection<RouteUnit>> entry : aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {
Collection<RouteUnit> routeUnits = entry.getValue();
Expand All @@ -74,7 +76,7 @@ public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext,
addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
}
}
return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext(), sqlRewriteUnits));
return new RouteSQLRewriteResult(translate(queryContext, sqlRewriteUnits));
}

private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {
Expand Down Expand Up @@ -155,12 +157,12 @@ private boolean isInSameDataNode(final Collection<DataNode> dataNodes, final Rou
return false;
}

private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatementContext sqlStatementContext, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
private Map<RouteUnit, SQLRewriteUnit> translate(final QueryContext queryContext, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
Map<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F);
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
for (Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
DatabaseType storageType = storageUnits.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageType();
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatementContext, storageType, database, globalRuleMetaData);
String sql = translatorRule.translate(entry.getValue().getSql(), queryContext, storageType, database, globalRuleMetaData);
SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters());
result.put(entry.getKey(), sqlRewriteUnit);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
Expand All @@ -59,12 +60,21 @@ void assertRewriteForGenericSQLRewriteResult() {
SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(
database, new RuleMetaData(Collections.singleton(new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()))), new ConfigurationProperties(new Properties()));
RouteContext routeContext = new RouteContext();
GenericSQLRewriteResult sqlRewriteResult = (GenericSQLRewriteResult) sqlRewriteEntry.rewrite("SELECT ?", Collections.singletonList(1), mock(CommonSQLStatementContext.class), routeContext,
mock(ConnectionContext.class), new HintValueContext());
GenericSQLRewriteResult sqlRewriteResult = (GenericSQLRewriteResult) sqlRewriteEntry.rewrite(createQueryContext(), routeContext, mock(ConnectionContext.class));
assertThat(sqlRewriteResult.getSqlRewriteUnit().getSql(), is("SELECT ?"));
assertThat(sqlRewriteResult.getSqlRewriteUnit().getParameters(), is(Collections.singletonList(1)));
}

private QueryContext createQueryContext() {
QueryContext result = mock(QueryContext.class);
when(result.getSql()).thenReturn("SELECT ?");
when(result.getParameters()).thenReturn(Collections.singletonList(1));
CommonSQLStatementContext sqlStatementContext = mock(CommonSQLStatementContext.class);
when(result.getSqlStatementContext()).thenReturn(sqlStatementContext);
when(result.getHintValueContext()).thenReturn(new HintValueContext());
return result;
}

@Test
void assertRewriteForRouteSQLRewriteResult() {
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME, TypedSPILoader.getService(DatabaseType.class, "H2"), mockResourceMetaData(),
Expand All @@ -77,8 +87,7 @@ void assertRewriteForRouteSQLRewriteResult() {
RouteUnit secondRouteUnit = mock(RouteUnit.class);
when(secondRouteUnit.getDataSourceMapper()).thenReturn(new RouteMapper("ds", "ds_1"));
routeContext.getRouteUnits().addAll(Arrays.asList(firstRouteUnit, secondRouteUnit));
RouteSQLRewriteResult sqlRewriteResult = (RouteSQLRewriteResult) sqlRewriteEntry.rewrite("SELECT ?",
Collections.singletonList(1), mock(CommonSQLStatementContext.class), routeContext, mock(ConnectionContext.class), new HintValueContext());
RouteSQLRewriteResult sqlRewriteResult = (RouteSQLRewriteResult) sqlRewriteEntry.rewrite(createQueryContext(), routeContext, mock(ConnectionContext.class));
assertThat(sqlRewriteResult.getSqlRewriteUnits().size(), is(2));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
import org.junit.jupiter.api.Test;
Expand All @@ -53,7 +54,7 @@ void assertRewrite() {
when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits);
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
new HintValueContext()), mock(QueryContext.class));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}
Expand All @@ -67,7 +68,7 @@ void assertRewriteStorageTypeIsEmpty() {
when(database.getResourceMetaData().getStorageUnits()).thenReturn(Collections.emptyMap());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
new HintValueContext()), mock(QueryContext.class));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}
Expand Down
Loading
Loading