Skip to content

Commit

Permalink
Skip sharding sql rewrite when insert statement not contains sharding…
Browse files Browse the repository at this point in the history
… table (#29091)
  • Loading branch information
strongduanmu authored Nov 20, 2023
1 parent e9dcbc5 commit a203da9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.shardingsphere.sharding.rewrite.context;

import lombok.Setter;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContextDecorator;
Expand All @@ -38,12 +40,25 @@ public final class ShardingSQLRewriteContextDecorator implements SQLRewriteConte

@Override
public void decorate(final ShardingRule shardingRule, final ConfigurationProperties props, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
if (sqlStatementContext instanceof InsertStatementContext && !containsShardingTable(shardingRule, sqlStatementContext)) {
return;
}
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters =
new ShardingParameterRewriterBuilder(shardingRule, routeContext, sqlRewriteContext.getDatabase().getSchemas(), sqlRewriteContext.getSqlStatementContext()).getParameterRewriters();
new ShardingParameterRewriterBuilder(shardingRule, routeContext, sqlRewriteContext.getDatabase().getSchemas(), sqlStatementContext).getParameterRewriters();
rewriteParameters(sqlRewriteContext, parameterRewriters);
}
sqlRewriteContext.addSQLTokenGenerators(new ShardingTokenGenerateBuilder(shardingRule, routeContext, sqlRewriteContext.getSqlStatementContext()).getSQLTokenGenerators());
sqlRewriteContext.addSQLTokenGenerators(new ShardingTokenGenerateBuilder(shardingRule, routeContext, sqlStatementContext).getSQLTokenGenerators());
}

private boolean containsShardingTable(final ShardingRule shardingRule, final SQLStatementContext sqlStatementContext) {
for (String each : sqlStatementContext.getTablesContext().getTableNames()) {
if (shardingRule.findTableRule(each).isPresent()) {
return true;
}
}
return false;
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sharding.rewrite.context;

import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
Expand All @@ -26,6 +27,7 @@
import org.junit.jupiter.api.Test;

import java.util.Collections;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
Expand All @@ -43,4 +45,16 @@ void assertDecorate() {
new ShardingSQLRewriteContextDecorator().decorate(mock(ShardingRule.class), mock(ConfigurationProperties.class), sqlRewriteContext, mock(RouteContext.class));
assertTrue(sqlRewriteContext.getSqlTokens().isEmpty());
}

@Test
void assertDecorateWhenInsertStatementNotContainsShardingTable() {
SQLRewriteContext sqlRewriteContext = mock(SQLRewriteContext.class);
InsertStatementContext insertStatementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(insertStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singleton("t_order"));
when(sqlRewriteContext.getSqlStatementContext()).thenReturn(insertStatementContext);
ShardingRule shardingRule = mock(ShardingRule.class);
when(shardingRule.findTableRule("t_order")).thenReturn(Optional.empty());
new ShardingSQLRewriteContextDecorator().decorate(shardingRule, mock(ConfigurationProperties.class), sqlRewriteContext, mock(RouteContext.class));
assertTrue(sqlRewriteContext.getSqlTokens().isEmpty());
}
}

0 comments on commit a203da9

Please sign in to comment.