diff --git a/functions/from.go b/functions/from.go index 3088cef751..b0ad6a8fc4 100644 --- a/functions/from.go +++ b/functions/from.go @@ -69,6 +69,22 @@ func (s *FromOpSpec) Kind() query.OperationKind { return FromKind } +func (s *FromOpSpec) BucketsAccessed() (readBuckets, writeBuckets []platform.BucketFilter) { + bf := platform.BucketFilter{} + if s.Bucket != "" { + bf.Name = &s.Bucket + } + + if len(s.BucketID) > 0 { + bf.ID = &s.BucketID + } + + if bf.ID != nil || bf.Name != nil { + readBuckets = append(readBuckets, bf) + } + return readBuckets, writeBuckets +} + type FromProcedureSpec struct { Bucket string BucketID platform.ID diff --git a/functions/from_test.go b/functions/from_test.go index fa3d441462..89c6736a91 100644 --- a/functions/from_test.go +++ b/functions/from_test.go @@ -112,3 +112,29 @@ func TestFromOperation_Marshaling(t *testing.T) { } querytest.OperationMarshalingTestHelper(t, data, op) } + +func TestFromOpSpec_BucketsAccessed(t *testing.T) { + bucketName := "my_bucket" + bucketID, _ := platform.IDFromString("deadbeef") + tests := []querytest.NewQueryTestCase{ + { + Name: "From with bucket", + Raw: `from(bucket:"my_bucket")`, + WantReadBuckets: &[]platform.BucketFilter{{Name: &bucketName}}, + WantWriteBuckets: &[]platform.BucketFilter{}, + }, + { + Name: "From with bucketID", + Raw: `from(bucketID:"deadbeef")`, + WantReadBuckets: &[]platform.BucketFilter{{ID: bucketID}}, + WantWriteBuckets: &[]platform.BucketFilter{}, + }, + } + for _, tc := range tests { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + querytest.NewQueryTestHelper(t, tc) + }) + } +} diff --git a/operation.go b/operation.go index 7b9526640b..6770218ad8 100644 --- a/operation.go +++ b/operation.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/pkg/errors" + "github.com/influxdata/platform" ) // Operation denotes a single operation in a query. @@ -71,6 +72,12 @@ type OperationSpec interface { Kind() OperationKind } +// BucketAwareOperationSpec specifies an operation that reads or writes buckets +type BucketAwareOperationSpec interface { + OperationSpec + BucketsAccessed() (readBuckets, writeBuckets []platform.BucketFilter) +} + // OperationID is a unique ID within a query for the operation. type OperationID string @@ -93,3 +100,7 @@ func RegisterOpSpec(k OperationKind, c NewOperationSpec) { func NumberOfOperations() int { return len(kindToOp) } + +func OperationSpecNewFn(k OperationKind) NewOperationSpec { + return kindToOp[k] +} diff --git a/preauthorizer.go b/preauthorizer.go new file mode 100644 index 0000000000..f51258a515 --- /dev/null +++ b/preauthorizer.go @@ -0,0 +1,63 @@ +package query + +import ( + "github.com/pkg/errors" + "github.com/influxdata/platform" + "context" +) + +// PreAuthorizer provides a method for ensuring that the buckets accessed by a query spec +// are allowed access by the given Authorization. This is a pre-check provided as a way for +// callers to fail early for operations that are not allowed. However, it's still possible +// for authorization to be denied at runtime even if this check passes. +type PreAuthorizer interface { + PreAuthorize(ctx context.Context, spec *Spec, auth *platform.Authorization) error +} + +// NewPreAuthorizer creates a new PreAuthorizer +func NewPreAuthorizer(bucketService platform.BucketService) PreAuthorizer { + return &preAuthorizer{bucketService: bucketService} +} + +type preAuthorizer struct { + bucketService platform.BucketService +} + +// PreAuthorize finds all the buckets read and written by the given spec, and ensures that execution is allowed +// given the Authorization. Returns nil on success, and an error with an appropriate message otherwise. +func (a *preAuthorizer) PreAuthorize(ctx context.Context, spec *Spec, auth *platform.Authorization) error { + + readBuckets, writeBuckets, err := spec.BucketsAccessed() + + if err != nil { + return errors.Wrap(err, "Could not retrieve buckets for query.Spec") + } + + for _, readBucketFilter := range readBuckets { + bucket, err := a.bucketService.FindBucket(ctx, readBucketFilter) + if err != nil { + return errors.Wrapf(err, "Bucket service error") + } else if bucket == nil { + return errors.New("Bucket service returned nil bucket") + } + + reqPerm := platform.ReadBucketPermission(bucket.ID) + if ! platform.Allowed(reqPerm, auth) { + return errors.New("No read permission for bucket: \"" + bucket.Name + "\"") + } + } + + for _, writeBucketFilter := range writeBuckets { + bucket, err := a.bucketService.FindBucket(context.Background(), writeBucketFilter) + if err != nil { + return errors.Wrapf(err, "Could not find bucket %v", writeBucketFilter) + } + + reqPerm := platform.WriteBucketPermission(bucket.ID) + if ! platform.Allowed(reqPerm, auth) { + return errors.New("No write permission for bucket: \"" + bucket.Name + "\"") + } + } + + return nil +} diff --git a/preauthorizer_test.go b/preauthorizer_test.go new file mode 100644 index 0000000000..c60880ae2a --- /dev/null +++ b/preauthorizer_test.go @@ -0,0 +1,71 @@ +package query + +import ( + "testing" + "time" + "context" + "github.com/influxdata/platform" + "github.com/influxdata/platform/mock" + "github.com/google/go-cmp/cmp" + "github.com/influxdata/platform/kit/errors" +) + +func newBucketServiceWithOneBucket(bucket platform.Bucket) platform.BucketService { + bs := mock.NewBucketService() + bs.FindBucketFn = func(ctx context.Context, bucketFilter platform.BucketFilter) (*platform.Bucket, error) { + if *bucketFilter.Name == bucket.Name { + return &bucket, nil + } + + return nil, errors.New("Unknown bucket") + } + + return bs +} + +func TestPreAuthorizer_PreAuthorize(t *testing.T) { + ctx := context.Background() + now := time.Now().UTC() + + q := `from(bucket:"my_bucket") |> range(start:-2h) |> yield()` + spec, err := Compile(ctx, q, now) + if err != nil { + t.Errorf("Error compiling query: %v", q) + } + + // Try to pre-authorize with bucket service with no buckets + // and no authorization + auth := &platform.Authorization{Status:platform.Active} + emptyBucketService := mock.NewBucketService() + preAuthorizer := NewPreAuthorizer(emptyBucketService) + + err = preAuthorizer.PreAuthorize(ctx, spec, auth) + if diagnostic := cmp.Diff("Bucket service returned nil bucket", err.Error()); diagnostic != "" { + t.Errorf("Authorize message mismatch: -want/+got:\n%v", diagnostic) + } + + // Try to authorize with a bucket service that knows about one bucket + // (still no authorization) + id, _ := platform.IDFromString("DEADBEEF") + bucketService := newBucketServiceWithOneBucket(platform.Bucket{ + Name: "my_bucket", + ID: *id, + }) + + preAuthorizer = NewPreAuthorizer(bucketService) + err = preAuthorizer.PreAuthorize(ctx, spec, auth) + if diagnostic := cmp.Diff(`No read permission for bucket: "my_bucket"`, err.Error()); diagnostic != "" { + t.Errorf("Authorize message mismatch: -want/+got:\n%v", diagnostic) + } + + // Try to authorize with read permission on bucket + auth = &platform.Authorization{ + Status:platform.Active, + Permissions: []platform.Permission{platform.ReadBucketPermission(*id)}, + } + + err = preAuthorizer.PreAuthorize(ctx, spec, auth) + if err != nil { + t.Errorf("Expected successful authorization, but got error: \"%v\"", err.Error()) + } +} diff --git a/querytest/compile.go b/querytest/compile.go index 974e08a28b..5295ed31fd 100644 --- a/querytest/compile.go +++ b/querytest/compile.go @@ -10,13 +10,17 @@ import ( "github.com/influxdata/platform/query" "github.com/influxdata/platform/query/functions" "github.com/influxdata/platform/query/semantic/semantictest" + "github.com/influxdata/platform" + "fmt" ) type NewQueryTestCase struct { - Name string - Raw string - Want *query.Spec - WantErr bool + Name string + Raw string + Want *query.Spec + WantErr bool + WantReadBuckets *[]platform.BucketFilter + WantWriteBuckets *[]platform.BucketFilter } var opts = append( @@ -41,8 +45,39 @@ func NewQueryTestHelper(t *testing.T, tc NewQueryTestCase) { } if tc.Want != nil { tc.Want.Now = now + if !cmp.Equal(tc.Want, got, opts...) { + t.Errorf("query.NewQuery() = -want/+got %s", cmp.Diff(tc.Want, got, opts...)) + } } - if !cmp.Equal(tc.Want, got, opts...) { - t.Errorf("query.NewQuery() = -want/+got %s", cmp.Diff(tc.Want, got, opts...)) + + var gotReadBuckets, gotWriteBuckets []platform.BucketFilter + if tc.WantReadBuckets != nil || tc.WantWriteBuckets != nil { + gotReadBuckets, gotWriteBuckets, err = got.BucketsAccessed() + } + + if tc.WantReadBuckets != nil { + if diagnostic := verifyBuckets(*tc.WantReadBuckets, gotReadBuckets); diagnostic != "" { + t.Errorf("Could not verify read buckets: %v", diagnostic) + } } + + if tc.WantWriteBuckets != nil { + if diagnostic := verifyBuckets(*tc.WantWriteBuckets, gotWriteBuckets); diagnostic != "" { + t.Errorf("Could not verify write buckets: %v", diagnostic) + } + } +} + +func verifyBuckets(wantBuckets, gotBuckets []platform.BucketFilter) string { + if len(wantBuckets) != len(gotBuckets) { + return fmt.Sprintf("Expected %v buckets but got %v", len(wantBuckets), len(gotBuckets)) + } + + for i, wantBucket := range wantBuckets { + if diagnostic := cmp.Diff(wantBucket, gotBuckets[i]); diagnostic != "" { + return fmt.Sprintf("Bucket mismatch: -want/+got:\n%v", diagnostic) + } + } + + return "" } diff --git a/semantic/types.go b/semantic/types.go index 7ee4149b4f..8104dbb2b1 100644 --- a/semantic/types.go +++ b/semantic/types.go @@ -32,6 +32,10 @@ type Type interface { // It panics if the type's Kind is not Array. ElementType() Type + // Params reports the parameters of a function type. + // It panics if the type's Kind is not Function. + Params() map[string]Type + // PipeArgument reports the name of the argument that can be pipe into. // It panics if the type's Kind is not Function. PipeArgument() string @@ -98,6 +102,9 @@ func (k Kind) Properties() map[string]Type { func (k Kind) ElementType() Type { panic(fmt.Errorf("cannot get element type from kind %s", k)) } +func (k Kind) Params() map[string]Type { + panic(fmt.Errorf("cannot get parameters from kind %s", k)) +} func (k Kind) PipeArgument() string { panic(fmt.Errorf("cannot get pipe argument name from kind %s", k)) } @@ -126,6 +133,9 @@ func (t *arrayType) Properties() map[string]Type { func (t *arrayType) ElementType() Type { return t.elementType } +func (t *arrayType) Params() map[string]Type { + panic(fmt.Errorf("cannot get parameters from kind %s", t.Kind())) +} func (t *arrayType) PipeArgument() string { panic(fmt.Errorf("cannot get pipe argument name from kind %s", t.Kind())) } @@ -213,6 +223,9 @@ func (t *objectType) Properties() map[string]Type { func (t *objectType) ElementType() Type { panic(fmt.Errorf("cannot get element type of kind %s", t.Kind())) } +func (t *objectType) Params() map[string]Type { + panic(fmt.Errorf("cannot get parameters from kind %s", t.Kind())) +} func (t *objectType) PipeArgument() string { panic(fmt.Errorf("cannot get pipe argument name from kind %s", t.Kind())) } @@ -354,6 +367,9 @@ func (t *functionType) Properties() map[string]Type { func (t *functionType) ElementType() Type { panic(fmt.Errorf("cannot get element type of kind %s", t.Kind())) } +func (t *functionType) Params() map[string]Type { + return t.params +} func (t *functionType) PipeArgument() string { return t.pipeArgument } @@ -362,10 +378,6 @@ func (t *functionType) ReturnType() Type { } func (t *functionType) typ() {} -func (t *functionType) Params() map[string]Type { - return t.params -} - func (t *functionType) equal(o *functionType) bool { if t == o { return true diff --git a/spec.go b/spec.go index 86aeadf9ff..0f638396b9 100644 --- a/spec.go +++ b/spec.go @@ -5,6 +5,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/influxdata/platform" ) // Spec specifies a query. @@ -179,3 +180,22 @@ func (q *Spec) Functions() ([]string, error) { }) return funcs, err } + +// BucketsAccessed returns the set of buckets read and written by a query spec +func (q *Spec) BucketsAccessed() (readBuckets, writeBuckets []platform.BucketFilter, err error) { + err = q.Walk(func(o *Operation) error { + bucketAwareOpSpec, ok := o.Spec.(BucketAwareOperationSpec) + if ok { + opBucketsRead, opBucketsWritten := bucketAwareOpSpec.BucketsAccessed() + readBuckets = append(readBuckets, opBucketsRead...) + writeBuckets = append(writeBuckets, opBucketsWritten...) + } + return nil + }) + + if err != nil { + return nil, nil, err + } + + return readBuckets, writeBuckets, nil +}