Skip to content

Commit

Permalink
feat(core/gcs): Add concurrent write for gcs back (#4820)
Browse files Browse the repository at this point in the history
Signed-off-by: Xuanwo <[email protected]>
  • Loading branch information
Xuanwo authored Jun 28, 2024
1 parent a8885ba commit b4dda8f
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 70 deletions.
32 changes: 19 additions & 13 deletions core/src/services/gcs/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,18 @@ impl Access for GcsBackend {
write_can_empty: true,
write_can_multi: true,
write_with_content_type: true,
// The buffer size should be a multiple of 256 KiB (256 x 1024 bytes), unless it's the last chunk that completes the upload.
// Larger chunk sizes typically make uploads faster, but note that there's a tradeoff between speed and memory usage.
// It's recommended that you use at least 8 MiB for the chunk size.
// The min multipart size of Gcs is 5 MiB.
//
// Reference: [Perform resumable uploads](https://cloud.google.com/storage/docs/performing-resumable-uploads)
write_multi_align_size: Some(8 * 1024 * 1024),
// ref: <https://cloud.google.com/storage/docs/xml-api/put-object-multipart>
write_multi_min_size: Some(5 * 1024 * 1024),
// The max multipart size of Gcs is 5 GiB.
//
// ref: <https://cloud.google.com/storage/docs/xml-api/put-object-multipart>
write_multi_max_size: if cfg!(target_pointer_width = "64") {
Some(5 * 1024 * 1024 * 1024)
} else {
Some(usize::MAX)
},

delete: true,
copy: true,
Expand All @@ -388,7 +394,7 @@ impl Access for GcsBackend {
let resp = self.core.gcs_get_object_metadata(path, &args).await?;

if !resp.status().is_success() {
return Err(parse_error(resp).await?);
return Err(parse_error(resp));
}

let slc = resp.into_body();
Expand Down Expand Up @@ -427,16 +433,16 @@ impl Access for GcsBackend {
_ => {
let (part, mut body) = resp.into_parts();
let buf = body.to_buffer().await?;
Err(parse_error(Response::from_parts(part, buf)).await?)
Err(parse_error(Response::from_parts(part, buf)))
}
}
}

async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
let concurrent = args.concurrent();
let executor = args.executor().cloned();
let w = GcsWriter::new(self.core.clone(), path, args);
// Gcs can't support concurrent write, always use concurrent=1 for now.
let w = oio::RangeWriter::new(w, executor, 1);
let w = oio::MultipartWriter::new(w, executor, concurrent);

Ok((RpWrite::default(), w))
}
Expand All @@ -448,7 +454,7 @@ impl Access for GcsBackend {
if resp.status().is_success() || resp.status() == StatusCode::NOT_FOUND {
Ok(RpDelete::default())
} else {
Err(parse_error(resp).await?)
Err(parse_error(resp))
}
}

Expand All @@ -470,7 +476,7 @@ impl Access for GcsBackend {
if resp.status().is_success() {
Ok(RpCopy::default())
} else {
Err(parse_error(resp).await?)
Err(parse_error(resp))
}
}

Expand Down Expand Up @@ -544,15 +550,15 @@ impl Access for GcsBackend {
if resp.status().is_success() || resp.status() == StatusCode::NOT_FOUND {
batched_result.push((path, Ok(RpDelete::default().into())));
} else {
batched_result.push((path, Err(parse_error(resp).await?)));
batched_result.push((path, Err(parse_error(resp))));
}
}

Ok(RpBatch::new(batched_result))
} else {
// If the overall request isn't formatted correctly and Cloud Storage is unable to parse it into sub-requests, you receive a 400 error.
// Otherwise, Cloud Storage returns a 200 status code, even if some or all of the sub-requests fail.
Err(parse_error(resp).await?)
Err(parse_error(resp))
}
}
}
Expand Down
124 changes: 123 additions & 1 deletion core/src/services/gcs/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::time::Duration;

use backon::ExponentialBuilder;
use backon::Retryable;
use bytes::Bytes;
use http::header::CONTENT_LENGTH;
use http::header::CONTENT_RANGE;
use http::header::CONTENT_TYPE;
Expand All @@ -37,7 +38,7 @@ use reqsign::GoogleCredentialLoader;
use reqsign::GoogleSigner;
use reqsign::GoogleToken;
use reqsign::GoogleTokenLoader;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::json;

use super::uri::percent_encode_path;
Expand Down Expand Up @@ -488,6 +489,20 @@ impl GcsCore {
self.send(req).await
}

pub async fn gcs_initiate_multipart_upload(&self, path: &str) -> Result<Response<Buffer>> {
let p = build_abs_path(&self.root, path);

let url = format!("{}/{}/{}?uploads", self.endpoint, self.bucket, p);

let mut req = Request::post(&url)
.header(CONTENT_LENGTH, 0)
.body(Buffer::new())
.map_err(new_request_build_error)?;

self.sign(&mut req).await?;
self.send(req).await
}

