From d47dd5910aa283a79a344e7f100b3586089e1ad0 Mon Sep 17 00:00:00 2001 From: Iaroslav Ciupin Date: Tue, 22 Aug 2023 21:29:43 +0300 Subject: [PATCH 1/4] Validate sort order in admin.ResourceListRequest Signed-off-by: Iaroslav Ciupin --- go.mod | 2 +- go.sum | 4 ++ .../impl/db_admin_data_provider.go | 7 +- .../impl/db_admin_data_provider_test.go | 11 ++-- pkg/clusterresource/impl/shared.go | 5 +- pkg/common/sorting.go | 37 +++++++++-- .../impl/description_entity_manager.go | 30 ++++----- pkg/manager/impl/execution_manager.go | 49 ++++++-------- pkg/manager/impl/execution_manager_test.go | 33 +++++----- pkg/manager/impl/launch_plan_manager.go | 66 +++++++++---------- pkg/manager/impl/launch_plan_manager_test.go | 12 ++-- pkg/manager/impl/named_entity_manager.go | 31 +++++---- pkg/manager/impl/node_execution_manager.go | 34 +++++----- .../impl/node_execution_manager_test.go | 24 +++---- pkg/manager/impl/project_manager.go | 35 +++++----- pkg/manager/impl/project_manager_test.go | 7 +- pkg/manager/impl/signal_manager.go | 20 +++--- pkg/manager/impl/task_execution_manager.go | 32 ++++----- pkg/manager/impl/task_manager.go | 45 ++++++------- pkg/manager/impl/task_manager_test.go | 14 ++-- pkg/manager/impl/workflow_manager.go | 59 +++++++++-------- pkg/manager/impl/workflow_manager_test.go | 8 ++- pkg/repositories/gormimpl/common.go | 39 +++++++---- .../gormimpl/description_entity_repo.go | 20 ++++-- pkg/repositories/gormimpl/execution_repo.go | 38 +++++++++-- .../gormimpl/execution_repo_test.go | 7 +- pkg/repositories/gormimpl/launch_plan_repo.go | 24 +++++-- .../gormimpl/launch_plan_repo_test.go | 7 +- .../gormimpl/named_entity_repo.go | 24 +++++-- .../gormimpl/named_entity_repo_test.go | 9 +-- .../gormimpl/node_execution_repo.go | 27 +++++++- .../gormimpl/node_execution_repo_test.go | 8 ++- pkg/repositories/gormimpl/project_repo.go | 16 ++++- .../gormimpl/project_repo_test.go | 25 +++---- pkg/repositories/gormimpl/signal_repo.go | 13 +++- .../gormimpl/task_execution_repo.go | 23 ++++++- pkg/repositories/gormimpl/task_repo.go | 21 ++++-- pkg/repositories/gormimpl/task_repo_test.go | 5 +- pkg/repositories/gormimpl/workflow_repo.go | 30 ++++++--- .../gormimpl/workflow_repo_test.go | 9 +-- pkg/repositories/interfaces/common.go | 4 +- 41 files changed, 564 insertions(+), 350 deletions(-) diff --git a/go.mod b/go.mod index b30095519..347b785eb 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/cloudevents/sdk-go/v2 v2.8.0 github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.12.0+incompatible - github.com/flyteorg/flyteidl v1.5.14 + github.com/flyteorg/flyteidl v1.5.17-0.20230822102414-1c76702b5f6a github.com/flyteorg/flyteplugins v1.0.67 github.com/flyteorg/flytepropeller v1.1.98 github.com/flyteorg/flytestdlib v1.0.22 diff --git a/go.sum b/go.sum index ed4b6ac85..2d463eac6 100644 --- a/go.sum +++ b/go.sum @@ -295,6 +295,8 @@ github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8S github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flyteorg/flyteidl v1.5.14 h1:+3ewipoOp82fPyIVgvvrMq1lorl5Kz3Lh6sh/a9+loI= github.com/flyteorg/flyteidl v1.5.14/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteidl v1.5.17-0.20230822102414-1c76702b5f6a h1:eqQ7g71w8NG5fnTq1QiyPNoIaryNnU2wW/2DtATFHa0= +github.com/flyteorg/flyteidl v1.5.17-0.20230822102414-1c76702b5f6a/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flyteplugins v1.0.67 h1:d2FXpwxQwX/k4YdmhuusykOemHb/cUTPEob4WBmdpjE= github.com/flyteorg/flyteplugins v1.0.67/go.mod h1:HHt4nKDKVwrZPKDsj99dNtDSIJL378xNotYMA3a/TFA= github.com/flyteorg/flytepropeller v1.1.98 h1:Zk2ENYB9VZRT5tFUIFjm+aCkr0TU2EuyJ5gh52fpLoA= @@ -760,6 +762,7 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.3.2 h1:IqNFLAmvJOgVlpdEBiQbDc2EwKW77amAycfTuWKdfvw= +github.com/google/martian/v3 v3.3.2/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -785,6 +788,7 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.7.1 h1:gF4c0zjUP2H/s/hEGyLA3I0fA2ZWjzYiONAD6cvPr8A= github.com/googleapis/gax-go/v2 v2.7.1/go.mod h1:4orTrqY6hXxxaUL4LHIPl6lGo8vAE38/qKbhSAKP6QI= +github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gopherjs/gopherjs v0.0.0-20181004151105-1babbf986f6f/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= diff --git a/pkg/clusterresource/impl/db_admin_data_provider.go b/pkg/clusterresource/impl/db_admin_data_provider.go index 3612cb5ec..12b837331 100644 --- a/pkg/clusterresource/impl/db_admin_data_provider.go +++ b/pkg/clusterresource/impl/db_admin_data_provider.go @@ -3,13 +3,14 @@ package impl import ( "context" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteadmin/pkg/clusterresource/interfaces" "github.com/flyteorg/flyteadmin/pkg/common" managerInterfaces "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" repositoryInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" ) // Implementation of an interfaces.FlyteAdminDataProvider which fetches data directly from the provided database connection. @@ -52,8 +53,8 @@ func (p dbAdminProvider) GetProjects(ctx context.Context) (*admin.Projects, erro return nil, err } projectModels, err := p.db.ProjectRepo().List(ctx, repositoryInterfaces.ListResourceInput{ - SortParameter: descCreatedAtSortDBParam, - InlineFilters: []common.InlineFilter{filter}, + SortParameters: descCreatedAtSortDBParam, + InlineFilters: []common.InlineFilter{filter}, }) if err != nil { return nil, err diff --git a/pkg/clusterresource/impl/db_admin_data_provider_test.go b/pkg/clusterresource/impl/db_admin_data_provider_test.go index 452a55995..076c53dd2 100644 --- a/pkg/clusterresource/impl/db_admin_data_provider_test.go +++ b/pkg/clusterresource/impl/db_admin_data_provider_test.go @@ -5,6 +5,11 @@ import ( "errors" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" "github.com/flyteorg/flyteadmin/pkg/manager/mocks" repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" @@ -12,10 +17,6 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" configMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) var errFoo = errors.New("foo") @@ -104,7 +105,7 @@ func TestGetProjects(t *testing.T) { mockRepo.(*repoMocks.MockRepository).ProjectRepoIface = &repoMocks.MockProjectRepo{ ListProjectsFunction: func(ctx context.Context, input repoInterfaces.ListResourceInput) ([]models.Project, error) { assert.Len(t, input.InlineFilters, 1) - assert.Equal(t, input.SortParameter.GetGormOrderExpr(), "created_at desc") + assert.Equal(t, input.SortParameters.GetGormOrderExpr(), "created_at desc") return []models.Project{ { Identifier: "flytesnacks", diff --git a/pkg/clusterresource/impl/shared.go b/pkg/clusterresource/impl/shared.go index 6ba07856f..bd350eb55 100644 --- a/pkg/clusterresource/impl/shared.go +++ b/pkg/clusterresource/impl/shared.go @@ -1,10 +1,11 @@ package impl import ( - "github.com/flyteorg/flyteadmin/pkg/common" - "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "google.golang.org/grpc/codes" + + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/errors" ) func NewMissingEntityError(entity string) error { diff --git a/pkg/common/sorting.go b/pkg/common/sorting.go index c4922d0b1..f0751e51f 100644 --- a/pkg/common/sorting.go +++ b/pkg/common/sorting.go @@ -3,9 +3,12 @@ package common import ( "fmt" - "github.com/flyteorg/flyteadmin/pkg/errors" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "google.golang.org/grpc/codes" + + "github.com/flyteorg/flyteadmin/pkg/errors" ) const gormDescending = "%s desc" @@ -23,7 +26,11 @@ func (s *sortParamImpl) GetGormOrderExpr() string { return s.gormOrderExpression } -func NewSortParameter(sort admin.Sort) (SortParameter, error) { +func NewSortParameter(sort *admin.Sort, allowed sets.String) ([]SortParameter, error) { + if !allowed.Has(sort.Key) { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid sort_key: %s", sort.Key) + } + var gormOrderExpression string switch sort.Direction { case admin.Sort_DESCENDING: @@ -33,7 +40,27 @@ func NewSortParameter(sort admin.Sort) (SortParameter, error) { default: return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid sort order specified: %v", sort) } - return &sortParamImpl{ - gormOrderExpression: gormOrderExpression, - }, nil + + return []SortParameter{&sortParamImpl{gormOrderExpression: gormOrderExpression}}, nil +} + +func NewSortParameters(request *admin.ResourceListRequest, allowed sets.String) ([]SortParameter, error) { + if len(request.SortKeys) > 0 && request.SortBy != nil { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "cannot specify both sort_keys and sort_by") + } + + if request.SortBy != nil { + request.SortKeys = append(request.SortKeys, request.SortBy) + } + + sortParams := make([]SortParameter, 0, len(request.SortKeys)) + for _, sortKey := range request.SortKeys { + params, err := NewSortParameter(sortKey, allowed) + if err != nil { + return sortParams, err + } + sortParams = append(sortParams, params...) + } + + return sortParams, nil } diff --git a/pkg/manager/impl/description_entity_manager.go b/pkg/manager/impl/description_entity_manager.go index 3dcd7ab3e..c8b67f613 100644 --- a/pkg/manager/impl/description_entity_manager.go +++ b/pkg/manager/impl/description_entity_manager.go @@ -4,22 +4,22 @@ import ( "context" "strconv" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "google.golang.org/grpc/codes" "github.com/flyteorg/flyteadmin/pkg/common" - "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" "github.com/flyteorg/flyteadmin/pkg/manager/impl/validation" "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flytestdlib/contextutils" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils" - "google.golang.org/grpc/codes" ) type DescriptionEntityMetrics struct { @@ -65,23 +65,21 @@ func (d *DescriptionEntityManager) ListDescriptionEntity(ctx context.Context, re logger.Error(ctx, "failed to get database filter") return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.DescriptionEntityColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListWorkflows", request.Token) } listDescriptionEntitiesInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } output, err := d.db.DescriptionEntityRepo().List(ctx, listDescriptionEntitiesInput) if err != nil { diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index e881ef10e..540132b31 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -6,49 +6,42 @@ import ( "strconv" "time" - "github.com/flyteorg/flytestdlib/promutils/labeled" - - "github.com/flyteorg/flyteadmin/plugins" - + "github.com/benbjohnson/clock" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - - "github.com/flyteorg/flyteadmin/auth" - - "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" - - dataInterfaces "github.com/flyteorg/flyteadmin/pkg/data/interfaces" "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/timestamp" "github.com/prometheus/client_golang/prometheus" + "google.golang.org/grpc/codes" - "github.com/flyteorg/flyteadmin/pkg/common" - - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/storage" - + "github.com/flyteorg/flyteadmin/auth" cloudeventInterfaces "github.com/flyteorg/flyteadmin/pkg/async/cloudevent/interfaces" eventWriter "github.com/flyteorg/flyteadmin/pkg/async/events/interfaces" "github.com/flyteorg/flyteadmin/pkg/async/notifications" notificationInterfaces "github.com/flyteorg/flyteadmin/pkg/async/notifications/interfaces" + "github.com/flyteorg/flyteadmin/pkg/common" + dataInterfaces "github.com/flyteorg/flyteadmin/pkg/data/interfaces" "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/executions" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" "github.com/flyteorg/flyteadmin/pkg/manager/impl/validation" "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" repositoryInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" workflowengineInterfaces "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "google.golang.org/grpc/codes" - - "github.com/benbjohnson/clock" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" - "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyteadmin/plugins" ) const childContainerQueueKey = "child_queue" @@ -1434,12 +1427,10 @@ func (m *ExecutionManager) ListExecutions( if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameters(&request, gormimpl.ExecutionColumns) + if err != nil { + return nil, err } offset, err := validation.ValidateToken(request.Token) @@ -1461,7 +1452,7 @@ func (m *ExecutionManager) ListExecutions( Limit: int(request.Limit), Offset: offset, InlineFilters: filters, - SortParameter: sortParameter, + SortParameters: sortParameters, JoinTableEntities: joinTableEntities, } output, err := m.db.ExecutionRepo().List(ctx, listExecutionsInput) diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 34d436ee9..d31a4e04d 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -13,6 +13,13 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/benbjohnson/clock" + "github.com/flyteorg/flyteidl/clients/go/coreutils" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/gogo/protobuf/jsonpb" + "github.com/golang/protobuf/ptypes" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "github.com/flyteorg/flyteadmin/pkg/common" commonTestUtils "github.com/flyteorg/flyteadmin/pkg/common/testutils" flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" @@ -21,18 +28,13 @@ import ( managerInterfaces "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" managerMocks "github.com/flyteorg/flyteadmin/pkg/manager/mocks" "github.com/flyteorg/flyteadmin/pkg/runtime" - "github.com/flyteorg/flyteidl/clients/go/coreutils" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/gogo/protobuf/jsonpb" - "github.com/golang/protobuf/ptypes" - "github.com/stretchr/testify/mock" - "google.golang.org/grpc/codes" "k8s.io/apimachinery/pkg/api/resource" - eventWriterMocks "github.com/flyteorg/flyteadmin/pkg/async/events/mocks" "k8s.io/apimachinery/pkg/util/sets" + eventWriterMocks "github.com/flyteorg/flyteadmin/pkg/async/events/mocks" + "github.com/flyteorg/flyteadmin/auth" commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" @@ -43,6 +45,13 @@ import ( "fmt" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + mockScope "github.com/flyteorg/flytestdlib/promutils" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/wrappers" + "github.com/stretchr/testify/assert" + notificationMocks "github.com/flyteorg/flyteadmin/pkg/async/notifications/mocks" dataMocks "github.com/flyteorg/flyteadmin/pkg/data/mocks" "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" @@ -55,12 +64,6 @@ import ( runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" workflowengineInterfaces "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" workflowengineMocks "github.com/flyteorg/flyteadmin/pkg/workflowengine/mocks" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - mockScope "github.com/flyteorg/flytestdlib/promutils" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes/wrappers" - "github.com/stretchr/testify/assert" ) var spec = testutils.GetExecutionRequest().Spec @@ -2979,7 +2982,7 @@ func TestListExecutions(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.False(t, nameFilter, "Included name equality filter") assert.Equal(t, limit, input.Limit) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) assert.Equal(t, 2, input.Offset) assert.EqualValues(t, map[common.Entity]bool{ common.Execution: true, @@ -3965,7 +3968,7 @@ func TestListExecutions_LegacyModel(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.False(t, nameFilter, "Included name equality filter") assert.Equal(t, limit, input.Limit) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) assert.Equal(t, 2, input.Offset) return interfaces.ExecutionCollectionOutput{ Executions: []models.Execution{ diff --git a/pkg/manager/impl/launch_plan_manager.go b/pkg/manager/impl/launch_plan_manager.go index f2192b701..88b53ca5c 100644 --- a/pkg/manager/impl/launch_plan_manager.go +++ b/pkg/manager/impl/launch_plan_manager.go @@ -5,6 +5,8 @@ import ( "context" "strconv" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" @@ -14,6 +16,11 @@ import ( "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/golang/protobuf/proto" + "google.golang.org/grpc/codes" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" @@ -23,10 +30,6 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/golang/protobuf/proto" - "google.golang.org/grpc/codes" ) type launchPlanMetrics struct { @@ -408,23 +411,21 @@ func (m *LaunchPlanManager) ListLaunchPlans(ctx context.Context, request admin.R return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.LaunchPlanColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListLaunchPlans", request.Token) } listLaunchPlansInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } output, err := m.db.LaunchPlanRepo().List(ctx, listLaunchPlansInput) @@ -463,23 +464,21 @@ func (m *LaunchPlanManager) ListActiveLaunchPlans(ctx context.Context, request a return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.LaunchPlanColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListActiveLaunchPlans", request.Token) } listLaunchPlansInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } output, err := m.db.LaunchPlanRepo().List(ctx, listLaunchPlansInput) @@ -514,22 +513,21 @@ func (m *LaunchPlanManager) ListLaunchPlanIds(ctx context.Context, request admin if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.LaunchPlanColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s", request.Token) } listLaunchPlansInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } output, err := m.db.LaunchPlanRepo().ListLaunchPlanIdentifiers(ctx, listLaunchPlansInput) diff --git a/pkg/manager/impl/launch_plan_manager_test.go b/pkg/manager/impl/launch_plan_manager_test.go index 61b5db80b..f5f91b184 100644 --- a/pkg/manager/impl/launch_plan_manager_test.go +++ b/pkg/manager/impl/launch_plan_manager_test.go @@ -11,9 +11,12 @@ import ( "github.com/flyteorg/flyteadmin/pkg/async/schedule/mocks" - scheduleInterfaces "github.com/flyteorg/flyteadmin/pkg/async/schedule/interfaces" "github.com/golang/protobuf/ptypes" + scheduleInterfaces "github.com/flyteorg/flyteadmin/pkg/async/schedule/interfaces" + + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyteadmin/pkg/common" flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" @@ -21,7 +24,6 @@ import ( repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" "github.com/flyteorg/flyteadmin/pkg/repositories/models" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" - "github.com/golang/protobuf/proto" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" @@ -1096,7 +1098,7 @@ func TestLaunchPlanManager_ListLaunchPlans(t *testing.T) { assert.True(t, nameFilter, "Missing name equality filter") assert.Equal(t, 10, input.Limit) assert.Equal(t, 2, input.Offset) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.LaunchPlanCollectionOutput{ LaunchPlans: []models.LaunchPlan{ @@ -1193,7 +1195,7 @@ func TestLaunchPlanManager_ListLaunchPlanIds(t *testing.T) { assert.True(t, projectFilter, "Missing project equality filter") assert.True(t, domainFilter, "Missing domain equality filter") assert.Equal(t, 10, input.Limit) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.LaunchPlanCollectionOutput{ LaunchPlans: []models.LaunchPlan{ @@ -1280,7 +1282,7 @@ func TestLaunchPlanManager_ListActiveLaunchPlans(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.True(t, activeFilter, "Missing active filter") assert.Equal(t, 10, input.Limit) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.LaunchPlanCollectionOutput{ LaunchPlans: []models.LaunchPlan{ diff --git a/pkg/manager/impl/named_entity_manager.go b/pkg/manager/impl/named_entity_manager.go index d65f5ec4d..d73e8a565 100644 --- a/pkg/manager/impl/named_entity_manager.go +++ b/pkg/manager/impl/named_entity_manager.go @@ -5,13 +5,20 @@ import ( "strconv" "strings" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/contextutils" + "google.golang.org/grpc/codes" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" - "google.golang.org/grpc/codes" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" "github.com/flyteorg/flyteadmin/pkg/manager/impl/validation" @@ -19,9 +26,6 @@ import ( repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils" ) const state = "state" @@ -119,13 +123,12 @@ func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admi if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.NamedEntityColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, @@ -133,10 +136,10 @@ func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admi } listInput := repoInterfaces.ListNamedEntityInput{ ListResourceInput: repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, }, Project: request.Project, Domain: request.Domain, diff --git a/pkg/manager/impl/node_execution_manager.go b/pkg/manager/impl/node_execution_manager.go index bcc4362db..15aa7b84e 100644 --- a/pkg/manager/impl/node_execution_manager.go +++ b/pkg/manager/impl/node_execution_manager.go @@ -4,23 +4,27 @@ import ( "context" "strconv" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" + cloudeventInterfaces "github.com/flyteorg/flyteadmin/pkg/async/cloudevent/interfaces" "github.com/flyteorg/flytestdlib/promutils/labeled" eventWriter "github.com/flyteorg/flyteadmin/pkg/async/events/interfaces" - notificationInterfaces "github.com/flyteorg/flyteadmin/pkg/async/notifications/interfaces" "github.com/golang/protobuf/proto" + notificationInterfaces "github.com/flyteorg/flyteadmin/pkg/async/notifications/interfaces" + "github.com/flyteorg/flytestdlib/storage" "github.com/flyteorg/flytestdlib/contextutils" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" "github.com/flyteorg/flytestdlib/promutils" "github.com/prometheus/client_golang/prometheus" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" + "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flyteadmin/pkg/common" @@ -28,6 +32,10 @@ import ( "fmt" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "google.golang.org/grpc/codes" + dataInterfaces "github.com/flyteorg/flyteadmin/pkg/data/interfaces" "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" @@ -36,9 +44,6 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "google.golang.org/grpc/codes" ) type nodeExecutionMetrics struct { @@ -378,23 +383,22 @@ func (m *NodeExecutionManager) listNodeExecutions( if err != nil { return nil, err } - var sortParameter common.SortParameter - if sortBy != nil { - sortParameter, err = common.NewSortParameter(*sortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(sortBy, gormimpl.NodeExecutionColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(requestToken) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListNodeExecutions", requestToken) } listInput := repoInterfaces.ListResourceInput{ - Limit: int(limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } listInput.MapFilters = mapFilters diff --git a/pkg/manager/impl/node_execution_manager_test.go b/pkg/manager/impl/node_execution_manager_test.go index 134880347..029a2af5c 100644 --- a/pkg/manager/impl/node_execution_manager_test.go +++ b/pkg/manager/impl/node_execution_manager_test.go @@ -15,16 +15,10 @@ import ( eventWriterMocks "github.com/flyteorg/flyteadmin/pkg/async/events/mocks" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" "github.com/flyteorg/flytestdlib/storage" - "github.com/flyteorg/flyteadmin/pkg/common" - commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" - dataMocks "github.com/flyteorg/flyteadmin/pkg/data/mocks" - flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" - "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" - repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" - "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" @@ -33,6 +27,14 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" + + "github.com/flyteorg/flyteadmin/pkg/common" + commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" + dataMocks "github.com/flyteorg/flyteadmin/pkg/data/mocks" + flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) var occurredAt = time.Now().UTC() @@ -807,7 +809,7 @@ func TestListNodeExecutionsLevelZero(t *testing.T) { "parent_task_execution_id": nil, }, filter) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.NodeExecutionCollectionOutput{ NodeExecutions: []models.NodeExecution{ { @@ -925,7 +927,7 @@ func TestListNodeExecutionsWithParent(t *testing.T) { assert.Equal(t, parentID, queryExpr.Args) assert.Equal(t, "parent_id = ?", queryExpr.Query) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.NodeExecutionCollectionOutput{ NodeExecutions: []models.NodeExecution{ { @@ -1139,7 +1141,7 @@ func TestListNodeExecutionsForTask(t *testing.T) { assert.Equal(t, uint(8), queryExpr.Args) assert.Equal(t, "parent_task_execution_id = ?", queryExpr.Query) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.NodeExecutionCollectionOutput{ NodeExecutions: []models.NodeExecution{ { diff --git a/pkg/manager/impl/project_manager.go b/pkg/manager/impl/project_manager.go index 27f429799..abe5c8625 100644 --- a/pkg/manager/impl/project_manager.go +++ b/pkg/manager/impl/project_manager.go @@ -4,6 +4,11 @@ import ( "context" "strconv" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" @@ -12,8 +17,6 @@ import ( repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" ) type ProjectManager struct { @@ -21,10 +24,10 @@ type ProjectManager struct { config runtimeInterfaces.Configuration } -var alphabeticalSortParam, _ = common.NewSortParameter(admin.Sort{ +var alphabeticalSortParam, _ = common.NewSortParameter(&admin.Sort{ Direction: admin.Sort_ASCENDING, Key: "identifier", -}) +}, gormimpl.ProjectColumns) func (m *ProjectManager) CreateProject(ctx context.Context, request admin.ProjectRegisterRequest) ( *admin.ProjectRegisterResponse, error) { @@ -61,14 +64,13 @@ func (m *ProjectManager) ListProjects(ctx context.Context, request admin.Project return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } - } else { - sortParameter = alphabeticalSortParam + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.ProjectColumns) + if err != nil { + return nil, err + } + + if len(sortParameters) == 0 { + sortParameters = alphabeticalSortParam } offset, err := validation.ValidateToken(request.Token) @@ -79,10 +81,10 @@ func (m *ProjectManager) ListProjects(ctx context.Context, request admin.Project // And finally, query the database listProjectsInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } projectModels, err := m.db.ProjectRepo().List(ctx, listProjectsInput) if err != nil { @@ -119,7 +121,6 @@ func (m *ProjectManager) UpdateProject(ctx context.Context, projectUpdate admin. // Transform the provided project into a model and apply to the DB. projectUpdateModel := transformers.CreateProjectModel(&projectUpdate) err = projectRepo.UpdateProject(ctx, projectUpdateModel) - if err != nil { return nil, err } diff --git a/pkg/manager/impl/project_manager_test.go b/pkg/manager/impl/project_manager_test.go index 5b2b437c0..5e3760fd5 100644 --- a/pkg/manager/impl/project_manager_test.go +++ b/pkg/manager/impl/project_manager_test.go @@ -9,14 +9,15 @@ import ( "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" "github.com/flyteorg/flyteadmin/pkg/repositories/models" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/stretchr/testify/assert" ) var mockProjectConfigProvider = runtimeMocks.NewMockConfigurationProvider( @@ -55,7 +56,7 @@ func testListProjects(request admin.ProjectListRequest, token string, orderExpr q, _ := input.InlineFilters[0].GetGormQueryExpr() assert.Equal(t, *queryExpr, q) } - assert.Equal(t, orderExpr, input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, orderExpr, input.SortParameters.GetGormOrderExpr()) activeState := int32(admin.Project_ACTIVE) return []models.Project{ { diff --git a/pkg/manager/impl/signal_manager.go b/pkg/manager/impl/signal_manager.go index df2fbcc7b..6676cf107 100644 --- a/pkg/manager/impl/signal_manager.go +++ b/pkg/manager/impl/signal_manager.go @@ -4,6 +4,8 @@ import ( "context" "strconv" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flyteadmin/pkg/common" @@ -83,12 +85,10 @@ func (s *SignalManager) ListSignals(ctx context.Context, request admin.SignalLis if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.SignalColumns) + if err != nil { + return nil, err } offset, err := validation.ValidateToken(request.Token) @@ -98,10 +98,10 @@ func (s *SignalManager) ListSignals(ctx context.Context, request admin.SignalLis } signalModelList, err := s.db.SignalRepo().List(ctx, repoInterfaces.ListResourceInput{ - InlineFilters: filters, - Offset: offset, - Limit: int(request.Limit), - SortParameter: sortParameter, + InlineFilters: filters, + Offset: offset, + Limit: int(request.Limit), + SortParameters: sortParameters, }) if err != nil { logger.Debugf(ctx, "Failed to list signals with request [%+v] with err %v", diff --git a/pkg/manager/impl/task_execution_manager.go b/pkg/manager/impl/task_execution_manager.go index 46967f264..70c4b73fd 100644 --- a/pkg/manager/impl/task_execution_manager.go +++ b/pkg/manager/impl/task_execution_manager.go @@ -5,13 +5,16 @@ import ( "fmt" "strconv" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" + cloudeventInterfaces "github.com/flyteorg/flyteadmin/pkg/async/cloudevent/interfaces" "github.com/flyteorg/flytestdlib/promutils/labeled" - notificationInterfaces "github.com/flyteorg/flyteadmin/pkg/async/notifications/interfaces" "github.com/golang/protobuf/proto" + notificationInterfaces "github.com/flyteorg/flyteadmin/pkg/async/notifications/interfaces" + "github.com/flyteorg/flytestdlib/storage" "github.com/flyteorg/flytestdlib/contextutils" @@ -19,6 +22,11 @@ import ( "github.com/flyteorg/flytestdlib/promutils" "github.com/prometheus/client_golang/prometheus" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/logger" + "google.golang.org/grpc/codes" + "github.com/flyteorg/flyteadmin/pkg/common" dataInterfaces "github.com/flyteorg/flyteadmin/pkg/data/interfaces" "github.com/flyteorg/flyteadmin/pkg/errors" @@ -29,10 +37,6 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/logger" - "google.golang.org/grpc/codes" ) type taskExecutionMetrics struct { @@ -258,12 +262,10 @@ func (m *TaskExecutionManager) ListTaskExecutions( if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.TaskExecutionColumns) + if err != nil { + return nil, err } offset, err := validation.ValidateToken(request.Token) @@ -273,10 +275,10 @@ func (m *TaskExecutionManager) ListTaskExecutions( } output, err := m.db.TaskExecutionRepo().List(ctx, repoInterfaces.ListResourceInput{ - InlineFilters: filters, - Offset: offset, - Limit: int(request.Limit), - SortParameter: sortParameter, + InlineFilters: filters, + Offset: offset, + Limit: int(request.Limit), + SortParameters: sortParameters, }) if err != nil { logger.Debugf(ctx, "Failed to list task executions with request [%+v] with err %v", diff --git a/pkg/manager/impl/task_manager.go b/pkg/manager/impl/task_manager.go index b4346fcd9..b5eb5c45a 100644 --- a/pkg/manager/impl/task_manager.go +++ b/pkg/manager/impl/task_manager.go @@ -6,6 +6,8 @@ import ( "strconv" "time" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/contextutils" @@ -18,6 +20,9 @@ import ( "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" @@ -28,8 +33,6 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" workflowengine "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" ) type taskMetrics struct { @@ -169,13 +172,12 @@ func (t *TaskManager) ListTasks(ctx context.Context, request admin.ResourceListR if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.TaskColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, @@ -183,10 +185,10 @@ func (t *TaskManager) ListTasks(ctx context.Context, request admin.ResourceListR } // And finally, query the database listTasksInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } output, err := t.db.TaskRepo().List(ctx, listTasksInput) if err != nil { @@ -226,23 +228,22 @@ func (t *TaskManager) ListUniqueTaskIdentifiers(ctx context.Context, request adm if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.TaskColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListUniqueTaskIdentifiers", request.Token) } listTasksInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } output, err := t.db.TaskRepo().ListTaskIdentifiers(ctx, listTasksInput) diff --git a/pkg/manager/impl/task_manager_test.go b/pkg/manager/impl/task_manager_test.go index 8b8a38e76..3101ba6b4 100644 --- a/pkg/manager/impl/task_manager_test.go +++ b/pkg/manager/impl/task_manager_test.go @@ -6,6 +6,9 @@ import ( "fmt" "testing" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyteadmin/pkg/common" adminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" @@ -14,16 +17,15 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" - "github.com/flyteorg/flytestdlib/promutils/labeled" - "github.com/golang/protobuf/proto" - workflowengine "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" - workflowMocks "github.com/flyteorg/flyteadmin/pkg/workflowengine/mocks" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" mockScope "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" + + workflowengine "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" + workflowMocks "github.com/flyteorg/flyteadmin/pkg/workflowengine/mocks" ) // Static values for test @@ -242,7 +244,7 @@ func TestListTasks(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.True(t, nameFilter, "Missing name equality filter") assert.Equal(t, 2, input.Limit) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.TaskCollectionOutput{ Tasks: []models.Task{ { @@ -365,7 +367,7 @@ func TestListUniqueTaskIdentifiers(t *testing.T) { } } assert.Equal(t, 10, input.Offset) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.TaskCollectionOutput{ Tasks: []models.Task{ diff --git a/pkg/manager/impl/workflow_manager.go b/pkg/manager/impl/workflow_manager.go index 09d6a0db2..fe0017634 100644 --- a/pkg/manager/impl/workflow_manager.go +++ b/pkg/manager/impl/workflow_manager.go @@ -6,8 +6,20 @@ import ( "strconv" "time" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + compiler "github.com/flyteorg/flytepropeller/pkg/compiler/common" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/storage" + "github.com/golang/protobuf/ptypes" + "github.com/prometheus/client_golang/prometheus" + "google.golang.org/grpc/codes" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" @@ -19,15 +31,6 @@ import ( runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" workflowengine "github.com/flyteorg/flyteadmin/pkg/workflowengine/impl" workflowengineInterfaces "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - compiler "github.com/flyteorg/flytepropeller/pkg/compiler/common" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/storage" - "github.com/golang/protobuf/ptypes" - "github.com/prometheus/client_golang/prometheus" - "google.golang.org/grpc/codes" ) var defaultStorageOptions = storage.Options{} @@ -252,23 +255,22 @@ func (w *WorkflowManager) ListWorkflows( if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.WorkflowColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListWorkflows", request.Token) } listWorkflowsInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } output, err := w.db.WorkflowRepo().List(ctx, listWorkflowsInput) if err != nil { @@ -306,23 +308,22 @@ func (w *WorkflowManager) ListWorkflowIdentifiers(ctx context.Context, request a if err != nil { return nil, err } - var sortParameter common.SortParameter - if request.SortBy != nil { - sortParameter, err = common.NewSortParameter(*request.SortBy) - if err != nil { - return nil, err - } + + sortParameters, err := common.NewSortParameter(request.SortBy, gormimpl.WorkflowColumns) + if err != nil { + return nil, err } + offset, err := validation.ValidateToken(request.Token) if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListWorkflowIdentifiers", request.Token) } listWorkflowsInput := repoInterfaces.ListResourceInput{ - Limit: int(request.Limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(request.Limit), + Offset: offset, + InlineFilters: filters, + SortParameters: sortParameters, } output, err := w.db.WorkflowRepo().ListIdentifiers(ctx, listWorkflowsInput) diff --git a/pkg/manager/impl/workflow_manager_test.go b/pkg/manager/impl/workflow_manager_test.go index cc30e8aaf..b539104c0 100644 --- a/pkg/manager/impl/workflow_manager_test.go +++ b/pkg/manager/impl/workflow_manager_test.go @@ -6,14 +6,16 @@ import ( "fmt" "testing" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyteadmin/pkg/common" commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" + adminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/golang/protobuf/proto" flyteErrors "github.com/flyteorg/flyteadmin/pkg/errors" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" @@ -387,7 +389,7 @@ func TestListWorkflows(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.True(t, nameFilter, "Missing name equality filter") assert.Equal(t, limit, input.Limit) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) assert.Equal(t, 10, input.Offset) return interfaces.WorkflowCollectionOutput{ Workflows: []models.Workflow{ @@ -524,7 +526,7 @@ func TestWorkflowManager_ListWorkflowIdentifiers(t *testing.T) { assert.True(t, projectFilter, "Missing project equality filter") assert.True(t, domainFilter, "Missing domain equality filter") assert.Equal(t, limit, input.Limit) - assert.Equal(t, "domain asc", input.SortParameter.GetGormOrderExpr()) + assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) return interfaces.WorkflowCollectionOutput{ Workflows: []models.Workflow{ { diff --git a/pkg/repositories/gormimpl/common.go b/pkg/repositories/gormimpl/common.go index c022bd973..b27cfe458 100644 --- a/pkg/repositories/gormimpl/common.go +++ b/pkg/repositories/gormimpl/common.go @@ -3,23 +3,33 @@ package gormimpl import ( "fmt" + "google.golang.org/grpc/codes" + "gorm.io/gorm" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteadmin/pkg/common" adminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" - - "google.golang.org/grpc/codes" - "gorm.io/gorm" ) -const Project = "project" -const Domain = "domain" -const Name = "name" -const Version = "version" -const Description = "description" -const ResourceType = "resource_type" -const State = "state" -const ID = "id" +const ( + Project = "project" + Domain = "domain" + Name = "name" + Version = "version" + Description = "description" + ResourceType = "resource_type" + State = "state" + ID = "id" + CreatedAt = "created_at" + UpdatedAt = "updated_at" + DeletedAt = "deleted_at" + ExecutionProject = "execution_project" + ExecutionDomain = "execution_domain" + ExecutionName = "execution_name" + NodeID = "node_id" +) const executionTableName = "executions" const namedEntityMetadataTableName = "named_entity_metadata" @@ -34,6 +44,13 @@ const executionAdminTagsTableName = "execution_admin_tags" const limit = "limit" const filters = "filters" +var ( + BaseColumnSet = sets.NewString(ID, CreatedAt, UpdatedAt, DeletedAt, ResourceType) + TaskKeyColumnSet = sets.NewString(Project, Domain, Name, Version) + ExecutionKeyColumnSet = sets.NewString(ExecutionProject, ExecutionDomain, ExecutionName) + NodeExecutionKeyColumnSet = ExecutionKeyColumnSet.Union(sets.NewString(NodeID)) +) + var identifierGroupBy = fmt.Sprintf("%s, %s, %s", Project, Domain, Name) var entityToTableName = map[common.Entity]string{ diff --git a/pkg/repositories/gormimpl/description_entity_repo.go b/pkg/repositories/gormimpl/description_entity_repo.go index 1f5dceb5a..594eaba2b 100644 --- a/pkg/repositories/gormimpl/description_entity_repo.go +++ b/pkg/repositories/gormimpl/description_entity_repo.go @@ -3,16 +3,28 @@ package gormimpl import ( "context" - "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" + "gorm.io/gorm" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteadmin/pkg/common" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "gorm.io/gorm" ) +var DescriptionEntityColumns = BaseColumnSet.Union(sets.NewString( + ResourceType, + Project, + Domain, + Name, + Version, + "short_description", + "long_description", + "link", +)) + // DescriptionEntityRepo Implementation of DescriptionEntityRepoInterface. type DescriptionEntityRepo struct { db *gorm.DB @@ -61,8 +73,8 @@ func (r *DescriptionEntityRepo) List( return interfaces.DescriptionEntityCollectionOutput{}, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } timer := r.metrics.ListDuration.Start() tx.Find(&descriptionEntities) diff --git a/pkg/repositories/gormimpl/execution_repo.go b/pkg/repositories/gormimpl/execution_repo.go index b128a2805..7359806b1 100644 --- a/pkg/repositories/gormimpl/execution_repo.go +++ b/pkg/repositories/gormimpl/execution_repo.go @@ -5,16 +5,42 @@ import ( "errors" "fmt" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/promutils" + "gorm.io/gorm" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteadmin/pkg/common" adminErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/promutils" - - "gorm.io/gorm" ) +var ExecutionColumns = BaseColumnSet. + Union(ExecutionKeyColumnSet). + Union(sets.NewString( + "launch_plan_id", + "workflow_id", + "task_id", + "phase", + "started_at", + "execution_created_at", + "execution_updated_at", + "duration", + "abort_cause", + "mode", + "source_execution_id", + "parent_node_execution_id", + "cluster", + "inputs_uri", + "user_inputs_uri", + "error_kind", + "error_code", + "user", + "state", + "launch_entity", + )) + // Implementation of ExecutionInterface. type ExecutionRepo struct { db *gorm.DB @@ -102,8 +128,8 @@ func (r *ExecutionRepo) List(_ context.Context, input interfaces.ListResourceInp return interfaces.ExecutionCollectionOutput{}, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } timer := r.metrics.ListDuration.Start() diff --git a/pkg/repositories/gormimpl/execution_repo_test.go b/pkg/repositories/gormimpl/execution_repo_test.go index 17cb85777..28038d20c 100644 --- a/pkg/repositories/gormimpl/execution_repo_test.go +++ b/pkg/repositories/gormimpl/execution_repo_test.go @@ -13,11 +13,12 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" mocket "github.com/Selvatico/go-mocket" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/stretchr/testify/assert" ) var createdAt = time.Date(2018, time.February, 17, 00, 00, 00, 00, time.UTC).UTC() @@ -258,7 +259,7 @@ func TestListExecutions_Order(t *testing.T) { Key: "name", }) _, err := executionRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameter: sortParameter, + SortParameters: sortParameter, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Task, "project", project), getEqualityFilter(common.Task, "domain", domain), @@ -287,7 +288,7 @@ func TestListExecutions_WithTags(t *testing.T) { tagFilter, err := common.NewRepeatedValueFilter(common.ExecutionAdminTag, common.ValueIn, "admin_tag_name", vals) assert.NoError(t, err) _, err = executionRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameter: sortParameter, + SortParameters: sortParameter, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Task, "project", project), getEqualityFilter(common.Task, "domain", domain), diff --git a/pkg/repositories/gormimpl/launch_plan_repo.go b/pkg/repositories/gormimpl/launch_plan_repo.go index dc379ed03..854af2ae2 100644 --- a/pkg/repositories/gormimpl/launch_plan_repo.go +++ b/pkg/repositories/gormimpl/launch_plan_repo.go @@ -7,16 +7,28 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/flyteorg/flytestdlib/logger" + "gorm.io/gorm" adminErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/flyteorg/flytestdlib/logger" - "gorm.io/gorm" ) const launchPlanTableName = "launch_plans" +var LaunchPlanColumns = BaseColumnSet.Union(sets.NewString( + Project, + Domain, + Name, + Version, + "workflow_id", + "state", + "schedule_type", +)) + type launchPlanMetrics struct { SetActiveDuration promutils.StopWatch } @@ -125,8 +137,8 @@ func (r *LaunchPlanRepo) List(ctx context.Context, input interfaces.ListResource return interfaces.LaunchPlanCollectionOutput{}, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } timer := r.metrics.ListDuration.Start() @@ -159,8 +171,8 @@ func (r *LaunchPlanRepo) ListLaunchPlanIdentifiers(ctx context.Context, input in return interfaces.LaunchPlanCollectionOutput{}, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } // Scan the results into a list of launch plans diff --git a/pkg/repositories/gormimpl/launch_plan_repo_test.go b/pkg/repositories/gormimpl/launch_plan_repo_test.go index f96bd7964..64ec30754 100644 --- a/pkg/repositories/gormimpl/launch_plan_repo_test.go +++ b/pkg/repositories/gormimpl/launch_plan_repo_test.go @@ -8,12 +8,13 @@ import ( mockScope "github.com/flyteorg/flytestdlib/promutils" mocket "github.com/Selvatico/go-mocket" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/stretchr/testify/assert" ) const workflowID = uint(1) @@ -350,7 +351,7 @@ func TestListLaunchPlans_Order(t *testing.T) { Key: "project", }) _, err := launchPlanRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameter: sortParameter, + SortParameters: sortParameter, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.LaunchPlan, "project", project), getEqualityFilter(common.LaunchPlan, "domain", domain), diff --git a/pkg/repositories/gormimpl/named_entity_repo.go b/pkg/repositories/gormimpl/named_entity_repo.go index 8e02390dd..a9eaa5b23 100644 --- a/pkg/repositories/gormimpl/named_entity_repo.go +++ b/pkg/repositories/gormimpl/named_entity_repo.go @@ -4,20 +4,32 @@ import ( "context" "fmt" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "google.golang.org/grpc/codes" + "github.com/flyteorg/flytestdlib/promutils" + "gorm.io/gorm" + "github.com/flyteorg/flyteadmin/pkg/common" adminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/flyteorg/flytestdlib/promutils" - "gorm.io/gorm" ) const innerJoinTableAlias = "entities" +var NamedEntityColumns = sets.NewString( + ResourceType, + Project, + Domain, + Name, + "description", + "state", +) + var resourceTypeToTableName = map[core.ResourceType]string{ core.ResourceType_LAUNCH_PLAN: launchPlanTableName, core.ResourceType_WORKFLOW: workflowTableName, @@ -37,8 +49,8 @@ func getSubQueryJoin(db *gorm.DB, tableName string, input interfaces.ListNamedEn Group(identifierGroupBy) // Apply consistent sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } return db.Joins(fmt.Sprintf(joinString, input.ResourceType), tx) @@ -190,8 +202,8 @@ func (r *NamedEntityRepo) List(ctx context.Context, input interfaces.ListNamedEn return interfaces.NamedEntityCollectionOutput{}, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } // Scan the results into a list of named entities diff --git a/pkg/repositories/gormimpl/named_entity_repo_test.go b/pkg/repositories/gormimpl/named_entity_repo_test.go index d586a2c8f..67e2f3be8 100644 --- a/pkg/repositories/gormimpl/named_entity_repo_test.go +++ b/pkg/repositories/gormimpl/named_entity_repo_test.go @@ -9,11 +9,12 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" mocket "github.com/Selvatico/go-mocket" + mockScope "github.com/flyteorg/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - mockScope "github.com/flyteorg/flytestdlib/promutils" - "github.com/stretchr/testify/assert" ) func getMockNamedEntityResponseFromDb(expected models.NamedEntity) map[string]interface{} { @@ -166,8 +167,8 @@ func TestListNamedEntity(t *testing.T) { Project: "admintests", Domain: "development", ListResourceInput: interfaces.ListResourceInput{ - Limit: 20, - SortParameter: sortParameter, + Limit: 20, + SortParameters: sortParameter, }, }) assert.NoError(t, err) diff --git a/pkg/repositories/gormimpl/node_execution_repo.go b/pkg/repositories/gormimpl/node_execution_repo.go index 65cd8a774..599d1367c 100644 --- a/pkg/repositories/gormimpl/node_execution_repo.go +++ b/pkg/repositories/gormimpl/node_execution_repo.go @@ -5,16 +5,37 @@ import ( "errors" "fmt" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" + "gorm.io/gorm" + adminErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "gorm.io/gorm" ) +var NodeExecutionColumns = BaseColumnSet. + Union(NodeExecutionKeyColumnSet). + Union(sets.NewString( + "phase", + "input_uri", + "started_at", + "node_execution_created_at", + "node_execution_updated_at", + "duration", + "parent_id", + "parent_task_execution_id", + "parent_node_execution_id", + "error_kind", + "error_code", + "cache_status", + "dynamic_workflow_remote_closure_reference", + )) + // Implementation of NodeExecutionInterface. type NodeExecutionRepo struct { db *gorm.DB @@ -127,8 +148,8 @@ func (r *NodeExecutionRepo) List(ctx context.Context, input interfaces.ListResou return interfaces.NodeExecutionCollectionOutput{}, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } timer := r.metrics.ListDuration.Start() diff --git a/pkg/repositories/gormimpl/node_execution_repo_test.go b/pkg/repositories/gormimpl/node_execution_repo_test.go index d3f778f10..610c4e2a0 100644 --- a/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -5,21 +5,23 @@ import ( "testing" "time" - flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "google.golang.org/grpc/codes" "gorm.io/gorm" + flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" + mockScope "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" mocket "github.com/Selvatico/go-mocket" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/stretchr/testify/assert" ) var nodePhase = core.NodeExecution_RUNNING.String() @@ -256,7 +258,7 @@ func TestListNodeExecutions_Order(t *testing.T) { Key: "project", }) _, err := nodeExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameter: sortParameter, + SortParameters: sortParameter, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.NodeExecution, "phase", nodePhase), }, diff --git a/pkg/repositories/gormimpl/project_repo.go b/pkg/repositories/gormimpl/project_repo.go index 7541fce3c..441649d3b 100644 --- a/pkg/repositories/gormimpl/project_repo.go +++ b/pkg/repositories/gormimpl/project_repo.go @@ -4,9 +4,12 @@ import ( "context" "errors" - flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" + "k8s.io/apimachinery/pkg/util/sets" + "google.golang.org/grpc/codes" + flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytestdlib/promutils" @@ -17,6 +20,13 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) +var ProjectColumns = BaseColumnSet.Union(sets.NewString( + "identifier", + "name", + "description", + "state", +)) + type ProjectRepo struct { db *gorm.DB errorTransformer flyteAdminDbErrors.ErrorTransformer @@ -72,8 +82,8 @@ func (r *ProjectRepo) List(ctx context.Context, input interfaces.ListResourceInp } // Apply sort ordering - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } timer := r.metrics.ListDuration.Start() diff --git a/pkg/repositories/gormimpl/project_repo_test.go b/pkg/repositories/gormimpl/project_repo_test.go index 145072133..f138aa6aa 100644 --- a/pkg/repositories/gormimpl/project_repo_test.go +++ b/pkg/repositories/gormimpl/project_repo_test.go @@ -5,13 +5,14 @@ import ( "testing" mocket "github.com/Selvatico/go-mocket" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + mockScope "github.com/flyteorg/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - mockScope "github.com/flyteorg/flytestdlib/promutils" - "github.com/stretchr/testify/assert" ) var alphabeticalSortParam, _ = common.NewSortParameter(admin.Sort{ @@ -96,25 +97,25 @@ func TestListProjects(t *testing.T) { filter, err := common.NewSingleValueFilter(common.Project, common.Equal, "name", "foo") assert.Nil(t, err) testListProjects(interfaces.ListResourceInput{ - Offset: 0, - Limit: 1, - InlineFilters: []common.InlineFilter{filter}, - SortParameter: alphabeticalSortParam, + Offset: 0, + Limit: 1, + InlineFilters: []common.InlineFilter{filter}, + SortParameters: alphabeticalSortParam, }, `SELECT * FROM "projects" WHERE name = $1 ORDER BY identifier asc LIMIT 1`, t) } func TestListProjects_NoFilters(t *testing.T) { testListProjects(interfaces.ListResourceInput{ - Offset: 0, - Limit: 1, - SortParameter: alphabeticalSortParam, + Offset: 0, + Limit: 1, + SortParameters: alphabeticalSortParam, }, `SELECT * FROM "projects" WHERE state != $1 ORDER BY identifier asc`, t) } func TestListProjects_NoLimit(t *testing.T) { testListProjects(interfaces.ListResourceInput{ - Offset: 0, - SortParameter: alphabeticalSortParam, + Offset: 0, + SortParameters: alphabeticalSortParam, }, `SELECT * FROM "projects" WHERE state != $1 ORDER BY identifier asc`, t) } diff --git a/pkg/repositories/gormimpl/signal_repo.go b/pkg/repositories/gormimpl/signal_repo.go index b87f70316..75c2a92d5 100644 --- a/pkg/repositories/gormimpl/signal_repo.go +++ b/pkg/repositories/gormimpl/signal_repo.go @@ -4,6 +4,8 @@ import ( "context" "errors" + "k8s.io/apimachinery/pkg/util/sets" + adminerrors "github.com/flyteorg/flyteadmin/pkg/errors" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" @@ -16,6 +18,12 @@ import ( "gorm.io/gorm" ) +var SignalColumns = BaseColumnSet. + Union(ExecutionKeyColumnSet). + Union(sets.NewString( + "signal_id", + )) + // SignalRepo is an implementation of SignalRepoInterface. type SignalRepo struct { db *gorm.DB @@ -66,9 +74,10 @@ func (s *SignalRepo) List(ctx context.Context, input interfaces.ListResourceInpu return nil, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } + timer := s.metrics.ListDuration.Start() tx.Find(&signals) timer.Stop() diff --git a/pkg/repositories/gormimpl/task_execution_repo.go b/pkg/repositories/gormimpl/task_execution_repo.go index b864d802e..b55e53c5f 100644 --- a/pkg/repositories/gormimpl/task_execution_repo.go +++ b/pkg/repositories/gormimpl/task_execution_repo.go @@ -4,16 +4,33 @@ import ( "context" "errors" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" + "gorm.io/gorm" + flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "gorm.io/gorm" ) +var TaskExecutionColumns = BaseColumnSet. + Union(TaskKeyColumnSet). + Union(ExecutionKeyColumnSet). + Union(sets.NewString( + "retry_attempt", + "phase", + "phase_version", + "input_uri", + "started_at", + "task_execution_started_at", + "task_execution_updated_at", + "duration", + )) + // Implementation of TaskExecutionInterface. type TaskExecutionRepo struct { db *gorm.DB @@ -113,8 +130,8 @@ func (r *TaskExecutionRepo) List(ctx context.Context, input interfaces.ListResou } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } timer := r.metrics.ListDuration.Start() diff --git a/pkg/repositories/gormimpl/task_repo.go b/pkg/repositories/gormimpl/task_repo.go index fae18c0db..41feaa7eb 100644 --- a/pkg/repositories/gormimpl/task_repo.go +++ b/pkg/repositories/gormimpl/task_repo.go @@ -4,16 +4,26 @@ import ( "context" "errors" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" + "gorm.io/gorm" + flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "gorm.io/gorm" ) +var TaskColumns = BaseColumnSet. + Union(TaskKeyColumnSet). + Union(sets.NewString( + "type", + "short_description", + )) + // Implementation of TaskRepoInterface. type TaskRepo struct { db *gorm.DB @@ -88,9 +98,10 @@ func (r *TaskRepo) List( return interfaces.TaskCollectionOutput{}, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } + timer := r.metrics.ListDuration.Start() tx.Find(&tasks) timer.Stop() @@ -122,8 +133,8 @@ func (r *TaskRepo) ListTaskIdentifiers(ctx context.Context, input interfaces.Lis tx = tx.Where(mapFilter.GetFilter()) } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } // Scan the results into a list of tasks diff --git a/pkg/repositories/gormimpl/task_repo_test.go b/pkg/repositories/gormimpl/task_repo_test.go index 678a5c382..62a044ad1 100644 --- a/pkg/repositories/gormimpl/task_repo_test.go +++ b/pkg/repositories/gormimpl/task_repo_test.go @@ -9,11 +9,12 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" mocket "github.com/Selvatico/go-mocket" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/stretchr/testify/assert" ) const pythonTestTaskType = "python-task" @@ -202,7 +203,7 @@ func TestListTasks_Order(t *testing.T) { Key: "project", }) _, err := taskRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameter: sortParameter, + SortParameters: sortParameter, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Task, "project", project), getEqualityFilter(common.Task, "domain", domain), diff --git a/pkg/repositories/gormimpl/workflow_repo.go b/pkg/repositories/gormimpl/workflow_repo.go index 69b711dab..afd5accc9 100644 --- a/pkg/repositories/gormimpl/workflow_repo.go +++ b/pkg/repositories/gormimpl/workflow_repo.go @@ -4,14 +4,27 @@ import ( "context" "errors" - flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" - "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" - "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" "gorm.io/gorm" + + flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) +var WorkflowColumns = BaseColumnSet. + Union(sets.NewString( + Project, + Domain, + Name, + Version, + "remote_closure_identifier", + "short_description", + )) + // Implementation of WorkflowRepoInterface. type WorkflowRepo struct { db *gorm.DB @@ -80,9 +93,10 @@ func (r *WorkflowRepo) List( return interfaces.WorkflowCollectionOutput{}, err } // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } + timer := r.metrics.ListDuration.Start() tx.Find(&workflows) timer.Stop() @@ -110,9 +124,9 @@ func (r *WorkflowRepo) ListIdentifiers(ctx context.Context, input interfaces.Lis return interfaces.WorkflowCollectionOutput{}, err } - // Apply sort ordering. - if input.SortParameter != nil { - tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + // Apply sort ordering + for _, sortParam := range input.SortParameters { + tx = tx.Order(sortParam.GetGormOrderExpr()) } // Scan the results into a list of workflows diff --git a/pkg/repositories/gormimpl/workflow_repo_test.go b/pkg/repositories/gormimpl/workflow_repo_test.go index ee300d609..3cb043b95 100644 --- a/pkg/repositories/gormimpl/workflow_repo_test.go +++ b/pkg/repositories/gormimpl/workflow_repo_test.go @@ -5,13 +5,14 @@ import ( "testing" mocket "github.com/Selvatico/go-mocket" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + mockScope "github.com/flyteorg/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - mockScope "github.com/flyteorg/flytestdlib/promutils" - "github.com/stretchr/testify/assert" ) var typedInterface = []byte{1, 2, 3} @@ -187,7 +188,7 @@ func TestListWorkflows_Order(t *testing.T) { Key: "project", }) _, err := workflowRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameter: sortParameter, + SortParameters: sortParameter, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Workflow, "project", project), getEqualityFilter(common.Workflow, "domain", domain), diff --git a/pkg/repositories/interfaces/common.go b/pkg/repositories/interfaces/common.go index 60065a20c..40d358f16 100644 --- a/pkg/repositories/interfaces/common.go +++ b/pkg/repositories/interfaces/common.go @@ -20,8 +20,8 @@ type ListResourceInput struct { // MapFilters refers to primary entity filters defined as map values rather than inline sql queries. // These exist to permit filtering on "IS NULL" which isn't permitted with inline filter queries and // pq driver value substitution. - MapFilters []common.MapFilter - SortParameter common.SortParameter + MapFilters []common.MapFilter + SortParameters []common.SortParameter // A set of the entities (besides the primary table being queried) that should be joined with when performing // the list query. This enables filtering on non-primary entity attributes. JoinTableEntities map[common.Entity]bool From 193a8ad249b022588a91f4f325df9261a1d32a4d Mon Sep 17 00:00:00 2001 From: Iaroslav Ciupin Date: Wed, 23 Aug 2023 10:16:55 +0300 Subject: [PATCH 2/4] fix columns Signed-off-by: Iaroslav Ciupin --- pkg/repositories/gormimpl/node_execution_repo.go | 6 +----- pkg/repositories/gormimpl/task_execution_repo.go | 7 ++----- pkg/repositories/gormimpl/workflow_repo.go | 3 +-- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/pkg/repositories/gormimpl/node_execution_repo.go b/pkg/repositories/gormimpl/node_execution_repo.go index 599d1367c..563918336 100644 --- a/pkg/repositories/gormimpl/node_execution_repo.go +++ b/pkg/repositories/gormimpl/node_execution_repo.go @@ -5,13 +5,10 @@ import ( "errors" "fmt" - "k8s.io/apimachinery/pkg/util/sets" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/promutils" - "gorm.io/gorm" + "k8s.io/apimachinery/pkg/util/sets" adminErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" @@ -29,7 +26,6 @@ var NodeExecutionColumns = BaseColumnSet. "duration", "parent_id", "parent_task_execution_id", - "parent_node_execution_id", "error_kind", "error_code", "cache_status", diff --git a/pkg/repositories/gormimpl/task_execution_repo.go b/pkg/repositories/gormimpl/task_execution_repo.go index b55e53c5f..8115de4c4 100644 --- a/pkg/repositories/gormimpl/task_execution_repo.go +++ b/pkg/repositories/gormimpl/task_execution_repo.go @@ -4,13 +4,10 @@ import ( "context" "errors" - "k8s.io/apimachinery/pkg/util/sets" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/promutils" - "gorm.io/gorm" + "k8s.io/apimachinery/pkg/util/sets" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" @@ -19,7 +16,7 @@ import ( var TaskExecutionColumns = BaseColumnSet. Union(TaskKeyColumnSet). - Union(ExecutionKeyColumnSet). + Union(NodeExecutionKeyColumnSet). Union(sets.NewString( "retry_attempt", "phase", diff --git a/pkg/repositories/gormimpl/workflow_repo.go b/pkg/repositories/gormimpl/workflow_repo.go index afd5accc9..6f67d60f1 100644 --- a/pkg/repositories/gormimpl/workflow_repo.go +++ b/pkg/repositories/gormimpl/workflow_repo.go @@ -4,11 +4,10 @@ import ( "context" "errors" - "k8s.io/apimachinery/pkg/util/sets" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" "gorm.io/gorm" + "k8s.io/apimachinery/pkg/util/sets" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" From cbbe1ca50946702ef282b7a872916d6124beb014 Mon Sep 17 00:00:00 2001 From: Iaroslav Ciupin Date: Wed, 23 Aug 2023 19:45:42 +0300 Subject: [PATCH 3/4] adjust tests Signed-off-by: Iaroslav Ciupin --- .../impl/db_admin_data_provider.go | 8 +++ .../impl/db_admin_data_provider_test.go | 12 +++- pkg/clusterresource/impl/shared.go | 9 --- pkg/common/sorting.go | 15 +++-- pkg/common/sorting_test.go | 40 +++++++++--- .../impl/description_entity_manager_test.go | 9 +-- pkg/manager/impl/execution_manager_test.go | 63 ++++++++----------- pkg/manager/impl/launch_plan_manager_test.go | 6 +- .../impl/node_execution_manager_test.go | 16 ++--- pkg/manager/impl/project_manager_test.go | 2 +- pkg/manager/impl/task_manager_test.go | 4 +- pkg/manager/impl/util_test.go | 15 +++++ pkg/manager/impl/workflow_manager_test.go | 4 +- pkg/repositories/gormimpl/common.go | 40 ++++++------ pkg/repositories/gormimpl/common_test.go | 39 ++++++++++++ .../gormimpl/description_entity_repo.go | 12 +--- .../gormimpl/description_entity_repo_test.go | 6 +- pkg/repositories/gormimpl/execution_repo.go | 26 +------- .../gormimpl/execution_repo_test.go | 18 ++---- pkg/repositories/gormimpl/launch_plan_repo.go | 14 +---- .../gormimpl/launch_plan_repo_test.go | 7 +-- .../gormimpl/named_entity_repo.go | 14 +---- .../gormimpl/named_entity_repo_test.go | 9 +-- .../gormimpl/node_execution_repo.go | 18 +----- .../gormimpl/node_execution_repo_test.go | 7 +-- pkg/repositories/gormimpl/project_repo.go | 16 +---- .../gormimpl/project_repo_test.go | 4 +- pkg/repositories/gormimpl/signal_repo.go | 16 ++--- .../gormimpl/task_execution_repo.go | 15 +---- pkg/repositories/gormimpl/task_repo.go | 9 +-- pkg/repositories/gormimpl/task_repo_test.go | 7 +-- pkg/repositories/gormimpl/utils_for_test.go | 16 ++++- pkg/repositories/gormimpl/workflow_repo.go | 11 +--- .../gormimpl/workflow_repo_test.go | 7 +-- tests/shared.go | 3 +- 35 files changed, 235 insertions(+), 282 deletions(-) create mode 100644 pkg/manager/impl/util_test.go create mode 100644 pkg/repositories/gormimpl/common_test.go diff --git a/pkg/clusterresource/impl/db_admin_data_provider.go b/pkg/clusterresource/impl/db_admin_data_provider.go index 12b837331..470b99ff4 100644 --- a/pkg/clusterresource/impl/db_admin_data_provider.go +++ b/pkg/clusterresource/impl/db_admin_data_provider.go @@ -8,6 +8,7 @@ import ( "github.com/flyteorg/flyteadmin/pkg/clusterresource/interfaces" "github.com/flyteorg/flyteadmin/pkg/common" managerInterfaces "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/gormimpl" repositoryInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" @@ -47,6 +48,13 @@ func (p dbAdminProvider) getDomains() []*admin.Domain { return domains } +var descCreatedAtSortParam = admin.Sort{ + Direction: admin.Sort_DESCENDING, + Key: "created_at", +} + +var descCreatedAtSortDBParam, _ = common.NewSortParameter(&descCreatedAtSortParam, gormimpl.ProjectColumns) + func (p dbAdminProvider) GetProjects(ctx context.Context) (*admin.Projects, error) { filter, err := common.NewSingleValueFilter(common.Project, common.NotEqual, "state", int32(admin.Project_ARCHIVED)) if err != nil { diff --git a/pkg/clusterresource/impl/db_admin_data_provider_test.go b/pkg/clusterresource/impl/db_admin_data_provider_test.go index 076c53dd2..c2d3238ce 100644 --- a/pkg/clusterresource/impl/db_admin_data_provider_test.go +++ b/pkg/clusterresource/impl/db_admin_data_provider_test.go @@ -3,6 +3,7 @@ package impl import ( "context" "errors" + "strings" "testing" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" @@ -10,6 +11,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" "github.com/flyteorg/flyteadmin/pkg/manager/mocks" repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" @@ -105,7 +107,7 @@ func TestGetProjects(t *testing.T) { mockRepo.(*repoMocks.MockRepository).ProjectRepoIface = &repoMocks.MockProjectRepo{ ListProjectsFunction: func(ctx context.Context, input repoInterfaces.ListResourceInput) ([]models.Project, error) { assert.Len(t, input.InlineFilters, 1) - assert.Equal(t, input.SortParameters.GetGormOrderExpr(), "created_at desc") + assert.Equal(t, "created_at desc", sortParamsSQL(input.SortParameters)) return []models.Project{ { Identifier: "flytesnacks", @@ -142,3 +144,11 @@ func TestGetProjects(t *testing.T) { assert.EqualError(t, err, errFoo.Error()) }) } + +func sortParamsSQL(params []common.SortParameter) string { + sqls := make([]string, len(params)) + for i, param := range params { + sqls[i] = param.GetGormOrderExpr() + } + return strings.Join(sqls, ", ") +} diff --git a/pkg/clusterresource/impl/shared.go b/pkg/clusterresource/impl/shared.go index bd350eb55..1b9a5d6a7 100644 --- a/pkg/clusterresource/impl/shared.go +++ b/pkg/clusterresource/impl/shared.go @@ -1,20 +1,11 @@ package impl import ( - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "google.golang.org/grpc/codes" - "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" ) func NewMissingEntityError(entity string) error { return errors.NewFlyteAdminErrorf(codes.NotFound, "Failed to find [%s]", entity) } - -var descCreatedAtSortParam = admin.Sort{ - Direction: admin.Sort_DESCENDING, - Key: "created_at", -} - -var descCreatedAtSortDBParam, _ = common.NewSortParameter(descCreatedAtSortParam) diff --git a/pkg/common/sorting.go b/pkg/common/sorting.go index f0751e51f..3d09aa347 100644 --- a/pkg/common/sorting.go +++ b/pkg/common/sorting.go @@ -27,16 +27,21 @@ func (s *sortParamImpl) GetGormOrderExpr() string { } func NewSortParameter(sort *admin.Sort, allowed sets.String) ([]SortParameter, error) { - if !allowed.Has(sort.Key) { - return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid sort_key: %s", sort.Key) + if sort == nil { + return nil, nil + } + + key := sort.GetKey() + if !allowed.Has(key) { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid sort_key: %s", key) } var gormOrderExpression string - switch sort.Direction { + switch sort.GetDirection() { case admin.Sort_DESCENDING: - gormOrderExpression = fmt.Sprintf(gormDescending, sort.Key) + gormOrderExpression = fmt.Sprintf(gormDescending, key) case admin.Sort_ASCENDING: - gormOrderExpression = fmt.Sprintf(gormAscending, sort.Key) + gormOrderExpression = fmt.Sprintf(gormAscending, key) default: return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid sort order specified: %v", sort) } diff --git a/pkg/common/sorting_test.go b/pkg/common/sorting_test.go index 20cb69f6d..97a7de6a3 100644 --- a/pkg/common/sorting_test.go +++ b/pkg/common/sorting_test.go @@ -5,22 +5,46 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/flyteorg/flyteadmin/pkg/errors" ) +func TestSortParameter_Empty(t *testing.T) { + sortParameter, err := NewSortParameter(nil, sets.NewString()) + + assert.NoError(t, err) + assert.Nil(t, sortParameter) +} + +func TestSortParameter_InvalidSortKey(t *testing.T) { + expected := errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid sort_key: wrong") + + _, err := NewSortParameter(&admin.Sort{ + Direction: admin.Sort_ASCENDING, + Key: "wrong", + }, sets.NewString("name")) + + assert.Equal(t, expected, err) +} + func TestSortParameter_Ascending(t *testing.T) { - sortParameter, err := NewSortParameter(admin.Sort{ + sortParameter, err := NewSortParameter(&admin.Sort{ Direction: admin.Sort_ASCENDING, Key: "name", - }) - assert.Nil(t, err) - assert.Equal(t, "name asc", sortParameter.GetGormOrderExpr()) + }, sets.NewString("name")) + + assert.NoError(t, err) + assert.Equal(t, "name asc", sortParameter[0].GetGormOrderExpr()) } func TestSortParameter_Descending(t *testing.T) { - sortParameter, err := NewSortParameter(admin.Sort{ + sortParameter, err := NewSortParameter(&admin.Sort{ Direction: admin.Sort_DESCENDING, Key: "project", - }) - assert.Nil(t, err) - assert.Equal(t, "project desc", sortParameter.GetGormOrderExpr()) + }, sets.NewString("project")) + + assert.NoError(t, err) + assert.Equal(t, "project desc", sortParameter[0].GetGormOrderExpr()) } diff --git a/pkg/manager/impl/description_entity_manager_test.go b/pkg/manager/impl/description_entity_manager_test.go index 66cdcc4a9..1b31dad90 100644 --- a/pkg/manager/impl/description_entity_manager_test.go +++ b/pkg/manager/impl/description_entity_manager_test.go @@ -4,15 +4,16 @@ import ( "context" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + mockScope "github.com/flyteorg/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - mockScope "github.com/flyteorg/flytestdlib/promutils" - "github.com/stretchr/testify/assert" ) var descriptionEntityIdentifier = core.Identifier{ diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index d31a4e04d..e189a1ec8 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -3,67 +3,54 @@ package impl import ( "context" "errors" + "fmt" "strings" "testing" - - "github.com/flyteorg/flyteadmin/plugins" - - "google.golang.org/grpc/status" - - "google.golang.org/protobuf/types/known/timestamppb" + "time" "github.com/benbjohnson/clock" "github.com/flyteorg/flyteidl/clients/go/coreutils" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + mockScope "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/storage" "github.com/gogo/protobuf/jsonpb" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" + "github.com/golang/protobuf/ptypes/wrappers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteadmin/auth" + eventWriterMocks "github.com/flyteorg/flyteadmin/pkg/async/events/mocks" + notificationMocks "github.com/flyteorg/flyteadmin/pkg/async/notifications/mocks" "github.com/flyteorg/flyteadmin/pkg/common" + commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" commonTestUtils "github.com/flyteorg/flyteadmin/pkg/common/testutils" + dataMocks "github.com/flyteorg/flyteadmin/pkg/data/mocks" flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/executions" "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" managerInterfaces "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" managerMocks "github.com/flyteorg/flyteadmin/pkg/manager/mocks" - "github.com/flyteorg/flyteadmin/pkg/runtime" - - "k8s.io/apimachinery/pkg/api/resource" - - "k8s.io/apimachinery/pkg/util/sets" - - eventWriterMocks "github.com/flyteorg/flyteadmin/pkg/async/events/mocks" - - "github.com/flyteorg/flyteadmin/auth" - - commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" - - "github.com/flyteorg/flytestdlib/storage" - - "time" - - "fmt" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - mockScope "github.com/flyteorg/flytestdlib/promutils" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes/wrappers" - "github.com/stretchr/testify/assert" - - notificationMocks "github.com/flyteorg/flyteadmin/pkg/async/notifications/mocks" - dataMocks "github.com/flyteorg/flyteadmin/pkg/data/mocks" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" "github.com/flyteorg/flyteadmin/pkg/repositories/models" "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" + "github.com/flyteorg/flyteadmin/pkg/runtime" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" runtimeIFaceMocks "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces/mocks" runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" workflowengineInterfaces "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" workflowengineMocks "github.com/flyteorg/flyteadmin/pkg/workflowengine/mocks" + "github.com/flyteorg/flyteadmin/plugins" ) var spec = testutils.GetExecutionRequest().Spec @@ -2982,7 +2969,7 @@ func TestListExecutions(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.False(t, nameFilter, "Included name equality filter") assert.Equal(t, limit, input.Limit) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "execution_domain asc", sortParamsSQL(input.SortParameters)) assert.Equal(t, 2, input.Offset) assert.EqualValues(t, map[common.Entity]bool{ common.Execution: true, @@ -3030,7 +3017,7 @@ func TestListExecutions(t *testing.T) { Limit: limit, SortBy: &admin.Sort{ Direction: admin.Sort_ASCENDING, - Key: "domain", + Key: "execution_domain", }, Token: "2", }) @@ -3968,7 +3955,7 @@ func TestListExecutions_LegacyModel(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.False(t, nameFilter, "Included name equality filter") assert.Equal(t, limit, input.Limit) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "execution_domain asc", sortParamsSQL(input.SortParameters)) assert.Equal(t, 2, input.Offset) return interfaces.ExecutionCollectionOutput{ Executions: []models.Execution{ @@ -4013,7 +4000,7 @@ func TestListExecutions_LegacyModel(t *testing.T) { Limit: limit, SortBy: &admin.Sort{ Direction: admin.Sort_ASCENDING, - Key: "domain", + Key: "execution_domain", }, Token: "2", }) diff --git a/pkg/manager/impl/launch_plan_manager_test.go b/pkg/manager/impl/launch_plan_manager_test.go index f5f91b184..83aa6512a 100644 --- a/pkg/manager/impl/launch_plan_manager_test.go +++ b/pkg/manager/impl/launch_plan_manager_test.go @@ -1098,7 +1098,7 @@ func TestLaunchPlanManager_ListLaunchPlans(t *testing.T) { assert.True(t, nameFilter, "Missing name equality filter") assert.Equal(t, 10, input.Limit) assert.Equal(t, 2, input.Offset) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "domain asc", sortParamsSQL(input.SortParameters)) return interfaces.LaunchPlanCollectionOutput{ LaunchPlans: []models.LaunchPlan{ @@ -1195,7 +1195,7 @@ func TestLaunchPlanManager_ListLaunchPlanIds(t *testing.T) { assert.True(t, projectFilter, "Missing project equality filter") assert.True(t, domainFilter, "Missing domain equality filter") assert.Equal(t, 10, input.Limit) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "domain asc", sortParamsSQL(input.SortParameters)) return interfaces.LaunchPlanCollectionOutput{ LaunchPlans: []models.LaunchPlan{ @@ -1282,7 +1282,7 @@ func TestLaunchPlanManager_ListActiveLaunchPlans(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.True(t, activeFilter, "Missing active filter") assert.Equal(t, 10, input.Limit) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "domain asc", sortParamsSQL(input.SortParameters)) return interfaces.LaunchPlanCollectionOutput{ LaunchPlans: []models.LaunchPlan{ diff --git a/pkg/manager/impl/node_execution_manager_test.go b/pkg/manager/impl/node_execution_manager_test.go index 029a2af5c..04211fd8e 100644 --- a/pkg/manager/impl/node_execution_manager_test.go +++ b/pkg/manager/impl/node_execution_manager_test.go @@ -809,7 +809,7 @@ func TestListNodeExecutionsLevelZero(t *testing.T) { "parent_task_execution_id": nil, }, filter) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "execution_domain asc", sortParamsSQL(input.SortParameters)) return interfaces.NodeExecutionCollectionOutput{ NodeExecutions: []models.NodeExecution{ { @@ -860,10 +860,10 @@ func TestListNodeExecutionsLevelZero(t *testing.T) { Token: "2", SortBy: &admin.Sort{ Direction: admin.Sort_ASCENDING, - Key: "domain", + Key: "execution_domain", }, }) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, nodeExecutions.NodeExecutions, 1) assert.True(t, proto.Equal(&admin.NodeExecution{ Id: &core.NodeExecutionIdentifier{ @@ -927,7 +927,7 @@ func TestListNodeExecutionsWithParent(t *testing.T) { assert.Equal(t, parentID, queryExpr.Args) assert.Equal(t, "parent_id = ?", queryExpr.Query) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "execution_domain asc", sortParamsSQL(input.SortParameters)) return interfaces.NodeExecutionCollectionOutput{ NodeExecutions: []models.NodeExecution{ { @@ -960,7 +960,7 @@ func TestListNodeExecutionsWithParent(t *testing.T) { Token: "2", SortBy: &admin.Sort{ Direction: admin.Sort_ASCENDING, - Key: "domain", + Key: "execution_domain", }, UniqueParentId: "parent_1", }) @@ -1087,7 +1087,7 @@ func TestListNodeExecutions_NothingToReturn(t *testing.T) { Token: "2", SortBy: &admin.Sort{ Direction: admin.Sort_ASCENDING, - Key: "domain", + Key: "execution_domain", }, }) assert.Nil(t, err) @@ -1141,7 +1141,7 @@ func TestListNodeExecutionsForTask(t *testing.T) { assert.Equal(t, uint(8), queryExpr.Args) assert.Equal(t, "parent_task_execution_id = ?", queryExpr.Query) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "execution_domain asc", sortParamsSQL(input.SortParameters)) return interfaces.NodeExecutionCollectionOutput{ NodeExecutions: []models.NodeExecution{ { @@ -1186,7 +1186,7 @@ func TestListNodeExecutionsForTask(t *testing.T) { Token: "2", SortBy: &admin.Sort{ Direction: admin.Sort_ASCENDING, - Key: "domain", + Key: "execution_domain", }, }) assert.Nil(t, err) diff --git a/pkg/manager/impl/project_manager_test.go b/pkg/manager/impl/project_manager_test.go index 5e3760fd5..e25a7c608 100644 --- a/pkg/manager/impl/project_manager_test.go +++ b/pkg/manager/impl/project_manager_test.go @@ -56,7 +56,7 @@ func testListProjects(request admin.ProjectListRequest, token string, orderExpr q, _ := input.InlineFilters[0].GetGormQueryExpr() assert.Equal(t, *queryExpr, q) } - assert.Equal(t, orderExpr, input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, orderExpr, sortParamsSQL(input.SortParameters)) activeState := int32(admin.Project_ACTIVE) return []models.Project{ { diff --git a/pkg/manager/impl/task_manager_test.go b/pkg/manager/impl/task_manager_test.go index 3101ba6b4..1100155b5 100644 --- a/pkg/manager/impl/task_manager_test.go +++ b/pkg/manager/impl/task_manager_test.go @@ -244,7 +244,7 @@ func TestListTasks(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.True(t, nameFilter, "Missing name equality filter") assert.Equal(t, 2, input.Limit) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "domain asc", sortParamsSQL(input.SortParameters)) return interfaces.TaskCollectionOutput{ Tasks: []models.Task{ { @@ -367,7 +367,7 @@ func TestListUniqueTaskIdentifiers(t *testing.T) { } } assert.Equal(t, 10, input.Offset) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "domain asc", sortParamsSQL(input.SortParameters)) return interfaces.TaskCollectionOutput{ Tasks: []models.Task{ diff --git a/pkg/manager/impl/util_test.go b/pkg/manager/impl/util_test.go new file mode 100644 index 000000000..a0bcba2c5 --- /dev/null +++ b/pkg/manager/impl/util_test.go @@ -0,0 +1,15 @@ +package impl + +import ( + "strings" + + "github.com/flyteorg/flyteadmin/pkg/common" +) + +func sortParamsSQL(params []common.SortParameter) string { + sqls := make([]string, len(params)) + for i, param := range params { + sqls[i] = param.GetGormOrderExpr() + } + return strings.Join(sqls, ", ") +} diff --git a/pkg/manager/impl/workflow_manager_test.go b/pkg/manager/impl/workflow_manager_test.go index b539104c0..69a2c6948 100644 --- a/pkg/manager/impl/workflow_manager_test.go +++ b/pkg/manager/impl/workflow_manager_test.go @@ -389,7 +389,7 @@ func TestListWorkflows(t *testing.T) { assert.True(t, domainFilter, "Missing domain equality filter") assert.True(t, nameFilter, "Missing name equality filter") assert.Equal(t, limit, input.Limit) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "domain asc", sortParamsSQL(input.SortParameters)) assert.Equal(t, 10, input.Offset) return interfaces.WorkflowCollectionOutput{ Workflows: []models.Workflow{ @@ -526,7 +526,7 @@ func TestWorkflowManager_ListWorkflowIdentifiers(t *testing.T) { assert.True(t, projectFilter, "Missing project equality filter") assert.True(t, domainFilter, "Missing domain equality filter") assert.Equal(t, limit, input.Limit) - assert.Equal(t, "domain asc", input.SortParameters.GetGormOrderExpr()) + assert.Equal(t, "domain asc", sortParamsSQL(input.SortParameters)) return interfaces.WorkflowCollectionOutput{ Workflows: []models.Workflow{ { diff --git a/pkg/repositories/gormimpl/common.go b/pkg/repositories/gormimpl/common.go index b27cfe458..8e2172864 100644 --- a/pkg/repositories/gormimpl/common.go +++ b/pkg/repositories/gormimpl/common.go @@ -2,9 +2,11 @@ package gormimpl import ( "fmt" + "sync" "google.golang.org/grpc/codes" "gorm.io/gorm" + "gorm.io/gorm/schema" "k8s.io/apimachinery/pkg/util/sets" "github.com/flyteorg/flyteadmin/pkg/common" @@ -14,21 +16,14 @@ import ( ) const ( - Project = "project" - Domain = "domain" - Name = "name" - Version = "version" - Description = "description" - ResourceType = "resource_type" - State = "state" - ID = "id" - CreatedAt = "created_at" - UpdatedAt = "updated_at" - DeletedAt = "deleted_at" - ExecutionProject = "execution_project" - ExecutionDomain = "execution_domain" - ExecutionName = "execution_name" - NodeID = "node_id" + Project = "project" + Domain = "domain" + Name = "name" + Version = "version" + Description = "description" + ResourceType = "resource_type" + State = "state" + ID = "id" ) const executionTableName = "executions" @@ -44,13 +39,6 @@ const executionAdminTagsTableName = "execution_admin_tags" const limit = "limit" const filters = "filters" -var ( - BaseColumnSet = sets.NewString(ID, CreatedAt, UpdatedAt, DeletedAt, ResourceType) - TaskKeyColumnSet = sets.NewString(Project, Domain, Name, Version) - ExecutionKeyColumnSet = sets.NewString(ExecutionProject, ExecutionDomain, ExecutionName) - NodeExecutionKeyColumnSet = ExecutionKeyColumnSet.Union(sets.NewString(NodeID)) -) - var identifierGroupBy = fmt.Sprintf("%s, %s, %s", Project, Domain, Name) var entityToTableName = map[common.Entity]string{ @@ -132,3 +120,11 @@ func applyScopedFilters(tx *gorm.DB, inlineFilters []common.InlineFilter, mapFil } return tx, nil } + +func modelColumns(v any) sets.String { + s, err := schema.Parse(v, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + panic(err) + } + return sets.NewString(s.DBNames...) +} diff --git a/pkg/repositories/gormimpl/common_test.go b/pkg/repositories/gormimpl/common_test.go new file mode 100644 index 000000000..f3701cc06 --- /dev/null +++ b/pkg/repositories/gormimpl/common_test.go @@ -0,0 +1,39 @@ +package gormimpl + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/flyteorg/flyteadmin/pkg/repositories/models" +) + +func Test_modelColumns(t *testing.T) { + expected := sets.NewString( + "closure", + "created_at", + "deleted_at", + "domain", + "duration", + "execution_domain", + "execution_name", + "execution_project", + "id", + "input_uri", + "name", + "node_id", + "phase", + "phase_version", + "project", + "retry_attempt", + "started_at", + "task_execution_created_at", + "task_execution_updated_at", + "updated_at", + "version") + + actual := modelColumns(models.TaskExecution{}) + + assert.Equal(t, expected, actual) +} diff --git a/pkg/repositories/gormimpl/description_entity_repo.go b/pkg/repositories/gormimpl/description_entity_repo.go index 594eaba2b..1725ae435 100644 --- a/pkg/repositories/gormimpl/description_entity_repo.go +++ b/pkg/repositories/gormimpl/description_entity_repo.go @@ -6,7 +6,6 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" "gorm.io/gorm" - "k8s.io/apimachinery/pkg/util/sets" "github.com/flyteorg/flyteadmin/pkg/common" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" @@ -14,16 +13,7 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -var DescriptionEntityColumns = BaseColumnSet.Union(sets.NewString( - ResourceType, - Project, - Domain, - Name, - Version, - "short_description", - "long_description", - "link", -)) +var DescriptionEntityColumns = modelColumns(models.DescriptionEntity{}) // DescriptionEntityRepo Implementation of DescriptionEntityRepoInterface. type DescriptionEntityRepo struct { diff --git a/pkg/repositories/gormimpl/description_entity_repo_test.go b/pkg/repositories/gormimpl/description_entity_repo_test.go index 9ae447417..f8bb8d11c 100644 --- a/pkg/repositories/gormimpl/description_entity_repo_test.go +++ b/pkg/repositories/gormimpl/description_entity_repo_test.go @@ -8,9 +8,10 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" mocket "github.com/Selvatico/go-mocket" - "github.com/flyteorg/flyteadmin/pkg/repositories/errors" mockScope "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyteadmin/pkg/repositories/errors" ) const shortDescription = "hello" @@ -35,7 +36,8 @@ func TestGetDescriptionEntity(t *testing.T) { GlobalMock := mocket.Catcher.Reset() GlobalMock.Logging = true // Only match on queries that append expected filters - GlobalMock.NewMock().WithQuery(`SELECT * FROM "description_entities" WHERE project = $1 AND domain = $2 AND name = $3 AND version = $4 LIMIT 1`). + GlobalMock.NewMock(). + WithQuery(`SELECT * FROM "description_entities" WHERE project = $1 AND domain = $2 AND name = $3 AND version = $4 LIMIT 1`). WithReply(descriptionEntities) output, err = descriptionEntityRepo.Get(context.Background(), interfaces.GetDescriptionEntityInput{ ResourceType: resourceType, diff --git a/pkg/repositories/gormimpl/execution_repo.go b/pkg/repositories/gormimpl/execution_repo.go index 7359806b1..557814422 100644 --- a/pkg/repositories/gormimpl/execution_repo.go +++ b/pkg/repositories/gormimpl/execution_repo.go @@ -8,7 +8,6 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" "gorm.io/gorm" - "k8s.io/apimachinery/pkg/util/sets" "github.com/flyteorg/flyteadmin/pkg/common" adminErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" @@ -16,30 +15,7 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -var ExecutionColumns = BaseColumnSet. - Union(ExecutionKeyColumnSet). - Union(sets.NewString( - "launch_plan_id", - "workflow_id", - "task_id", - "phase", - "started_at", - "execution_created_at", - "execution_updated_at", - "duration", - "abort_cause", - "mode", - "source_execution_id", - "parent_node_execution_id", - "cluster", - "inputs_uri", - "user_inputs_uri", - "error_kind", - "error_code", - "user", - "state", - "launch_entity", - )) +var ExecutionColumns = modelColumns(models.Execution{}) // Implementation of ExecutionInterface. type ExecutionRepo struct { diff --git a/pkg/repositories/gormimpl/execution_repo_test.go b/pkg/repositories/gormimpl/execution_repo_test.go index 28038d20c..97934ed4f 100644 --- a/pkg/repositories/gormimpl/execution_repo_test.go +++ b/pkg/repositories/gormimpl/execution_repo_test.go @@ -6,12 +6,12 @@ import ( "testing" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" mockScope "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - mocket "github.com/Selvatico/go-mocket" "github.com/stretchr/testify/assert" @@ -254,12 +254,9 @@ func TestListExecutions_Order(t *testing.T) { mockQuery := GlobalMock.NewMock().WithQuery(`name asc`) mockQuery.WithReply(executions) - sortParameter, _ := common.NewSortParameter(admin.Sort{ - Direction: admin.Sort_ASCENDING, - Key: "name", - }) + sortParameters := makeSortParameters(t, admin.Sort_ASCENDING, "name") _, err := executionRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameters: sortParameter, + SortParameters: sortParameters, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Task, "project", project), getEqualityFilter(common.Task, "domain", domain), @@ -280,15 +277,12 @@ func TestListExecutions_WithTags(t *testing.T) { mockQuery := GlobalMock.NewMock().WithQuery(`name asc`) mockQuery.WithReply(executions) - sortParameter, _ := common.NewSortParameter(admin.Sort{ - Direction: admin.Sort_ASCENDING, - Key: "name", - }) + sortParameters := makeSortParameters(t, admin.Sort_ASCENDING, "name") vals := []string{"tag1", "tag2"} tagFilter, err := common.NewRepeatedValueFilter(common.ExecutionAdminTag, common.ValueIn, "admin_tag_name", vals) assert.NoError(t, err) _, err = executionRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameters: sortParameter, + SortParameters: sortParameters, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Task, "project", project), getEqualityFilter(common.Task, "domain", domain), diff --git a/pkg/repositories/gormimpl/launch_plan_repo.go b/pkg/repositories/gormimpl/launch_plan_repo.go index 854af2ae2..773be097f 100644 --- a/pkg/repositories/gormimpl/launch_plan_repo.go +++ b/pkg/repositories/gormimpl/launch_plan_repo.go @@ -6,10 +6,8 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/promutils" - "k8s.io/apimachinery/pkg/util/sets" - "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" "gorm.io/gorm" adminErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" @@ -19,15 +17,7 @@ import ( const launchPlanTableName = "launch_plans" -var LaunchPlanColumns = BaseColumnSet.Union(sets.NewString( - Project, - Domain, - Name, - Version, - "workflow_id", - "state", - "schedule_type", -)) +var LaunchPlanColumns = modelColumns(models.LaunchPlan{}) type launchPlanMetrics struct { SetActiveDuration promutils.StopWatch diff --git a/pkg/repositories/gormimpl/launch_plan_repo_test.go b/pkg/repositories/gormimpl/launch_plan_repo_test.go index 64ec30754..724c8cd8f 100644 --- a/pkg/repositories/gormimpl/launch_plan_repo_test.go +++ b/pkg/repositories/gormimpl/launch_plan_repo_test.go @@ -346,12 +346,9 @@ func TestListLaunchPlans_Order(t *testing.T) { mockQuery.WithQuery(`project desc`) mockQuery.WithReply(launchPlans) - sortParameter, _ := common.NewSortParameter(admin.Sort{ - Direction: admin.Sort_DESCENDING, - Key: "project", - }) + sortParameters := makeSortParameters(t, admin.Sort_DESCENDING, "project") _, err := launchPlanRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameters: sortParameter, + SortParameters: sortParameters, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.LaunchPlan, "project", project), getEqualityFilter(common.LaunchPlan, "domain", domain), diff --git a/pkg/repositories/gormimpl/named_entity_repo.go b/pkg/repositories/gormimpl/named_entity_repo.go index a9eaa5b23..49776017f 100644 --- a/pkg/repositories/gormimpl/named_entity_repo.go +++ b/pkg/repositories/gormimpl/named_entity_repo.go @@ -4,12 +4,9 @@ import ( "context" "fmt" - "k8s.io/apimachinery/pkg/util/sets" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "google.golang.org/grpc/codes" - "github.com/flyteorg/flytestdlib/promutils" + "google.golang.org/grpc/codes" "gorm.io/gorm" "github.com/flyteorg/flyteadmin/pkg/common" @@ -21,14 +18,7 @@ import ( const innerJoinTableAlias = "entities" -var NamedEntityColumns = sets.NewString( - ResourceType, - Project, - Domain, - Name, - "description", - "state", -) +var NamedEntityColumns = modelColumns(models.NamedEntity{}) var resourceTypeToTableName = map[core.ResourceType]string{ core.ResourceType_LAUNCH_PLAN: launchPlanTableName, diff --git a/pkg/repositories/gormimpl/named_entity_repo_test.go b/pkg/repositories/gormimpl/named_entity_repo_test.go index 67e2f3be8..8c0f72d7f 100644 --- a/pkg/repositories/gormimpl/named_entity_repo_test.go +++ b/pkg/repositories/gormimpl/named_entity_repo_test.go @@ -4,8 +4,6 @@ import ( "context" "testing" - "github.com/flyteorg/flyteadmin/pkg/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" mocket "github.com/Selvatico/go-mocket" @@ -158,17 +156,14 @@ func TestListNamedEntity(t *testing.T) { mockQuery.WithQuery( `SELECT entities.project,entities.domain,entities.name,'2' AS resource_type,named_entity_metadata.description,named_entity_metadata.state FROM "named_entity_metadata" RIGHT JOIN (SELECT project,domain,name FROM "workflows" WHERE "domain" = $1 AND "project" = $2 GROUP BY project, domain, name ORDER BY name desc LIMIT 20) AS entities ON named_entity_metadata.resource_type = 2 AND named_entity_metadata.project = entities.project AND named_entity_metadata.domain = entities.domain AND named_entity_metadata.name = entities.name GROUP BY entities.project, entities.domain, entities.name, named_entity_metadata.description, named_entity_metadata.state ORDER BY name desc`).WithReply(results) - sortParameter, _ := common.NewSortParameter(admin.Sort{ - Direction: admin.Sort_DESCENDING, - Key: "name", - }) + sortParameters := makeSortParameters(t, admin.Sort_DESCENDING, "name") output, err := metadataRepo.List(context.Background(), interfaces.ListNamedEntityInput{ ResourceType: resourceType, Project: "admintests", Domain: "development", ListResourceInput: interfaces.ListResourceInput{ Limit: 20, - SortParameters: sortParameter, + SortParameters: sortParameters, }, }) assert.NoError(t, err) diff --git a/pkg/repositories/gormimpl/node_execution_repo.go b/pkg/repositories/gormimpl/node_execution_repo.go index 563918336..6bd02e2ed 100644 --- a/pkg/repositories/gormimpl/node_execution_repo.go +++ b/pkg/repositories/gormimpl/node_execution_repo.go @@ -8,29 +8,13 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" "gorm.io/gorm" - "k8s.io/apimachinery/pkg/util/sets" adminErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -var NodeExecutionColumns = BaseColumnSet. - Union(NodeExecutionKeyColumnSet). - Union(sets.NewString( - "phase", - "input_uri", - "started_at", - "node_execution_created_at", - "node_execution_updated_at", - "duration", - "parent_id", - "parent_task_execution_id", - "error_kind", - "error_code", - "cache_status", - "dynamic_workflow_remote_closure_reference", - )) +var NodeExecutionColumns = modelColumns(models.NodeExecution{}) // Implementation of NodeExecutionInterface. type NodeExecutionRepo struct { diff --git a/pkg/repositories/gormimpl/node_execution_repo_test.go b/pkg/repositories/gormimpl/node_execution_repo_test.go index 610c4e2a0..f83dc3394 100644 --- a/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -253,12 +253,9 @@ func TestListNodeExecutions_Order(t *testing.T) { mockQuery.WithQuery(`project desc`) mockQuery.WithReply(nodeExecutions) - sortParameter, _ := common.NewSortParameter(admin.Sort{ - Direction: admin.Sort_DESCENDING, - Key: "project", - }) + sortParameters := makeSortParameters(t, admin.Sort_DESCENDING, "project") _, err := nodeExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameters: sortParameter, + SortParameters: sortParameters, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.NodeExecution, "phase", nodePhase), }, diff --git a/pkg/repositories/gormimpl/project_repo.go b/pkg/repositories/gormimpl/project_repo.go index 441649d3b..00e3b7a8e 100644 --- a/pkg/repositories/gormimpl/project_repo.go +++ b/pkg/repositories/gormimpl/project_repo.go @@ -4,28 +4,18 @@ import ( "context" "errors" - "k8s.io/apimachinery/pkg/util/sets" - - "google.golang.org/grpc/codes" - - flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytestdlib/promutils" - + "google.golang.org/grpc/codes" "gorm.io/gorm" + flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -var ProjectColumns = BaseColumnSet.Union(sets.NewString( - "identifier", - "name", - "description", - "state", -)) +var ProjectColumns = modelColumns(models.Project{}) type ProjectRepo struct { db *gorm.DB diff --git a/pkg/repositories/gormimpl/project_repo_test.go b/pkg/repositories/gormimpl/project_repo_test.go index f138aa6aa..efce83bf3 100644 --- a/pkg/repositories/gormimpl/project_repo_test.go +++ b/pkg/repositories/gormimpl/project_repo_test.go @@ -15,10 +15,10 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -var alphabeticalSortParam, _ = common.NewSortParameter(admin.Sort{ +var alphabeticalSortParam, _ = common.NewSortParameter(&admin.Sort{ Direction: admin.Sort_ASCENDING, Key: "identifier", -}) +}, ProjectColumns) func TestCreateProject(t *testing.T) { projectRepo := NewProjectRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) diff --git a/pkg/repositories/gormimpl/signal_repo.go b/pkg/repositories/gormimpl/signal_repo.go index 75c2a92d5..754188109 100644 --- a/pkg/repositories/gormimpl/signal_repo.go +++ b/pkg/repositories/gormimpl/signal_repo.go @@ -4,25 +4,17 @@ import ( "context" "errors" - "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flytestdlib/promutils" + "google.golang.org/grpc/codes" + "gorm.io/gorm" adminerrors "github.com/flyteorg/flyteadmin/pkg/errors" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" - - "github.com/flyteorg/flytestdlib/promutils" - - "google.golang.org/grpc/codes" - - "gorm.io/gorm" ) -var SignalColumns = BaseColumnSet. - Union(ExecutionKeyColumnSet). - Union(sets.NewString( - "signal_id", - )) +var SignalColumns = modelColumns(models.Signal{}) // SignalRepo is an implementation of SignalRepoInterface. type SignalRepo struct { diff --git a/pkg/repositories/gormimpl/task_execution_repo.go b/pkg/repositories/gormimpl/task_execution_repo.go index 8115de4c4..c66e81ee7 100644 --- a/pkg/repositories/gormimpl/task_execution_repo.go +++ b/pkg/repositories/gormimpl/task_execution_repo.go @@ -7,26 +7,13 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" "gorm.io/gorm" - "k8s.io/apimachinery/pkg/util/sets" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -var TaskExecutionColumns = BaseColumnSet. - Union(TaskKeyColumnSet). - Union(NodeExecutionKeyColumnSet). - Union(sets.NewString( - "retry_attempt", - "phase", - "phase_version", - "input_uri", - "started_at", - "task_execution_started_at", - "task_execution_updated_at", - "duration", - )) +var TaskExecutionColumns = modelColumns(models.TaskExecution{}) // Implementation of TaskExecutionInterface. type TaskExecutionRepo struct { diff --git a/pkg/repositories/gormimpl/task_repo.go b/pkg/repositories/gormimpl/task_repo.go index 41feaa7eb..b6b1328ab 100644 --- a/pkg/repositories/gormimpl/task_repo.go +++ b/pkg/repositories/gormimpl/task_repo.go @@ -4,8 +4,6 @@ import ( "context" "errors" - "k8s.io/apimachinery/pkg/util/sets" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" @@ -17,12 +15,7 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -var TaskColumns = BaseColumnSet. - Union(TaskKeyColumnSet). - Union(sets.NewString( - "type", - "short_description", - )) +var TaskColumns = modelColumns(models.Task{}) // Implementation of TaskRepoInterface. type TaskRepo struct { diff --git a/pkg/repositories/gormimpl/task_repo_test.go b/pkg/repositories/gormimpl/task_repo_test.go index 62a044ad1..ff911f3df 100644 --- a/pkg/repositories/gormimpl/task_repo_test.go +++ b/pkg/repositories/gormimpl/task_repo_test.go @@ -198,12 +198,9 @@ func TestListTasks_Order(t *testing.T) { mockQuery.WithQuery(`project desc`) mockQuery.WithReply(tasks) - sortParameter, _ := common.NewSortParameter(admin.Sort{ - Direction: admin.Sort_DESCENDING, - Key: "project", - }) + sortParameters := makeSortParameters(t, admin.Sort_DESCENDING, "project") _, err := taskRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameters: sortParameter, + SortParameters: sortParameters, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Task, "project", project), getEqualityFilter(common.Task, "domain", domain), diff --git a/pkg/repositories/gormimpl/utils_for_test.go b/pkg/repositories/gormimpl/utils_for_test.go index 3959033b6..afaaaa091 100644 --- a/pkg/repositories/gormimpl/utils_for_test.go +++ b/pkg/repositories/gormimpl/utils_for_test.go @@ -4,9 +4,14 @@ package gormimpl import ( "testing" - "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteadmin/pkg/common" + mocket "github.com/Selvatico/go-mocket" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -32,3 +37,12 @@ func getEqualityFilter(entity common.Entity, field string, value interface{}) co filter, _ := common.NewSingleValueFilter(entity, common.Equal, field, value) return filter } + +func makeSortParameters(t *testing.T, direction admin.Sort_Direction, key string) []common.SortParameter { + params, err := common.NewSortParameter(&admin.Sort{ + Direction: direction, + Key: key, + }, sets.NewString(key)) + require.NoError(t, err) + return params +} diff --git a/pkg/repositories/gormimpl/workflow_repo.go b/pkg/repositories/gormimpl/workflow_repo.go index 6f67d60f1..ee04cb6f7 100644 --- a/pkg/repositories/gormimpl/workflow_repo.go +++ b/pkg/repositories/gormimpl/workflow_repo.go @@ -7,22 +7,13 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" "gorm.io/gorm" - "k8s.io/apimachinery/pkg/util/sets" flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -var WorkflowColumns = BaseColumnSet. - Union(sets.NewString( - Project, - Domain, - Name, - Version, - "remote_closure_identifier", - "short_description", - )) +var WorkflowColumns = modelColumns(models.Workflow{}) // Implementation of WorkflowRepoInterface. type WorkflowRepo struct { diff --git a/pkg/repositories/gormimpl/workflow_repo_test.go b/pkg/repositories/gormimpl/workflow_repo_test.go index 3cb043b95..176ebfd93 100644 --- a/pkg/repositories/gormimpl/workflow_repo_test.go +++ b/pkg/repositories/gormimpl/workflow_repo_test.go @@ -183,12 +183,9 @@ func TestListWorkflows_Order(t *testing.T) { mockQuery.WithQuery(`project desc`) mockQuery.WithReply(workflows) - sortParameter, _ := common.NewSortParameter(admin.Sort{ - Direction: admin.Sort_DESCENDING, - Key: "project", - }) + sortParameters := makeSortParameters(t, admin.Sort_DESCENDING, "project") _, err := workflowRepo.List(context.Background(), interfaces.ListResourceInput{ - SortParameters: sortParameter, + SortParameters: sortParameters, InlineFilters: []common.InlineFilter{ getEqualityFilter(common.Workflow, "project", project), getEqualityFilter(common.Workflow, "domain", domain), diff --git a/tests/shared.go b/tests/shared.go index 2dbaee3a4..d9403196c 100644 --- a/tests/shared.go +++ b/tests/shared.go @@ -13,10 +13,11 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" ) const project, domain, name = "project", "domain", "execution name" From d8991bef6e565a33300fb66d7f5a8897a8ac1852 Mon Sep 17 00:00:00 2001 From: Iaroslav Ciupin Date: Wed, 23 Aug 2023 21:17:12 +0300 Subject: [PATCH 4/4] improve coverage Signed-off-by: Iaroslav Ciupin --- pkg/common/sorting_test.go | 47 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/pkg/common/sorting_test.go b/pkg/common/sorting_test.go index 97a7de6a3..502c7cece 100644 --- a/pkg/common/sorting_test.go +++ b/pkg/common/sorting_test.go @@ -48,3 +48,50 @@ func TestSortParameter_Descending(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "project desc", sortParameter[0].GetGormOrderExpr()) } + +func TestSortParameters_SingleAndMultipleSortKeys(t *testing.T) { + expected := errors.NewFlyteAdminErrorf(codes.InvalidArgument, "cannot specify both sort_keys and sort_by") + + _, err := NewSortParameters(&admin.ResourceListRequest{ + SortBy: &admin.Sort{}, + SortKeys: []*admin.Sort{{}}, + }, sets.NewString()) + + assert.Equal(t, expected, err) +} + +func TestSortParameters_SingleSortKey(t *testing.T) { + params, err := NewSortParameters(&admin.ResourceListRequest{ + SortBy: &admin.Sort{Key: "foo"}, + }, sets.NewString("foo")) + + assert.NoError(t, err) + assert.Equal(t, "foo desc", params[0].GetGormOrderExpr()) +} + +func TestSortParameters_OK(t *testing.T) { + params, err := NewSortParameters(&admin.ResourceListRequest{ + SortKeys: []*admin.Sort{ + {Key: "key"}, + {Key: "foo", Direction: admin.Sort_ASCENDING}, + }, + }, sets.NewString("key", "foo")) + + assert.NoError(t, err) + if assert.Len(t, params, 2) { + assert.Equal(t, "key desc", params[0].GetGormOrderExpr()) + assert.Equal(t, "foo asc", params[1].GetGormOrderExpr()) + } +} + +func TestSortParameters_Invalid(t *testing.T) { + expected := errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid sort_key: foo") + + _, err := NewSortParameters(&admin.ResourceListRequest{ + SortKeys: []*admin.Sort{ + {Key: "foo"}, + }, + }, sets.NewString("key")) + + assert.Equal(t, expected, err) +}