Skip to content

Commit

Permalink
Fix gateway handle list objects versions (#7028)
Browse files Browse the repository at this point in the history
  • Loading branch information
nopcoder authored Dec 3, 2023
1 parent ee17222 commit 847fd07
Show file tree
Hide file tree
Showing 19 changed files with 164 additions and 108 deletions.
1 change: 1 addition & 0 deletions cmd/lakefs/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ var runCmd = &cobra.Command{
s3FallbackURL,
cfg.Logging.AuditLogLevel,
cfg.Logging.TraceRequestHeaders,
cfg.Gateways.S3.VerifyUnsupported,
)
s3gatewayHandler = apiAuthenticator(s3gatewayHandler)

Expand Down
1 change: 1 addition & 0 deletions docs/reference/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ This reference uses `.` to denote the nesting of values.
local development, if using [virtual-host addressing](https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html).
* `gateways.s3.region` `(string : "us-east-1")` - AWS region we're pretending to be in, it should match the region configuration used in AWS SDK clients
* `gateways.s3.fallback_url` `(string)` - If specified, requests with a non-existing repository will be forwarded to this URL. This can be useful for using lakeFS side-by-side with S3, with the URL pointing at an [S3Proxy](https://github.com/gaul/s3proxy) instance.
* `gateways.s3.verify_unsupported` `(bool : true)` - The S3 gateway errors on unsupported requests, but when disabled, defers to target-based handlers.
* `stats.enabled` `(bool : true)` - Whether to periodically collect anonymous usage statistics
* `stats.flush_interval` `(duration : 30s)` - Interval used to post anonymous statistics collected
* `stats.flush_size` `(int : 100)` - A size (in records) of anonymous statistics collected in which we post
Expand Down
19 changes: 9 additions & 10 deletions esti/s3_gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,34 +76,33 @@ func TestS3UploadAndDownload(t *testing.T) {
objects = make(chan Object, parallelism*2)
)

client := newMinioClient(t, sig.GetCredentials)
wg.Add(parallelism)
for i := 0; i < parallelism; i++ {
client := newMinioClient(t, sig.GetCredentials)

wg.Add(1)
go func() {
defer wg.Done()
for o := range objects {
_, err := client.PutObject(
ctx, repo, o.Path, strings.NewReader(o.Content), int64(len(o.Content)), minio.PutObjectOptions{})
_, err := client.PutObject(ctx, repo, o.Path, strings.NewReader(o.Content), int64(len(o.Content)), minio.PutObjectOptions{})
if err != nil {
t.Errorf("minio.Client.PutObject(%s): %s", o.Path, err)
continue
}

download, err := client.GetObject(
ctx, repo, o.Path, minio.GetObjectOptions{})
download, err := client.GetObject(ctx, repo, o.Path, minio.GetObjectOptions{})
if err != nil {
t.Errorf("minio.Client.GetObject(%s): %s", o.Path, err)
continue
}
contents := bytes.NewBuffer(nil)
_, err = io.Copy(contents, download)
if err != nil {
t.Errorf("download %s: %s", o.Path, err)
continue
}
if strings.Compare(contents.String(), o.Content) != 0 {
t.Errorf(
"Downloaded bytes %v from uploaded bytes %v", contents.Bytes(), o.Content)
t.Errorf("Downloaded bytes %v from uploaded bytes %v", contents.Bytes(), o.Content)
}
}
wg.Done()
}()
}

Expand Down
7 changes: 4 additions & 3 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,10 @@ type Config struct {
} `mapstructure:"graveler"`
Gateways struct {
S3 struct {
DomainNames Strings `mapstructure:"domain_name"`
Region string `mapstructure:"region"`
FallbackURL string `mapstructure:"fallback_url"`
DomainNames Strings `mapstructure:"domain_name"`
Region string `mapstructure:"region"`
FallbackURL string `mapstructure:"fallback_url"`
VerifyUnsupported bool `mapstructure:"verify_unsupported"`
} `mapstructure:"s3"`
}
Stats struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/config/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func setDefaults(cfgType string) {

viper.SetDefault("gateways.s3.domain_name", "s3.local.lakefs.io")
viper.SetDefault("gateways.s3.region", "us-east-1")
viper.SetDefault("gateways.s3.verify_unsupported", true)

viper.SetDefault("blockstore.gs.s3_endpoint", "https://storage.googleapis.com")
viper.SetDefault("blockstore.gs.pre_signed_expiry", 15*time.Minute)
Expand Down
36 changes: 19 additions & 17 deletions pkg/gateway/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,18 @@ type handler struct {
}

type ServerContext struct {
region string
bareDomains []string
catalog *catalog.Catalog
multipartTracker multipart.Tracker
blockStore block.Adapter
authService auth.GatewayService
stats stats.Collector
pathProvider upload.PathProvider
region string
bareDomains []string
catalog *catalog.Catalog
multipartTracker multipart.Tracker
blockStore block.Adapter
authService auth.GatewayService
stats stats.Collector
pathProvider upload.PathProvider
verifyUnsupported bool
}

func NewHandler(region string, catalog *catalog.Catalog, multipartTracker multipart.Tracker, blockStore block.Adapter, authService auth.GatewayService, bareDomains []string, stats stats.Collector, pathProvider upload.PathProvider, fallbackURL *url.URL, auditLogLevel string, traceRequestHeaders bool) http.Handler {
func NewHandler(region string, catalog *catalog.Catalog, multipartTracker multipart.Tracker, blockStore block.Adapter, authService auth.GatewayService, bareDomains []string, stats stats.Collector, pathProvider upload.PathProvider, fallbackURL *url.URL, auditLogLevel string, traceRequestHeaders bool, verifyUnsupported bool) http.Handler {
var fallbackHandler http.Handler
if fallbackURL != nil {
fallbackProxy := gohttputil.NewSingleHostReverseProxy(fallbackURL)
Expand All @@ -75,14 +76,15 @@ func NewHandler(region string, catalog *catalog.Catalog, multipartTracker multip
})
}
sc := &ServerContext{
catalog: catalog,
multipartTracker: multipartTracker,
region: region,
bareDomains: bareDomains,
blockStore: blockStore,
authService: authService,
stats: stats,
pathProvider: pathProvider,
catalog: catalog,
multipartTracker: multipartTracker,
region: region,
bareDomains: bareDomains,
blockStore: blockStore,
authService: authService,
stats: stats,
pathProvider: pathProvider,
verifyUnsupported: verifyUnsupported,
}

// setup routes
Expand Down
57 changes: 29 additions & 28 deletions pkg/gateway/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,13 @@ func EnrichWithOperation(sc *ServerContext, next http.Handler) http.Handler {
ctx := req.Context()
client := httputil.GetRequestLakeFSClient(req)
o := &operations.Operation{
Region: sc.region,
FQDN: getBareDomain(stripPort(req.Host), sc.bareDomains),
Catalog: sc.catalog,
MultipartTracker: sc.multipartTracker,
BlockStore: sc.blockStore,
Auth: sc.authService,
Region: sc.region,
FQDN: getBareDomain(stripPort(req.Host), sc.bareDomains),
Catalog: sc.catalog,
MultipartTracker: sc.multipartTracker,
BlockStore: sc.blockStore,
Auth: sc.authService,
VerifyUnsupported: sc.verifyUnsupported,
Incr: func(action, userID, repository, ref string) {
logging.FromContext(ctx).
WithFields(logging.Fields{
Expand Down Expand Up @@ -199,28 +200,21 @@ func OperationLookupHandler(next http.Handler) http.Handler {
ctx := req.Context()
o := ctx.Value(ContextKeyOperation).(*operations.Operation)
repoID := ctx.Value(ContextKeyRepositoryID).(string)
o.OperationID = operations.OperationIDOperationNotFound
if repoID == "" {
if req.Method == http.MethodGet {
o.OperationID = operations.OperationIDListBuckets
} else {
_ = o.EncodeError(w, req, nil, gatewayerrors.ERRLakeFSNotSupported.ToAPIErr())
return
}
} else {
ref := ctx.Value(ContextKeyRef).(string)
pth := ctx.Value(ContextKeyPath).(string)
switch {
case ref != "" && pth != "":
req = req.WithContext(ctx)
o.OperationID = pathBasedOperationID(req.Method)
case ref == "" && pth == "":
o.OperationID = repositoryBasedOperationID(req.Method)
default:
w.WriteHeader(http.StatusNotFound)
return
}
ref := ctx.Value(ContextKeyRef).(string)
pth := ctx.Value(ContextKeyPath).(string)

// based on the operation level, we can determine the operation id
switch {
case repoID == "":
o.OperationID = rootBasedOperationID(req.Method)
case ref != "" && pth != "":
o.OperationID = pathBasedOperationID(req.Method)
case ref == "" && pth == "":
o.OperationID = repositoryBasedOperationID(req.Method)
default:
o.OperationID = operations.OperationIDOperationNotFound
}

req = req.WithContext(logging.AddFields(ctx, logging.Fields{"operation_id": o.OperationID}))
next.ServeHTTP(w, req)
})
Expand Down Expand Up @@ -277,7 +271,7 @@ func ParseRequestParts(host string, urlPath string, bareDomains []string) Reques
}

if !parts.MatchedHost {
// assume path based for domains we don't explicitly know
// assume path-based for domains we don't explicitly know
p = strings.SplitN(urlPath, path.Separator, 3) //nolint: gomnd
parts.Repository = p[0]
if len(p) >= 1 {
Expand All @@ -295,6 +289,13 @@ func ParseRequestParts(host string, urlPath string, bareDomains []string) Reques
return parts
}

func rootBasedOperationID(method string) operations.OperationID {
if method == http.MethodGet {
return operations.OperationIDListBuckets
}
return operations.OperationIDOperationNotFound
}

func pathBasedOperationID(method string) operations.OperationID {
switch method {
case http.MethodDelete:
Expand Down
34 changes: 24 additions & 10 deletions pkg/gateway/operations/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"slices"

"github.com/treeverse/lakefs/pkg/auth"
"github.com/treeverse/lakefs/pkg/auth/keys"
Expand Down Expand Up @@ -44,16 +45,17 @@ const (
type ActionIncr func(action, userID, repository, ref string)

type Operation struct {
OperationID OperationID
Region string
FQDN string
Catalog *catalog.Catalog
MultipartTracker multipart.Tracker
BlockStore block.Adapter
Auth auth.GatewayService
Incr ActionIncr
MatchedHost bool
PathProvider upload.PathProvider
OperationID OperationID
Region string
FQDN string
Catalog *catalog.Catalog
MultipartTracker multipart.Tracker
BlockStore block.Adapter
Auth auth.GatewayService
Incr ActionIncr
MatchedHost bool
PathProvider upload.PathProvider
VerifyUnsupported bool
}

func StorageClassFromHeader(header http.Header) *string {
Expand Down Expand Up @@ -84,6 +86,18 @@ func (o *Operation) EncodeXMLBytes(w http.ResponseWriter, req *http.Request, t [
}
}

func (o *Operation) HandleUnsupported(w http.ResponseWriter, req *http.Request, keys ...string) bool {
if !o.VerifyUnsupported {
return false
}
query := req.URL.Query()
if slices.ContainsFunc(keys, query.Has) {
_ = o.EncodeError(w, req, nil, gwerrors.ERRLakeFSNotSupported.ToAPIErr())
return true
}
return false
}

func EncodeResponse(w http.ResponseWriter, entity interface{}, statusCode int) error {
// We don't indent the XML document because of Java.
// See: https://github.com/spulec/moto/issues/1870
Expand Down
7 changes: 4 additions & 3 deletions pkg/gateway/operations/deleteobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ func (controller *DeleteObject) HandleAbortMultipartUpload(w http.ResponseWriter
}

func (controller *DeleteObject) Handle(w http.ResponseWriter, req *http.Request, o *PathOperation) {
if o.HandleUnsupported(w, req, "tagging", "acl", "torrent") {
return
}
query := req.URL.Query()

_, hasUploadID := query[QueryParamUploadID]
if hasUploadID {
if query.Has(QueryParamUploadID) {
controller.HandleAbortMultipartUpload(w, req, o)
return
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/gateway/operations/deleteobjects.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ func (controller *DeleteObjects) RequiredPermissions(_ *http.Request, _ string)
}

func (controller *DeleteObjects) Handle(w http.ResponseWriter, req *http.Request, o *RepoOperation) {
// verify we only handle delete request
query := req.URL.Query()
if !query.Has("delete") {
_ = o.EncodeError(w, req, nil, gerrors.ERRLakeFSNotSupported.ToAPIErr())
return
}

o.Incr("delete_objects", o.Principal, o.Repository.Name, "")
decodedXML := &serde.Delete{}
err := DecodeXMLBody(req.Body, decodedXML)
Expand Down
3 changes: 3 additions & 0 deletions pkg/gateway/operations/getobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ func (controller *GetObject) RequiredPermissions(_ *http.Request, repoID, _, pat
}

func (controller *GetObject) Handle(w http.ResponseWriter, req *http.Request, o *PathOperation) {
if o.HandleUnsupported(w, req, "torrent", "acl", "retention", "legal-hold", "lambdaArn") {
return
}
o.Incr("get_object", o.Principal, o.Repository.Name, o.Reference)
ctx := req.Context()
query := req.URL.Query()
Expand Down
5 changes: 4 additions & 1 deletion pkg/gateway/operations/headbucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ func (controller *HeadBucket) RequiredPermissions(_ *http.Request, repoID string
}, nil
}

func (controller *HeadBucket) Handle(w http.ResponseWriter, _ *http.Request, o *RepoOperation) {
func (controller *HeadBucket) Handle(w http.ResponseWriter, req *http.Request, o *RepoOperation) {
if o.HandleUnsupported(w, req, "acl") {
return
}
o.Incr("get_repo", o.Principal, o.Repository.Name, "")
w.WriteHeader(http.StatusOK)
}
4 changes: 4 additions & 0 deletions pkg/gateway/operations/listbuckets.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ func (controller *ListBuckets) RequiredPermissions(_ *http.Request) (permissions

// Handle - list buckets (repositories)
func (controller *ListBuckets) Handle(w http.ResponseWriter, req *http.Request, o *AuthorizedOperation) {
if o.HandleUnsupported(w, req, "events") {
return
}

o.Incr("list_repos", o.Principal, "", "")

buckets := make([]serde.Bucket, 0)
Expand Down
33 changes: 27 additions & 6 deletions pkg/gateway/operations/listobjects.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (

const (
ListObjectMaxKeys = 1000

// defaultBucketLocation used to identify if we need to specify the location constraint
defaultBucketLocation = "us-east-1"
)

type ListObjects struct{}
Expand Down Expand Up @@ -349,16 +352,34 @@ func (controller *ListObjects) ListV1(w http.ResponseWriter, req *http.Request,
}

func (controller *ListObjects) Handle(w http.ResponseWriter, req *http.Request, o *RepoOperation) {
o.Incr("list_objects", o.Principal, o.Repository.Name, "")
// parse request parameters
// GET /example?list-type=2&prefix=main%2F&delimiter=%2F&encoding-type=url HTTP/1.1

// handle GET /?versioning
if o.HandleUnsupported(w, req, "inventory", "metrics", "publicAccessBlock", "ownershipControls",
"intelligent-tiering", "analytics", "policy", "lifecycle", "encryption", "object-lock", "replication",
"notification", "events", "acl", "cors", "website", "accelerate",
"requestPayment", "logging", "tagging", "uploads", "versions", "policyStatus") {
return
}
query := req.URL.Query()
if _, found := query["versioning"]; found {

// getbucketlocation support
if query.Has("location") {
o.Incr("get_bucket_location", o.Principal, o.Repository.Name, "")
response := serde.LocationResponse{}
if o.Region != "" && o.Region != defaultBucketLocation {
response.Location = o.Region
}
o.EncodeResponse(w, req, response, http.StatusOK)
return
}

// getbucketversioing support
if query.Has("versioning") {
o.EncodeXMLBytes(w, req, []byte(serde.VersioningResponse), http.StatusOK)
return
}
o.Incr("list_objects", o.Principal, o.Repository.Name, "")

// parse request parameters
// GET /example?list-type=2&prefix=main%2F&delimiter=%2F&encoding-type=url HTTP/1.1

// handle ListObjects versions
listType := query.Get("list-type")
Expand Down
Loading

0 comments on commit 847fd07

Please sign in to comment.