diff --git a/mode/core/src/main/java/org/apache/shardingsphere/mode/metadata/refresher/MetaDataRefreshEngine.java b/mode/core/src/main/java/org/apache/shardingsphere/mode/metadata/refresher/MetaDataRefreshEngine.java index a98b52b05b809..9ded4c358124f 100644 --- a/mode/core/src/main/java/org/apache/shardingsphere/mode/metadata/refresher/MetaDataRefreshEngine.java +++ b/mode/core/src/main/java/org/apache/shardingsphere/mode/metadata/refresher/MetaDataRefreshEngine.java @@ -26,7 +26,6 @@ import org.apache.shardingsphere.infra.route.context.RouteUnit; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader; import org.apache.shardingsphere.mode.persist.service.MetaDataManagerPersistService; -import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.ddl.AlterIndexStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.ddl.AlterSchemaStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.ddl.AlterTableStatement; @@ -79,14 +78,10 @@ public void refresh(final SQLStatementContext sqlStatementContext, final Collect } Optional schemaRefresher = TypedSPILoader.findService(MetaDataRefresher.class, sqlStatementClass); if (schemaRefresher.isPresent()) { - String schemaName = null; - if (sqlStatementContext instanceof TableAvailable) { - schemaName = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName() - .orElseGet(() -> new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName())).toLowerCase(); - } Collection logicDataSourceNames = routeUnits.stream().map(each -> each.getDataSourceMapper().getLogicName()).collect(Collectors.toList()); - schemaRefresher.get().refresh(metaDataManagerPersistService, database, - logicDataSourceNames, schemaName, sqlStatementContext.getDatabaseType(), sqlStatementContext.getSqlStatement(), props); + String schemaName = sqlStatementContext instanceof TableAvailable ? getSchemaName(sqlStatementContext) : null; + schemaRefresher.get().refresh( + metaDataManagerPersistService, database, logicDataSourceNames, schemaName, sqlStatementContext.getDatabaseType(), sqlStatementContext.getSqlStatement(), props); } } @@ -97,11 +92,13 @@ public void refresh(final SQLStatementContext sqlStatementContext, final Collect */ @SuppressWarnings("unchecked") public void refresh(final SQLStatementContext sqlStatementContext) { - getFederationMetaDataRefresher(sqlStatementContext).ifPresent(federationMetaDataRefresher -> { - String schemaName = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName() - .orElseGet(() -> new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName())).toLowerCase(); - federationMetaDataRefresher.refresh(metaDataManagerPersistService, database, schemaName, sqlStatementContext.getDatabaseType(), sqlStatementContext.getSqlStatement()); - }); + getFederationMetaDataRefresher(sqlStatementContext).ifPresent(optional -> optional.refresh( + metaDataManagerPersistService, database, getSchemaName(sqlStatementContext), sqlStatementContext.getDatabaseType(), sqlStatementContext.getSqlStatement())); + } + + private String getSchemaName(final SQLStatementContext sqlStatementContext) { + return ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName() + .orElseGet(() -> new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName())).toLowerCase(); } /** @@ -116,7 +113,6 @@ public boolean isFederation(final SQLStatementContext sqlStatementContext) { @SuppressWarnings("rawtypes") private Optional getFederationMetaDataRefresher(final SQLStatementContext sqlStatementContext) { - Class sqlStatementClass = sqlStatementContext.getSqlStatement().getClass(); - return TypedSPILoader.findService(FederationMetaDataRefresher.class, sqlStatementClass.getSuperclass()); + return TypedSPILoader.findService(FederationMetaDataRefresher.class, sqlStatementContext.getSqlStatement().getClass().getSuperclass()); } }