From 4aebaa6c8bc689bffd05d12bca41ed7c8b8e0c57 Mon Sep 17 00:00:00 2001 From: Waqar Ahmed Khan Date: Tue, 9 Jan 2024 11:03:01 -0800 Subject: [PATCH] Fix EventStream to properly handle utf-8 (#216) --- .../AwsCommonRuntimeKit/crt/AWSString.swift | 2 +- .../AwsCommonRuntimeKit/crt/Utilities.swift | 5 +- .../event-stream/EventStreamMessage.swift | 152 +++++++++--------- .../event-stream/EventStreamTests.swift | 6 +- .../http/HTTPTests.swift | 7 + 5 files changed, 92 insertions(+), 80 deletions(-) diff --git a/Source/AwsCommonRuntimeKit/crt/AWSString.swift b/Source/AwsCommonRuntimeKit/crt/AWSString.swift index b2bfc19bb..31927b35d 100644 --- a/Source/AwsCommonRuntimeKit/crt/AWSString.swift +++ b/Source/AwsCommonRuntimeKit/crt/AWSString.swift @@ -6,7 +6,7 @@ final class AWSString { let rawValue: UnsafeMutablePointer init(_ str: String) { - self.rawValue = aws_string_new_from_array(allocator.rawValue, str, str.count) + self.rawValue = aws_string_new_from_array(allocator.rawValue, str, str.utf8.count) } var count: Int { diff --git a/Source/AwsCommonRuntimeKit/crt/Utilities.swift b/Source/AwsCommonRuntimeKit/crt/Utilities.swift index 5a7900abd..09a96c632 100644 --- a/Source/AwsCommonRuntimeKit/crt/Utilities.swift +++ b/Source/AwsCommonRuntimeKit/crt/Utilities.swift @@ -32,7 +32,7 @@ extension String { func withByteCursor(_ body: (aws_byte_cursor) -> Result ) -> Result { return self.withCString { arg1C in - return body(aws_byte_cursor_from_c_str(arg1C)) + return body(aws_byte_cursor_from_array(arg1C, self.utf8.count)) } } @@ -94,6 +94,9 @@ extension aws_byte_buf { } func toData() -> Data { + if self.len == 0 { + return Data() + } return Data(bytes: self.buffer, count: self.len) } } diff --git a/Source/AwsCommonRuntimeKit/event-stream/EventStreamMessage.swift b/Source/AwsCommonRuntimeKit/event-stream/EventStreamMessage.swift index 4c1660407..ceb3f3c8d 100644 --- a/Source/AwsCommonRuntimeKit/event-stream/EventStreamMessage.swift +++ b/Source/AwsCommonRuntimeKit/event-stream/EventStreamMessage.swift @@ -47,90 +47,90 @@ public struct EventStreamMessage { extension EventStreamMessage { func addHeader(header: EventStreamHeader, rawHeaders: UnsafeMutablePointer) throws { - if header.name.count > EventStreamHeader.maxNameLength { + let headerNameLength = header.name.utf8.count + if headerNameLength > EventStreamHeader.maxNameLength { throw CommonRunTimeError.crtError( .init( code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue)) } let addCHeader: () throws -> Int32 = { - return try header.name.withCString { headerName in - switch header.value { - case .bool(let value): - return aws_event_stream_add_bool_header( - rawHeaders, - headerName, - UInt8(header.name.count), - Int8(value.uintValue)) - case .byte(let value): - return aws_event_stream_add_byte_header( - rawHeaders, - headerName, - UInt8(header.name.count), - value) - case .int16(let value): - return aws_event_stream_add_int16_header( - rawHeaders, - headerName, - UInt8(header.name.count), - value) - case .int32(let value): - return aws_event_stream_add_int32_header( + let headerNameLength = UInt8(headerNameLength) + switch header.value { + case .bool(let value): + return aws_event_stream_add_bool_header( + rawHeaders, + header.name, + headerNameLength, + Int8(value.uintValue)) + case .byte(let value): + return aws_event_stream_add_byte_header( + rawHeaders, + header.name, + headerNameLength, + value) + case .int16(let value): + return aws_event_stream_add_int16_header( + rawHeaders, + header.name, + headerNameLength, + value) + case .int32(let value): + return aws_event_stream_add_int32_header( + rawHeaders, + header.name, + headerNameLength, + value) + case .int64(let value): + return aws_event_stream_add_int64_header( + rawHeaders, + header.name, + headerNameLength, + value) + case .byteBuf(var value): + let valueCount = value.count + if valueCount > EventStreamHeader.maxValueLength { + throw CommonRunTimeError.crtError( + .init( + code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue)) + } + return value.withUnsafeMutableBytes { + let bytes = $0.bindMemory(to: UInt8.self).baseAddress! + return aws_event_stream_add_bytebuf_header( rawHeaders, - headerName, - UInt8(header.name.count), - value) - case .int64(let value): - return aws_event_stream_add_int64_header( + header.name, + headerNameLength, + bytes, + UInt16(valueCount), + 1) + } + case .string(let value): + let valueCount = value.utf8.count + if valueCount > EventStreamHeader.maxValueLength { + throw CommonRunTimeError.crtError( + .init( + code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue)) + } + return aws_event_stream_add_string_header( rawHeaders, - headerName, - UInt8(header.name.count), - value) - case .byteBuf(var value): - if value.count > EventStreamHeader.maxValueLength { - throw CommonRunTimeError.crtError( - .init( - code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue)) - } - return value.withUnsafeMutableBytes { - let bytes = $0.bindMemory(to: UInt8.self).baseAddress! - return aws_event_stream_add_bytebuf_header( - rawHeaders, - headerName, - UInt8(header.name.count), - bytes, - UInt16($0.count), - 1) - } - case .string(let value): - if value.count > EventStreamHeader.maxValueLength { - throw CommonRunTimeError.crtError( - .init( - code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue)) - } - return value.withCString { - aws_event_stream_add_string_header( - rawHeaders, - headerName, - UInt8(header.name.count), - $0, - UInt16(value.count), - 1) - } - case .timestamp(let value): - return aws_event_stream_add_timestamp_header( + header.name, + headerNameLength, + value, + UInt16(valueCount), + 1) + case .timestamp(let value): + return aws_event_stream_add_timestamp_header( + rawHeaders, + header.name, + headerNameLength, + Int64(value.millisecondsSince1970)) + case .uuid(let value): + return withUnsafeBytes(of: value) { + let address = $0.baseAddress?.assumingMemoryBound(to: UInt8.self) + return aws_event_stream_add_uuid_header( rawHeaders, - headerName, - UInt8(header.name.count), - Int64(value.millisecondsSince1970)) - case .uuid(let value): - return withUnsafeBytes(of: value) { - let address = $0.baseAddress?.assumingMemoryBound(to: UInt8.self) - return aws_event_stream_add_uuid_header( - rawHeaders, - headerName, - UInt8(header.name.count), - address) - } + header.name, + headerNameLength, + address) } } } diff --git a/Test/AwsCommonRuntimeKitTests/event-stream/EventStreamTests.swift b/Test/AwsCommonRuntimeKitTests/event-stream/EventStreamTests.swift index 9b20bfda2..18bfd7b9a 100644 --- a/Test/AwsCommonRuntimeKitTests/event-stream/EventStreamTests.swift +++ b/Test/AwsCommonRuntimeKitTests/event-stream/EventStreamTests.swift @@ -18,8 +18,10 @@ class EventStreamTests: XCBaseTestCase { EventStreamHeader(name: "int32", value: .int32(value: 32)), EventStreamHeader(name: "int64", value: .int32(value: 64)), EventStreamHeader(name: "byteBuf", value: .byteBuf(value: "data".data(using: .utf8)!)), + EventStreamHeader(name: "emptyByteBuf", value: .byteBuf(value: Data())), EventStreamHeader(name: "host", value: .string(value: "aws-crt-test-stuff.s3.amazonaws.com")), EventStreamHeader(name: "host", value: .string(value: "aws-crt-test-stuff.s3.amazonaws.com")), + EventStreamHeader(name: "headerWithUtf8Character🧐", value: .string(value: "testValueWithEmoji🤯")), EventStreamHeader(name: "bool", value: .bool(value: false)), EventStreamHeader(name: "timestamp", value: .timestamp(value: Date(timeIntervalSinceNow: 10))), EventStreamHeader(name: "uuid", value: .uuid(value: UUID(uuidString: "63318232-1C63-4D04-9A0C-6907F347704E")!)), @@ -32,8 +34,8 @@ class EventStreamTests: XCBaseTestCase { XCTFail("OnPayload callback is triggered unexpectedly.") }, onPreludeReceived: { totalLength, headersLength in - XCTAssertEqual(totalLength, 210) - XCTAssertEqual(headersLength, 194) + XCTAssertEqual(totalLength, 279) + XCTAssertEqual(headersLength, 263) }, onHeaderReceived: { header in decodedHeaders.append(header) diff --git a/Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift b/Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift index b811df57a..2f9d02723 100644 --- a/Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift +++ b/Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift @@ -15,6 +15,13 @@ class HTTPTests: HTTPClientTestFixture { _ = try await sendHTTPRequest(method: "GET", endpoint: host, path: getPath, connectionManager: connectionManager) _ = try await sendHTTPRequest(method: "GET", endpoint: host, path: "/delete", expectedStatus: 404, connectionManager: connectionManager) } + + func testGetHTTPSRequestWithUtf8Header() async throws { + let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let utf8Header = HTTPHeader(name: "TestHeader", value: "TestValueWithEmoji🤯") + let headers = try await sendHTTPRequest(method: "GET", endpoint: host, path: "/response-headers?\(utf8Header.name)=\(utf8Header.value)", connectionManager: connectionManager).headers + XCTAssertTrue(headers.contains(where: {$0.name == utf8Header.name && $0.value==utf8Header.value})) + } func testGetHTTPRequest() async throws { let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: false, port: 80)