Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sql translator interface and modify jooq and native translator #28791

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@

import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContextDecorator;
Expand Down Expand Up @@ -76,11 +74,9 @@ public SQLRewriteResult rewrite(final String sql, final List<Object> params, fin
final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext);
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
DatabaseType protocolType = database.getProtocolType();
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, protocolType, storageUnits, globalRuleMetaData).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, protocolType, storageUnits, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext);
? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext);
}

private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
Expand All @@ -37,9 +38,7 @@ public final class GenericSQLRewriteEngine {

private final SQLTranslatorRule translatorRule;

private final DatabaseType protocolType;

private final Map<String, StorageUnit> storageUnits;
private final ShardingSphereDatabase database;

private final RuleMetaData globalRuleMetaData;

Expand All @@ -50,9 +49,11 @@ public final class GenericSQLRewriteEngine {
* @return SQL rewrite result
*/
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) {
DatabaseType protocolType = database.getProtocolType();
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
DatabaseType storageType = storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType();
String sql = translatorRule.translate(
new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext().getSqlStatement(), protocolType,
storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType(), globalRuleMetaData);
new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext(), storageType, database, globalRuleMetaData);
return new GenericSQLRewriteResult(new SQLRewriteUnit(sql, sqlRewriteContext.getParameterBuilder().getParameters()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
Expand All @@ -33,7 +34,6 @@
import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
Expand All @@ -53,9 +53,7 @@ public final class RouteSQLRewriteEngine {

private final SQLTranslatorRule translatorRule;

private final DatabaseType protocolType;

private final Map<String, StorageUnit> storageUnits;
private final ShardingSphereDatabase database;

private final RuleMetaData globalRuleMetaData;

Expand All @@ -76,7 +74,7 @@ public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext,
addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
}
}
return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext().getSqlStatement(), sqlRewriteUnits));
return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext(), sqlRewriteUnits));
}

private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {
Expand Down Expand Up @@ -157,11 +155,12 @@ private boolean isInSameDataNode(final Collection<DataNode> dataNodes, final Rou
return false;
}

private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatement sqlStatement, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatementContext sqlStatementContext, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
Map<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F);
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
for (Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
DatabaseType storageType = storageUnits.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageType();
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatement, protocolType, storageType, globalRuleMetaData);
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatementContext, storageType, database, globalRuleMetaData);
SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters());
result.put(entry.getKey(), sqlRewriteUnit);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ class GenericSQLRewriteEngineTest {
void assertRewrite() {
DatabaseType databaseType = mock(DatabaseType.class);
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getProtocolType()).thenReturn(databaseType);
Map<String, StorageUnit> storageUnits = mockStorageUnits(databaseType);
when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits);
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
Expand All @@ -57,8 +61,12 @@ void assertRewrite() {
@Test
void assertRewriteStorageTypeIsEmpty() {
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, mock(DatabaseType.class), Collections.emptyMap(), mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
when(database.getSchemas()).thenReturn(Collections.singletonMap("test", mock(ShardingSphereSchema.class)));
when(database.getResourceMetaData().getStorageUnits()).thenReturn(Collections.emptyMap());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
Expand All @@ -69,11 +77,4 @@ private Map<String, StorageUnit> mockStorageUnits(final DatabaseType databaseTyp
when(result.getStorageType()).thenReturn(databaseType);
return Collections.singletonMap("ds_0", result);
}

private ShardingSphereDatabase mockDatabase() {
ShardingSphereDatabase result = mock(ShardingSphereDatabase.class);
when(result.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
when(result.getSchemas()).thenReturn(Collections.singletonMap("test", mock(ShardingSphereSchema.class)));
return result;
}
}
Loading