Skip to content

Commit

Permalink
FIR-31838 getting account_id from the server and not sending it as a …
Browse files Browse the repository at this point in the history
…parameter for system engine
  • Loading branch information
alexradzin committed Apr 8, 2024
1 parent f967503 commit c9a933d
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,18 @@ private Map<String, String> getAllParameters(FireboltProperties fireboltProperti

getResponseFormatParameter(statementInfoWrapper.getType() == StatementType.QUERY, isLocalDb)
.ifPresent(format -> params.put(format.getKey(), format.getValue()));

String accountId = fireboltProperties.getAccountId();
if (systemEngine) {
if (fireboltProperties.getAccountId() != null) {
params.put(FireboltQueryParameterKey.ACCOUNT_ID.getKey(), fireboltProperties.getAccountId());
if (accountId != null && connection.getInfraVersion() < 2) {
// if infra version >= 2 we should add account_id only if it was supplied by system URL returned from server.
// In this case it will be in additionalProperties anyway.
params.put(FireboltQueryParameterKey.ACCOUNT_ID.getKey(), accountId);
}
} else {
if (connection.getInfraVersion() >= 2) {
if (fireboltProperties.getAccountId() != null) {
params.put(FireboltQueryParameterKey.ACCOUNT_ID.getKey(), fireboltProperties.getAccountId());
if (accountId != null) {
params.put(FireboltQueryParameterKey.ACCOUNT_ID.getKey(), accountId);
params.put(FireboltQueryParameterKey.ENGINE.getKey(), fireboltProperties.getEngine());
}
params.put(FireboltQueryParameterKey.QUERY_LABEL.getKey(), statementInfoWrapper.getLabel()); //QUERY_LABEL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.firebolt.jdbc.client.authentication.ServiceAccountAuthenticationRequest;
import com.firebolt.jdbc.client.gateway.GatewayUrlResponse;
import com.firebolt.jdbc.connection.settings.FireboltProperties;
import com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey;
import com.firebolt.jdbc.exception.FireboltException;
import com.firebolt.jdbc.service.FireboltAccountIdService;
import com.firebolt.jdbc.service.FireboltAuthenticationService;
Expand All @@ -20,10 +21,15 @@
import lombok.NonNull;
import okhttp3.OkHttpClient;

import java.net.URI;
import java.net.URL;
import java.sql.SQLException;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Properties;

import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.ACCOUNT_ID;
import static com.firebolt.jdbc.exception.ExceptionType.RESOURCE_NOT_FOUND;
import static java.lang.String.format;

Expand Down Expand Up @@ -100,15 +106,24 @@ private FireboltProperties getSessionPropertiesForSystemEngine(String accessToke
String systemEngineEndpoint = fireboltGatewayUrlService.getUrl(accessToken, accountName);
FireboltAccount account = fireboltAccountIdService.getValue(accessToken, accountName);
infraVersion = account.getInfraVersion();
URL systemEngienUrl = UrlUtil.createUrl(systemEngineEndpoint);
Map<String, String> systemEngineUrlUrlParams = UrlUtil.getQueryParameters(systemEngienUrl);
String accountId = systemEngineUrlUrlParams.getOrDefault(ACCOUNT_ID.getKey(), account.getId());
for (Entry<String, String> e : systemEngineUrlUrlParams.entrySet()) {
loginProperties.addProperty(e);
}
return loginProperties
.toBuilder()
.systemEngine(true)
.compress(false)
.accountId(account.getId())
.host(UrlUtil.createUrl(systemEngineEndpoint).getHost())
.accountId(accountId)
.host(systemEngienUrl.getHost())
.build();
}




private FireboltEngineService getFireboltEngineService() {
if (fireboltEngineService == null) {
int currentInfraVersion = Optional.ofNullable(loginProperties.getAdditionalProperties().get("infraVersion")).map(Integer::parseInt).orElse(infraVersion);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Optional;
import java.util.Set;

import static com.firebolt.jdbc.statement.rawstatement.StatementValidatorFactory.createValidator;
import static java.util.stream.Collectors.toCollection;

@CustomLog
Expand Down Expand Up @@ -92,6 +93,7 @@ protected Optional<ResultSet> execute(List<StatementInfoWrapper> statements) thr
}

private Optional<ResultSet> execute(StatementInfoWrapper statementInfoWrapper, boolean verifyNotCancelled, boolean isStandardSql) throws SQLException {
createValidator(statementInfoWrapper.getInitialStatement(), connection).validate(statementInfoWrapper.getInitialStatement());
ResultSet resultSet = null;
if (!verifyNotCancelled || isStatementNotCancelled(statementInfoWrapper)) {
runningStatementLabel = statementInfoWrapper.getLabel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.Map;

import static com.firebolt.jdbc.statement.StatementUtil.replaceParameterMarksWithValues;
import static com.firebolt.jdbc.statement.rawstatement.StatementValidatorFactory.createValidator;
import static java.sql.Types.VARBINARY;

@CustomLog
Expand All @@ -61,6 +62,7 @@ public FireboltPreparedStatement(FireboltStatementService statementService, Fire
log.debug("Populating PreparedStatement object for SQL: {}", sql);
this.providedParameters = new HashMap<>();
this.rawStatement = StatementUtil.parseToRawStatementWrapper(sql);
rawStatement.getSubStatements().forEach(statement -> createValidator(statement, connection).validate(statement));
this.rows = new ArrayList<>();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.firebolt.jdbc.statement.rawstatement;

public class NoOpStatementValidator implements StatementValidator {
@Override
public void validate(RawStatement statement) {
// do nothing
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,10 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;

import java.util.Arrays;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeSet;

import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.ACCOUNT_ID;
import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.DATABASE;
import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.ENGINE;
import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.OUTPUT_FORMAT;
import static com.firebolt.jdbc.statement.StatementType.PARAM_SETTING;
import static java.lang.String.CASE_INSENSITIVE_ORDER;
import static java.lang.String.format;
import static java.util.stream.Collectors.toCollection;

/**
* A Set param statement is a special statement that sets a parameter internally
Expand All @@ -27,35 +17,15 @@
@Getter
@EqualsAndHashCode(callSuper = true)
public class SetParamRawStatement extends RawStatement {
private static final Set<String> forbiddenParameters = caseInsensitiveNameSet(DATABASE, ENGINE, ACCOUNT_ID, OUTPUT_FORMAT);
private static final Set<String> useSupporting = caseInsensitiveNameSet(DATABASE, ENGINE);
private static final String FORBIDDEN_PROPERTY_ERROR_PREFIX = "Could not set parameter. Set parameter '%s' is not allowed. ";
private static final String FORBIDDEN_PROPERTY_ERROR_USE_SUFFIX = "Try again with 'USE %s' instead of SET.";
private static final String FORBIDDEN_PROPERTY_ERROR_SET_SUFFIX = "Try again with a different parameter name.";
private static final String USE_ERROR = FORBIDDEN_PROPERTY_ERROR_PREFIX + FORBIDDEN_PROPERTY_ERROR_USE_SUFFIX;
private static final String SET_ERROR = FORBIDDEN_PROPERTY_ERROR_PREFIX + FORBIDDEN_PROPERTY_ERROR_SET_SUFFIX;

private final Entry<String, String> additionalProperty;

public SetParamRawStatement(String sql, String cleanSql, List<ParamMarker> paramPositions, Entry<String, String> additionalProperty) {
super(sql, cleanSql, paramPositions);
validateProperty(additionalProperty.getKey().toUpperCase());
this.additionalProperty = additionalProperty;
}

@Override
public StatementType getStatementType() {
return PARAM_SETTING;
}

private void validateProperty(String name) {
if (forbiddenParameters.contains(name)) {
throw new IllegalArgumentException(format(useSupporting.contains(name) ? USE_ERROR : SET_ERROR, name, name));
}
}

@SafeVarargs
private static <T extends Enum<T>> Set<String> caseInsensitiveNameSet(Enum<T> ... elements) {
return Arrays.stream(elements).map(Enum::name).collect(toCollection(() -> new TreeSet<>(CASE_INSENSITIVE_ORDER)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.firebolt.jdbc.statement.rawstatement;

import com.firebolt.jdbc.connection.FireboltConnection;

import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;

import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.ACCOUNT_ID;
import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.DATABASE;
import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.ENGINE;
import static com.firebolt.jdbc.connection.settings.FireboltQueryParameterKey.OUTPUT_FORMAT;
import static java.lang.String.CASE_INSENSITIVE_ORDER;
import static java.lang.String.format;
import static java.util.stream.Collectors.toMap;

public class SetValidator implements StatementValidator {
private static final Map<String, String> forbiddenParameters1 = caseInsensitiveNameSet(DATABASE, ENGINE, ACCOUNT_ID, OUTPUT_FORMAT);
private static final Map<String, String> forbiddenParameters2 = caseInsensitiveNameSet(DATABASE, ENGINE, OUTPUT_FORMAT);
private static final Map<String, String> useSupporting = caseInsensitiveNameSet(DATABASE, ENGINE);
private static final String FORBIDDEN_PROPERTY_ERROR_PREFIX = "Could not set parameter. Set parameter '%s' is not allowed. ";
private static final String FORBIDDEN_PROPERTY_ERROR_USE_SUFFIX = "Try again with 'USE %s' instead of SET.";
private static final String FORBIDDEN_PROPERTY_ERROR_SET_SUFFIX = "Try again with a different parameter name.";
private static final String USE_ERROR = FORBIDDEN_PROPERTY_ERROR_PREFIX + FORBIDDEN_PROPERTY_ERROR_USE_SUFFIX;
private static final String SET_ERROR = FORBIDDEN_PROPERTY_ERROR_PREFIX + FORBIDDEN_PROPERTY_ERROR_SET_SUFFIX;

private final Map<String, String> forbiddenParameters;

public SetValidator(FireboltConnection connection) {
forbiddenParameters = connection.getInfraVersion() < 2 ? forbiddenParameters1 : forbiddenParameters2;
}

@Override
public void validate(RawStatement statement) {
validateProperty(((SetParamRawStatement)statement).getAdditionalProperty().getKey());
}

private void validateProperty(String name) {
String standardName = forbiddenParameters.get(name);
if (standardName != null) {
throw new IllegalArgumentException(format(useSupporting.containsKey(name) ? USE_ERROR : SET_ERROR, standardName, standardName));
}
}

@SafeVarargs
private static <T extends Enum<T>> Map<String, String> caseInsensitiveNameSet(Enum<T> ... elements) {
return Arrays.stream(elements).map(Enum::name).collect(toMap(name -> name, name -> name, (one, two) -> two, () -> new TreeMap<>(CASE_INSENSITIVE_ORDER)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.firebolt.jdbc.statement.rawstatement;

public interface StatementValidator {
void validate(RawStatement statement);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.firebolt.jdbc.statement.rawstatement;

import com.firebolt.jdbc.connection.FireboltConnection;

public abstract class StatementValidatorFactory {
private StatementValidatorFactory() {
// empty private constructor to ensure that this class will be used as factory only.
}

public static StatementValidator createValidator(RawStatement statement, FireboltConnection connection) {
return statement instanceof SetParamRawStatement ? new SetValidator(connection) : new NoOpStatementValidator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.firebolt.jdbc.client.account.FireboltAccountRetriever;
import com.firebolt.jdbc.client.gateway.GatewayUrlResponse;
import com.firebolt.jdbc.connection.settings.FireboltProperties;
import com.firebolt.jdbc.exception.FireboltException;
import com.firebolt.jdbc.service.FireboltGatewayUrlService;
import org.junit.jupiter.api.Test;
Expand All @@ -10,9 +11,12 @@

import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Map;
import java.util.Properties;

import static java.lang.String.format;
import static java.util.stream.Collectors.toMap;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -77,19 +81,21 @@ void getMetadata(String testName, String engineParameter, boolean readOnly) thro

@ParameterizedTest(name = "{0}")
@CsvSource({
"http://the-endpoint,the-endpoint",
"https://the-endpoint,the-endpoint",
"the-endpoint,the-endpoint",
"http://the-endpoint?foo=1&bar=2,the-endpoint",
"https://the-endpoint?foo=1&bar=2,the-endpoint",
"the-endpoint?foo=1&bar=2,the-endpoint",
"http://the-endpoint,the-endpoint,",
"https://the-endpoint,the-endpoint,",
"the-endpoint,the-endpoint,",
"http://the-endpoint?foo=1&bar=2,the-endpoint,foo=1;bar=2",
"https://the-endpoint?foo=1&bar=2,the-endpoint,foo=1;bar=2",
"the-endpoint?foo=1&bar=2,the-endpoint,foo=1;bar=2",
})
void checkSystemEngineEndpoint(String gatewayUrl, String expectedHost) throws SQLException {
void checkSystemEngineEndpoint(String gatewayUrl, String expectedHost, String expectedProps) throws SQLException {
@SuppressWarnings("unchecked") FireboltAccountRetriever<GatewayUrlResponse> fireboltGatewayUrlClient = mock(FireboltAccountRetriever.class);
when(fireboltGatewayUrlClient.retrieve(any(), any())).thenReturn(new GatewayUrlResponse(gatewayUrl));
FireboltGatewayUrlService gatewayUrlService = new FireboltGatewayUrlService(fireboltGatewayUrlClient);
FireboltConnection connection = new FireboltConnectionServiceSecret(SYSTEM_ENGINE_URL, connectionProperties, fireboltAuthenticationService, gatewayUrlService, fireboltStatementService, fireboltEngineService, fireboltAccountIdService);
assertEquals(expectedHost, connection.getSessionProperties().getHost());
FireboltProperties sessionProperties = connection.getSessionProperties();
assertEquals(expectedHost, sessionProperties.getHost());
assertEquals(expectedProps == null ? Map.of() : Arrays.stream(expectedProps.split(";")).map(kv -> kv.split("=")).collect(toMap(kv -> kv[0], kv -> kv[1])), sessionProperties.getAdditionalProperties());
}

@Test
Expand Down
41 changes: 38 additions & 3 deletions src/test/java/com/firebolt/jdbc/connection/UrlUtilTest.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
package com.firebolt.jdbc.connection;

import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;

import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

class UrlUtilTest {

Expand All @@ -25,4 +32,32 @@ void shouldGetAllPropertiesFromUri(String uri, String expectedPath, String expec
assertEquals(expectedProperties, properties);
}

@Test
void createUrl() throws MalformedURLException {
String spec = "http://myhost/path?x=1&y=2";
assertEquals(new URL(spec), UrlUtil.createUrl(spec));
}

@Test
void createBadUrl() {
assertEquals(MalformedURLException.class, assertThrows(IllegalArgumentException.class, () -> UrlUtil.createUrl("not url")).getCause().getClass());
}

@ParameterizedTest
@ValueSource(strings = {"http://host", "http://host/", "http://host/?", "http://host?", "http://host:8080", "http://host:8080/", "http://host:8080/?", "http://host:8080?"})
void getQueryParametersNoParameters(String spec) throws MalformedURLException {
assertEquals(Map.of(), UrlUtil.getQueryParameters(new URL(spec)));
}

@ParameterizedTest
@ValueSource(strings = {
"http://the-host.com?database&engine=diesel&format=json", // set each parameter only once
"http://the-host.com?database&format=xml&engine=benzine&engine=diesel&format=json" // override parameters
})
void getQueryParameters() throws MalformedURLException {
Map<String, String> expected = new HashMap<>();
expected.put("engine", "diesel");
expected.put("format", "json");
assertEquals(expected, UrlUtil.getQueryParameters(new URL("http://the-host.com?database&engine=diesel&format=json")));
}
}
Loading

0 comments on commit c9a933d

Please sign in to comment.