Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix endpoints set incorrectly with credential chain #32

Merged
merged 4 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/MainDistributionPipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ concurrency:
jobs:
duckdb-stable-build:
name: Build extension binaries
uses: duckdb/duckdb/.github/workflows/[email protected]
uses: duckdb/extension-ci-tools/.github/workflows/[email protected]
with:
extension_name: aws
duckdb_version: 'v0.10.0'
Expand All @@ -23,7 +23,7 @@ jobs:
duckdb-stable-deploy:
name: Deploy extension binaries
needs: duckdb-stable-build
uses: duckdb/duckdb/.github/workflows/[email protected]
uses: duckdb/extension-ci-tools/.github/workflows/[email protected]
secrets: inherit
with:
extension_name: aws
Expand Down
81 changes: 54 additions & 27 deletions src/aws_secret.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,13 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &
scope.push_back("gcs://");
scope.push_back("gs://");
} else {
throw InternalException("Unknown secret type found in httpfs extension: '%s'", input.type);
throw InternalException("Unknown secret type found in aws extension: '%s'", input.type);
}
}

auto result = ConstructBaseS3Secret(scope, input.type, input.provider, input.name);


if (!region.empty()) {
result->secret_map["region"] = region;
}
Expand All @@ -156,39 +157,65 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &

ParseCoreS3Config(input, *result);

// Set endpoint defaults TODO: move to consumer side of secret
auto endpoint_lu = result->secret_map.find("endpoint");
if (endpoint_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) {
if (input.type == "s3") {
result->secret_map["endpoint"] = "s3.amazonaws.com";
} else if (input.type == "r2") {
if (input.options.find("account_id") != input.options.end()) {
result->secret_map["endpoint"] = input.options["account_id"].ToString() + ".r2.cloudflarestorage.com";
}
} else if (input.type == "gcs") {
result->secret_map["endpoint"] = "storage.googleapis.com";
} else {
throw InternalException("Unknown secret type found in httpfs extension: '%s'", input.type);
}
}

// Set endpoint defaults TODO: move to consumer side of secret
auto url_style_lu = result->secret_map.find("url_style");
if (url_style_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) {
if (input.type == "gcs" || input.type == "r2") {
result->secret_map["url_style"] = "path";
}
}

return result;
}

void CreateAwsSecretFunctions::Register(DatabaseInstance &instance) {
string type = "S3";

// Register the credential_chain secret provider
CreateSecretFunction cred_chain_function = {type, "credential_chain", CreateAWSSecretFromCredentialChain};

// Params for adding / overriding settings to the automatically fetched ones
cred_chain_function.named_parameters["key_id"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["secret"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["region"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["session_token"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["endpoint"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["url_style"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["use_ssl"] = LogicalType::BOOLEAN;
cred_chain_function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN;

if (type == "r2") {
cred_chain_function.named_parameters["account_id"] = LogicalType::VARCHAR;
}
vector<string> types = {"s3", "r2", "gcs"};

for (const auto& type : types) {
// Register the credential_chain secret provider
CreateSecretFunction cred_chain_function = {type, "credential_chain", CreateAWSSecretFromCredentialChain};

// Params for adding / overriding settings to the automatically fetched ones
cred_chain_function.named_parameters["key_id"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["secret"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["region"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["session_token"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["endpoint"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["url_style"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["use_ssl"] = LogicalType::BOOLEAN;
cred_chain_function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN;

if (type == "r2") {
cred_chain_function.named_parameters["account_id"] = LogicalType::VARCHAR;
}

// Param for configuring the chain that is used
cred_chain_function.named_parameters["chain"] = LogicalType::VARCHAR;
// Param for configuring the chain that is used
cred_chain_function.named_parameters["chain"] = LogicalType::VARCHAR;

// Params for configuring the credential loading
cred_chain_function.named_parameters["profile"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_resource_path"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_endpoint"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_token"] = LogicalType::VARCHAR;
// Params for configuring the credential loading
cred_chain_function.named_parameters["profile"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_resource_path"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_endpoint"] = LogicalType::VARCHAR;
cred_chain_function.named_parameters["task_role_token"] = LogicalType::VARCHAR;

ExtensionUtil::RegisterFunction(instance, cred_chain_function);
ExtensionUtil::RegisterFunction(instance, cred_chain_function);
}
}

} // namespace duckdb
5 changes: 5 additions & 0 deletions test/sql/aws_secret_chains.test
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ CREATE SECRET config_no_profile (
CHAIN 'config'
);

query I
SELECT secret_string FROM duckdb_secrets(redact=false) where name='config_no_profile';
----
<REGEX>:.*endpoint=s3.amazonaws.com.*

statement ok
CREATE SECRET config_with_profile (
TYPE S3,
Expand Down
29 changes: 29 additions & 0 deletions test/sql/aws_secret_gcs.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# name: test/sql/aws_secret_gcs.test
# description: test aws extension with gcs secret
# group: [aws]

require aws

require httpfs

# Note this test is not very intelligent since we dont assume any profiles to be available

statement ok
SET allow_persistent_secrets=false

statement ok
CREATE SECRET s1 (
TYPE GCS,
PROVIDER credential_chain
);

query I
SELECT which_secret('gcs://haha/hoehoe.parkoe', 'gcs')
----
s1

statement error
from "gcs://a/b.csv"
----
https://storage.googleapis.com/a/b.csv

29 changes: 29 additions & 0 deletions test/sql/aws_secret_r2.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# name: test/sql/aws_secret_r2.test
# description: test aws extension with r2 secret
# group: [aws]

require aws

require httpfs

# Note this test is not very intelligent since we dont assume any profiles to be available

statement ok
SET allow_persistent_secrets=false

statement ok
CREATE SECRET s1 (
TYPE R2,
PROVIDER credential_chain,
ACCOUNT_ID "<account>"
);

query I
SELECT which_secret('r2://haha/hoehoe.parkoe', 'r2')
----
s1

statement error
from "r2://blabla/file.csv"
----
https://<account>.r2.cloudflarestorage.com/blabla/file.csv'
Loading