diff --git a/execute/executetest/source.go b/execute/executetest/source.go index 96f9ede17f..23cc719de7 100644 --- a/execute/executetest/source.go +++ b/execute/executetest/source.go @@ -4,6 +4,7 @@ import ( "context" "github.com/influxdata/flux/execute" + "github.com/influxdata/flux/memory" "github.com/influxdata/flux/plan" uuid "github.com/satori/go.uuid" ) @@ -63,3 +64,45 @@ func (src *FromProcedureSpec) Run(ctx context.Context) { func CreateFromSource(spec plan.ProcedureSpec, id execute.DatasetID, a execute.Administration) (execute.Source, error) { return spec.(*FromProcedureSpec), nil } + +// AllocatingFromProcedureSpec is a procedure spec AND an execution node +// that allocates ByteCount bytes during execution. +type AllocatingFromProcedureSpec struct { + ByteCount int + + alloc *memory.Allocator + ts []execute.Transformation +} + +const AllocatingFromTestKind = "allocating-from-test" + +func (AllocatingFromProcedureSpec) Kind() plan.ProcedureKind { + return AllocatingFromTestKind +} + +func (s *AllocatingFromProcedureSpec) Copy() plan.ProcedureSpec { + return &AllocatingFromProcedureSpec{ + ByteCount: s.ByteCount, + alloc: s.alloc, + } +} + +func (AllocatingFromProcedureSpec) Cost(inStats []plan.Statistics) (cost plan.Cost, outStats plan.Statistics) { + return plan.Cost{}, plan.Statistics{} +} + +func CreateAllocatingFromSource(spec plan.ProcedureSpec, id execute.DatasetID, a execute.Administration) (execute.Source, error) { + s := spec.(*AllocatingFromProcedureSpec) + s.alloc = a.Allocator() + return s, nil +} + +func (s *AllocatingFromProcedureSpec) Run(ctx context.Context) { + if err := s.alloc.Allocate(s.ByteCount); err != nil { + panic(err) + } +} + +func (s *AllocatingFromProcedureSpec) AddTransformation(t execute.Transformation) { + s.ts = append(s.ts, t) +} diff --git a/execute/executor_test.go b/execute/executor_test.go index feca9bf334..1f52bc1828 100644 --- a/execute/executor_test.go +++ b/execute/executor_test.go @@ -12,6 +12,7 @@ import ( _ "github.com/influxdata/flux/builtin" "github.com/influxdata/flux/execute" "github.com/influxdata/flux/execute/executetest" + "github.com/influxdata/flux/memory" "github.com/influxdata/flux/plan" "github.com/influxdata/flux/plan/plantest" "github.com/influxdata/flux/semantic" @@ -20,16 +21,19 @@ import ( ) func init() { - execute.RegisterSource("from-test", executetest.CreateFromSource) + execute.RegisterSource(executetest.FromTestKind, executetest.CreateFromSource) + execute.RegisterSource(executetest.AllocatingFromTestKind, executetest.CreateAllocatingFromSource) execute.RegisterTransformation(executetest.ToTestKind, executetest.CreateToTransformation) plan.RegisterProcedureSpecWithSideEffect(executetest.ToTestKind, executetest.NewToProcedure, executetest.ToTestKind) } func TestExecutor_Execute(t *testing.T) { testcases := []struct { - name string - spec *plantest.PlanSpec - want map[string][]*executetest.Table + name string + spec *plantest.PlanSpec + want map[string][]*executetest.Table + allocator *memory.Allocator + wantErr error }{ { name: `from`, @@ -694,6 +698,20 @@ func TestExecutor_Execute(t *testing.T) { }}, }, }, + { + name: "memory limit exceeded", + spec: &plantest.PlanSpec{ + Nodes: []plan.Node{ + plan.CreatePhysicalNode("allocating-from-test", &executetest.AllocatingFromProcedureSpec{ByteCount: 65}), + plan.CreatePhysicalNode("yield", &universe.YieldProcedureSpec{Name: "_result"}), + }, + Edges: [][2]int{ + {0, 1}, + }, + }, + allocator: &memory.Allocator{Limit: func(v int64) *int64 { return &v }(64)}, + wantErr: memory.LimitExceededError{Limit: 64, Wanted: 65}, + }, } for _, tc := range testcases { @@ -711,24 +729,46 @@ func TestExecutor_Execute(t *testing.T) { plan := plantest.CreatePlanSpec(tc.spec) exe := execute.NewExecutor(nil, zaptest.NewLogger(t)) - results, _, err := exe.Execute(context.Background(), plan, executetest.UnlimitedAllocator) - if err != nil { - t.Fatal(err) + + alloc := tc.allocator + if alloc == nil { + alloc = executetest.UnlimitedAllocator } - got := make(map[string][]*executetest.Table, len(results)) - for name, r := range results { - if err := r.Tables().Do(func(tbl flux.Table) error { - cb, err := executetest.ConvertTable(tbl) - if err != nil { - return err + + // Execute the query and preserve any error returned + results, _, err := exe.Execute(context.Background(), plan, alloc) + var got map[string][]*executetest.Table + if err == nil { + got = make(map[string][]*executetest.Table, len(results)) + for name, r := range results { + if err = r.Tables().Do(func(tbl flux.Table) error { + cb, err := executetest.ConvertTable(tbl) + if err != nil { + return err + } + got[name] = append(got[name], cb) + return nil + }); err != nil { + break } - got[name] = append(got[name], cb) - return nil - }); err != nil { - t.Fatal(err) } } + if tc.wantErr == nil && err != nil { + t.Fatal(err) + } + + if tc.wantErr != nil { + if err == nil { + t.Fatalf(`expected an error "%v" but got none`, tc.wantErr) + } + + if diff := cmp.Diff(tc.wantErr, err); diff != "" { + t.Fatalf("unexpected error: -want/+got: %v", diff) + } + return + } + for _, g := range got { executetest.NormalizeTables(g) }