Skip to content

Commit

Permalink
Support extracting column name from function for postgresql (#34147)
Browse files Browse the repository at this point in the history
* Support extracting column name from function for postgresql

* Support extracting column name from function for postgresql

* Support extracting column name from function for postgresql

* Support extracting column name from function for postgresql

* Support extracting column name from function for postgresql
  • Loading branch information
FlyingZC authored Dec 25, 2024
1 parent 9fc1b2a commit e43c4aa
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPI;
import org.apache.shardingsphere.infra.spi.annotation.SingletonSPI;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

Expand Down Expand Up @@ -46,12 +47,13 @@ public interface DialectProjectionIdentifierExtractor extends DatabaseTypedSPI {
String getColumnNameFromFunction(String functionName, String functionExpression);

/**
* Get column name from expression.
* Get column name from expression segment.
*
* @param expression expression
* @param expressionSegment expression segment
* @return column name
*/
String getColumnNameFromExpression(String expression);

String getColumnNameFromExpression(ExpressionSegment expressionSegment);

/**
* Get column name from subquery segment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

Expand Down Expand Up @@ -61,11 +62,12 @@ public String getColumnNameFromFunction(final String functionName, final String
/**
* Get column name from expression.
*
* @param expression expression
* @param expressionSegment expression segment
* @return column name
*/
public String getColumnNameFromExpression(final String expression) {
return DatabaseTypedSPILoader.findService(DialectProjectionIdentifierExtractor.class, databaseType).map(optional -> optional.getColumnNameFromExpression(expression)).orElse(expression);
public String getColumnNameFromExpression(final ExpressionSegment expressionSegment) {
return DatabaseTypedSPILoader.findService(DialectProjectionIdentifierExtractor.class, databaseType).map(optional -> optional.getColumnNameFromExpression(expressionSegment))
.orElse(expressionSegment.getText());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package org.apache.shardingsphere.infra.binder.context.segment.select.projection.extractor.dialect;

import org.apache.shardingsphere.infra.binder.context.segment.select.projection.extractor.DialectProjectionIdentifierExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

Expand All @@ -37,8 +40,10 @@ public String getColumnNameFromFunction(final String functionName, final String
}

@Override
public String getColumnNameFromExpression(final String expression) {
return "?column?";
public String getColumnNameFromExpression(final ExpressionSegment expressionSegment) {
return expressionSegment instanceof ExpressionProjectionSegment && ((ExpressionProjectionSegment) expressionSegment).getExpr() instanceof FunctionSegment
? ((FunctionSegment) ((ExpressionProjectionSegment) expressionSegment).getExpr()).getFunctionName()
: "?column?";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.infra.binder.context.segment.select.projection.extractor.dialect;

import org.apache.shardingsphere.infra.binder.context.segment.select.projection.extractor.DialectProjectionIdentifierExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

Expand All @@ -37,8 +38,8 @@ public String getColumnNameFromFunction(final String functionName, final String
}

@Override
public String getColumnNameFromExpression(final String expression) {
return expression.replace(" ", "").toUpperCase();
public String getColumnNameFromExpression(final ExpressionSegment expressionSegment) {
return expressionSegment.getText().replace(" ", "").toUpperCase();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package org.apache.shardingsphere.infra.binder.context.segment.select.projection.extractor.dialect;

import org.apache.shardingsphere.infra.binder.context.segment.select.projection.extractor.DialectProjectionIdentifierExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

Expand All @@ -37,8 +40,10 @@ public String getColumnNameFromFunction(final String functionName, final String
}

@Override
public String getColumnNameFromExpression(final String expression) {
return "?column?";
public String getColumnNameFromExpression(final ExpressionSegment expressionSegment) {
return expressionSegment instanceof ExpressionProjectionSegment && ((ExpressionProjectionSegment) expressionSegment).getExpr() instanceof FunctionSegment
? ((FunctionSegment) ((ExpressionProjectionSegment) expressionSegment).getExpr()).getFunctionName()
: "?column?";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public String getColumnName() {
@Override
public String getColumnLabel() {
ProjectionIdentifierExtractEngine extractEngine = new ProjectionIdentifierExtractEngine(databaseType);
return getAlias().map(extractEngine::getIdentifierValue).orElseGet(() -> extractEngine.getColumnNameFromExpression(expressionSegment.getText()));
return getAlias().map(extractEngine::getIdentifierValue).orElseGet(() -> extractEngine.getColumnNameFromExpression(expressionSegment));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,6 @@ private static String getColumnNameFromExpression(final ExpressionSegment expres
if (expressionSegment instanceof AliasAvailable && ((AliasAvailable) expressionSegment).getAlias().isPresent()) {
return extractEngine.getIdentifierValue(((AliasAvailable) expressionSegment).getAlias().get());
}
return extractEngine.getColumnNameFromExpression(expressionSegment.getText());
return extractEngine.getColumnNameFromExpression(expressionSegment);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -50,7 +51,7 @@ void assertGetColumnNameFromFunction() {

@Test
void assertGetColumnNameFromExpression() {
assertThat(new ProjectionIdentifierExtractEngine(databaseType).getColumnNameFromExpression("expression"), is("expression"));
assertThat(new ProjectionIdentifierExtractEngine(databaseType).getColumnNameFromExpression(new ExpressionProjectionSegment(0, 0, "expression")), is("expression"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;
Expand All @@ -49,7 +51,12 @@ void assertGetColumnNameFromFunction() {

@Test
void assertGetColumnNameFromExpression() {
assertThat(extractor.getColumnNameFromExpression("expression"), is("?column?"));
assertThat(extractor.getColumnNameFromExpression(new ExpressionProjectionSegment(0, 0, "expression")), is("?column?"));
}

@Test
void assertGetColumnNameFromFunctionExpression() {
assertThat(extractor.getColumnNameFromExpression(new ExpressionProjectionSegment(0, 0, "SUM(ID)", new FunctionSegment(0, 0, "SUM", "SUM(ID)"))), is("SUM"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;
Expand All @@ -49,7 +50,7 @@ void assertGetColumnNameFromFunction() {

@Test
void assertGetColumnNameFromExpression() {
assertThat(extractor.getColumnNameFromExpression("expression"), is("EXPRESSION"));
assertThat(extractor.getColumnNameFromExpression(new ExpressionProjectionSegment(0, 0, "expression")), is("EXPRESSION"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;
Expand All @@ -49,7 +51,12 @@ void assertGetColumnNameFromFunction() {

@Test
void assertGetColumnNameFromExpression() {
assertThat(extractor.getColumnNameFromExpression("expression"), is("?column?"));
assertThat(extractor.getColumnNameFromExpression(new ExpressionProjectionSegment(0, 0, "expression")), is("?column?"));
}

@Test
void assertGetColumnNameFromFunctionExpression() {
assertThat(extractor.getColumnNameFromExpression(new ExpressionProjectionSegment(0, 0, "SUM(ID)", new FunctionSegment(0, 0, "SUM", "SUM(ID)"))), is("SUM"));
}

@Test
Expand Down

0 comments on commit e43c4aa

Please sign in to comment.