Skip to content

Commit

Permalink
Refactor SQLTranslator interface to reduce parameter passing
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Oct 31, 2023
1 parent 58432f3 commit 6cd86e2
Show file tree
Hide file tree
Showing 17 changed files with 67 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private RouteContext route(final QueryContext queryContext, final ShardingSphere
private SQLRewriteResult rewrite(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData,
final ConfigurationProperties props, final RouteContext routeContext, final ConnectionContext connectionContext) {
SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, globalRuleMetaData, props);
return sqlRewriteEntry.rewrite(queryContext.getSql(), queryContext.getParameters(), queryContext.getSqlStatementContext(), routeContext, connectionContext, queryContext.getHintValueContext());
return sqlRewriteEntry.rewrite(queryContext, routeContext, connectionContext);
}

private ExecutionContext createExecutionContext(final QueryContext queryContext, final ShardingSphereDatabase database, final RouteContext routeContext, final SQLRewriteResult rewriteResult) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.shardingsphere.infra.rewrite;

import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
Expand All @@ -30,10 +29,10 @@
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPILoader;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

Expand Down Expand Up @@ -61,27 +60,23 @@ public SQLRewriteEntry(final ShardingSphereDatabase database, final RuleMetaData
/**
* Rewrite.
*
* @param sql SQL
* @param params SQL parameters
* @param sqlStatementContext SQL statement context
* @param queryContext query context
* @param routeContext route context
* @param connectionContext connection context
* @param hintValueContext hint value context
*
* @return route unit and SQL rewrite result map
*/
public SQLRewriteResult rewrite(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext);
public SQLRewriteResult rewrite(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(queryContext, routeContext, connectionContext);
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext);
? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, queryContext)
: new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext, queryContext);
}

