Skip to content

Commit

Permalink
Merge pull request apache#32828 from strongduanmu/dev-0910
Browse files Browse the repository at this point in the history
Minor refactor for sharding insert clause condition handle logic
  • Loading branch information
iamhucong authored Sep 10, 2024
2 parents b3dabf7 + e0f3ac3 commit 7c56b28
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sharding.route.engine.condition.engine;

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.algorithm.core.context.AlgorithmSQLContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.GeneratedKeyContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
Expand All @@ -28,7 +29,6 @@
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.table.NoSuchTableException;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.algorithm.core.context.AlgorithmSQLContext;
import org.apache.shardingsphere.sharding.route.engine.condition.ExpressionConditionUtils;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
Expand Down Expand Up @@ -77,7 +77,8 @@ public List<ShardingCondition> createShardingConditions(final InsertStatementCon
}

private List<ShardingCondition> createShardingConditionsWithInsertValues(final InsertStatementContext sqlStatementContext, final List<Object> params) {
String tableName = sqlStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
String tableName = sqlStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue())
.orElseGet(() -> sqlStatementContext.getTablesContext().getTableNames().iterator().next());
Collection<String> columnNames = getColumnNames(sqlStatementContext);
List<InsertValueContext> insertValueContexts = sqlStatementContext.getInsertValueContexts();
List<ShardingCondition> result = new ArrayList<>(insertValueContexts.size());
Expand All @@ -92,7 +93,8 @@ private List<ShardingCondition> createShardingConditionsWithInsertValues(final I
private void appendMissingShardingConditions(final InsertStatementContext sqlStatementContext, final Collection<String> columnNames, final List<ShardingCondition> shardingConditions) {
String defaultSchemaName = new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName());
ShardingSphereSchema schema = sqlStatementContext.getTablesContext().getSchemaName().map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchemaName));
String tableName = sqlStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
String tableName = sqlStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue())
.orElseGet(() -> sqlStatementContext.getTablesContext().getTableNames().iterator().next());
ShardingSpherePreconditions.checkState(schema.containsTable(tableName), () -> new NoSuchTableException(tableName));
Collection<String> allColumnNames = schema.getTable(tableName).getColumnNames();
if (columnNames.size() == allColumnNames.size()) {
Expand All @@ -113,12 +115,13 @@ private void appendMissingShardingConditions(final List<ShardingCondition> shard

private Collection<String> getColumnNames(final InsertStatementContext insertStatementContext) {
Optional<GeneratedKeyContext> generatedKey = insertStatementContext.getGeneratedKeyContext();
List<String> columnNames = insertStatementContext.getColumnNames();
if (generatedKey.isPresent() && generatedKey.get().isGenerated()) {
Collection<String> result = new LinkedHashSet<>(insertStatementContext.getColumnNames());
Collection<String> result = new LinkedHashSet<>(columnNames);
result.remove(generatedKey.get().getColumnName());
return result;
}
return new LinkedHashSet<>(insertStatementContext.getColumnNames());
return new LinkedHashSet<>(columnNames);
}

private ShardingCondition createShardingCondition(final String tableName, final Iterator<String> columnNames,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public TablesContext(final Collection<? extends TableSegment> tables, final Map<
}
this.tables.addAll(tables);
for (TableSegment each : tables) {
if (each instanceof SimpleTableSegment) {
if (each instanceof SimpleTableSegment && !"DUAL".equalsIgnoreCase(((SimpleTableSegment) each).getTableName().getIdentifier().getValue())) {
SimpleTableSegment simpleTableSegment = (SimpleTableSegment) each;
simpleTables.add(simpleTableSegment);
tableNames.add(simpleTableSegment.getTableName().getIdentifier().getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
~ limitations under the License.
-->

<dataset update-count="2">
<dataset update-count="1">
<metadata data-nodes="tbl.t_order_${0..9}">
<column name="order_id" type="numeric" />
<column name="user_id" type="numeric" />
Expand All @@ -25,16 +25,15 @@
<column name="creation_date" type="datetime" />
</metadata>
<row data-node="tbl.t_order_0" values="1000, 10, init, 1, test, 2017-08-08" />
<row data-node="tbl.t_order_1" values="1, 1, insert, 1, test, 2017-08-08" />
<row data-node="tbl.t_order_1" values="1001, 10, init, 2, test, 2017-08-08" />
<row data-node="tbl.t_order_2" values="2, 2, insert, 2, test, 2017-08-08" />
<row data-node="tbl.t_order_2" values="1002, 10, init, 3, test, 2017-08-08" />
<row data-node="tbl.t_order_3" values="1003, 10, init, 4, test, 2017-08-08" />
<row data-node="tbl.t_order_4" values="1004, 10, init, 5, test, 2017-08-08" />
<row data-node="tbl.t_order_5" values="1005, 10, init, 6, test, 2017-08-08" />
<row data-node="tbl.t_order_6" values="1006, 10, init, 7, test, 2017-08-08" />
<row data-node="tbl.t_order_7" values="1007, 10, init, 8, test, 2017-08-08" />
<row data-node="tbl.t_order_8" values="1008, 10, init, 9, test, 2017-08-08" />
<row data-node="tbl.t_order_9" values="999, 10, insertALL, 1, test, 2017-08-08" />
<row data-node="tbl.t_order_9" values="1009, 10, init, 10, test, 2017-08-08" />
<row data-node="tbl.t_order_0" values="1100, 11, init, 11, test, 2017-08-08" />
<row data-node="tbl.t_order_1" values="1101, 11, init, 12, test, 2017-08-08" />
Expand All @@ -46,5 +45,4 @@
<row data-node="tbl.t_order_7" values="1107, 11, init, 18, test, 2017-08-08" />
<row data-node="tbl.t_order_8" values="1108, 11, init, 19, test, 2017-08-08" />
<row data-node="tbl.t_order_9" values="1109, 11, init, 20, test, 2017-08-08" />
<row data-node="tbl.t_order_9" values="999, 10, insertALL, 1, test, 2017-08-08" />
</dataset>

0 comments on commit 7c56b28

Please sign in to comment.