pub async fn gcs_initiate_resumable_upload(&self, path: &str) -> Result<Response<Buffer>> {
let p = build_abs_path(&self.root, path);
let url = format!(
Expand All @@ -504,6 +519,90 @@ impl GcsCore {
self.send(req).await
}

pub async fn gcs_upload_part(
&self,
path: &str,
upload_id: &str,
part_number: usize,
size: u64,
body: Buffer,
) -> Result<Response<Buffer>> {
let p = build_abs_path(&self.root, path);

let url = format!(
"{}/{}/{}?partNumber={}&uploadId={}",
self.endpoint,
self.bucket,
percent_encode_path(&p),
part_number,
percent_encode_path(upload_id)
);

let mut req = Request::put(&url);

req = req.header(CONTENT_LENGTH, size);

let mut req = req.body(body).map_err(new_request_build_error)?;

self.sign(&mut req).await?;
self.send(req).await
}

pub async fn gcs_complete_multipart_upload(
&self,
path: &str,
upload_id: &str,
parts: Vec<CompleteMultipartUploadRequestPart>,
) -> Result<Response<Buffer>> {
let p = build_abs_path(&self.root, path);

let url = format!(
"{}/{}/{}?uploadId={}",
self.endpoint,
self.bucket,
percent_encode_path(&p),
percent_encode_path(upload_id)
);

let req = Request::post(&url);

let content = quick_xml::se::to_string(&CompleteMultipartUploadRequest { part: parts })
.map_err(new_xml_deserialize_error)?;
// Make sure content length has been set to avoid post with chunked encoding.
let req = req.header(CONTENT_LENGTH, content.len());
// Set content-type to `application/xml` to avoid mixed with form post.
let req = req.header(CONTENT_TYPE, "application/xml");

let mut req = req
.body(Buffer::from(Bytes::from(content)))
.map_err(new_request_build_error)?;

self.sign(&mut req).await?;
self.send(req).await
}

pub async fn gcs_abort_multipart_upload(
&self,
path: &str,
upload_id: &str,
) -> Result<Response<Buffer>> {
let p = build_abs_path(&self.root, path);

let url = format!(
"{}/{}/{}?uploadId={}",
self.endpoint,
self.bucket,
percent_encode_path(&p),
percent_encode_path(upload_id)
);

let mut req = Request::delete(&url)
.body(Buffer::new())
.map_err(new_request_build_error)?;
self.sign(&mut req).await?;
self.send(req).await
}

pub fn gcs_upload_in_resumable_upload(
&self,
location: &str,
Expand Down Expand Up @@ -592,6 +691,29 @@ pub struct ListResponseItem {
pub content_type: String,
}

/// Result of CreateMultipartUpload
#[derive(Default, Debug, Deserialize)]
#[serde(default, rename_all = "PascalCase")]
pub struct InitiateMultipartUploadResult {
pub upload_id: String,
}

/// Request of CompleteMultipartUploadRequest
#[derive(Default, Debug, Serialize)]
#[serde(default, rename = "CompleteMultipartUpload", rename_all = "PascalCase")]
pub struct CompleteMultipartUploadRequest {
pub part: Vec<CompleteMultipartUploadRequestPart>,
}

#[derive(Clone, Default, Debug, Serialize)]
#[serde(default, rename_all = "PascalCase")]
pub struct CompleteMultipartUploadRequestPart {
#[serde(rename = "PartNumber")]
pub part_number: usize,
#[serde(rename = "ETag")]
pub etag: String,
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
9 changes: 4 additions & 5 deletions core/src/services/gcs/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use bytes::Buf;
use http::Response;
use http::StatusCode;
use serde::Deserialize;
Expand Down Expand Up @@ -49,9 +48,9 @@ struct GcsErrorDetail {
}

/// Parse error response into Error.
pub async fn parse_error(resp: Response<Buffer>) -> Result<Error> {
let (parts, mut body) = resp.into_parts();
let bs = body.copy_to_bytes(body.remaining());
pub fn parse_error(resp: Response<Buffer>) -> Error {
let (parts, body) = resp.into_parts();
let bs = body.to_bytes();

let (kind, retryable) = match parts.status {
StatusCode::NOT_FOUND => (ErrorKind::NotFound, false),
Expand Down Expand Up @@ -79,7 +78,7 @@ pub async fn parse_error(resp: Response<Buffer>) -> Result<Error> {
err = err.set_temporary();
}

Ok(err)
err
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion core/src/services/gcs/lister.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl oio::PageList for GcsLister {
.await?;

if !resp.status().is_success() {
return Err(parse_error(resp).await?);
return Err(parse_error(resp));
}
let bytes = resp.into_body();

Expand Down
Loading

0 comments on commit b4dda8f

Please sign in to comment.