private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
SQLRewriteContext result = new SQLRewriteContext(database, sqlStatementContext, sql, params, connectionContext, hintValueContext);
private SQLRewriteContext createSQLRewriteContext(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) {
HintValueContext hintValueContext = queryContext.getHintValueContext();
SQLRewriteContext result = new SQLRewriteContext(database, queryContext.getSqlStatementContext(), queryContext.getSql(), queryContext.getParameters(), connectionContext, hintValueContext);
decorate(decorators, result, routeContext, hintValueContext);
result.generateSQLTokens();
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
import org.apache.shardingsphere.infra.rewrite.sql.impl.DefaultSQLBuilder;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;

import java.util.Map;
Expand All @@ -46,14 +47,14 @@ public final class GenericSQLRewriteEngine {
* Rewrite SQL and parameters.
*
* @param sqlRewriteContext SQL rewrite context
* @param queryContext query context
* @return SQL rewrite result
*/
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) {
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final QueryContext queryContext) {
DatabaseType protocolType = database.getProtocolType();
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
DatabaseType storageType = storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType();
String sql = translatorRule.translate(
new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext(), storageType, database, globalRuleMetaData);
String sql = translatorRule.translate(new DefaultSQLBuilder(sqlRewriteContext).toSQL(), queryContext, storageType, database, globalRuleMetaData);
return new GenericSQLRewriteResult(new SQLRewriteUnit(sql, sqlRewriteContext.getParameterBuilder().getParameters()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
Expand Down Expand Up @@ -62,9 +63,10 @@ public final class RouteSQLRewriteEngine {
*
* @param sqlRewriteContext SQL rewrite context
* @param routeContext route context
* @param queryContext query context
* @return SQL rewrite result
*/
public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final QueryContext queryContext) {
Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits = new LinkedHashMap<>(routeContext.getRouteUnits().size(), 1F);
for (Entry<String, Collection<RouteUnit>> entry : aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {
Collection<RouteUnit> routeUnits = entry.getValue();
Expand All @@ -74,7 +76,7 @@ public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext,
addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
}
}
return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext(), sqlRewriteUnits));
return new RouteSQLRewriteResult(translate(queryContext, sqlRewriteUnits));
}

private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {
Expand Down Expand Up @@ -155,12 +157,12 @@ private boolean isInSameDataNode(final Collection<DataNode> dataNodes, final Rou
return false;
}

private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatementContext sqlStatementContext, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
private Map<RouteUnit, SQLRewriteUnit> translate(final QueryContext queryContext, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
Map<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F);
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
for (Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
DatabaseType storageType = storageUnits.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageType();
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatementContext, storageType, database, globalRuleMetaData);
String sql = translatorRule.translate(entry.getValue().getSql(), queryContext, storageType, database, globalRuleMetaData);
SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters());
result.put(entry.getKey(), sqlRewriteUnit);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
Expand All @@ -59,12 +60,21 @@ void assertRewriteForGenericSQLRewriteResult() {
SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(
database, new RuleMetaData(Collections.singleton(new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()))), new ConfigurationProperties(new Properties()));
RouteContext routeContext = new RouteContext();
GenericSQLRewriteResult sqlRewriteResult = (GenericSQLRewriteResult) sqlRewriteEntry.rewrite("SELECT ?", Collections.singletonList(1), mock(CommonSQLStatementContext.class), routeContext,
mock(ConnectionContext.class), new HintValueContext());
GenericSQLRewriteResult sqlRewriteResult = (GenericSQLRewriteResult) sqlRewriteEntry.rewrite(createQueryContext(), routeContext, mock(ConnectionContext.class));
assertThat(sqlRewriteResult.getSqlRewriteUnit().getSql(), is("SELECT ?"));
assertThat(sqlRewriteResult.getSqlRewriteUnit().getParameters(), is(Collections.singletonList(1)));
}

private QueryContext createQueryContext() {
QueryContext result = mock(QueryContext.class);
when(result.getSql()).thenReturn("SELECT ?");
when(result.getParameters()).thenReturn(Collections.singletonList(1));
CommonSQLStatementContext sqlStatementContext = mock(CommonSQLStatementContext.class);
when(result.getSqlStatementContext()).thenReturn(sqlStatementContext);
when(result.getHintValueContext()).thenReturn(new HintValueContext());
return result;
}

@Test
void assertRewriteForRouteSQLRewriteResult() {
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME, TypedSPILoader.getService(DatabaseType.class, "H2"), mockResourceMetaData(),
Expand All @@ -77,8 +87,7 @@ void assertRewriteForRouteSQLRewriteResult() {
RouteUnit secondRouteUnit = mock(RouteUnit.class);
when(secondRouteUnit.getDataSourceMapper()).thenReturn(new RouteMapper("ds", "ds_1"));
routeContext.getRouteUnits().addAll(Arrays.asList(firstRouteUnit, secondRouteUnit));
RouteSQLRewriteResult sqlRewriteResult = (RouteSQLRewriteResult) sqlRewriteEntry.rewrite("SELECT ?",
Collections.singletonList(1), mock(CommonSQLStatementContext.class), routeContext, mock(ConnectionContext.class), new HintValueContext());
RouteSQLRewriteResult sqlRewriteResult = (RouteSQLRewriteResult) sqlRewriteEntry.rewrite(createQueryContext(), routeContext, mock(ConnectionContext.class));
assertThat(sqlRewriteResult.getSqlRewriteUnits().size(), is(2));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
import org.junit.jupiter.api.Test;
Expand All @@ -53,7 +54,7 @@ void assertRewrite() {
when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits);
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
new HintValueContext()), mock(QueryContext.class));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}
Expand All @@ -67,7 +68,7 @@ void assertRewriteStorageTypeIsEmpty() {
when(database.getResourceMetaData().getStorageUnits()).thenReturn(Collections.emptyMap());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
new HintValueContext()), mock(QueryContext.class));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}
Expand Down
Loading

0 comments on commit 6cd86e2

Please sign in to comment.