Skip to content

Commit

Permalink
Error management for Snowflake source and sink, Added new validation …
Browse files Browse the repository at this point in the history
…for maximum split size and NPE issue handled
  • Loading branch information
Amit-CloudSufi committed Dec 18, 2024
1 parent 1d9cacd commit 7006718
Show file tree
Hide file tree
Showing 16 changed files with 495 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,33 +230,10 @@ public String getConnectionArguments() {
}

public void validate(FailureCollector collector) {
if (getOauth2Enabled()) {
if (!containsMacro(PROPERTY_CLIENT_ID)
&& Strings.isNullOrEmpty(getClientId())) {
collector.addFailure("Client ID is not set.", null)
.withConfigProperty(PROPERTY_CLIENT_ID);
}
if (!containsMacro(PROPERTY_CLIENT_SECRET)
&& Strings.isNullOrEmpty(getClientSecret())) {
collector.addFailure("Client Secret is not set.", null)
.withConfigProperty(PROPERTY_CLIENT_SECRET);
}
if (!containsMacro(PROPERTY_REFRESH_TOKEN)
&& Strings.isNullOrEmpty(getRefreshToken())) {
collector.addFailure("Refresh Token is not set.", null)
.withConfigProperty(PROPERTY_REFRESH_TOKEN);
}
} else if (getKeyPairEnabled()) {
if (!containsMacro(PROPERTY_USERNAME)
&& Strings.isNullOrEmpty(getUsername())) {
collector.addFailure("Username is not set.", null)
.withConfigProperty(PROPERTY_USERNAME);
}
if (!containsMacro(PROPERTY_PRIVATE_KEY)
&& Strings.isNullOrEmpty(getPrivateKey())) {
collector.addFailure("Private Key is not set.", null)
.withConfigProperty(PROPERTY_PRIVATE_KEY);
}
if (Boolean.TRUE.equals(getOauth2Enabled())) {
validateWhenOath2Enabled(collector);
} else if (Boolean.TRUE.equals(getKeyPairEnabled())) {
validateWhenKeyPairEnabled(collector);
} else {
if (!containsMacro(PROPERTY_USERNAME)
&& Strings.isNullOrEmpty(getUsername())) {
Expand All @@ -272,6 +249,37 @@ public void validate(FailureCollector collector) {
validateConnection(collector);
}

private void validateWhenKeyPairEnabled(FailureCollector collector) {
if (!containsMacro(PROPERTY_USERNAME)
&& Strings.isNullOrEmpty(getUsername())) {
collector.addFailure("Username is not set.", null)
.withConfigProperty(PROPERTY_USERNAME);
}
if (!containsMacro(PROPERTY_PRIVATE_KEY)
&& Strings.isNullOrEmpty(getPrivateKey())) {
collector.addFailure("Private Key is not set.", null)
.withConfigProperty(PROPERTY_PRIVATE_KEY);
}
}

private void validateWhenOath2Enabled(FailureCollector collector) {
if (!containsMacro(PROPERTY_CLIENT_ID)
&& Strings.isNullOrEmpty(getClientId())) {
collector.addFailure("Client ID is not set.", null)
.withConfigProperty(PROPERTY_CLIENT_ID);
}
if (!containsMacro(PROPERTY_CLIENT_SECRET)
&& Strings.isNullOrEmpty(getClientSecret())) {
collector.addFailure("Client Secret is not set.", null)
.withConfigProperty(PROPERTY_CLIENT_SECRET);
}
if (!containsMacro(PROPERTY_REFRESH_TOKEN)
&& Strings.isNullOrEmpty(getRefreshToken())) {
collector.addFailure("Refresh Token is not set.", null)
.withConfigProperty(PROPERTY_REFRESH_TOKEN);
}
}

public boolean canConnect() {
return (!containsMacro(PROPERTY_DATABASE) && !containsMacro(PROPERTY_SCHEMA_NAME)
&& !containsMacro(PROPERTY_ACCOUNT_NAME) && !containsMacro(PROPERTY_USERNAME)
Expand Down Expand Up @@ -299,7 +307,7 @@ protected void validateConnection(FailureCollector collector) {
.withConfigProperty(PROPERTY_USERNAME);

// TODO: for oauth2
if (keyPairEnabled) {
if (Boolean.TRUE.equals(keyPairEnabled)) {
failure.withConfigProperty(PROPERTY_PRIVATE_KEY);
} else {
failure.withConfigProperty(PROPERTY_PASSWORD);
Expand Down
23 changes: 19 additions & 4 deletions src/main/java/io/cdap/plugin/snowflake/common/OAuthUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@
import com.google.gson.JsonElement;
import com.google.gson.JsonParser;
import com.google.gson.JsonSyntaxException;
import io.cdap.cdap.api.exception.ErrorCategory;
import io.cdap.cdap.api.exception.ErrorType;
import io.cdap.cdap.api.exception.ErrorUtils;
import io.cdap.cdap.api.exception.ProgramFailureException;
import io.cdap.cdap.etl.api.exception.ErrorPhase;
import io.cdap.plugin.snowflake.common.exception.ConnectionTimeoutException;
import io.cdap.plugin.snowflake.common.exception.SchemaParseException;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.util.EntityUtils;
import scala.xml.Null;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.util.Base64;
import java.util.Objects;

/**
* A class which contains utilities to make OAuth2 specific calls.
Expand All @@ -50,9 +59,15 @@ public static String getAccessTokenByRefreshToken(CloseableHttpClient httpclient
httppost.setHeader("Content-type", "application/x-www-form-urlencoded");

// set grant type and refresh_token. It should be in body not url!
StringEntity entity = new StringEntity(String.format("refresh_token=%s&grant_type=refresh_token",
URLEncoder.encode(config.getRefreshToken(), "UTF-8")));
httppost.setEntity(entity);
try {
StringEntity entity = new StringEntity(String.format("refresh_token=%s&grant_type=refresh_token",
URLEncoder.encode(Objects.requireNonNull(Objects.requireNonNull(config).getRefreshToken()), "UTF-8")));
httppost.setEntity(entity);
} catch (NullPointerException e) {
String errorMessage = "Error encoding URL due to missing Refresh Token.";
throw ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
errorMessage, String.format("Error message: %s", errorMessage), ErrorType.SYSTEM, true, e);
}

// set 'Authorization' header
String stringToEncode = config.getClientId() + ":" + config.getClientSecret();
Expand All @@ -72,7 +87,7 @@ public static String getAccessTokenByRefreshToken(CloseableHttpClient httpclient

// if exception happened during parsing OR if json does not contain 'access_token' key.
if (jsonElement == null) {
throw new RuntimeException(String.format("Unexpected response '%s' from '%s'", responseString, uri.toString()));
throw new RuntimeException(String.format("Unexpected response '%s' from '%s'", responseString, uri));
}

return jsonElement.getAsString();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright © 2024 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.plugin.snowflake.common;

import com.google.common.base.Throwables;
import io.cdap.cdap.api.data.format.UnexpectedFormatException;
import io.cdap.cdap.api.exception.ErrorCategory;
import io.cdap.cdap.api.exception.ErrorCodeType;
import io.cdap.cdap.api.exception.ErrorType;
import io.cdap.cdap.api.exception.ErrorUtils;
import io.cdap.cdap.api.exception.ProgramFailureException;
import io.cdap.cdap.etl.api.exception.ErrorContext;
import io.cdap.cdap.etl.api.exception.ErrorDetailsProvider;
import io.cdap.plugin.snowflake.common.exception.ConnectionTimeoutException;
import io.cdap.plugin.snowflake.common.exception.SchemaParseException;

import java.net.URISyntaxException;
import java.util.List;


/**
* Error details provided for the Snowflake
**/
public class SnowflakeErrorDetailsProvider implements ErrorDetailsProvider {

@Override
public ProgramFailureException getExceptionDetails(Exception e, ErrorContext errorContext) {
List<Throwable> causalChain = Throwables.getCausalChain(e);
for (Throwable t : causalChain) {
if (t instanceof ProgramFailureException) {
// if causal chain already has program failure exception, return null to avoid double wrap.
return null;
}
if (t instanceof IllegalArgumentException) {
return getProgramFailureException((IllegalArgumentException) t, errorContext);
}
if (t instanceof IllegalStateException) {
return getProgramFailureException((IllegalStateException) t, errorContext);
}
if (t instanceof URISyntaxException) {
return getProgramFailureException((URISyntaxException) t, errorContext);
}
if (t instanceof SchemaParseException) {
return getProgramFailureException((SchemaParseException) t, errorContext);
}
if (t instanceof UnexpectedFormatException) {
return getProgramFailureException((UnexpectedFormatException) t, errorContext);
}
if (t instanceof ConnectionTimeoutException) {
return getProgramFailureException((ConnectionTimeoutException) t, errorContext);
}
}
return null;
}

/**
* Get a ProgramFailureException with the given error
* information from {@link IllegalArgumentException}.
*
* @param e The IllegalArgumentException to get the error information from.
* @return A ProgramFailureException with the given error information.
*/
private ProgramFailureException getProgramFailureException(IllegalArgumentException e, ErrorContext errorContext) {
String errorMessage = e.getMessage();
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";

return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
errorMessage,
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.USER, false, e);
}

/**
* Get a ProgramFailureException with the given error
* information from {@link IllegalStateException}.
*
* @param e The IllegalStateException to get the error information from.
* @return A ProgramFailureException with the given error information.
*/
private ProgramFailureException getProgramFailureException(IllegalStateException e, ErrorContext errorContext) {
String errorMessage = e.getMessage();
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
errorMessage,
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.SYSTEM, false, e);
}

/**
* Get a ProgramFailureException with the given error
* information from {@link URISyntaxException}.
*
* @param e The URISyntaxException to get the error information from.
* @return A ProgramFailureException with the given error information.
*/
private ProgramFailureException getProgramFailureException(URISyntaxException e,
ErrorContext errorContext) {
String errorMessage = e.getMessage();
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
errorMessage,
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.SYSTEM, false, e);
}

/**
* Get a ProgramFailureException with the given error
* information from {@link SchemaParseException}.
*
* @param e The SchemaParseException to get the error information from.
* @return A ProgramFailureException with the given error information.
*/
private ProgramFailureException getProgramFailureException(SchemaParseException e, ErrorContext errorContext) {
String errorMessage = e.getMessage();
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
errorMessage,
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.SYSTEM, false, e);
}

/**
* Get a ProgramFailureException with the given error
* information from {@link UnexpectedFormatException}.
*
* @param e The UnexpectedFormatException to get the error information from.
* @return A ProgramFailureException with the given error information.
*/
private ProgramFailureException getProgramFailureException(UnexpectedFormatException e, ErrorContext errorContext) {
String errorMessage = e.getMessage();
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
errorMessage,
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.SYSTEM, false, e);
}

/**
* Get a ProgramFailureException with the given error
* information from {@link ConnectionTimeoutException}.
*
* @param e The ConnectionTimeoutException to get the error information from.
* @return A ProgramFailureException with the given error information.
*/
private ProgramFailureException getProgramFailureException(ConnectionTimeoutException e, ErrorContext errorContext) {
String errorMessage = e.getMessage();
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
errorMessage,
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.SYSTEM, false, e);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright © 2024 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.plugin.snowflake.common;

import io.cdap.cdap.api.exception.ErrorType;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

/**
* Error Type provided based on the Snowflake error message code
*
**/
public class SnowflakeErrorType {

//https://github.com/snowflakedb/snowflake-jdbc/blob/master/src/main/java/net/snowflake/client/jdbc/ErrorCode.java
private static final Set<Integer> USER_ERRORS = new HashSet<>(Arrays.asList(
200004, 200006, 200007, 200008, 200009, 200010, 200011, 200012, 200014,
200017, 200018, 200019, 200021, 200023, 200024, 200025, 200026, 200028,
200029, 200030, 200031, 200032, 200033, 200034, 200035, 200036, 200037,
200038, 200045, 200046, 200047, 200056
));

private static final Set<Integer> SYSTEM_ERRORS = new HashSet<>(Arrays.asList(
200001, 200002, 200003, 200013, 200015, 200016, 200020, 200022, 200039,
200040, 200044, 200061
));

/**
* Method to get the error type based on the error code.
*
* @param errorCode the error code to classify
* @return the corresponding ErrorType (USER, SYSTEM, UNKNOWN)
*/
public static ErrorType getErrorType(int errorCode) {
if (USER_ERRORS.contains(errorCode)) {
return ErrorType.USER;
} else if (SYSTEM_ERRORS.contains(errorCode)) {
return ErrorType.SYSTEM;
} else {
return ErrorType.UNKNOWN;
}
}
}
Loading

0 comments on commit 7006718

Please sign in to comment.