Skip to content

Commit

Permalink
Refactor sql translator interface and modify jooq and native translator
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Oct 18, 2023
1 parent 32aed3e commit 2afa34d
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@

import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContextDecorator;
Expand Down Expand Up @@ -76,11 +74,9 @@ public SQLRewriteResult rewrite(final String sql, final List<Object> params, fin
final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext);
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
DatabaseType protocolType = database.getProtocolType();
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, protocolType, storageUnits, globalRuleMetaData).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, protocolType, storageUnits, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext);
? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext);
}

private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
Expand All @@ -37,9 +38,7 @@ public final class GenericSQLRewriteEngine {

private final SQLTranslatorRule translatorRule;

private final DatabaseType protocolType;

private final Map<String, StorageUnit> storageUnits;
private final ShardingSphereDatabase database;

private final RuleMetaData globalRuleMetaData;

Expand All @@ -50,9 +49,11 @@ public final class GenericSQLRewriteEngine {
* @return SQL rewrite result
*/
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) {
DatabaseType protocolType = database.getProtocolType();
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
DatabaseType storageType = storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType();
String sql = translatorRule.translate(
new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext().getSqlStatement(), protocolType,
storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType(), globalRuleMetaData);
new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext(), storageType, database, globalRuleMetaData);
return new GenericSQLRewriteResult(new SQLRewriteUnit(sql, sqlRewriteContext.getParameterBuilder().getParameters()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
Expand All @@ -33,7 +34,6 @@
import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
Expand All @@ -53,9 +53,7 @@ public final class RouteSQLRewriteEngine {

private final SQLTranslatorRule translatorRule;

private final DatabaseType protocolType;

private final Map<String, StorageUnit> storageUnits;
private final ShardingSphereDatabase database;

private final RuleMetaData globalRuleMetaData;

Expand All @@ -76,7 +74,7 @@ public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext,
addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
}
}
return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext().getSqlStatement(), sqlRewriteUnits));
return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext(), sqlRewriteUnits));
}

private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {
Expand Down Expand Up @@ -157,11 +155,12 @@ private boolean isInSameDataNode(final Collection<DataNode> dataNodes, final Rou
return false;
}

private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatement sqlStatement, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatementContext sqlStatementContext, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
Map<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F);
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
for (Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
DatabaseType storageType = storageUnits.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageType();
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatement, protocolType, storageType, globalRuleMetaData);
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatementContext, storageType, database, globalRuleMetaData);
SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters());
result.put(entry.getKey(), sqlRewriteUnit);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ class GenericSQLRewriteEngineTest {
void assertRewrite() {
DatabaseType databaseType = mock(DatabaseType.class);
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getProtocolType()).thenReturn(databaseType);
Map<String, StorageUnit> storageUnits = mockStorageUnits(databaseType);
when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits);
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
Expand All @@ -57,8 +61,12 @@ void assertRewrite() {
@Test
void assertRewriteStorageTypeIsEmpty() {
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, mock(DatabaseType.class), Collections.emptyMap(), mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
when(database.getSchemas()).thenReturn(Collections.singletonMap("test", mock(ShardingSphereSchema.class)));
when(database.getResourceMetaData().getStorageUnits()).thenReturn(Collections.emptyMap());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
Expand All @@ -69,11 +77,4 @@ private Map<String, StorageUnit> mockStorageUnits(final DatabaseType databaseTyp
when(result.getStorageType()).thenReturn(databaseType);
return Collections.singletonMap("ds_0", result);
}

private ShardingSphereDatabase mockDatabase() {
ShardingSphereDatabase result = mock(ShardingSphereDatabase.class);
when(result.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
when(result.getSchemas()).thenReturn(Collections.singletonMap("test", mock(ShardingSphereSchema.class)));
return result;
}
}
Loading

0 comments on commit 2afa34d

Please sign in to comment.