diff --git a/clients/client-s3/test/unit/flexibleChecksums.spec.ts b/clients/client-s3/test/unit/flexibleChecksums.spec.ts index 710b513c1b30d..32b56e10e0af1 100644 --- a/clients/client-s3/test/unit/flexibleChecksums.spec.ts +++ b/clients/client-s3/test/unit/flexibleChecksums.spec.ts @@ -1,4 +1,9 @@ -import { ChecksumAlgorithm } from "@aws-sdk/middleware-flexible-checksums"; +import { + ChecksumAlgorithm, + DEFAULT_CHECKSUM_ALGORITHM, + RequestChecksumCalculation, + ResponseChecksumValidation, +} from "@aws-sdk/middleware-flexible-checksums"; import { HttpRequest } from "@smithy/protocol-http"; import { BuildMiddleware } from "@smithy/types"; import { Readable } from "stream"; @@ -7,7 +12,7 @@ import { describe, expect, test as it } from "vitest"; import { ChecksumAlgorithm as Algo, S3 } from "../../src/index"; describe("Flexible Checksums", () => { - const testCases = [ + const testCases: [string, string | undefined, string][] = [ ["", ChecksumAlgorithm.CRC32, "AAAAAA=="], ["abc", ChecksumAlgorithm.CRC32, "NSRBwg=="], ["Hello world", ChecksumAlgorithm.CRC32, "i9aeUg=="], @@ -23,148 +28,204 @@ describe("Flexible Checksums", () => { ["", ChecksumAlgorithm.SHA256, "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU="], ["abc", ChecksumAlgorithm.SHA256, "ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0="], ["Hello world", ChecksumAlgorithm.SHA256, "ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw="], + + // Choose default checksum algorithm when explicily not provided. + ["Hello world", undefined, "i9aeUg=="], ]; describe("putObject", () => { - testCases.forEach(([body, checksumAlgorithm, checksumValue]) => { - const checksumHeader = `x-amz-checksum-${checksumAlgorithm.toLowerCase()}`; - - describe(`sets ${checksumHeader}="${checksumValue}"" for checksum="${checksumAlgorithm}"`, () => { - const getBodyAsReadableStream = (content: string) => { - const readableStream = new Readable(); - const separator = " "; - const wordsAsChunks = content.split(separator); - wordsAsChunks.forEach((word, index) => { - readableStream.push(word); - if (index !== wordsAsChunks.length - 1) { - readableStream.push(separator); - } - }); - readableStream.push(null); - return readableStream; - }; - - it(`when body is sent as a request`, async () => { - const requestChecksumValidator: BuildMiddleware = (next) => async (args) => { - // middleware intercept the request and return it early - const request = args.request as HttpRequest; - const { headers } = request; - expect(headers["x-amz-sdk-checksum-algorithm"]).to.equal(checksumAlgorithm); - expect(headers[checksumHeader]).to.equal(checksumValue); - return { output: {} as any, response: {} as any }; - }; - - const client = new S3({ - region: "us-west-2", - credentials: { - accessKeyId: "CLIENT_TEST", - secretAccessKey: "CLIENT_TEST", - }, - }); - client.middlewareStack.addRelativeTo(requestChecksumValidator, { - relation: "after", - toMiddleware: "flexibleChecksumsMiddleware", - }); - - return await client.putObject({ - Bucket: "bucket", - Key: "key", - Body: body, - ChecksumAlgorithm: checksumAlgorithm as Algo, - }); - }); - - it(`when body is sent as a stream`, async () => { - const requestChecksumValidator: BuildMiddleware = (next) => async (args) => { - // middleware intercept the request and return it early - const request = args.request as HttpRequest; - const { headers, body } = request; - expect(headers["content-length"]).to.be.undefined; - expect(headers["content-encoding"]).to.equal("aws-chunked"); - expect(headers["transfer-encoding"]).to.equal("chunked"); - expect(headers["x-amz-content-sha256"]).to.equal("STREAMING-UNSIGNED-PAYLOAD-TRAILER"); - expect(headers["x-amz-trailer"]).to.equal(checksumHeader); - body.on("data", (data: any) => { - const stringValue = data.toString(); - if (stringValue.startsWith(checksumHeader)) { - const receivedChecksum = stringValue.replace("\r\n", "").split(":")[1]; - expect(receivedChecksum).to.equal(checksumValue); - } + describe.each([undefined, RequestChecksumCalculation.WHEN_SUPPORTED, RequestChecksumCalculation.WHEN_REQUIRED])( + `when requestChecksumCalculation='%s'`, + (requestChecksumCalculation) => { + describe.each(testCases)( + `for body="%s" and checksumAlgorithm="%s", sets checksum="%s"`, + (body, checksumAlgorithm, checksumValue) => { + const checksumHeader = `x-amz-checksum-${(checksumAlgorithm ?? DEFAULT_CHECKSUM_ALGORITHM).toLowerCase()}`; + const getBodyAsReadableStream = (content: string) => { + const readableStream = new Readable(); + const separator = " "; + const wordsAsChunks = content.split(separator); + wordsAsChunks.forEach((word, index) => { + readableStream.push(word); + if (index !== wordsAsChunks.length - 1) { + readableStream.push(separator); + } + }); + readableStream.push(null); + return readableStream; + }; + + it(`when body is sent as a string`, async () => { + const requestChecksumValidator: BuildMiddleware = (next) => async (args) => { + // middleware intercept the request and return it early + const request = args.request as HttpRequest; + const { headers } = request; + + // Headers are not set when checksumAlgorithm is not provided, + // and requestChecksumCalculation is explicitly set to WHEN_SUPPORTED. + if ( + checksumAlgorithm === undefined && + requestChecksumCalculation === RequestChecksumCalculation.WHEN_REQUIRED + ) { + expect(headers["x-amz-sdk-checksum-algorithm"]).toBeUndefined(); + expect(headers[checksumHeader]).toBeUndefined(); + } else { + expect(headers["x-amz-sdk-checksum-algorithm"]).toEqual( + checksumAlgorithm ?? DEFAULT_CHECKSUM_ALGORITHM + ); + expect(headers[checksumHeader]).toEqual(checksumValue); + } + + return { output: {} as any, response: {} as any }; + }; + + const client = new S3({ + region: "us-west-2", + credentials: { + accessKeyId: "CLIENT_TEST", + secretAccessKey: "CLIENT_TEST", + }, + requestChecksumCalculation, + }); + client.middlewareStack.addRelativeTo(requestChecksumValidator, { + relation: "after", + toMiddleware: "flexibleChecksumsMiddleware", + }); + + return await client.putObject({ + Bucket: "bucket", + Key: "key", + Body: body, + ChecksumAlgorithm: checksumAlgorithm as Algo, + }); + }); + + it(`when body is sent as a stream`, async () => { + const requestChecksumValidator: BuildMiddleware = (next) => async (args) => { + // middleware intercept the request and return it early + const request = args.request as HttpRequest; + const { headers, body } = request; + expect(headers["content-length"]).toBeUndefined(); + + // Headers are not set when checksumAlgorithm is not provided, + // and requestChecksumCalculation is explicitly set to WHEN_SUPPORTED. + if ( + checksumAlgorithm === undefined && + requestChecksumCalculation === RequestChecksumCalculation.WHEN_REQUIRED + ) { + expect(headers["content-encoding"]).toBeUndefined(); + expect(headers["transfer-encoding"]).toBeUndefined(); + expect(headers["x-amz-content-sha256"]).toBeUndefined(); + expect(headers["x-amz-trailer"]).toBeUndefined(); + } else { + expect(headers["content-encoding"]).toEqual("aws-chunked"); + expect(headers["transfer-encoding"]).toEqual("chunked"); + expect(headers["x-amz-content-sha256"]).toEqual("STREAMING-UNSIGNED-PAYLOAD-TRAILER"); + expect(headers["x-amz-trailer"]).toEqual(checksumHeader); + } + body.on("data", (data: any) => { + const stringValue = data.toString(); + if (stringValue.startsWith(checksumHeader)) { + const receivedChecksum = stringValue.replace("\r\n", "").split(":")[1]; + expect(receivedChecksum).toEqual(checksumValue); + } + }); + return { output: {} as any, response: {} as any }; + }; + + const client = new S3({ + region: "us-west-2", + credentials: { + accessKeyId: "CLIENT_TEST", + secretAccessKey: "CLIENT_TEST", + }, + requestChecksumCalculation, + }); + client.middlewareStack.addRelativeTo(requestChecksumValidator, { + relation: "after", + toMiddleware: "flexibleChecksumsMiddleware", + }); + + const bodyStream = getBodyAsReadableStream(body); + await client.putObject({ + Bucket: "bucket", + Key: "key", + Body: bodyStream, + ChecksumAlgorithm: checksumAlgorithm as Algo, + }); }); - return { output: {} as any, response: {} as any }; - }; - - const client = new S3({ - region: "us-west-2", - credentials: { - accessKeyId: "CLIENT_TEST", - secretAccessKey: "CLIENT_TEST", - }, - }); - client.middlewareStack.addRelativeTo(requestChecksumValidator, { - relation: "after", - toMiddleware: "flexibleChecksumsMiddleware", - }); - - const bodyStream = getBodyAsReadableStream(body); - await client.putObject({ - Bucket: "bucket", - Key: "key", - Body: bodyStream, - ChecksumAlgorithm: checksumAlgorithm as Algo, - }); - }); - }); - }); + } + ); + } + ); }); describe("getObject", async () => { - testCases.forEach(([body, checksumAlgorithm, checksumValue]) => { - const checksumHeader = `x-amz-checksum-${checksumAlgorithm.toLowerCase()}`; - - it(`validates ${checksumHeader}="${checksumValue}"" set for checksum="${checksumAlgorithm}"`, async () => { - const responseBody = new Readable(); - responseBody.push(body); - responseBody.push(null); - const responseChecksumValidator: BuildMiddleware = (next, context) => async (args) => { - const request = args.request as HttpRequest; - return { - output: { - $metadata: { attempts: 0, httpStatusCode: 200 }, - request, - context, - Body: responseBody, - } as any, - response: { - body: responseBody, - headers: { - [checksumHeader]: checksumValue, + describe.each([undefined, ResponseChecksumValidation.WHEN_SUPPORTED, ResponseChecksumValidation.WHEN_REQUIRED])( + `when responseChecksumValidation='%s'`, + (responseChecksumValidation) => { + it.each(testCases)( + `for body="%s" and checksumAlgorithm="%s", validates ChecksumMode`, + async (body, checksumAlgorithm, checksumValue) => { + const checksumHeader = `x-amz-checksum-${(checksumAlgorithm ?? DEFAULT_CHECKSUM_ALGORITHM).toLowerCase()}`; + + const responseBody = new Readable(); + responseBody.push(body); + responseBody.push(null); + const responseChecksumValidator: BuildMiddleware = (next, context) => async (args) => { + // ChecksumMode is not set when checksumAlgorithm is not provided, + // and responseChecksumValidation is explicitly set to WHEN_SUPPORTED. + if ( + checksumAlgorithm === undefined && + responseChecksumValidation === ResponseChecksumValidation.WHEN_REQUIRED + ) { + expect(args.input.ChecksumMode).toBeUndefined(); + } else { + expect(args.input.ChecksumMode).toEqual("ENABLED"); + } + + const request = args.request as HttpRequest; + return { + output: { + $metadata: { attempts: 0, httpStatusCode: 200 }, + request, + context, + Body: responseBody, + } as any, + response: { + body: responseBody, + headers: { + [checksumHeader]: checksumValue, + }, + } as any, + }; + }; + + const client = new S3({ + region: "us-west-2", + credentials: { + accessKeyId: "CLIENT_TEST", + secretAccessKey: "CLIENT_TEST", }, - } as any, - }; - }; - - const client = new S3({ - region: "us-west-2", - credentials: { - accessKeyId: "CLIENT_TEST", - secretAccessKey: "CLIENT_TEST", - }, - }); - client.middlewareStack.addRelativeTo(responseChecksumValidator, { - relation: "after", - toMiddleware: "flexibleChecksumsMiddleware", - }); - - const { Body } = await client.getObject({ - Bucket: "bucket", - Key: "key", - ChecksumMode: "ENABLED", - }); - (Body as Readable).on("data", (chunk) => { - expect(chunk.toString()).to.equal(body); - }); - }); - }); + responseChecksumValidation, + }); + client.middlewareStack.addRelativeTo(responseChecksumValidator, { + relation: "after", + toMiddleware: "flexibleChecksumsMiddleware", + }); + + const { Body } = await client.getObject({ + Bucket: "bucket", + Key: "key", + // Do not pass ChecksumMode if algorithm is not explicitly defined. It'll be set by SDK. + ChecksumMode: checksumAlgorithm ? "ENABLED" : undefined, + }); + (Body as Readable).on("data", (chunk) => { + expect(chunk.toString()).toEqual(body); + }); + } + ); + } + ); }); }); diff --git a/packages/middleware-flexible-checksums/src/configuration.ts b/packages/middleware-flexible-checksums/src/configuration.ts index e92871b491b52..a6a066c749242 100644 --- a/packages/middleware-flexible-checksums/src/configuration.ts +++ b/packages/middleware-flexible-checksums/src/configuration.ts @@ -4,10 +4,13 @@ import { Encoder, GetAwsChunkedEncodingStream, HashConstructor, + Provider, StreamCollector, StreamHasher, } from "@smithy/types"; +import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants"; + export interface PreviouslyResolved { /** * The function that will be used to convert binary data to a base64-encoded string. @@ -31,6 +34,16 @@ export interface PreviouslyResolved { */ md5: ChecksumConstructor | HashConstructor; + /** + * Determines when a checksum will be calculated for request payloads + */ + requestChecksumCalculation: Provider; + + /** + * Determines when a checksum will be calculated for response payloads + */ + responseChecksumValidation: Provider; + /** * A constructor for a class implementing the {@link Hash} interface that computes SHA1 hashes. * @internal diff --git a/packages/middleware-flexible-checksums/src/constants.ts b/packages/middleware-flexible-checksums/src/constants.ts index 1a1d13a2efa7f..e55ca2cc07e2b 100644 --- a/packages/middleware-flexible-checksums/src/constants.ts +++ b/packages/middleware-flexible-checksums/src/constants.ts @@ -52,6 +52,9 @@ export const DEFAULT_RESPONSE_CHECKSUM_VALIDATION = RequestChecksumCalculation.W * Checksum Algorithms supported by the SDK. */ export enum ChecksumAlgorithm { + /** + * @deprecated Use {@link ChecksumAlgorithm.CRC32} instead. + */ MD5 = "MD5", CRC32 = "CRC32", CRC32C = "CRC32C", @@ -70,9 +73,4 @@ export enum ChecksumLocation { /** * @internal */ -export const DEFAULT_CHECKSUM_ALGORITHM = ChecksumAlgorithm.MD5; - -/** - * @internal - */ -export const S3_EXPRESS_DEFAULT_CHECKSUM_ALGORITHM = ChecksumAlgorithm.CRC32; +export const DEFAULT_CHECKSUM_ALGORITHM = ChecksumAlgorithm.CRC32; diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.spec.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.spec.ts new file mode 100644 index 0000000000000..7f715114d18c1 --- /dev/null +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.spec.ts @@ -0,0 +1,103 @@ +import { setFeature } from "@aws-sdk/core"; +import { afterEach, describe, expect, test as it, vi } from "vitest"; + +import { PreviouslyResolved } from "./configuration"; +import { DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation, ResponseChecksumValidation } from "./constants"; +import { flexibleChecksumsInputMiddleware } from "./flexibleChecksumsInputMiddleware"; + +vi.mock("@aws-sdk/core"); + +describe(flexibleChecksumsInputMiddleware.name, () => { + const mockNext = vi.fn(); + const mockRequestValidationModeMember = "mockRequestValidationModeMember"; + + const mockConfig = { + requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_SUPPORTED), + responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_SUPPORTED), + } as PreviouslyResolved; + + afterEach(() => { + expect(mockNext).toHaveBeenCalledTimes(1); + vi.clearAllMocks(); + }); + + describe("sets input.requestValidationModeMember", () => { + it("when requestValidationModeMember is defined and responseChecksumValidation is supported", async () => { + const mockMiddlewareConfigWithMockRequestValidationModeMember = { + requestValidationModeMember: mockRequestValidationModeMember, + }; + const handler = flexibleChecksumsInputMiddleware( + mockConfig, + mockMiddlewareConfigWithMockRequestValidationModeMember + )(mockNext, {}); + await handler({ input: {} }); + expect(mockNext).toHaveBeenCalledWith({ input: { [mockRequestValidationModeMember]: "ENABLED" } }); + }); + }); + + describe("leaves input.requestValidationModeMember", () => { + const mockArgs = { input: {} }; + + it("when requestValidationModeMember is not defined", async () => { + const handler = flexibleChecksumsInputMiddleware(mockConfig, {})(mockNext, {}); + await handler(mockArgs); + expect(mockNext).toHaveBeenCalledWith(mockArgs); + }); + + it("when responseChecksumValidation is required", async () => { + const mockConfigResWhenRequired = { + ...mockConfig, + responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_REQUIRED), + } as PreviouslyResolved; + + const handler = flexibleChecksumsInputMiddleware(mockConfigResWhenRequired, {})(mockNext, {}); + await handler(mockArgs); + + expect(mockNext).toHaveBeenCalledWith(mockArgs); + }); + }); + + describe("set feature", () => { + it.each([ + [ + "FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED", + "a", + "requestChecksumCalculation", + RequestChecksumCalculation.WHEN_REQUIRED, + ], + [ + "FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED", + "Z", + "requestChecksumCalculation", + RequestChecksumCalculation.WHEN_SUPPORTED, + ], + [ + "FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED", + "c", + "responseChecksumValidation", + ResponseChecksumValidation.WHEN_REQUIRED, + ], + [ + "FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED", + "b", + "responseChecksumValidation", + ResponseChecksumValidation.WHEN_SUPPORTED, + ], + ])("logs %s:%s when %s=%s", async (feature, value, configKey, configValue) => { + const mockConfigOverride = { + ...mockConfig, + [configKey]: () => Promise.resolve(configValue), + } as PreviouslyResolved; + + const handler = flexibleChecksumsInputMiddleware(mockConfigOverride, {})(mockNext, {}); + await handler({ input: {} }); + + expect(setFeature).toHaveBeenCalledTimes(2); + if (configKey === "requestChecksumCalculation") { + expect(setFeature).toHaveBeenNthCalledWith(1, expect.anything(), feature, value); + } else { + expect(setFeature).toHaveBeenNthCalledWith(2, expect.anything(), feature, value); + } + }); + }); +}); diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.ts new file mode 100644 index 0000000000000..0dcb7c94cba92 --- /dev/null +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.ts @@ -0,0 +1,82 @@ +import { setFeature } from "@aws-sdk/core"; +import { + HandlerExecutionContext, + MetadataBearer, + RelativeMiddlewareOptions, + SerializeHandler, + SerializeHandlerArguments, + SerializeHandlerOutput, + SerializeMiddleware, +} from "@smithy/types"; + +import { PreviouslyResolved } from "./configuration"; +import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants"; + +export interface FlexibleChecksumsInputMiddlewareConfig { + /** + * Defines a top-level operation input member used to opt-in to best-effort validation + * of a checksum returned in the HTTP response of the operation. + */ + requestValidationModeMember?: string; +} + +/** + * @internal + */ +export const flexibleChecksumsInputMiddlewareOptions: RelativeMiddlewareOptions = { + name: "flexibleChecksumsInputMiddleware", + toMiddleware: "serializerMiddleware", + relation: "before", + tags: ["BODY_CHECKSUM"], + override: true, +}; + +/** + * @internal + * + * The input counterpart to the flexibleChecksumsMiddleware. + */ +export const flexibleChecksumsInputMiddleware = + ( + config: PreviouslyResolved, + middlewareConfig: FlexibleChecksumsInputMiddlewareConfig + ): SerializeMiddleware => + ( + next: SerializeHandler, + context: HandlerExecutionContext + ): SerializeHandler => + async (args: SerializeHandlerArguments): Promise> => { + const input = args.input; + const { requestValidationModeMember } = middlewareConfig; + + const requestChecksumCalculation = await config.requestChecksumCalculation(); + const responseChecksumValidation = await config.responseChecksumValidation(); + + switch (requestChecksumCalculation) { + case RequestChecksumCalculation.WHEN_REQUIRED: + setFeature(context, "FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED", "a"); + break; + case RequestChecksumCalculation.WHEN_SUPPORTED: + setFeature(context, "FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED", "Z"); + break; + } + + switch (responseChecksumValidation) { + case ResponseChecksumValidation.WHEN_REQUIRED: + setFeature(context, "FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED", "c"); + break; + case ResponseChecksumValidation.WHEN_SUPPORTED: + setFeature(context, "FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED", "b"); + break; + } + + // The value for input member to opt-in to best-effort validation of a checksum returned in the HTTP response is not set. + if (requestValidationModeMember && !input[requestValidationModeMember]) { + // Set requestValidationModeMember as ENABLED only if response checksum validation is supported. + if (responseChecksumValidation === ResponseChecksumValidation.WHEN_SUPPORTED) { + input[requestValidationModeMember] = "ENABLED"; + } + } + + return next(args); + }; diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts index 3d173e5841170..27afd7773cce3 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts @@ -3,7 +3,7 @@ import { BuildHandlerArguments } from "@smithy/types"; import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest"; import { PreviouslyResolved } from "./configuration"; -import { ChecksumAlgorithm } from "./constants"; +import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants"; import { flexibleChecksumsMiddleware } from "./flexibleChecksumsMiddleware"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { getChecksumLocationName } from "./getChecksumLocationName"; @@ -13,6 +13,7 @@ import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; import { stringHasher } from "./stringHasher"; +vi.mock("@aws-sdk/core"); vi.mock("@smithy/protocol-http"); vi.mock("./getChecksumAlgorithmForRequest"); vi.mock("./getChecksumLocationName"); @@ -28,10 +29,14 @@ describe(flexibleChecksumsMiddleware.name, () => { const mockChecksum = "mockChecksum"; const mockChecksumAlgorithmFunction = vi.fn(); const mockChecksumLocationName = "mock-checksum-location-name"; + const mockRequestAlgorithmMember = "mockRequestAlgorithmMember"; + const mockRequestAlgorithmMemberHttpHeader = "mock-request-algorithm-member-http-header"; const mockInput = {}; - const mockConfig = {} as PreviouslyResolved; - const mockMiddlewareConfig = { requestChecksumRequired: false }; + const mockConfig = { + requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_REQUIRED), + } as PreviouslyResolved; + const mockMiddlewareConfig = { input: mockInput, requestChecksumRequired: false }; const mockBody = { body: "mockRequestBody" }; const mockHeaders = { "content-length": 100, "content-encoding": "gzip" }; @@ -41,9 +46,8 @@ describe(flexibleChecksumsMiddleware.name, () => { beforeEach(() => { mockNext.mockResolvedValueOnce(mockResult); - const { isInstance } = HttpRequest; - (isInstance as unknown as any).mockReturnValue(true); - vi.mocked(getChecksumAlgorithmForRequest).mockReturnValue(ChecksumAlgorithm.MD5); + vi.mocked(HttpRequest.isInstance).mockReturnValue(true); + vi.mocked(getChecksumAlgorithmForRequest).mockReturnValue(ChecksumAlgorithm.CRC32); vi.mocked(getChecksumLocationName).mockReturnValue(mockChecksumLocationName); vi.mocked(hasHeader).mockReturnValue(true); vi.mocked(hasHeaderWithPrefix).mockReturnValue(false); @@ -58,8 +62,7 @@ describe(flexibleChecksumsMiddleware.name, () => { describe("skips", () => { it("if not an instance of HttpRequest", async () => { - const { isInstance } = HttpRequest; - (isInstance as unknown as any).mockReturnValue(false); + vi.mocked(HttpRequest.isInstance).mockReturnValue(false); const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {}); await handler(mockArgs); expect(getChecksumAlgorithmForRequest).not.toHaveBeenCalled(); @@ -77,7 +80,7 @@ describe(flexibleChecksumsMiddleware.name, () => { expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1); }); - it("if header is already present", async () => { + it("skip if header is already present", async () => { const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {}); vi.mocked(hasHeaderWithPrefix).mockReturnValue(true); @@ -94,11 +97,53 @@ describe(flexibleChecksumsMiddleware.name, () => { describe("adds checksum in the request header", () => { afterEach(() => { + expect(HttpRequest.isInstance).toHaveBeenCalledTimes(1); + expect(hasHeaderWithPrefix).toHaveBeenCalledTimes(1); expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1); expect(getChecksumLocationName).toHaveBeenCalledTimes(1); expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); }); + describe("if input.requestAlgorithmMember can be set", () => { + describe("input[requestAlgorithmMember] is not defined and", () => { + const mockMwConfigWithReqAlgoMember = { + ...mockMiddlewareConfig, + requestAlgorithmMember: { + name: mockRequestAlgorithmMember, + httpHeader: mockRequestAlgorithmMemberHttpHeader, + }, + }; + + it("requestChecksumCalculation is supported", async () => { + const handler = flexibleChecksumsMiddleware( + { + ...mockConfig, + requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_SUPPORTED), + }, + mockMwConfigWithReqAlgoMember + )(mockNext, {}); + await handler(mockArgs); + expect(mockNext.mock.calls[0][0].input[mockRequestAlgorithmMember]).toEqual(DEFAULT_CHECKSUM_ALGORITHM); + expect(mockNext.mock.calls[0][0].request.headers[mockRequestAlgorithmMemberHttpHeader]).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); + + it("requestChecksumRequired is set to true", async () => { + const handler = flexibleChecksumsMiddleware(mockConfig, { + ...mockMwConfigWithReqAlgoMember, + requestChecksumRequired: true, + })(mockNext, {}); + + await handler(mockArgs); + expect(mockNext.mock.calls[0][0].input[mockRequestAlgorithmMember]).toEqual(DEFAULT_CHECKSUM_ALGORITHM); + expect(mockNext.mock.calls[0][0].request.headers[mockRequestAlgorithmMemberHttpHeader]).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); + }); + }); + it("for streaming body", async () => { vi.mocked(isStreaming).mockReturnValue(true); const mockUpdatedBody = { body: "mockUpdatedBody" }; diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts index 470a2fcc08e68..8872adde5d938 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts @@ -11,7 +11,7 @@ import { } from "@smithy/types"; import { PreviouslyResolved } from "./configuration"; -import { ChecksumAlgorithm } from "./constants"; +import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { getChecksumLocationName } from "./getChecksumLocationName"; import { hasHeader } from "./hasHeader"; @@ -73,15 +73,27 @@ export const flexibleChecksumsMiddleware = const { body: requestBody, headers } = request; const { base64Encoder, streamHasher } = config; const { requestChecksumRequired, requestAlgorithmMember } = middlewareConfig; + const requestChecksumCalculation = await config.requestChecksumCalculation(); - const checksumAlgorithm = getChecksumAlgorithmForRequest( - input, - { - requestChecksumRequired, - requestAlgorithmMember: requestAlgorithmMember?.name, - }, - !!context.isS3ExpressBucket - ); + const requestAlgorithmMemberName = requestAlgorithmMember?.name; + const requestAlgorithmMemberHttpHeader = requestAlgorithmMember?.httpHeader; + // The value for input member to configure flexible checksum is not set. + if (requestAlgorithmMemberName && !input[requestAlgorithmMemberName]) { + // Set requestAlgorithmMember as default checksum algorithm only if request checksum calculation is supported + // or request checksum is required. + if (requestChecksumCalculation === RequestChecksumCalculation.WHEN_SUPPORTED || requestChecksumRequired) { + input[requestAlgorithmMemberName] = DEFAULT_CHECKSUM_ALGORITHM; + if (requestAlgorithmMemberHttpHeader) { + headers[requestAlgorithmMemberHttpHeader] = DEFAULT_CHECKSUM_ALGORITHM; + } + } + } + + const checksumAlgorithm = getChecksumAlgorithmForRequest(input, { + requestChecksumRequired, + requestAlgorithmMember: requestAlgorithmMember?.name, + requestChecksumCalculation, + }); let updatedBody = requestBody; let updatedHeaders = headers; diff --git a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts index f333b56f9d9eb..7ac3ce5438d7a 100644 --- a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts +++ b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts @@ -1,6 +1,6 @@ import { describe, expect, test as it } from "vitest"; -import { ChecksumAlgorithm } from "./constants"; +import { DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { CLIENT_SUPPORTED_ALGORITHMS } from "./types"; @@ -8,36 +8,64 @@ describe(getChecksumAlgorithmForRequest.name, () => { const mockRequestAlgorithmMember = "mockRequestAlgorithmMember"; describe("when requestAlgorithmMember is not provided", () => { - it("returns MD5 if requestChecksumRequired is set", () => { - expect(getChecksumAlgorithmForRequest({}, { requestChecksumRequired: true })).toEqual(ChecksumAlgorithm.MD5); - }); + describe(`when requestChecksumCalculation is '${RequestChecksumCalculation.WHEN_REQUIRED}'`, () => { + const mockOptions = { requestChecksumCalculation: RequestChecksumCalculation.WHEN_REQUIRED }; + + it(`returns ${DEFAULT_CHECKSUM_ALGORITHM} if requestChecksumRequired is set`, () => { + expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: true })).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); - it("returns undefined if requestChecksumRequired is false", () => { - expect(getChecksumAlgorithmForRequest({}, { requestChecksumRequired: false })).toBeUndefined(); + it("returns undefined if requestChecksumRequired is false", () => { + expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: false })).toBeUndefined(); + }); }); - }); - describe("when requestAlgorithmMember is not set in input", () => { - const mockOptions = { requestAlgorithmMember: mockRequestAlgorithmMember }; + describe(`when requestChecksumCalculation is '${RequestChecksumCalculation.WHEN_SUPPORTED}'`, () => { + const mockOptions = { requestChecksumCalculation: RequestChecksumCalculation.WHEN_SUPPORTED }; - it("returns MD5 if requestChecksumRequired is set", () => { - expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: true })).toEqual( - ChecksumAlgorithm.MD5 - ); + it(`returns ${DEFAULT_CHECKSUM_ALGORITHM} if requestChecksumRequired is set`, () => { + expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: true })).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); + + it(`returns ${DEFAULT_CHECKSUM_ALGORITHM} if requestChecksumRequired is false`, () => { + expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: false })).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); }); + }); - it("returns undefined if requestChecksumRequired is false", () => { - expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: false })).toBeUndefined(); + describe("returns undefined if input[requestAlgorithmMember] is not set", () => { + describe.each([true, false])("when requestChecksumRequired='%s'", (requestChecksumRequired) => { + it.each([RequestChecksumCalculation.WHEN_SUPPORTED, RequestChecksumCalculation.WHEN_REQUIRED])( + "when requestChecksumCalculation='%s'", + (requestChecksumCalculation) => { + const mockOptions = { + requestChecksumRequired, + requestChecksumCalculation, + requestAlgorithmMember: mockRequestAlgorithmMember, + }; + expect(getChecksumAlgorithmForRequest({}, mockOptions)).toBeUndefined(); + } + ); }); }); it("throws error if input[requestAlgorithmMember] if not supported by client", () => { const unsupportedAlgo = "unsupportedAlgo"; const mockInput = { [mockRequestAlgorithmMember]: unsupportedAlgo }; - const mockOptions = { requestChecksumRequired: true, requestAlgorithmMember: mockRequestAlgorithmMember }; + const mockOptions = { + requestChecksumRequired: true, + requestAlgorithmMember: mockRequestAlgorithmMember, + requestChecksumCalculation: RequestChecksumCalculation.WHEN_REQUIRED, + }; expect(() => { getChecksumAlgorithmForRequest(mockInput, mockOptions); - }).toThrowError( + }).toThrow( `The checksum algorithm "${unsupportedAlgo}" is not supported by the client.` + ` Select one of ${CLIENT_SUPPORTED_ALGORITHMS}.` ); @@ -46,7 +74,11 @@ describe(getChecksumAlgorithmForRequest.name, () => { describe("returns input[requestAlgorithmMember] if supported by client", () => { it.each(CLIENT_SUPPORTED_ALGORITHMS)("Supported algorithm: %s", (supportedAlgorithm) => { const mockInput = { [mockRequestAlgorithmMember]: supportedAlgorithm }; - const mockOptions = { requestChecksumRequired: true, requestAlgorithmMember: mockRequestAlgorithmMember }; + const mockOptions = { + requestChecksumRequired: true, + requestAlgorithmMember: mockRequestAlgorithmMember, + requestChecksumCalculation: RequestChecksumCalculation.WHEN_REQUIRED, + }; expect(getChecksumAlgorithmForRequest(mockInput, mockOptions)).toEqual(supportedAlgorithm); }); }); diff --git a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.ts b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.ts index f3cba4b2313f3..809b6714b24ba 100644 --- a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.ts +++ b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.ts @@ -1,4 +1,4 @@ -import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, S3_EXPRESS_DEFAULT_CHECKSUM_ALGORITHM } from "./constants"; +import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants"; import { CLIENT_SUPPORTED_ALGORITHMS } from "./types"; export interface GetChecksumAlgorithmForRequestOptions { @@ -11,6 +11,11 @@ export interface GetChecksumAlgorithmForRequestOptions { * Defines a top-level operation input member that is used to configure request checksum behavior. */ requestAlgorithmMember?: string; + + /** + * Determines when a checksum will be calculated for request payloads + */ + requestChecksumCalculation: RequestChecksumCalculation; } /** @@ -20,16 +25,19 @@ export interface GetChecksumAlgorithmForRequestOptions { */ export const getChecksumAlgorithmForRequest = ( input: any, - { requestChecksumRequired, requestAlgorithmMember }: GetChecksumAlgorithmForRequestOptions, - isS3Express?: boolean + { requestChecksumRequired, requestAlgorithmMember, requestChecksumCalculation }: GetChecksumAlgorithmForRequestOptions ): ChecksumAlgorithm | undefined => { - const defaultAlgorithm = isS3Express ? S3_EXPRESS_DEFAULT_CHECKSUM_ALGORITHM : DEFAULT_CHECKSUM_ALGORITHM; + // The Operation input member that is used to configure request checksum behavior is not set. + if (!requestAlgorithmMember) { + // Select an algorithm only if request checksum calculation is supported + // or request checksum is required. + return requestChecksumCalculation === RequestChecksumCalculation.WHEN_SUPPORTED || requestChecksumRequired + ? DEFAULT_CHECKSUM_ALGORITHM + : undefined; + } - // Either the Operation input member that is used to configure request checksum behavior is not set, or - // the value for input member to configure flexible checksum is not set. - if (!requestAlgorithmMember || !input[requestAlgorithmMember]) { - // Select an algorithm only if request checksum is required. - return requestChecksumRequired ? defaultAlgorithm : undefined; + if (!input[requestAlgorithmMember]) { + return undefined; } const checksumAlgorithm = input[requestAlgorithmMember]; diff --git a/packages/middleware-flexible-checksums/src/getFlexibleChecksumsPlugin.ts b/packages/middleware-flexible-checksums/src/getFlexibleChecksumsPlugin.ts index 94dd3ecea9b07..0d6898ea5ea76 100644 --- a/packages/middleware-flexible-checksums/src/getFlexibleChecksumsPlugin.ts +++ b/packages/middleware-flexible-checksums/src/getFlexibleChecksumsPlugin.ts @@ -1,6 +1,11 @@ import { Pluggable } from "@smithy/types"; import { PreviouslyResolved } from "./configuration"; +import { + flexibleChecksumsInputMiddleware, + FlexibleChecksumsInputMiddlewareConfig, + flexibleChecksumsInputMiddlewareOptions, +} from "./flexibleChecksumsInputMiddleware"; import { flexibleChecksumsMiddleware, flexibleChecksumsMiddlewareOptions, @@ -14,6 +19,7 @@ import { export interface FlexibleChecksumsMiddlewareConfig extends FlexibleChecksumsRequestMiddlewareConfig, + FlexibleChecksumsInputMiddlewareConfig, FlexibleChecksumsResponseMiddlewareConfig {} export const getFlexibleChecksumsPlugin = ( @@ -22,6 +28,10 @@ export const getFlexibleChecksumsPlugin = ( ): Pluggable => ({ applyToStack: (clientStack) => { clientStack.add(flexibleChecksumsMiddleware(config, middlewareConfig), flexibleChecksumsMiddlewareOptions); + clientStack.addRelativeTo( + flexibleChecksumsInputMiddleware(config, middlewareConfig), + flexibleChecksumsInputMiddlewareOptions + ); clientStack.addRelativeTo( flexibleChecksumsResponseMiddleware(config, middlewareConfig), flexibleChecksumsResponseMiddlewareOptions diff --git a/private/aws-middleware-test/src/middleware-serde.spec.ts b/private/aws-middleware-test/src/middleware-serde.spec.ts index 0d2f4ea3cfe21..9049f5bb0591b 100644 --- a/private/aws-middleware-test/src/middleware-serde.spec.ts +++ b/private/aws-middleware-test/src/middleware-serde.spec.ts @@ -27,7 +27,7 @@ describe("middleware-serde", () => { "x-amz-acl": "private", "content-length": "509", Expect: "100-continue", - "content-md5": "qpwmS0vhCISEXes008aoXA==", + "x-amz-checksum-crc32": "XnKFaw==", host: "s3.us-west-2.amazonaws.com", "x-amz-content-sha256": "c0a89780e1aac5dfa17604e9e25616e7babba0b655db189be49b4c352543bb22", },