Skip to content

Commit

Permalink
fix: download from s3 custom domain
Browse files Browse the repository at this point in the history
  • Loading branch information
lihsai0 committed Nov 21, 2024
1 parent 0d1f63f commit 8b18b99
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 40 deletions.
16 changes: 14 additions & 2 deletions adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ export abstract class Adapter {

abstract getObjectInfo(region: string, object: StorageObject): Promise<ObjectInfo>;
abstract getObjectHeader(region: string, object: StorageObject, domain?: Domain): Promise<ObjectHeader>;
abstract getObject(region: string, object: StorageObject, domain?: Domain): Promise<ObjectGetResult>;
abstract getObject(
region: string,
object: StorageObject,
domain?: Domain,
style?: UrlStyle,
): Promise<ObjectGetResult>;
abstract getObjectURL(
region: string,
object: StorageObject,
domain?: Domain,
deadline?: Date,
style?: 'path' | 'virtualHost' | 'bucketEndpoint'
style?: UrlStyle,
): Promise<URL>;
abstract getObjectStream(s3RegionId: string, object: StorageObject, domain?: Domain, option?: GetObjectStreamOption): Promise<Readable>;
abstract putObject(
Expand Down Expand Up @@ -261,8 +266,15 @@ export interface PutObjectOption {
accelerateUploading?: boolean;
}

export enum UrlStyle {
Path = 'path',
VirtualHost = 'virtualHost',
BucketEndpoint = 'bucketEndpoint',
}

export interface GetObjectStreamOption {
rangeStart?: number;
rangeEnd?: number;
abortSignal?: AbortSignal;
urlStyle?: UrlStyle,
}
6 changes: 5 additions & 1 deletion downloader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { ThrottleGroup, ThrottleOptions } from 'stream-throttle';

import { Ref } from './types';
import { Progress, ProgressStream, SpeedMonitor } from './progress-stream';
import { Adapter, Domain, ObjectHeader, StorageObject } from './adapter';
import { Adapter, Domain, ObjectHeader, StorageObject, UrlStyle } from './adapter';
import { HttpClient } from './http-client';

export class Downloader {
Expand Down Expand Up @@ -92,6 +92,7 @@ export class Downloader {
domain,
{
rangeStart: recoveredFrom,
urlStyle: getFileOption?.urlStyle,
},
);
pipeList.unshift(reader);
Expand Down Expand Up @@ -154,6 +155,7 @@ export class Downloader {
domain: Domain | undefined,
option?: {
rangeStart?: number,
urlStyle?: UrlStyle,
},
): Promise<Readable> {
// default values
Expand All @@ -167,6 +169,7 @@ export class Downloader {
{
rangeStart: start,
abortSignal: this.abortController?.signal,
urlStyle: option?.urlStyle,
}
);
}
Expand Down Expand Up @@ -265,4 +268,5 @@ export interface GetFileOption {
downloadThrottleGroup?: ThrottleGroup;
downloadThrottleOption?: ThrottleOptions;
getCallback?: GetCallback;
urlStyle?: UrlStyle;
}
13 changes: 10 additions & 3 deletions kodo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
StorageObject,
TransferObject,
UploadPartOutput,
UrlStyle,
} from './adapter';
import { KodoHttpClient, RequestOptions, ServiceName } from './kodo-http-client';
import { RequestStats, URLRequestOptions } from './http-client';
Expand Down Expand Up @@ -676,7 +677,13 @@ export class Kodo implements Adapter {
headers.Range = `bytes=${option?.rangeStart ?? ''}-${option?.rangeEnd ?? ''}`;
}

const url = await this.getObjectURL(s3RegionId, object, domain);
const url = await this.getObjectURL(
s3RegionId,
object,
domain,
undefined,
option?.urlStyle,
);
const response = await this.callUrl(
[
url.toString(),
Expand Down Expand Up @@ -704,7 +711,7 @@ export class Kodo implements Adapter {
object: StorageObject,
domain?: Domain,
deadline?: Date,
style: 'path' | 'virtualHost' | 'bucketEndpoint' = 'bucketEndpoint',
style: UrlStyle = UrlStyle.BucketEndpoint,
): Promise<URL> {
if (!domain) {
const domains = await this._listDomains(s3RegionId, object.bucket);
Expand All @@ -714,7 +721,7 @@ export class Kodo implements Adapter {
domain = domains[0];
}

if (style !== 'bucketEndpoint') {
if (style !== UrlStyle.BucketEndpoint) {
throw new Error('Only support "bucketEndpoint" style for now');
}

Expand Down
68 changes: 34 additions & 34 deletions s3.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import {
StorageObject,
TransferObject,
UploadPartOutput,
UrlStyle,
} from './adapter';
import {
ErrorRequestUplogEntry,
Expand Down Expand Up @@ -65,8 +66,15 @@ export class S3 extends Kodo {
protected clientsLock = new AsyncLock();
protected listKodoBucketsPromise?: Promise<Bucket[]>;

private async getClient(s3RegionId?: string, s3ForcePathStyle = true): Promise<AWS.S3> {
const cacheKey = [s3RegionId ?? '', s3ForcePathStyle ? 's3ForcePathStyle' : ''].join(':');
/**
* if domain exists, the urlStyle will be forced to 'bucketEndpoint'
*/
private async getClient(
s3RegionId?: string,
urlStyle: UrlStyle = UrlStyle.Path,
domain?: Domain,
): Promise<AWS.S3> {
const cacheKey = [s3RegionId ?? '', urlStyle, domain?.name ?? ''].join(':');
if (this.clients[cacheKey]) {
return this.clients[cacheKey];
}
Expand All @@ -77,12 +85,26 @@ export class S3 extends Kodo {
userAgent += `/${this.adapterOption.appendedUserAgent}`;
}
const s3IdEndpoint = await this.regionService.getS3Endpoint(s3RegionId, this.getRegionRequestOptions());
const urlStyleOptions: {
endpoint: string,
s3ForcePathStyle?: boolean,
s3BucketEndpoint?: boolean,
} = {
endpoint: !domain
? s3IdEndpoint.s3Endpoint
: `${domain.protocol}://${domain.name}`,
};
if (urlStyle === UrlStyle.BucketEndpoint) {
urlStyleOptions.s3BucketEndpoint = true;
} else {
urlStyleOptions.s3ForcePathStyle = urlStyle === UrlStyle.Path;
}

return new AWS.S3({
apiVersion: '2006-03-01',
customUserAgent: userAgent,
computeChecksums: true,
region: s3IdEndpoint.s3Id,
endpoint: s3IdEndpoint.s3Endpoint,
maxRetries: 10,
signatureVersion: 'v4',
useDualstack: true,
Expand All @@ -94,11 +116,11 @@ export class S3 extends Kodo {
httpOptions: {
connectTimeout: 30000,
timeout: 300000,
agent: s3IdEndpoint.s3Endpoint.startsWith('https://')
agent: urlStyleOptions.endpoint.startsWith('https://')
? HttpClient.httpsKeepaliveAgent
: HttpClient.httpKeepaliveAgent,
},
s3ForcePathStyle
...urlStyleOptions
});
});
this.clients[cacheKey] = client;
Expand Down Expand Up @@ -540,10 +562,11 @@ export class S3 extends Kodo {
async getObject(
s3RegionId: string,
object: StorageObject,
_domain?: Domain
domain?: Domain,
style?: UrlStyle,
): Promise<ObjectGetResult> {
const [s3, bucketId] = await Promise.all([
this.getClient(s3RegionId),
this.getClient(s3RegionId, style, domain),
this.fromKodoBucketNameToS3BucketId(object.bucket),
]);
const request = s3.getObject({ Bucket: bucketId, Key: object.key });
Expand All @@ -562,11 +585,11 @@ export class S3 extends Kodo {
async getObjectStream(
s3RegionId: string,
object: StorageObject,
_domain?: Domain,
domain?: Domain,
option?: GetObjectStreamOption,
): Promise<Readable> {
const [s3, bucketId] = await Promise.all([
this.getClient(s3RegionId),
this.getClient(s3RegionId, option?.urlStyle, domain),
this.fromKodoBucketNameToS3BucketId(object.bucket),
]);
let range: string | undefined;
Expand Down Expand Up @@ -595,33 +618,10 @@ export class S3 extends Kodo {
object: StorageObject,
domain?: Domain,
deadline?: Date,
style: 'path' | 'virtualHost' | 'bucketEndpoint' = 'path',
style: UrlStyle = UrlStyle.Path,
): Promise<URL> {
let s3Promise: Promise<AWS.S3>;
// if domain is not undefined, use the domain, else use the default s3 endpoint
if (domain) {
if (style !== 'bucketEndpoint') {
throw new Error('Custom S3 endpoint only support "bucketEndpoint" style');
}
s3Promise = Promise.resolve(new AWS.S3({
apiVersion: '2006-03-01',
region: s3RegionId,
endpoint: `${domain.protocol}://${domain.name}`,
credentials: {
accessKeyId: this.adapterOption.accessKey,
secretAccessKey: this.adapterOption.secretKey,
},
signatureVersion: 'v4',
s3BucketEndpoint: true, // use bucketEndpoint style
}));
} else {
if (style === 'bucketEndpoint') {
throw new Error('Default S3 endpoint not support "bucketEndpoint" style');
}
s3Promise = this.getClient(s3RegionId, style === 'path');
}
const [s3, bucketId] = await Promise.all([
s3Promise,
this.getClient(s3RegionId, style, domain),
this.fromKodoBucketNameToS3BucketId(object.bucket),
]);
const expires = deadline
Expand Down

0 comments on commit 8b18b99

Please sign in to comment.