diff --git a/esti/s3_gateway_test.go b/esti/s3_gateway_test.go index f20db0f28e3..d234358c81a 100644 --- a/esti/s3_gateway_test.go +++ b/esti/s3_gateway_test.go @@ -186,18 +186,6 @@ func TestS3UploadAndDownload(t *testing.T) { }) } } -func uploadMultipartCompleteIfNoneMatch(ctx context.Context, svc *s3.Client, resp *s3.CreateMultipartUploadOutput, completedParts []types.CompletedPart) (*s3.CompleteMultipartUploadOutput, error) { - completeInput := &s3.CompleteMultipartUploadInput{ - Bucket: resp.Bucket, - Key: resp.Key, - UploadId: resp.UploadId, - MultipartUpload: &types.CompletedMultipartUpload{ - Parts: completedParts, - }, - } - return svc.CompleteMultipartUpload(ctx, completeInput) -} - func TestMultipartUploadIfNoneMatch(t *testing.T) { // timeResolution is a duration greater than the timestamp resolution of the backing // store. Multipart object on S3 is the time of create-MPU, waiting before completion @@ -272,9 +260,7 @@ func setHTTPHeaders(ifNoneMatch string) func(*middleware.Stack) error { if req, ok := in.Request.(*smithyhttp.Request); ok { // Add the If-None-Match header req.Header.Set("If-None-Match", ifNoneMatch) - fmt.Printf("Set If-None-Match header: %s\n", ifNoneMatch) // Debug logging } - // Continue with the next middleware handler return next.HandleBuild(ctx, in) }), middleware.Before) } diff --git a/pkg/gateway/operations/postobject.go b/pkg/gateway/operations/postobject.go index cc7eb4eeb38..fe1d5aee539 100644 --- a/pkg/gateway/operations/postobject.go +++ b/pkg/gateway/operations/postobject.go @@ -10,6 +10,7 @@ import ( "time" "github.com/treeverse/lakefs/pkg/block" + "github.com/treeverse/lakefs/pkg/catalog" gatewayErrors "github.com/treeverse/lakefs/pkg/gateway/errors" "github.com/treeverse/lakefs/pkg/gateway/multipart" "github.com/treeverse/lakefs/pkg/gateway/path" @@ -94,15 +95,21 @@ func (controller *PostObject) HandleCompleteMultipartUpload(w http.ResponseWrite _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError)) return } - allowOverWrite, err := o.checkIfAbsent(req) - if errors.Is(err, gatewayErrors.ErrPreconditionFailed) { - _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) - return - } - if errors.Is(err, gatewayErrors.ErrNotImplemented) { + // before completing multipart upload, check whether if-none-match header is added, + // in order to not overwrite object + allowOverwrite, err := o.checkIfAbsent(req) + if err != nil { _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrNotImplemented)) return } + if !allowOverwrite { + _, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, o.Reference, o.Path, catalog.GetEntryParams{}) + if err == nil { + // In case object exists in catalog, no error returns + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) + return + } + } objName := multiPart.PhysicalAddress req = req.WithContext(logging.AddFields(req.Context(), logging.Fields{logging.PhysicalAddressFieldKey: objName})) xmlMultipartComplete, err := io.ReadAll(req.Body) @@ -133,7 +140,7 @@ func (controller *PostObject) HandleCompleteMultipartUpload(w http.ResponseWrite return } checksum := strings.Split(resp.ETag, "-")[0] - err = o.finishUpload(req, resp.MTime, checksum, objName, resp.ContentLength, true, multiPart.Metadata, multiPart.ContentType, allowOverWrite) + err = o.finishUpload(req, resp.MTime, checksum, objName, resp.ContentLength, true, multiPart.Metadata, multiPart.ContentType, allowOverwrite) if errors.Is(err, graveler.ErrPreconditionFailed) { _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) return diff --git a/pkg/gateway/operations/putobject.go b/pkg/gateway/operations/putobject.go index c2c76f9fceb..98c49db8b73 100644 --- a/pkg/gateway/operations/putobject.go +++ b/pkg/gateway/operations/putobject.go @@ -298,15 +298,21 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) { o.Incr("put_object", o.Principal, o.Repository.Name, o.Reference) storageClass := StorageClassFromHeader(req.Header) opts := block.PutOpts{StorageClass: storageClass} - allowOverWrite, err := o.checkIfAbsent(req) - if errors.Is(err, gatewayErrors.ErrPreconditionFailed) { - _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) - return - } - if errors.Is(err, gatewayErrors.ErrNotImplemented) { + // before uploading object, check whether if-none-match header is added, + // in order to not overwrite object + allowOverwrite, err := o.checkIfAbsent(req) + if err != nil { _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrNotImplemented)) return } + if !allowOverwrite { + _, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, o.Reference, o.Path, catalog.GetEntryParams{}) + if err == nil { + // In case object exists in catalog, no error returns + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) + return + } + } address := o.PathProvider.NewPath() blob, err := upload.WriteBlob(req.Context(), o.BlockStore, o.Repository.StorageNamespace, address, req.Body, req.ContentLength, opts) if err != nil { @@ -318,7 +324,7 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) { // write metadata metadata := amzMetaAsMetadata(req) contentType := req.Header.Get("Content-Type") - err = o.finishUpload(req, &blob.CreationDate, blob.Checksum, blob.PhysicalAddress, blob.Size, true, metadata, contentType, allowOverWrite) + err = o.finishUpload(req, &blob.CreationDate, blob.Checksum, blob.PhysicalAddress, blob.Size, true, metadata, contentType, allowOverwrite) if errors.Is(err, graveler.ErrPreconditionFailed) { _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) return @@ -340,20 +346,12 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) { } func (o *PathOperation) checkIfAbsent(req *http.Request) (bool, error) { - Header := req.Header.Get(IfNoneMatchHeader) - if Header == "" { + headerValue := req.Header.Get(IfNoneMatchHeader) + if headerValue == "" { return true, nil } - if Header == "*" { - _, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, o.Reference, o.Path, catalog.GetEntryParams{}) - if err == nil { - return false, gatewayErrors.ErrPreconditionFailed - } - if !errors.Is(err, graveler.ErrNotFound) { - return false, gatewayErrors.ErrInternalError - } else { - return true, nil - } + if headerValue == "*" { + return false, nil } return false, gatewayErrors.ErrNotImplemented }