From 6cd86e272c3ffd83e02007f10b08b48b8c0fc5d1 Mon Sep 17 00:00:00 2001 From: duanzhengqiang Date: Tue, 31 Oct 2023 11:48:20 +0800 Subject: [PATCH] Refactor SQLTranslator interface to reduce parameter passing --- .../connection/kernel/KernelProcessor.java | 2 +- .../infra/rewrite/SQLRewriteEntry.java | 23 ++++++++----------- .../engine/GenericSQLRewriteEngine.java | 7 +++--- .../rewrite/engine/RouteSQLRewriteEngine.java | 10 ++++---- .../infra/rewrite/SQLRewriteEntryTest.java | 17 ++++++++++---- .../engine/GenericSQLRewriteEngineTest.java | 5 ++-- .../engine/RouteSQLRewriteEngineTest.java | 13 ++++++----- .../datanode/SingleTableDataNodeLoader.java | 2 +- kernel/sql-translator/api/pom.xml | 2 +- .../sqltranslator/spi/SQLTranslator.java | 6 ++--- .../sqltranslator/rule/SQLTranslatorRule.java | 8 +++---- .../rule/SQLTranslatorRuleTest.java | 12 +++++----- .../fixture/AlwaysFailedSQLTranslator.java | 4 ++-- .../ConvertToUpperCaseSQLTranslator.java | 4 ++-- .../sqltranslator/jooq/JooQSQLTranslator.java | 4 ++-- .../natived/NativeSQLTranslator.java | 4 ++-- .../test/it/rewrite/engine/SQLRewriterIT.java | 3 +-- 17 files changed, 67 insertions(+), 59 deletions(-) diff --git a/infra/context/src/main/java/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java b/infra/context/src/main/java/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java index 0b9f9364cb3dc..a5b82ff3ade21 100644 --- a/infra/context/src/main/java/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java +++ b/infra/context/src/main/java/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java @@ -63,7 +63,7 @@ private RouteContext route(final QueryContext queryContext, final ShardingSphere private SQLRewriteResult rewrite(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData, final ConfigurationProperties props, final RouteContext routeContext, final ConnectionContext connectionContext) { SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, globalRuleMetaData, props); - return sqlRewriteEntry.rewrite(queryContext.getSql(), queryContext.getParameters(), queryContext.getSqlStatementContext(), routeContext, connectionContext, queryContext.getHintValueContext()); + return sqlRewriteEntry.rewrite(queryContext, routeContext, connectionContext); } private ExecutionContext createExecutionContext(final QueryContext queryContext, final ShardingSphereDatabase database, final RouteContext routeContext, final SQLRewriteResult rewriteResult) { diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java index c41b9a8c81fcd..0bf9d71d10f22 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java @@ -17,7 +17,6 @@ package org.apache.shardingsphere.infra.rewrite; -import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.config.props.ConfigurationProperties; import org.apache.shardingsphere.infra.hint.HintValueContext; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; @@ -30,10 +29,10 @@ import org.apache.shardingsphere.infra.route.context.RouteContext; import org.apache.shardingsphere.infra.rule.ShardingSphereRule; import org.apache.shardingsphere.infra.session.connection.ConnectionContext; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPILoader; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; -import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -61,27 +60,23 @@ public SQLRewriteEntry(final ShardingSphereDatabase database, final RuleMetaData /** * Rewrite. * - * @param sql SQL - * @param params SQL parameters - * @param sqlStatementContext SQL statement context + * @param queryContext query context * @param routeContext route context * @param connectionContext connection context - * @param hintValueContext hint value context * * @return route unit and SQL rewrite result map */ - public SQLRewriteResult rewrite(final String sql, final List params, final SQLStatementContext sqlStatementContext, - final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) { - SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext); + public SQLRewriteResult rewrite(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) { + SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(queryContext, routeContext, connectionContext); SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class); return routeContext.getRouteUnits().isEmpty() - ? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext) - : new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext); + ? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, queryContext) + : new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext, queryContext); } - private SQLRewriteContext createSQLRewriteContext(final String sql, final List params, final SQLStatementContext sqlStatementContext, - final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) { - SQLRewriteContext result = new SQLRewriteContext(database, sqlStatementContext, sql, params, connectionContext, hintValueContext); + private SQLRewriteContext createSQLRewriteContext(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) { + HintValueContext hintValueContext = queryContext.getHintValueContext(); + SQLRewriteContext result = new SQLRewriteContext(database, queryContext.getSqlStatementContext(), queryContext.getSql(), queryContext.getParameters(), connectionContext, hintValueContext); decorate(decorators, result, routeContext, hintValueContext); result.generateSQLTokens(); return result; diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngine.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngine.java index 4618c1ed01391..b710d640de9e3 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngine.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngine.java @@ -26,6 +26,7 @@ import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult; import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit; import org.apache.shardingsphere.infra.rewrite.sql.impl.DefaultSQLBuilder; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; import java.util.Map; @@ -46,14 +47,14 @@ public final class GenericSQLRewriteEngine { * Rewrite SQL and parameters. * * @param sqlRewriteContext SQL rewrite context + * @param queryContext query context * @return SQL rewrite result */ - public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) { + public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final QueryContext queryContext) { DatabaseType protocolType = database.getProtocolType(); Map storageUnits = database.getResourceMetaData().getStorageUnits(); DatabaseType storageType = storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType(); - String sql = translatorRule.translate( - new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext(), storageType, database, globalRuleMetaData); + String sql = translatorRule.translate(new DefaultSQLBuilder(sqlRewriteContext).toSQL(), queryContext, storageType, database, globalRuleMetaData); return new GenericSQLRewriteResult(new SQLRewriteUnit(sql, sqlRewriteContext.getParameterBuilder().getParameters())); } } diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java index 484d012c90cfd..d42228a059c1e 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java @@ -34,6 +34,7 @@ import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder; import org.apache.shardingsphere.infra.route.context.RouteContext; import org.apache.shardingsphere.infra.route.context.RouteUnit; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils; import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; @@ -62,9 +63,10 @@ public final class RouteSQLRewriteEngine { * * @param sqlRewriteContext SQL rewrite context * @param routeContext route context + * @param queryContext query context * @return SQL rewrite result */ - public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) { + public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final QueryContext queryContext) { Map sqlRewriteUnits = new LinkedHashMap<>(routeContext.getRouteUnits().size(), 1F); for (Entry> entry : aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) { Collection routeUnits = entry.getValue(); @@ -74,7 +76,7 @@ public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits); } } - return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext(), sqlRewriteUnits)); + return new RouteSQLRewriteResult(translate(queryContext, sqlRewriteUnits)); } private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final Collection routeUnits) { @@ -155,12 +157,12 @@ private boolean isInSameDataNode(final Collection dataNodes, final Rou return false; } - private Map translate(final SQLStatementContext sqlStatementContext, final Map sqlRewriteUnits) { + private Map translate(final QueryContext queryContext, final Map sqlRewriteUnits) { Map result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F); Map storageUnits = database.getResourceMetaData().getStorageUnits(); for (Entry entry : sqlRewriteUnits.entrySet()) { DatabaseType storageType = storageUnits.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageType(); - String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatementContext, storageType, database, globalRuleMetaData); + String sql = translatorRule.translate(entry.getValue().getSql(), queryContext, storageType, database, globalRuleMetaData); SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters()); result.put(entry.getKey(), sqlRewriteUnit); } diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntryTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntryTest.java index 28b47e698c7a2..2fb3eacfba9e4 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntryTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntryTest.java @@ -33,6 +33,7 @@ import org.apache.shardingsphere.infra.route.context.RouteMapper; import org.apache.shardingsphere.infra.route.context.RouteUnit; import org.apache.shardingsphere.infra.session.connection.ConnectionContext; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader; import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; @@ -59,12 +60,21 @@ void assertRewriteForGenericSQLRewriteResult() { SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry( database, new RuleMetaData(Collections.singleton(new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()))), new ConfigurationProperties(new Properties())); RouteContext routeContext = new RouteContext(); - GenericSQLRewriteResult sqlRewriteResult = (GenericSQLRewriteResult) sqlRewriteEntry.rewrite("SELECT ?", Collections.singletonList(1), mock(CommonSQLStatementContext.class), routeContext, - mock(ConnectionContext.class), new HintValueContext()); + GenericSQLRewriteResult sqlRewriteResult = (GenericSQLRewriteResult) sqlRewriteEntry.rewrite(createQueryContext(), routeContext, mock(ConnectionContext.class)); assertThat(sqlRewriteResult.getSqlRewriteUnit().getSql(), is("SELECT ?")); assertThat(sqlRewriteResult.getSqlRewriteUnit().getParameters(), is(Collections.singletonList(1))); } + private QueryContext createQueryContext() { + QueryContext result = mock(QueryContext.class); + when(result.getSql()).thenReturn("SELECT ?"); + when(result.getParameters()).thenReturn(Collections.singletonList(1)); + CommonSQLStatementContext sqlStatementContext = mock(CommonSQLStatementContext.class); + when(result.getSqlStatementContext()).thenReturn(sqlStatementContext); + when(result.getHintValueContext()).thenReturn(new HintValueContext()); + return result; + } + @Test void assertRewriteForRouteSQLRewriteResult() { ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME, TypedSPILoader.getService(DatabaseType.class, "H2"), mockResourceMetaData(), @@ -77,8 +87,7 @@ void assertRewriteForRouteSQLRewriteResult() { RouteUnit secondRouteUnit = mock(RouteUnit.class); when(secondRouteUnit.getDataSourceMapper()).thenReturn(new RouteMapper("ds", "ds_1")); routeContext.getRouteUnits().addAll(Arrays.asList(firstRouteUnit, secondRouteUnit)); - RouteSQLRewriteResult sqlRewriteResult = (RouteSQLRewriteResult) sqlRewriteEntry.rewrite("SELECT ?", - Collections.singletonList(1), mock(CommonSQLStatementContext.class), routeContext, mock(ConnectionContext.class), new HintValueContext()); + RouteSQLRewriteResult sqlRewriteResult = (RouteSQLRewriteResult) sqlRewriteEntry.rewrite(createQueryContext(), routeContext, mock(ConnectionContext.class)); assertThat(sqlRewriteResult.getSqlRewriteUnits().size(), is(2)); } diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java index 2c550a2f5be87..8d3512e9bdf94 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java @@ -28,6 +28,7 @@ import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult; import org.apache.shardingsphere.infra.session.connection.ConnectionContext; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; import org.junit.jupiter.api.Test; @@ -53,7 +54,7 @@ void assertRewrite() { when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits); GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)) .rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), - new HintValueContext())); + new HintValueContext()), mock(QueryContext.class)); assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1")); assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList())); } @@ -67,7 +68,7 @@ void assertRewriteStorageTypeIsEmpty() { when(database.getResourceMetaData().getStorageUnits()).thenReturn(Collections.emptyMap()); GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)) .rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), - new HintValueContext())); + new HintValueContext()), mock(QueryContext.class)); assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1")); assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList())); } diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java index 372ad392bbfe8..ad83303895cac 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java @@ -35,6 +35,7 @@ import org.apache.shardingsphere.infra.route.context.RouteMapper; import org.apache.shardingsphere.infra.route.context.RouteUnit; import org.apache.shardingsphere.infra.session.connection.ConnectionContext; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; import org.junit.jupiter.api.Test; @@ -62,7 +63,7 @@ void assertRewriteWithStandardParameterBuilder() { RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, mock(QueryContext.class)); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("SELECT ?")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -92,7 +93,7 @@ void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() { routeContext.getRouteUnits().add(firstRouteUnit); routeContext.getRouteUnits().add(secondRouteUnit); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, mock(QueryContext.class)); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getSql(), is("SELECT ? UNION ALL SELECT ?")); assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getParameters(), is(Arrays.asList(1, 1))); @@ -113,7 +114,7 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() { RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, mock(QueryContext.class)); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -136,7 +137,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() { // TODO check why data node is "ds.tbl_0", not "ds_0.tbl_0" routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds.tbl_0"))); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, mock(QueryContext.class)); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -158,7 +159,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() { routeContext.getRouteUnits().add(routeUnit); routeContext.getOriginalDataNodes().add(Collections.emptyList()); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, mock(QueryContext.class)); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -180,7 +181,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithNotSameDataNode() { routeContext.getRouteUnits().add(routeUnit); routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds_1.tbl_1"))); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, mock(QueryContext.class)); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertTrue(actual.getSqlRewriteUnits().get(routeUnit).getParameters().isEmpty()); diff --git a/kernel/single/core/src/main/java/org/apache/shardingsphere/single/datanode/SingleTableDataNodeLoader.java b/kernel/single/core/src/main/java/org/apache/shardingsphere/single/datanode/SingleTableDataNodeLoader.java index 2265e0bf91c4c..42500e70f37a6 100644 --- a/kernel/single/core/src/main/java/org/apache/shardingsphere/single/datanode/SingleTableDataNodeLoader.java +++ b/kernel/single/core/src/main/java/org/apache/shardingsphere/single/datanode/SingleTableDataNodeLoader.java @@ -20,9 +20,9 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; import org.apache.shardingsphere.infra.database.DatabaseTypeEngine; +import org.apache.shardingsphere.infra.database.core.metadata.data.loader.type.SchemaMetaDataLoader; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.datanode.DataNode; -import org.apache.shardingsphere.infra.database.core.metadata.data.loader.type.SchemaMetaDataLoader; import org.apache.shardingsphere.infra.rule.ShardingSphereRule; import org.apache.shardingsphere.single.api.constant.SingleTableConstants; import org.apache.shardingsphere.single.exception.SingleTablesLoadingException; diff --git a/kernel/sql-translator/api/pom.xml b/kernel/sql-translator/api/pom.xml index 14fe397ff9dc3..6b7db04db9472 100644 --- a/kernel/sql-translator/api/pom.xml +++ b/kernel/sql-translator/api/pom.xml @@ -34,7 +34,7 @@ org.apache.shardingsphere - shardingsphere-infra-binder + shardingsphere-infra-session ${project.version} diff --git a/kernel/sql-translator/api/src/main/java/org/apache/shardingsphere/sqltranslator/spi/SQLTranslator.java b/kernel/sql-translator/api/src/main/java/org/apache/shardingsphere/sqltranslator/spi/SQLTranslator.java index d0753cbeb01d7..ed569f47da64b 100644 --- a/kernel/sql-translator/api/src/main/java/org/apache/shardingsphere/sqltranslator/spi/SQLTranslator.java +++ b/kernel/sql-translator/api/src/main/java/org/apache/shardingsphere/sqltranslator/spi/SQLTranslator.java @@ -17,10 +17,10 @@ package org.apache.shardingsphere.sqltranslator.spi; -import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.infra.spi.annotation.SingletonSPI; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPI; @@ -34,11 +34,11 @@ public interface SQLTranslator extends TypedSPI { * Translate SQL. * * @param sql to be translated SQL - * @param sqlStatementContext SQL statement context + * @param queryContext query context * @param storageType storage type * @param database database * @param globalRuleMetaData global rule meta data * @return translated SQL */ - String translate(String sql, SQLStatementContext sqlStatementContext, DatabaseType storageType, ShardingSphereDatabase database, RuleMetaData globalRuleMetaData); + String translate(String sql, QueryContext queryContext, DatabaseType storageType, ShardingSphereDatabase database, RuleMetaData globalRuleMetaData); } diff --git a/kernel/sql-translator/core/src/main/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRule.java b/kernel/sql-translator/core/src/main/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRule.java index 6f7b3096447ac..7f8163b6647a3 100644 --- a/kernel/sql-translator/core/src/main/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRule.java +++ b/kernel/sql-translator/core/src/main/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRule.java @@ -18,11 +18,11 @@ package org.apache.shardingsphere.sqltranslator.rule; import lombok.Getter; -import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.rule.identifier.scope.GlobalRule; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader; import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration; import org.apache.shardingsphere.sqltranslator.exception.SQLTranslationException; @@ -50,20 +50,20 @@ public SQLTranslatorRule(final SQLTranslatorRuleConfiguration ruleConfig) { * Translate SQL. * * @param sql to be translated SQL - * @param sqlStatementContext SQL statement context + * @param queryContext query context * @param storageType storage type * @param database database * @param globalRuleMetaData global rule meta data * @return translated SQL */ - public String translate(final String sql, final SQLStatementContext sqlStatementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + public String translate(final String sql, final QueryContext queryContext, final DatabaseType storageType, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { DatabaseType protocolType = database.getProtocolType(); if (protocolType.equals(storageType) || null == storageType) { return sql; } try { - return translator.translate(sql, sqlStatementContext, storageType, database, globalRuleMetaData); + return translator.translate(sql, queryContext, storageType, database, globalRuleMetaData); } catch (final SQLTranslationException ex) { if (useOriginalSQLWhenTranslatingFailed) { return sql; diff --git a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRuleTest.java b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRuleTest.java index 7e703ed826a2e..8c482a04b91a9 100644 --- a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRuleTest.java +++ b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRuleTest.java @@ -17,10 +17,10 @@ package org.apache.shardingsphere.sqltranslator.rule; -import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader; import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration; import org.apache.shardingsphere.sqltranslator.exception.syntax.UnsupportedTranslatedDatabaseException; @@ -42,7 +42,7 @@ void assertTranslateWhenProtocolSameAsStorage() { DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); ShardingSphereDatabase database = mock(ShardingSphereDatabase.class); when(database.getProtocolType()).thenReturn(databaseType); - String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, mock(SQLStatementContext.class), databaseType, database, + String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, mock(QueryContext.class), databaseType, database, mock(RuleMetaData.class)); assertThat(actual, is(expected)); } @@ -53,7 +53,7 @@ void assertTranslateWhenNoStorage() { ShardingSphereDatabase database = mock(ShardingSphereDatabase.class); DatabaseType protocolType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); when(database.getProtocolType()).thenReturn(protocolType); - String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, mock(SQLStatementContext.class), null, database, + String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, mock(QueryContext.class), null, database, mock(RuleMetaData.class)); assertThat(actual, is(expected)); } @@ -65,7 +65,7 @@ void assertTranslateWithProtocolDifferentWithStorage() { ShardingSphereDatabase database = mock(ShardingSphereDatabase.class); when(database.getProtocolType()).thenReturn(protocolType); DatabaseType storageType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); - String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(input, mock(SQLStatementContext.class), storageType, database, + String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(input, mock(QueryContext.class), storageType, database, mock(RuleMetaData.class)); assertThat(actual, is(input.toUpperCase(Locale.ROOT))); } @@ -78,7 +78,7 @@ void assertUseOriginalSQLWhenTranslatingFailed() { when(database.getProtocolType()).thenReturn(protocolType); DatabaseType storageType = TypedSPILoader.getService(DatabaseType.class, "MySQL"); String actual = - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", true)).translate(expected, mock(SQLStatementContext.class), storageType, database, mock(RuleMetaData.class)); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", true)).translate(expected, mock(QueryContext.class), storageType, database, mock(RuleMetaData.class)); assertThat(actual, is(expected)); } @@ -89,7 +89,7 @@ void assertNotUseOriginalSQLWhenTranslatingFailed() { when(database.getProtocolType()).thenReturn(protocolType); DatabaseType storageType = TypedSPILoader.getService(DatabaseType.class, "MySQL"); assertThrows(UnsupportedTranslatedDatabaseException.class, - () -> new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", false)).translate("", mock(SQLStatementContext.class), storageType, database, + () -> new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", false)).translate("", mock(QueryContext.class), storageType, database, mock(RuleMetaData.class))); } diff --git a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/AlwaysFailedSQLTranslator.java b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/AlwaysFailedSQLTranslator.java index c8f0b06c1d976..fb10033457f34 100644 --- a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/AlwaysFailedSQLTranslator.java +++ b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/AlwaysFailedSQLTranslator.java @@ -17,17 +17,17 @@ package org.apache.shardingsphere.sqltranslator.rule.fixture; -import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.exception.syntax.UnsupportedTranslatedDatabaseException; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; public final class AlwaysFailedSQLTranslator implements SQLTranslator { @Override - public String translate(final String sql, final SQLStatementContext sqlStatementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + public String translate(final String sql, final QueryContext queryContext, final DatabaseType storageType, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { throw new UnsupportedTranslatedDatabaseException(storageType); } diff --git a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/ConvertToUpperCaseSQLTranslator.java b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/ConvertToUpperCaseSQLTranslator.java index 5d71ef0d76603..6978a901c8c6f 100644 --- a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/ConvertToUpperCaseSQLTranslator.java +++ b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/ConvertToUpperCaseSQLTranslator.java @@ -17,10 +17,10 @@ package org.apache.shardingsphere.sqltranslator.rule.fixture; -import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; import java.util.Locale; @@ -28,7 +28,7 @@ public final class ConvertToUpperCaseSQLTranslator implements SQLTranslator { @Override - public String translate(final String sql, final SQLStatementContext sqlStatementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + public String translate(final String sql, final QueryContext queryContext, final DatabaseType storageType, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { return sql.toUpperCase(Locale.ROOT); } diff --git a/kernel/sql-translator/provider/jooq/src/main/java/org/apache/shardingsphere/sqltranslator/jooq/JooQSQLTranslator.java b/kernel/sql-translator/provider/jooq/src/main/java/org/apache/shardingsphere/sqltranslator/jooq/JooQSQLTranslator.java index 32ba37b1c44a1..3797b3a1fe1bc 100644 --- a/kernel/sql-translator/provider/jooq/src/main/java/org/apache/shardingsphere/sqltranslator/jooq/JooQSQLTranslator.java +++ b/kernel/sql-translator/provider/jooq/src/main/java/org/apache/shardingsphere/sqltranslator/jooq/JooQSQLTranslator.java @@ -17,10 +17,10 @@ package org.apache.shardingsphere.sqltranslator.jooq; -import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.exception.syntax.UnsupportedTranslatedSQLException; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; import org.jooq.Query; @@ -32,7 +32,7 @@ public final class JooQSQLTranslator implements SQLTranslator { @Override - public String translate(final String sql, final SQLStatementContext statementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + public String translate(final String sql, final QueryContext queryContext, final DatabaseType storageType, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { try { Query query = DSL.using(JooQDialectRegistry.getSQLDialect(database.getProtocolType())).parser().parseQuery(sql); diff --git a/kernel/sql-translator/provider/native/src/main/java/org/apache/shardingsphere/sqltranslator/natived/NativeSQLTranslator.java b/kernel/sql-translator/provider/native/src/main/java/org/apache/shardingsphere/sqltranslator/natived/NativeSQLTranslator.java index 1ce5e045e2306..b93ecd5d6191a 100644 --- a/kernel/sql-translator/provider/native/src/main/java/org/apache/shardingsphere/sqltranslator/natived/NativeSQLTranslator.java +++ b/kernel/sql-translator/provider/native/src/main/java/org/apache/shardingsphere/sqltranslator/natived/NativeSQLTranslator.java @@ -17,10 +17,10 @@ package org.apache.shardingsphere.sqltranslator.natived; -import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; /** @@ -29,7 +29,7 @@ public final class NativeSQLTranslator implements SQLTranslator { @Override - public String translate(final String sql, final SQLStatementContext statementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + public String translate(final String sql, final QueryContext queryContext, final DatabaseType storageType, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { // TODO return sql; diff --git a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java index f21492e1e76a0..388d9384c0cf0 100644 --- a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java +++ b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java @@ -148,8 +148,7 @@ private Collection createSQLRewriteUnits(final SQLRewriteEngineT SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, globalRuleMetaData, props); ConnectionContext connectionContext = mock(ConnectionContext.class); when(connectionContext.getCursorContext()).thenReturn(new CursorConnectionContext()); - SQLRewriteResult sqlRewriteResult = sqlRewriteEntry.rewrite(testParams.getInputSQL(), testParams.getInputParameters(), sqlStatementContext, routeContext, connectionContext, - queryContext.getHintValueContext()); + SQLRewriteResult sqlRewriteResult = sqlRewriteEntry.rewrite(queryContext, routeContext, connectionContext); return sqlRewriteResult instanceof GenericSQLRewriteResult ? Collections.singleton(((GenericSQLRewriteResult) sqlRewriteResult).getSqlRewriteUnit()) : (((RouteSQLRewriteResult) sqlRewriteResult).getSqlRewriteUnits()).values();