diff --git a/src/aws_secret.cpp b/src/aws_secret.cpp index 0e9e74e..cd6685a 100644 --- a/src/aws_secret.cpp +++ b/src/aws_secret.cpp @@ -135,12 +135,13 @@ static unique_ptr CreateAWSSecretFromCredentialChain(ClientContext & scope.push_back("gcs://"); scope.push_back("gs://"); } else { - throw InternalException("Unknown secret type found in httpfs extension: '%s'", input.type); + throw InternalException("Unknown secret type found in aws extension: '%s'", input.type); } } auto result = ConstructBaseS3Secret(scope, input.type, input.provider, input.name); + if (!region.empty()) { result->secret_map["region"] = region; } @@ -156,39 +157,65 @@ static unique_ptr CreateAWSSecretFromCredentialChain(ClientContext & ParseCoreS3Config(input, *result); + // Set endpoint defaults TODO: move to consumer side of secret + auto endpoint_lu = result->secret_map.find("endpoint"); + if (endpoint_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) { + if (input.type == "s3") { + result->secret_map["endpoint"] = "s3.amazonaws.com"; + } else if (input.type == "r2") { + if (input.options.find("account_id") != input.options.end()) { + result->secret_map["endpoint"] = input.options["account_id"].ToString() + ".r2.cloudflarestorage.com"; + } + } else if (input.type == "gcs") { + result->secret_map["endpoint"] = "storage.googleapis.com"; + } else { + throw InternalException("Unknown secret type found in httpfs extension: '%s'", input.type); + } + } + + // Set endpoint defaults TODO: move to consumer side of secret + auto url_style_lu = result->secret_map.find("url_style"); + if (url_style_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) { + if (input.type == "gcs" || input.type == "r2") { + result->secret_map["url_style"] = "path"; + } + } + return result; } void CreateAwsSecretFunctions::Register(DatabaseInstance &instance) { - string type = "S3"; - - // Register the credential_chain secret provider - CreateSecretFunction cred_chain_function = {type, "credential_chain", CreateAWSSecretFromCredentialChain}; - - // Params for adding / overriding settings to the automatically fetched ones - cred_chain_function.named_parameters["key_id"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["secret"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["region"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["session_token"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["endpoint"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["url_style"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["use_ssl"] = LogicalType::BOOLEAN; - cred_chain_function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN; - - if (type == "r2") { - cred_chain_function.named_parameters["account_id"] = LogicalType::VARCHAR; - } + vector types = {"s3", "r2", "gcs"}; + + for (const auto& type : types) { + // Register the credential_chain secret provider + CreateSecretFunction cred_chain_function = {type, "credential_chain", CreateAWSSecretFromCredentialChain}; + + // Params for adding / overriding settings to the automatically fetched ones + cred_chain_function.named_parameters["key_id"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["secret"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["region"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["session_token"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["endpoint"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["url_style"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["use_ssl"] = LogicalType::BOOLEAN; + cred_chain_function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN; + + if (type == "r2") { + cred_chain_function.named_parameters["account_id"] = LogicalType::VARCHAR; + } - // Param for configuring the chain that is used - cred_chain_function.named_parameters["chain"] = LogicalType::VARCHAR; + // Param for configuring the chain that is used + cred_chain_function.named_parameters["chain"] = LogicalType::VARCHAR; - // Params for configuring the credential loading - cred_chain_function.named_parameters["profile"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["task_role_resource_path"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["task_role_endpoint"] = LogicalType::VARCHAR; - cred_chain_function.named_parameters["task_role_token"] = LogicalType::VARCHAR; + // Params for configuring the credential loading + cred_chain_function.named_parameters["profile"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["task_role_resource_path"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["task_role_endpoint"] = LogicalType::VARCHAR; + cred_chain_function.named_parameters["task_role_token"] = LogicalType::VARCHAR; - ExtensionUtil::RegisterFunction(instance, cred_chain_function); + ExtensionUtil::RegisterFunction(instance, cred_chain_function); + } } } // namespace duckdb diff --git a/test/sql/aws_secret_gcs.test b/test/sql/aws_secret_gcs.test new file mode 100644 index 0000000..0b1fd40 --- /dev/null +++ b/test/sql/aws_secret_gcs.test @@ -0,0 +1,29 @@ +# name: test/sql/aws_secret_gcs.test +# description: test aws extension with gcs secret +# group: [aws] + +require aws + +require httpfs + +# Note this test is not very intelligent since we dont assume any profiles to be available + +statement ok +SET allow_persistent_secrets=false + +statement ok +CREATE SECRET s1 ( + TYPE GCS, + PROVIDER credential_chain +); + +query I +SELECT which_secret('gcs://haha/hoehoe.parkoe', 'gcs') +---- +s1 + +statement error +from "gcs://a/b.csv" +---- +https://storage.googleapis.com/a/b.csv + diff --git a/test/sql/aws_secret_r2.test b/test/sql/aws_secret_r2.test new file mode 100644 index 0000000..01be38b --- /dev/null +++ b/test/sql/aws_secret_r2.test @@ -0,0 +1,29 @@ +# name: test/sql/aws_secret_r2.test +# description: test aws extension with r2 secret +# group: [aws] + +require aws + +require httpfs + +# Note this test is not very intelligent since we dont assume any profiles to be available + +statement ok +SET allow_persistent_secrets=false + +statement ok +CREATE SECRET s1 ( + TYPE R2, + PROVIDER credential_chain, + ACCOUNT_ID "" +); + +query I +SELECT which_secret('r2://haha/hoehoe.parkoe', 'r2') +---- +s1 + +statement error +from "r2://blabla/file.csv" +---- +https://.r2.cloudflarestorage.com/blabla/file.csv' \ No newline at end of file