Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent declaredParameter list from Modification after compiling AbstractJdbcCall #33729

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,16 @@ protected CallableStatementCreatorFactory getCallableStatementFactory() {
* @param parameter the {@link SqlParameter} to add
*/
public void addDeclaredParameter(SqlParameter parameter) {
Assert.notNull(parameter, "The supplied parameter must not be null");
if (!StringUtils.hasText(parameter.getName())) {
throw new InvalidDataAccessApiUsageException(
"You must specify a parameter name when declaring parameters for \"" + getProcedureName() + "\"");
}
this.declaredParameters.add(parameter);
if (logger.isDebugEnabled()) {
logger.debug("Added declared parameter for [" + getProcedureName() + "]: " + parameter.getName());
if(!isCompiled()) {
Assert.notNull(parameter, "The supplied parameter must not be null");
if (!StringUtils.hasText(parameter.getName())) {
throw new InvalidDataAccessApiUsageException(
"You must specify a parameter name when declaring parameters for \"" + getProcedureName() + "\"");
}
this.declaredParameters.add(parameter);
if (logger.isDebugEnabled()) {
logger.debug("Added declared parameter for [" + getProcedureName() + "]: " + parameter.getName());
}
}
}

Expand All @@ -266,9 +268,11 @@ public void addDeclaredParameter(SqlParameter parameter) {
* @param rowMapper the RowMapper implementation to use
*/
public void addDeclaredRowMapper(String parameterName, RowMapper<?> rowMapper) {
this.declaredRowMappers.put(parameterName, rowMapper);
if (logger.isDebugEnabled()) {
logger.debug("Added row mapper for [" + getProcedureName() + "]: " + parameterName);
if(!isCompiled()) {
this.declaredRowMappers.put(parameterName, rowMapper);
if (logger.isDebugEnabled()) {
logger.debug("Added row mapper for [" + getProcedureName() + "]: " + parameterName);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

package org.springframework.jdbc.core.simple;

import java.lang.reflect.Field;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.List;
import java.util.Map;

import javax.sql.DataSource;

Expand All @@ -30,6 +33,7 @@

import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.jdbc.BadSqlGrammarException;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlOutParameter;
import org.springframework.jdbc.core.SqlParameter;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
Expand Down Expand Up @@ -360,4 +364,78 @@ void correctSybaseFunctionStatementNamed() throws Exception {
verifyStatement(adder, "{call ADD_INVOICE(@AMOUNT = ?, @CUSTID = ?)}");
}

/**
* This test verifies that when declaring a parameter for a SimpleJdbcCall,
* then the parameter is added as expected.
*/
@SuppressWarnings("unchecked")
@Test
void verifyUncompiledDeclareParameterIsAdded() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException {
SimpleJdbcCall call = new SimpleJdbcCall(dataSource)
.withProcedureName("procedure_name")
.declareParameters(new SqlParameter("PARAM", Types.VARCHAR));

Field params = AbstractJdbcCall.class.getDeclaredField("declaredParameters");
params.setAccessible(true);
List<SqlParameter> paramList = (List<SqlParameter>) params.get(call);
assertThat(paramList).hasSize(1).allMatch(sqlParam -> sqlParam.getName().equals("PARAM"));
}

/**
* This verifies that once the SimpleJdbcCall is compiled, then adding
* a parameter is ignored
*/
@SuppressWarnings("unchecked")
@Test
void verifyWhenCompiledThenDeclareParameterIsIgnored() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException {
SimpleJdbcCall call = new SimpleJdbcCall(dataSource)
.withProcedureName("procedure_name")
.declareParameters(new SqlParameter("PARAM", Types.VARCHAR));
call.compile();

call.declareParameters(new SqlParameter("Ignored Param", Types.VARCHAR));

Field params = AbstractJdbcCall.class.getDeclaredField("declaredParameters");
params.setAccessible(true);
List<SqlParameter> paramList = (List<SqlParameter>) params.get(call);
assertThat(paramList).hasSize(1).allMatch(sqlParam -> sqlParam.getName().equals("PARAM"));
}

/**
* When adding a declared row mapper, this verifies that the declaredRowMappers
* gets the new mapper
*/
@SuppressWarnings("unchecked")
@Test
void verifyUncompiledDeclareRowMapperIsAdded() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException {
SimpleJdbcCall call = new SimpleJdbcCall(dataSource)
.withProcedureName("procedure_name")
.returningResultSet("result_set", (rs,i) -> new Object());

Field rowMappers = AbstractJdbcCall.class.getDeclaredField("declaredRowMappers");
rowMappers.setAccessible(true);
Map<String, RowMapper<?>> mappers = (Map<String, RowMapper<?>>) rowMappers.get(call);
assertThat(mappers).hasSize(1).allSatisfy((key,value) -> key.equals("result_set"));
}

/**
* This verifies that when adding a row mapper after the call is compiled
* then the request is ignored
*/
@SuppressWarnings("unchecked")
@Test
void verifyWhenCompiledThenDeclareRowMapperIsIgnored() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException {
SimpleJdbcCall call = new SimpleJdbcCall(dataSource)
.withProcedureName("procedure_name")
.returningResultSet("result_set", (rs,i) -> new Object());
call.compile();

call.returningResultSet("not added", (rs,i) -> new Object());

Field rowMappers = AbstractJdbcCall.class.getDeclaredField("declaredRowMappers");
rowMappers.setAccessible(true);
Map<String, RowMapper<?>> mappers = (Map<String, RowMapper<?>>) rowMappers.get(call);
assertThat(mappers).hasSize(1).allSatisfy((key,value) -> key.equals("result_set"));
}

}