Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Account ID from C #312

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions Source/AwsCommonRuntimeKit/auth/credentials/Credentials.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ public final class Credentials {

let rawValue: OpaquePointer

// TODO: remove this property once aws-c-auth supports account_id
private let accountId: String?

init(rawValue: OpaquePointer, accountId: String? = nil) {
init(rawValue: OpaquePointer) {
self.rawValue = rawValue
aws_credentials_acquire(rawValue)
self.accountId = accountId
}

/// Creates a new set of aws credentials
Expand All @@ -23,15 +19,15 @@ public final class Credentials {
/// - accessKey: value for the aws access key id field
/// - secret: value for the secret access key field
/// - sessionToken: (Optional) security token associated with the credentials
/// - accountId: (Optional) the account ID for the resolved credentials, if known
/// - accountId: (Optional) account id associated with the credentials
/// - expiration: (Optional) Point in time after which credentials will no longer be valid.
/// For credentials that do not expire, use nil.
/// If expiration.timeIntervalSince1970 is greater than UInt64.max, it will be converted to nil.
/// - Throws: CommonRuntimeError.crtError
public init(accessKey: String,
secret: String,
accountId: String? = nil,
sessionToken: String? = nil,
accountId: String? = nil,
expiration: Date? = nil) throws {

let expirationTimeout: UInt64
Expand All @@ -45,19 +41,20 @@ public final class Credentials {
guard let rawValue = (withByteCursorFromStrings(
accessKey,
secret,
sessionToken) { accessKeyCursor, secretCursor, sessionTokenCursor in
sessionToken,
accountId) { accessKeyCursor, secretCursor, sessionTokenCursor, accountIdCursor in

return aws_credentials_new(
return aws_credentials_new_with_account_id(
allocator.rawValue,
accessKeyCursor,
secretCursor,
sessionTokenCursor,
accountIdCursor,
expirationTimeout)
}) else {
throw CommonRunTimeError.crtError(.makeFromLastError())
}
self.rawValue = rawValue
self.accountId = accountId
}

/// Gets the access key from the `aws_credentials` instance
Expand All @@ -76,15 +73,6 @@ public final class Credentials {
return secret.toOptionalString()
}

/// Gets the account ID from the `Credentials`, if any.
///
/// Temporarily, `accountId` is backed by a Swift instance variable.
/// In the future, when the C implementation implements `account_id` the implementation will get account ID from the `aws_credentials` instance.
/// - Returns:`String?`: The AWS `accountId` or nil
public func getAccountId() -> String? {
accountId
}

/// Gets the session token from the `aws_credentials` instance
///
/// - Returns:`String?`: The AWS Session token or nil
Expand All @@ -93,6 +81,14 @@ public final class Credentials {
return token.toOptionalString()
}

/// Gets the account id from the `aws_credentials` instance
///
/// - Returns:`String?`: The account id or nil
public func getAccountId() -> String? {
let accountId = aws_credentials_get_account_id(rawValue)
return accountId.toOptionalString()
}

/// Gets the expiration timeout from the `aws_credentials` instance
///
/// - Returns:`Data?`: The timeout in seconds of when the credentials expire.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ public class CredentialsProvider: CredentialsProviding {

let rawValue: UnsafeMutablePointer<aws_credentials_provider>

// TODO: remove this property once aws-c-auth supports account_id
private let accountId: String?

init(credentialsProvider: UnsafeMutablePointer<aws_credentials_provider>, accountId: String? = nil) {
init(credentialsProvider: UnsafeMutablePointer<aws_credentials_provider>) {
self.rawValue = credentialsProvider
self.accountId = accountId
}

/// Retrieves credentials from a provider by calling its implementation of get credentials and returns them to
Expand All @@ -29,10 +25,7 @@ public class CredentialsProvider: CredentialsProviding {
/// - Throws: CommonRuntimeError.crtError
public func getCredentials() async throws -> Credentials {
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Credentials, Error>) in
let continuationCore = ContinuationCore(
continuation: continuation,
userData: ["accountId": accountId as Any]
)
let continuationCore = ContinuationCore(continuation: continuation)
if aws_credentials_provider_get_credentials(rawValue,
onGetCredentials,
continuationCore.passRetained()) != AWS_OP_SUCCESS {
Expand All @@ -59,14 +52,6 @@ extension CredentialsProvider {
self.init(credentialsProvider: unsafeProvider)
}

// TODO: Remove the following initializer when aws-c-auth provides account_id in credentials
/// Creates a credentials provider that sources the credentials from the provided source and `accountId`
@_spi(AccountIDTempSupport)
public convenience init(source: Source, accountId: String?) throws {
let unsafeProvider = try source.makeProvider()
self.init(credentialsProvider: unsafeProvider, accountId: accountId)
}

/// Create a credentials provider that depends on provider to fetch the credentials.
/// It will retain the provider until shutdown callback is triggered for AwsCredentialsProvider
/// - Parameters:
Expand Down Expand Up @@ -98,12 +83,14 @@ extension CredentialsProvider.Source {
/// - accessKey: The access key to use.
/// - secret: The secret to use.
/// - sessionToken: (Optional) Session token to use.
/// - accountId: (Optional) Account id to use.
/// - shutdownCallback: (Optional) shutdown callback
/// - Returns: `CredentialsProvider`
/// - Throws: CommonRuntimeError.crtError
public static func `static`(accessKey: String,
secret: String,
sessionToken: String? = nil,
accountId: String? = nil,
shutdownCallback: ShutdownCallback? = nil) -> Self {
Self {

Expand All @@ -113,10 +100,12 @@ extension CredentialsProvider.Source {
guard let provider: UnsafeMutablePointer<aws_credentials_provider> = withByteCursorFromStrings(
accessKey,
secret,
sessionToken, { accessKeyCursor, secretCursor, sessionTokenCursor in
sessionToken,
accountId, { accessKeyCursor, secretCursor, sessionTokenCursor, accountIdCursor in
staticOptions.access_key_id = accessKeyCursor
staticOptions.secret_access_key = secretCursor
staticOptions.session_token = sessionTokenCursor
staticOptions.account_id = accountIdCursor
return aws_credentials_provider_new_static(allocator.rawValue, &staticOptions)
})
else {
Expand Down Expand Up @@ -583,8 +572,7 @@ private func onGetCredentials(credentials: OpaquePointer?,
}

// Success
let accountId = continuationCore.userData?["accountId"] as? String
continuationCore.continuation.resume(returning: Credentials(rawValue: credentials!, accountId: accountId))
continuationCore.continuation.resume(returning: Credentials(rawValue: credentials!))
}

// We need to share this pointer to C in a task block but Swift compiler complains
Expand Down
6 changes: 1 addition & 5 deletions Source/AwsCommonRuntimeKit/crt/ContinuationCore.swift
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0.

// TODO: Remove userData property once it is no longer needed for accountId on credentials

/// Core classes have manual memory management.
/// You have to balance the retain & release calls in all cases to avoid leaking memory.
class ContinuationCore<T> {
let continuation: CheckedContinuation<T, Error>
let userData: [String: Any]?

init(continuation: CheckedContinuation<T, Error>, userData: [String: Any]? = nil) {
init(continuation: CheckedContinuation<T, Error>) {
self.continuation = continuation
self.userData = userData
}

func passRetained() -> UnsafeMutableRawPointer {
Expand Down
23 changes: 18 additions & 5 deletions Test/AwsCommonRuntimeKitTests/auth/CredentialsProviderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
// SPDX-License-Identifier: Apache-2.0.

import XCTest
@_spi(AccountIDTempSupport) @testable import AwsCommonRuntimeKit
@testable import AwsCommonRuntimeKit

class CredentialsProviderTests: XCBaseTestCase {
let accessKey = "AccessKey"
let secret = "Sekrit"
var accountId: String? = nil
let sessionToken = "Token"

let shutdownWasCalled = XCTestExpectation(description: "Shutdown callback was called")
Expand Down Expand Up @@ -69,17 +68,31 @@ class CredentialsProviderTests: XCBaseTestCase {
wait(for: [shutdownWasCalled], timeout: 15)
}

// TODO: change this test to not pass accountId separately once the source function handles it
func testCreateCredentialsProviderStatic() async throws {
accountId = "0123456789"
do {
let provider = try CredentialsProvider(source: .static(accessKey: accessKey,
secret: secret,
sessionToken: sessionToken,
shutdownCallback: getShutdownCallback()), accountId: accountId)
shutdownCallback: getShutdownCallback()))
let credentials = try await provider.getCredentials()
XCTAssertNotNil(credentials)
assertCredentials(credentials: credentials)
}
wait(for: [shutdownWasCalled], timeout: 15)
}

func testCreateCredentialsProviderStaticWithAccountId() async throws {
do {
let accountId = "Account ID"
let provider = try CredentialsProvider(source: .static(accessKey: accessKey,
secret: secret,
sessionToken: sessionToken,
accountId: accountId,
shutdownCallback: getShutdownCallback()))
let credentials = try await provider.getCredentials()
XCTAssertNotNil(credentials)
assertCredentials(credentials: credentials)
XCTAssertEqual(accountId, credentials.getAccountId())
}
wait(for: [shutdownWasCalled], timeout: 15)
}
Expand Down
20 changes: 1 addition & 19 deletions Test/AwsCommonRuntimeKitTests/auth/CredentialsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ class CredentialsTests: XCBaseTestCase {
func testCreateAWSCredentials() async throws {
let accessKey = "AccessKey"
let secret = "Secret"
let accountId = "0123456789"
let sessionToken = "Token"
let expiration = Date(timeIntervalSinceNow: 10)

let credentials = try Credentials(accessKey: accessKey, secret: secret, accountId: accountId, sessionToken: sessionToken, expiration: expiration)
let credentials = try Credentials(accessKey: accessKey, secret: secret, sessionToken: sessionToken, expiration: expiration)

XCTAssertEqual(accessKey, credentials.getAccessKey())
XCTAssertEqual(secret, credentials.getSecret())
XCTAssertEqual(accountId, credentials.getAccountId())
XCTAssertEqual(sessionToken, credentials.getSessionToken())
XCTAssertEqual(UInt64(expiration.timeIntervalSince1970), UInt64(credentials.getExpiration()!.timeIntervalSince1970))

Expand All @@ -40,22 +38,6 @@ class CredentialsTests: XCBaseTestCase {
XCTAssertNil(credentials2.getExpiration())
}

func testCreateAWSCredentialsWithoutAccountId() async throws {
let accessKey = "AccessKey"
let secret = "Secret"
let sessionToken = "Token"
let expiration = Date(timeIntervalSinceNow: 10)

let credentials = try Credentials(accessKey: accessKey, secret: secret, accountId: nil, sessionToken: sessionToken, expiration: expiration)

XCTAssertEqual(accessKey, credentials.getAccessKey())
XCTAssertEqual(secret, credentials.getSecret())
XCTAssertNil(credentials.getAccountId())
XCTAssertEqual(sessionToken, credentials.getSessionToken())
XCTAssertEqual(UInt64(expiration.timeIntervalSince1970), UInt64(credentials.getExpiration()!.timeIntervalSince1970))

}

func testCreateAWSCredentialsWithoutSessionToken() async throws {
let accessKey = "AccessKey"
let secret = "Secret"
Expand Down
2 changes: 1 addition & 1 deletion aws-common-runtime/s2n
Submodule s2n updated 267 files
Loading