Skip to content

Commit

Permalink
fix endpoints not being set correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
samansmink committed Feb 22, 2024
1 parent 043890c commit 43f23a3
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 27 deletions.
81 changes: 54 additions & 27 deletions src/aws_secret.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,13 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &
scope.push_back("gcs://");
scope.push_back("gs://");
} else {
throw InternalException("Unknown secret type found in httpfs extension: '%s'", input.type);
throw InternalException("Unknown secret type found in aws extension: '%s'", input.type);
}
}

auto result = ConstructBaseS3Secret(scope, input.type, input.provider, input.name);


if (!region.empty()) {
result->secret_map["region"] = region;
}
Expand All @@ -156,39 +157,65 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &

ParseCoreS3Config(input, *result);

// Set endpoint defaults TODO: move to consumer side of secret
auto endpoint_lu = result->secret_map.find("endpoint");
if (endpoint_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) {
if (input.type == "s3") {
result->secret_map["endpoint"] = "s3.amazonaws.com";
} else if (input.type == "r2") {
if (input.options.find("account_id") != input.options.end()) {
result->secret_map["endpoint"] = input.options["account_id"].ToString() + ".r2.cloudflarestorage.com";
}
} else if (input.type == "gcs") {
result->secret_map["endpoint"] = "storage.googleapis.com";
} else {
throw InternalException("Unknown secret type found in httpfs extension: '%s'", input.type);
}
}

// Set endpoint defaults TODO: move to consumer side of secret
auto url_style_lu = result->secret_map.find("url_style");
if (url_style_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) {
if (input.type == "gcs" || input.type == "r2") {
result->secret_map["url_style"] = "path";
}
}

return result;
}

void CreateAwsSecretFunctions::Register(DatabaseInstance &instance) {
string type = "S3";

// Register the credential_chain secret provider
CreateSecretFunction cred_chain_function = {type, "credential_chain", CreateAWSSecretFromCredentialChain};

// Params for adding / overriding settings to the automatically fetched ones
cred_chain_function.named_parameters["key_id"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["secret"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["region"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["session_token"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["endpoint"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["url_style"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["use_ssl"] = LogicalType::BOOLEAN;
cred_chain_function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN;

if (type == "r2") {
cred_chain_function.named_parameters["account_id"] = LogicalType::VARCHAR;
}
vector<string> types = {"s3", "r2", "gcs"};

for (const auto& type : types) {
// Register the credential_chain secret provider
CreateSecretFunction cred_chain_function = {type, "credential_chain", CreateAWSSecretFromCredentialChain};

// Params for adding / overriding settings to the automatically fetched ones
cred_chain_function.named_parameters["key_id"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["secret"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["region"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["session_token"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["endpoint"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["url_style"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["use_ssl"] = LogicalType::BOOLEAN;
cred_chain_function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN;

if (type == "r2") {
cred_chain_function.named_parameters["account_id"] = LogicalType::VARCHAR;
}

// Param for configuring the chain that is used
cred_chain_function.named_parameters["chain"] = LogicalType::VARCHAR;
// Param for configuring the chain that is used
cred_chain_function.named_parameters["chain"] = LogicalType::VARCHAR;

// Params for configuring the credential loading
cred_chain_function.named_parameters["profile"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_resource_path"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_endpoint"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_token"] = LogicalType::VARCHAR;
// Params for configuring the credential loading
cred_chain_function.named_parameters["profile"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_resource_path"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_endpoint"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_token"] = LogicalType::VARCHAR;

ExtensionUtil::RegisterFunction(instance, cred_chain_function);
ExtensionUtil::RegisterFunction(instance, cred_chain_function);
}
}

} // namespace duckdb
29 changes: 29 additions & 0 deletions test/sql/aws_secret_gcs.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# name: test/sql/aws_secret_gcs.test
# description: test aws extension with gcs secret
# group: [aws]

require aws

require httpfs

# Note this test is not very intelligent since we dont assume any profiles to be available

statement ok
SET allow_persistent_secrets=false

statement ok
CREATE SECRET s1 (
TYPE GCS,
PROVIDER credential_chain
);

query I
SELECT which_secret('gcs://haha/hoehoe.parkoe', 'gcs')
----
s1

statement error
from "gcs://a/b.csv"
----
https://storage.googleapis.com/a/b.csv

29 changes: 29 additions & 0 deletions test/sql/aws_secret_r2.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# name: test/sql/aws_secret_r2.test
# description: test aws extension with r2 secret
# group: [aws]

require aws

require httpfs

# Note this test is not very intelligent since we dont assume any profiles to be available

statement ok
SET allow_persistent_secrets=false

statement ok
CREATE SECRET s1 (
TYPE R2,
PROVIDER credential_chain,
ACCOUNT_ID "<account>"
);

query I
SELECT which_secret('r2://haha/hoehoe.parkoe', 'r2')
----
s1

statement error
from "r2://blabla/file.csv"
----
https://<account>.r2.cloudflarestorage.com/blabla/file.csv'

0 comments on commit 43f23a3

Please sign in to comment.