diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java index ae8198aa47746..c41b9a8c81fcd 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java @@ -19,10 +19,8 @@ import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.config.props.ConfigurationProperties; -import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.hint.HintValueContext; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; -import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContextDecorator; @@ -76,11 +74,9 @@ public SQLRewriteResult rewrite(final String sql, final List params, fin final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) { SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext); SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class); - DatabaseType protocolType = database.getProtocolType(); - Map storageUnits = database.getResourceMetaData().getStorageUnits(); return routeContext.getRouteUnits().isEmpty() - ? new GenericSQLRewriteEngine(rule, protocolType, storageUnits, globalRuleMetaData).rewrite(sqlRewriteContext) - : new RouteSQLRewriteEngine(rule, protocolType, storageUnits, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext); + ? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext) + : new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext); } private SQLRewriteContext createSQLRewriteContext(final String sql, final List params, final SQLStatementContext sqlStatementContext, diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngine.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngine.java index 7e1a4547fdec8..4618c1ed01391 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngine.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngine.java @@ -19,6 +19,7 @@ import lombok.RequiredArgsConstructor; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; @@ -37,9 +38,7 @@ public final class GenericSQLRewriteEngine { private final SQLTranslatorRule translatorRule; - private final DatabaseType protocolType; - - private final Map storageUnits; + private final ShardingSphereDatabase database; private final RuleMetaData globalRuleMetaData; @@ -50,9 +49,11 @@ public final class GenericSQLRewriteEngine { * @return SQL rewrite result */ public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) { + DatabaseType protocolType = database.getProtocolType(); + Map storageUnits = database.getResourceMetaData().getStorageUnits(); + DatabaseType storageType = storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType(); String sql = translatorRule.translate( - new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext().getSqlStatement(), protocolType, - storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType(), globalRuleMetaData); + new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext(), storageType, database, globalRuleMetaData); return new GenericSQLRewriteResult(new SQLRewriteUnit(sql, sqlRewriteContext.getParameterBuilder().getParameters())); } } diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java index a7c0d2c0ddc33..484d012c90cfd 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java @@ -22,6 +22,7 @@ import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.datanode.DataNode; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; @@ -33,7 +34,6 @@ import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder; import org.apache.shardingsphere.infra.route.context.RouteContext; import org.apache.shardingsphere.infra.route.context.RouteUnit; -import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils; import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; @@ -53,9 +53,7 @@ public final class RouteSQLRewriteEngine { private final SQLTranslatorRule translatorRule; - private final DatabaseType protocolType; - - private final Map storageUnits; + private final ShardingSphereDatabase database; private final RuleMetaData globalRuleMetaData; @@ -76,7 +74,7 @@ public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits); } } - return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext().getSqlStatement(), sqlRewriteUnits)); + return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext(), sqlRewriteUnits)); } private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final Collection routeUnits) { @@ -157,11 +155,12 @@ private boolean isInSameDataNode(final Collection dataNodes, final Rou return false; } - private Map translate(final SQLStatement sqlStatement, final Map sqlRewriteUnits) { + private Map translate(final SQLStatementContext sqlStatementContext, final Map sqlRewriteUnits) { Map result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F); + Map storageUnits = database.getResourceMetaData().getStorageUnits(); for (Entry entry : sqlRewriteUnits.entrySet()) { DatabaseType storageType = storageUnits.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageType(); - String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatement, protocolType, storageType, globalRuleMetaData); + String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatementContext, storageType, database, globalRuleMetaData); SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters()); result.put(entry.getKey(), sqlRewriteUnit); } diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java index 90d19b541ed78..2c550a2f5be87 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java @@ -47,8 +47,12 @@ class GenericSQLRewriteEngineTest { void assertRewrite() { DatabaseType databaseType = mock(DatabaseType.class); SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()); - GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)) - .rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS); + when(database.getProtocolType()).thenReturn(databaseType); + Map storageUnits = mockStorageUnits(databaseType); + when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits); + GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)) + .rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), new HintValueContext())); assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1")); assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList())); @@ -57,8 +61,12 @@ void assertRewrite() { @Test void assertRewriteStorageTypeIsEmpty() { SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()); - GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, mock(DatabaseType.class), Collections.emptyMap(), mock(RuleMetaData.class)) - .rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS); + when(database.getName()).thenReturn(DefaultDatabase.LOGIC_NAME); + when(database.getSchemas()).thenReturn(Collections.singletonMap("test", mock(ShardingSphereSchema.class))); + when(database.getResourceMetaData().getStorageUnits()).thenReturn(Collections.emptyMap()); + GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)) + .rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), new HintValueContext())); assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1")); assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList())); @@ -69,11 +77,4 @@ private Map mockStorageUnits(final DatabaseType databaseTyp when(result.getStorageType()).thenReturn(databaseType); return Collections.singletonMap("ds_0", result); } - - private ShardingSphereDatabase mockDatabase() { - ShardingSphereDatabase result = mock(ShardingSphereDatabase.class); - when(result.getName()).thenReturn(DefaultDatabase.LOGIC_NAME); - when(result.getSchemas()).thenReturn(Collections.singletonMap("test", mock(ShardingSphereSchema.class))); - return result; - } } diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java index db5eae2ea573d..372ad392bbfe8 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java @@ -54,33 +54,45 @@ class RouteSQLRewriteEngineTest { @Test void assertRewriteWithStandardParameterBuilder() { + DatabaseType databaseType = mock(DatabaseType.class); + ShardingSphereDatabase database = mockDatabase(databaseType); SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); - DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("SELECT ?")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); } + private ShardingSphereDatabase mockDatabase(final DatabaseType databaseType) { + ShardingSphereDatabase result = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS); + when(result.getProtocolType()).thenReturn(databaseType); + Map storageUnits = mockStorageUnits(databaseType); + when(result.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits); + when(result.getName()).thenReturn(DefaultDatabase.LOGIC_NAME); + when(result.getSchemas()).thenReturn(Collections.singletonMap("test", mock(ShardingSphereSchema.class))); + return result; + } + @Test void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() { SelectStatementContext statementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS); when(statementContext.getOrderByContext().getItems()).thenReturn(Collections.emptyList()); when(statementContext.getPaginationContext().isHasPagination()).thenReturn(false); - SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(mockDatabase(), statementContext, "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + DatabaseType databaseType = mock(DatabaseType.class); + ShardingSphereDatabase database = mockDatabase(databaseType); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, statementContext, "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteContext routeContext = new RouteContext(); RouteUnit firstRouteUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteUnit secondRouteUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_1"))); routeContext.getRouteUnits().add(firstRouteUnit); routeContext.getRouteUnits().add(secondRouteUnit); - DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getSql(), is("SELECT ? UNION ALL SELECT ?")); assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getParameters(), is(Arrays.asList(1, 1))); @@ -93,14 +105,15 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() { when(statementContext.getInsertSelectContext()).thenReturn(null); when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1))); when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList()); + DatabaseType databaseType = mock(DatabaseType.class); + ShardingSphereDatabase database = mockDatabase(databaseType); SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); - DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -113,16 +126,17 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() { when(statementContext.getInsertSelectContext()).thenReturn(null); when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1))); when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList()); + DatabaseType databaseType = mock(DatabaseType.class); + ShardingSphereDatabase database = mockDatabase(databaseType); SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); // TODO check why data node is "ds.tbl_0", not "ds_0.tbl_0" routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds.tbl_0"))); - DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -135,15 +149,16 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() { when(statementContext.getInsertSelectContext()).thenReturn(null); when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1))); when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList()); + DatabaseType databaseType = mock(DatabaseType.class); + ShardingSphereDatabase database = mockDatabase(databaseType); SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); routeContext.getOriginalDataNodes().add(Collections.emptyList()); - DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -156,27 +171,21 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithNotSameDataNode() { when(statementContext.getInsertSelectContext()).thenReturn(null); when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1))); when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList()); + DatabaseType databaseType = mock(DatabaseType.class); + ShardingSphereDatabase database = mockDatabase(databaseType); SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds_1.tbl_1"))); - DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertTrue(actual.getSqlRewriteUnits().get(routeUnit).getParameters().isEmpty()); } - private ShardingSphereDatabase mockDatabase() { - ShardingSphereDatabase result = mock(ShardingSphereDatabase.class); - when(result.getName()).thenReturn(DefaultDatabase.LOGIC_NAME); - when(result.getSchemas()).thenReturn(Collections.singletonMap("test", mock(ShardingSphereSchema.class))); - return result; - } - private Map mockStorageUnits(final DatabaseType databaseType) { StorageUnit result = mock(StorageUnit.class, RETURNS_DEEP_STUBS); when(result.getStorageType()).thenReturn(databaseType); diff --git a/kernel/sql-translator/api/pom.xml b/kernel/sql-translator/api/pom.xml index 987edd885af76..04aea3fda3893 100644 --- a/kernel/sql-translator/api/pom.xml +++ b/kernel/sql-translator/api/pom.xml @@ -32,5 +32,10 @@ shardingsphere-infra-common ${project.version} + + org.apache.shardingsphere + shardingsphere-infra-binder + ${project.version} + diff --git a/kernel/sql-translator/api/src/main/java/org/apache/shardingsphere/sqltranslator/spi/SQLTranslator.java b/kernel/sql-translator/api/src/main/java/org/apache/shardingsphere/sqltranslator/spi/SQLTranslator.java index ed0d35b962098..d0753cbeb01d7 100644 --- a/kernel/sql-translator/api/src/main/java/org/apache/shardingsphere/sqltranslator/spi/SQLTranslator.java +++ b/kernel/sql-translator/api/src/main/java/org/apache/shardingsphere/sqltranslator/spi/SQLTranslator.java @@ -17,11 +17,12 @@ package org.apache.shardingsphere.sqltranslator.spi; +import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.spi.annotation.SingletonSPI; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPI; -import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; /** * SQL translator. @@ -33,11 +34,11 @@ public interface SQLTranslator extends TypedSPI { * Translate SQL. * * @param sql to be translated SQL - * @param sqlStatement to be translated SQL statement - * @param protocolType protocol type + * @param sqlStatementContext SQL statement context * @param storageType storage type + * @param database database * @param globalRuleMetaData global rule meta data * @return translated SQL */ - String translate(String sql, SQLStatement sqlStatement, DatabaseType protocolType, DatabaseType storageType, RuleMetaData globalRuleMetaData); + String translate(String sql, SQLStatementContext sqlStatementContext, DatabaseType storageType, ShardingSphereDatabase database, RuleMetaData globalRuleMetaData); } diff --git a/kernel/sql-translator/core/src/main/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRule.java b/kernel/sql-translator/core/src/main/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRule.java index c55b14b9ddaa6..6f7b3096447ac 100644 --- a/kernel/sql-translator/core/src/main/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRule.java +++ b/kernel/sql-translator/core/src/main/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRule.java @@ -18,11 +18,12 @@ package org.apache.shardingsphere.sqltranslator.rule; import lombok.Getter; +import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.rule.identifier.scope.GlobalRule; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader; -import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration; import org.apache.shardingsphere.sqltranslator.exception.SQLTranslationException; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; @@ -49,18 +50,20 @@ public SQLTranslatorRule(final SQLTranslatorRuleConfiguration ruleConfig) { * Translate SQL. * * @param sql to be translated SQL - * @param sqlStatement to be translated SQL statement - * @param protocolType protocol type + * @param sqlStatementContext SQL statement context * @param storageType storage type + * @param database database * @param globalRuleMetaData global rule meta data * @return translated SQL */ - public String translate(final String sql, final SQLStatement sqlStatement, final DatabaseType protocolType, final DatabaseType storageType, final RuleMetaData globalRuleMetaData) { + public String translate(final String sql, final SQLStatementContext sqlStatementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + final RuleMetaData globalRuleMetaData) { + DatabaseType protocolType = database.getProtocolType(); if (protocolType.equals(storageType) || null == storageType) { return sql; } try { - return translator.translate(sql, sqlStatement, protocolType, storageType, globalRuleMetaData); + return translator.translate(sql, sqlStatementContext, storageType, database, globalRuleMetaData); } catch (final SQLTranslationException ex) { if (useOriginalSQLWhenTranslatingFailed) { return sql; diff --git a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRuleTest.java b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRuleTest.java index 4736040d9156c..7e703ed826a2e 100644 --- a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRuleTest.java +++ b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/SQLTranslatorRuleTest.java @@ -17,7 +17,9 @@ package org.apache.shardingsphere.sqltranslator.rule; +import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader; import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration; @@ -30,6 +32,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; class SQLTranslatorRuleTest { @@ -37,39 +40,57 @@ class SQLTranslatorRuleTest { void assertTranslateWhenProtocolSameAsStorage() { String expected = "select 1"; DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); - String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, null, databaseType, databaseType, mock(RuleMetaData.class)); + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class); + when(database.getProtocolType()).thenReturn(databaseType); + String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, mock(SQLStatementContext.class), databaseType, database, + mock(RuleMetaData.class)); assertThat(actual, is(expected)); } @Test void assertTranslateWhenNoStorage() { String expected = "select 1"; - String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate( - expected, null, TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), null, mock(RuleMetaData.class)); + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class); + DatabaseType protocolType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); + when(database.getProtocolType()).thenReturn(protocolType); + String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, mock(SQLStatementContext.class), null, database, + mock(RuleMetaData.class)); assertThat(actual, is(expected)); } @Test void assertTranslateWithProtocolDifferentWithStorage() { String input = "select 1"; - String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate( - input, null, TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL"), mock(RuleMetaData.class)); + DatabaseType protocolType = TypedSPILoader.getService(DatabaseType.class, "MySQL"); + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class); + when(database.getProtocolType()).thenReturn(protocolType); + DatabaseType storageType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); + String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(input, mock(SQLStatementContext.class), storageType, database, + mock(RuleMetaData.class)); assertThat(actual, is(input.toUpperCase(Locale.ROOT))); } @Test void assertUseOriginalSQLWhenTranslatingFailed() { String expected = "select 1"; - String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", true)).translate(expected, null, - TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL"), mock(RuleMetaData.class)); + DatabaseType protocolType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class); + when(database.getProtocolType()).thenReturn(protocolType); + DatabaseType storageType = TypedSPILoader.getService(DatabaseType.class, "MySQL"); + String actual = + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", true)).translate(expected, mock(SQLStatementContext.class), storageType, database, mock(RuleMetaData.class)); assertThat(actual, is(expected)); } @Test void assertNotUseOriginalSQLWhenTranslatingFailed() { + DatabaseType protocolType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class); + when(database.getProtocolType()).thenReturn(protocolType); + DatabaseType storageType = TypedSPILoader.getService(DatabaseType.class, "MySQL"); assertThrows(UnsupportedTranslatedDatabaseException.class, - () -> new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", false)).translate("", null, - TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL"), mock(RuleMetaData.class))); + () -> new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", false)).translate("", mock(SQLStatementContext.class), storageType, database, + mock(RuleMetaData.class))); } @Test diff --git a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/AlwaysFailedSQLTranslator.java b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/AlwaysFailedSQLTranslator.java index 388c48ea3a35e..c8f0b06c1d976 100644 --- a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/AlwaysFailedSQLTranslator.java +++ b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/AlwaysFailedSQLTranslator.java @@ -17,16 +17,18 @@ package org.apache.shardingsphere.sqltranslator.rule.fixture; +import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; -import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; import org.apache.shardingsphere.sqltranslator.exception.syntax.UnsupportedTranslatedDatabaseException; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; public final class AlwaysFailedSQLTranslator implements SQLTranslator { @Override - public String translate(final String sql, final SQLStatement sqlStatement, final DatabaseType protocolType, final DatabaseType storageType, final RuleMetaData globalRuleMetaData) { + public String translate(final String sql, final SQLStatementContext sqlStatementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + final RuleMetaData globalRuleMetaData) { throw new UnsupportedTranslatedDatabaseException(storageType); } diff --git a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/ConvertToUpperCaseSQLTranslator.java b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/ConvertToUpperCaseSQLTranslator.java index 4427a63de02e4..5d71ef0d76603 100644 --- a/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/ConvertToUpperCaseSQLTranslator.java +++ b/kernel/sql-translator/core/src/test/java/org/apache/shardingsphere/sqltranslator/rule/fixture/ConvertToUpperCaseSQLTranslator.java @@ -17,9 +17,10 @@ package org.apache.shardingsphere.sqltranslator.rule.fixture; +import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; -import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; import java.util.Locale; @@ -27,7 +28,8 @@ public final class ConvertToUpperCaseSQLTranslator implements SQLTranslator { @Override - public String translate(final String sql, final SQLStatement sqlStatement, final DatabaseType protocolType, final DatabaseType storageType, final RuleMetaData globalRuleMetaData) { + public String translate(final String sql, final SQLStatementContext sqlStatementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + final RuleMetaData globalRuleMetaData) { return sql.toUpperCase(Locale.ROOT); } diff --git a/kernel/sql-translator/provider/jooq/src/main/java/org/apache/shardingsphere/sqltranslator/jooq/JooQSQLTranslator.java b/kernel/sql-translator/provider/jooq/src/main/java/org/apache/shardingsphere/sqltranslator/jooq/JooQSQLTranslator.java index 878f88736d80d..32ba37b1c44a1 100644 --- a/kernel/sql-translator/provider/jooq/src/main/java/org/apache/shardingsphere/sqltranslator/jooq/JooQSQLTranslator.java +++ b/kernel/sql-translator/provider/jooq/src/main/java/org/apache/shardingsphere/sqltranslator/jooq/JooQSQLTranslator.java @@ -17,9 +17,10 @@ package org.apache.shardingsphere.sqltranslator.jooq; +import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; -import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; import org.apache.shardingsphere.sqltranslator.exception.syntax.UnsupportedTranslatedSQLException; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; import org.jooq.Query; @@ -31,9 +32,10 @@ public final class JooQSQLTranslator implements SQLTranslator { @Override - public String translate(final String sql, final SQLStatement statement, final DatabaseType protocolType, final DatabaseType storageType, final RuleMetaData globalRuleMetaData) { + public String translate(final String sql, final SQLStatementContext statementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + final RuleMetaData globalRuleMetaData) { try { - Query query = DSL.using(JooQDialectRegistry.getSQLDialect(protocolType)).parser().parseQuery(sql); + Query query = DSL.using(JooQDialectRegistry.getSQLDialect(database.getProtocolType())).parser().parseQuery(sql); return DSL.using(JooQDialectRegistry.getSQLDialect(storageType)).render(query); // CHECKSTYLE:OFF } catch (final Exception ignored) { diff --git a/kernel/sql-translator/provider/native/src/main/java/org/apache/shardingsphere/sqltranslator/natived/NativeSQLTranslator.java b/kernel/sql-translator/provider/native/src/main/java/org/apache/shardingsphere/sqltranslator/natived/NativeSQLTranslator.java index 1f959eabb5c61..1ce5e045e2306 100644 --- a/kernel/sql-translator/provider/native/src/main/java/org/apache/shardingsphere/sqltranslator/natived/NativeSQLTranslator.java +++ b/kernel/sql-translator/provider/native/src/main/java/org/apache/shardingsphere/sqltranslator/natived/NativeSQLTranslator.java @@ -17,9 +17,10 @@ package org.apache.shardingsphere.sqltranslator.natived; +import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; -import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator; /** @@ -28,7 +29,8 @@ public final class NativeSQLTranslator implements SQLTranslator { @Override - public String translate(final String sql, final SQLStatement statement, final DatabaseType protocolType, final DatabaseType storageType, final RuleMetaData globalRuleMetaData) { + public String translate(final String sql, final SQLStatementContext statementContext, final DatabaseType storageType, final ShardingSphereDatabase database, + final RuleMetaData globalRuleMetaData) { // TODO return sql; }