Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Dec 17, 2024
1 parent b72d984 commit 5098f5e
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@
import org.apache.shardingsphere.infra.database.postgresql.type.PostgreSQLDatabaseType;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.NoDatabaseSelectedException;
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.UnknownDatabaseException;
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.table.TableExistsException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.SchemaNotFoundException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.TableNotFoundException;
import org.apache.shardingsphere.infra.metadata.database.schema.manager.SystemSchemaManager;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.column.ColumnDefinitionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment;
Expand Down Expand Up @@ -72,11 +76,10 @@ public static SimpleTableSegment bind(final SimpleTableSegment segment, final SQ
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts) {
fillPivotColumnNamesInBinderContext(segment, binderContext);
IdentifierValue databaseName = getDatabaseName(segment, binderContext);
ShardingSpherePreconditions.checkNotNull(databaseName.getValue(), NoDatabaseSelectedException::new);
IdentifierValue schemaName = getSchemaName(segment, binderContext);
IdentifierValue schemaName = getSchemaName(segment, binderContext, databaseName);
IdentifierValue tableName = segment.getTableName().getIdentifier();
checkTableExists(binderContext, databaseName.getValue(), schemaName.getValue(), tableName.getValue());
ShardingSphereSchema schema = binderContext.getMetaData().getDatabase(databaseName.getValue()).getSchema(schemaName.getValue());
checkTableExists(binderContext, schema, schemaName.getValue(), tableName.getValue());
tableBinderContexts.put(new CaseInsensitiveString(segment.getAliasName().orElseGet(tableName::getValue)),
createSimpleTableBinderContext(segment, schema, databaseName, schemaName, binderContext));
TableNameSegment tableNameSegment = new TableNameSegment(segment.getTableName().getStartIndex(), segment.getTableName().getStopIndex(), tableName);
Expand All @@ -94,7 +97,17 @@ private static void fillPivotColumnNamesInBinderContext(final SimpleTableSegment
private static IdentifierValue getDatabaseName(final SimpleTableSegment segment, final SQLStatementBinderContext binderContext) {
DialectDatabaseMetaData dialectDatabaseMetaData = new DatabaseTypeRegistry(binderContext.getSqlStatement().getDatabaseType()).getDialectDatabaseMetaData();
Optional<OwnerSegment> owner = dialectDatabaseMetaData.getDefaultSchema().isPresent() ? segment.getOwner().flatMap(OwnerSegment::getOwner) : segment.getOwner();
return new IdentifierValue(owner.map(optional -> optional.getIdentifier().getValue()).orElse(binderContext.getCurrentDatabaseName()));
IdentifierValue result = new IdentifierValue(owner.map(optional -> optional.getIdentifier().getValue()).orElse(binderContext.getCurrentDatabaseName()));
ShardingSpherePreconditions.checkNotNull(result.getValue(), NoDatabaseSelectedException::new);
ShardingSpherePreconditions.checkState(binderContext.getMetaData().containsDatabase(result.getValue()), () -> new UnknownDatabaseException(result.getValue()));
return result;
}

private static IdentifierValue getSchemaName(final SimpleTableSegment segment, final SQLStatementBinderContext binderContext, final IdentifierValue databaseName) {
IdentifierValue result = getSchemaName(segment, binderContext);
ShardingSpherePreconditions.checkState(binderContext.getMetaData().getDatabase(databaseName.getValue()).containsSchema(result.getValue()),
() -> new SchemaNotFoundException(result.getValue()));
return result;
}

private static IdentifierValue getSchemaName(final SimpleTableSegment segment, final SQLStatementBinderContext binderContext) {
Expand All @@ -110,8 +123,10 @@ private static IdentifierValue getSchemaName(final SimpleTableSegment segment, f
return new IdentifierValue(new DatabaseTypeRegistry(databaseType).getDefaultSchemaName(binderContext.getCurrentDatabaseName()));
}

private static void checkTableExists(final SQLStatementBinderContext binderContext, final String databaseName, final String schemaName, final String tableName) {
private static void checkTableExists(final SQLStatementBinderContext binderContext, final ShardingSphereSchema schema, final String schemaName, final String tableName) {
if (binderContext.getSqlStatement() instanceof CreateTableStatement) {
CreateTableStatement sqlStatement = (CreateTableStatement) binderContext.getSqlStatement();
ShardingSpherePreconditions.checkState(sqlStatement.isIfNotExists() || !schema.containsTable(tableName), () -> new TableExistsException(tableName));
return;
}
if ("DUAL".equalsIgnoreCase(tableName)) {
Expand All @@ -123,10 +138,7 @@ private static void checkTableExists(final SQLStatementBinderContext binderConte
if (binderContext.getExternalTableBinderContexts().containsKey(new CaseInsensitiveString(tableName))) {
return;
}
ShardingSpherePreconditions.checkState(binderContext.getMetaData().containsDatabase(databaseName)
&& binderContext.getMetaData().getDatabase(databaseName).containsSchema(schemaName)
&& binderContext.getMetaData().getDatabase(databaseName).getSchema(schemaName).containsTable(tableName),
() -> new TableNotFoundException(tableName));
ShardingSpherePreconditions.checkState(schema.containsTable(tableName), () -> new TableNotFoundException(tableName));
}

private static SimpleTableSegmentBinderContext createSimpleTableBinderContext(final SimpleTableSegment segment, final ShardingSphereSchema schema, final IdentifierValue databaseName,
Expand All @@ -135,9 +147,22 @@ private static SimpleTableSegmentBinderContext createSimpleTableBinderContext(fi
if (binderContext.getMetaData().getDatabase(databaseName.getValue()).getSchema(schemaName.getValue()).containsTable(tableName.getValue())) {
return createSimpleTableSegmentBinderContextWithMetaData(segment, schema, databaseName, schemaName, binderContext, tableName);
}
if (binderContext.getSqlStatement() instanceof CreateTableStatement) {
return new SimpleTableSegmentBinderContext(createProjectionSegments((CreateTableStatement) binderContext.getSqlStatement(), databaseName, schemaName, tableName));
}
return new SimpleTableSegmentBinderContext(Collections.emptyList());
}

private static Collection<ProjectionSegment> createProjectionSegments(final CreateTableStatement sqlStatement, final IdentifierValue databaseName,
final IdentifierValue schemaName, final IdentifierValue tableName) {
Collection<ProjectionSegment> result = new LinkedList<>();
for (ColumnDefinitionSegment each : sqlStatement.getColumnDefinitions()) {
each.getColumnName().setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(databaseName, schemaName), tableName, each.getColumnName().getIdentifier()));
result.add(new ColumnProjectionSegment(each.getColumnName()));
}
return result;
}

private static SimpleTableSegmentBinderContext createSimpleTableSegmentBinderContextWithMetaData(final SimpleTableSegment segment, final ShardingSphereSchema schema,
final IdentifierValue databaseName, final IdentifierValue schemaName,
final SQLStatementBinderContext binderContext, final IdentifierValue tableName) {
Expand Down
38 changes: 19 additions & 19 deletions test/it/binder/src/test/resources/cases/ddl/create-table.xml
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,69 @@

<sql-parser-test-cases>
<create-table sql-case-id="create_table">
<table name="t_order" start-index="13" stop-index="19">
<table name="t_order_tmp" start-index="13" stop-index="23">
<table-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
</table-bound>
</table>
<column-definition type="BIGINT" primary-key="true" start-index="22" stop-index="48">
<column-definition type="BIGINT" primary-key="true" start-index="26" stop-index="52">
<column name="order_id">
<column-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
<original-table name="t_order" />
<original-column name="order_id" start-delimiter="`" end-delimiter="`" />
<original-table name="t_order_tmp" />
<original-column name="order_id" />
</column-bound>
</column>
</column-definition>
<column-definition type="INT" not-null="true" start-index="51" stop-index="70">
<column-definition type="INT" not-null="true" start-index="55" stop-index="74">
<column name="user_id">
<column-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
<original-table name="t_order" />
<original-column name="user_id" start-delimiter="`" end-delimiter="`" />
<original-table name="t_order_tmp" />
<original-column name="user_id" />
</column-bound>
</column>
</column-definition>
<column-definition type="VARCHAR" not-null="true" start-index="73" stop-index="99">
<column-definition type="VARCHAR" not-null="true" start-index="77" stop-index="103">
<column name="status">
<column-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
<original-table name="t_order" />
<original-column name="status" start-delimiter="`" end-delimiter="`" />
<original-table name="t_order_tmp" />
<original-column name="status" />
</column-bound>
</column>
</column-definition>
<column-definition type="INT" start-index="102" stop-index="116">
<column-definition type="INT" start-index="106" stop-index="120">
<column name="merchant_id">
<column-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
<original-table name="t_order" />
<original-column name="merchant_id" start-delimiter="`" end-delimiter="`" />
<original-table name="t_order_tmp" />
<original-column name="merchant_id" />
</column-bound>
</column>
</column-definition>
<column-definition type="VARCHAR" not-null="true" start-index="119" stop-index="145">
<column-definition type="VARCHAR" not-null="true" start-index="123" stop-index="149">
<column name="remark">
<column-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
<original-table name="t_order" />
<original-column name="remark" start-delimiter="`" end-delimiter="`" />
<original-table name="t_order_tmp" />
<original-column name="remark" />
</column-bound>
</column>
</column-definition>
<column-definition type="DATE" not-null="true" start-index="148" stop-index="174">
<column-definition type="DATE" not-null="true" start-index="152" stop-index="178">
<column name="creation_date">
<column-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
<original-table name="t_order" />
<original-column name="creation_date" start-delimiter="`" end-delimiter="`" />
<original-table name="t_order_tmp" />
<original-column name="creation_date" />
</column-bound>
</column>
</column-definition>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
-->

<sql-cases>
<sql-case id="create_table" value="CREATE TABLE t_order (order_id BIGINT PRIMARY KEY, user_id INT NOT NULL, status VARCHAR(50) NOT NULL, merchant_id INT, remark VARCHAR(50) NOT NULL, creation_date DATE NOT NULL)" db-types="MySQL,Doris" />
<sql-case id="create_table" value="CREATE TABLE t_order_tmp (order_id BIGINT PRIMARY KEY, user_id INT NOT NULL, status VARCHAR(50) NOT NULL, merchant_id INT, remark VARCHAR(50) NOT NULL, creation_date DATE NOT NULL)" db-types="MySQL,Doris" />
</sql-cases>
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

<rewrite-assertions yaml-rule="scenario/encrypt/config/query-with-cipher.yaml">
<rewrite-assertion id="create_table_for_cipher" db-types="MySQL">
<input sql="CREATE TABLE t_account_bak (id INT NOT NULL, name VARCHAR(100) NOT NULL DEFAULT '', password VARCHAR(255) NOT NULL DEFAULT '', PRIMARY KEY (`id`))" />
<output sql="CREATE TABLE t_account_bak (id INT NOT NULL, name VARCHAR(100) NOT NULL DEFAULT '', cipher_password VARCHAR(4000), assisted_query_password VARCHAR(4000), like_query_password VARCHAR(4000), PRIMARY KEY (`id`))" />
<input sql="CREATE TABLE t_account_bak_for_create (id INT NOT NULL, name VARCHAR(100) NOT NULL DEFAULT '', password VARCHAR(255) NOT NULL DEFAULT '', PRIMARY KEY (`id`))" />
<output sql="CREATE TABLE t_account_bak_for_create (id INT NOT NULL, name VARCHAR(100) NOT NULL DEFAULT '', cipher_password VARCHAR(4000), assisted_query_password VARCHAR(4000), like_query_password VARCHAR(4000), PRIMARY KEY (`id`))" />
</rewrite-assertion>

<rewrite-assertion id="create_table_with_index_for_cipher" db-types="MySQL">
<input sql="CREATE TABLE t_account_bak (id INT NOT NULL, name VARCHAR(100) NOT NULL DEFAULT '', password VARCHAR(255) NOT NULL DEFAULT '', PRIMARY KEY (`id`), KEY `t_account_bak_amount_idx` (`amount`))" />
<output sql="CREATE TABLE t_account_bak (id INT NOT NULL, name VARCHAR(100) NOT NULL DEFAULT '', cipher_password VARCHAR(4000), assisted_query_password VARCHAR(4000), like_query_password VARCHAR(4000), PRIMARY KEY (`id`), KEY `t_account_bak_amount_idx` (`cipher_amount`))" />
<input sql="CREATE TABLE t_account_bak_for_create (id INT NOT NULL, name VARCHAR(100) NOT NULL DEFAULT '', password VARCHAR(255) NOT NULL DEFAULT '', PRIMARY KEY (`id`), KEY `t_account_bak_amount_idx` (`amount`))" />
<output sql="CREATE TABLE t_account_bak_for_create (id INT NOT NULL, name VARCHAR(100) NOT NULL DEFAULT '', cipher_password VARCHAR(4000), assisted_query_password VARCHAR(4000), like_query_password VARCHAR(4000), PRIMARY KEY (`id`), KEY `t_account_bak_amount_idx` (`cipher_amount`))" />
</rewrite-assertion>
</rewrite-assertions>
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,32 @@ rules:
cipher:
name: cipher_amount
encryptorName: rewrite_normal_fixture
t_account_bak_for_create:
columns:
certificate_number:
cipher:
name: cipher_certificate_number
encryptorName: rewrite_normal_fixture
assistedQuery:
name: assisted_query_certificate_number
encryptorName: rewrite_assisted_query_fixture
likeQuery:
name: like_query_certificate_number
encryptorName: rewrite_like_query_fixture
password:
cipher:
name: cipher_password
encryptorName: rewrite_normal_fixture
assistedQuery:
name: assisted_query_password
encryptorName: rewrite_assisted_query_fixture
likeQuery:
name: like_query_password
encryptorName: rewrite_like_query_fixture
amount:
cipher:
name: cipher_amount
encryptorName: rewrite_normal_fixture
t_account_detail:
columns:
certificate_number:
Expand Down

0 comments on commit 5098f5e

Please sign in to comment.