diff --git a/pkg/abstractions/common/deployment.go b/pkg/abstractions/common/deployment.go new file mode 100644 index 000000000..d9c618f30 --- /dev/null +++ b/pkg/abstractions/common/deployment.go @@ -0,0 +1,55 @@ +package abstractions + +import ( + "context" + "strconv" + + apiv1 "github.com/beam-cloud/beta9/pkg/api/v1" + "github.com/beam-cloud/beta9/pkg/auth" + "github.com/beam-cloud/beta9/pkg/repository" + "github.com/beam-cloud/beta9/pkg/types" +) + +func ParseAndValidateDeploymentStubId( + ctx context.Context, + authInfo *auth.AuthInfo, + stubId string, + deploymentName string, + version string, + stubType string, + backendRepo repository.BackendRepository, +) (string, error) { + if deploymentName != "" { + var deployment *types.DeploymentWithRelated + + if version == "" { + var err error + deployment, err = backendRepo.GetLatestDeploymentByName(ctx, authInfo.Workspace.Id, deploymentName, stubType, true) + if err != nil { + return "", apiv1.HTTPBadRequest("Invalid deployment") + } + } else { + version, err := strconv.Atoi(version) + if err != nil { + return "", apiv1.HTTPBadRequest("Invalid deployment version") + } + + deployment, err = backendRepo.GetDeploymentByNameAndVersion(ctx, authInfo.Workspace.Id, deploymentName, uint(version), stubType) + if err != nil { + return "", apiv1.HTTPBadRequest("Invalid deployment") + } + } + + if deployment == nil { + return "", apiv1.HTTPBadRequest("Invalid deployment") + } + + if !deployment.Active { + return "", apiv1.HTTPBadRequest("Deployment is not active") + } + + stubId = deployment.Stub.ExternalId + } + + return stubId, nil +} diff --git a/pkg/abstractions/endpoint/endpoint.go b/pkg/abstractions/endpoint/endpoint.go index fb8c4f9d1..e94713936 100644 --- a/pkg/abstractions/endpoint/endpoint.go +++ b/pkg/abstractions/endpoint/endpoint.go @@ -146,7 +146,7 @@ func (es *HttpEndpointService) endpointTaskFactory(ctx context.Context, msg type }, nil } -func (es *HttpEndpointService) isPublic(stubId string) (*types.Workspace, error) { +func (es *HttpEndpointService) IsPublic(stubId string) (*types.Workspace, error) { instance, err := es.getOrCreateEndpointInstance(es.ctx, stubId) if err != nil { return nil, err diff --git a/pkg/abstractions/endpoint/http.go b/pkg/abstractions/endpoint/http.go index 610b2c7de..4605e56a1 100644 --- a/pkg/abstractions/endpoint/http.go +++ b/pkg/abstractions/endpoint/http.go @@ -1,9 +1,7 @@ package endpoint import ( - "strconv" - - apiv1 "github.com/beam-cloud/beta9/pkg/api/v1" + abstractions "github.com/beam-cloud/beta9/pkg/abstractions/common" "github.com/beam-cloud/beta9/pkg/auth" "github.com/beam-cloud/beta9/pkg/types" "github.com/labstack/echo/v4" @@ -17,22 +15,22 @@ type endpointGroup struct { func registerEndpointRoutes(g *echo.Group, es *HttpEndpointService) *endpointGroup { group := &endpointGroup{routeGroup: g, es: es} - g.POST("/id/:stubId", auth.WithAuth(group.endpointRequest)) - g.POST("/:deploymentName", auth.WithAuth(group.endpointRequest)) - g.POST("/:deploymentName/latest", auth.WithAuth(group.endpointRequest)) - g.POST("/:deploymentName/v:version", auth.WithAuth(group.endpointRequest)) - g.POST("/public/:stubId", auth.WithAssumedStubAuth(group.endpointRequest, group.es.isPublic)) + g.POST("/id/:stubId", auth.WithAuth(group.EndpointRequest)) + g.POST("/:deploymentName", auth.WithAuth(group.EndpointRequest)) + g.POST("/:deploymentName/latest", auth.WithAuth(group.EndpointRequest)) + g.POST("/:deploymentName/v:version", auth.WithAuth(group.EndpointRequest)) + g.POST("/public/:stubId", auth.WithAssumedStubAuth(group.EndpointRequest, group.es.IsPublic)) - g.GET("/id/:stubId", auth.WithAuth(group.endpointRequest)) - g.GET("/:deploymentName", auth.WithAuth(group.endpointRequest)) - g.GET("/:deploymentName/latest", auth.WithAuth(group.endpointRequest)) - g.GET("/:deploymentName/v:version", auth.WithAuth(group.endpointRequest)) - g.GET("/public/:stubId", auth.WithAssumedStubAuth(group.endpointRequest, group.es.isPublic)) + g.GET("/id/:stubId", auth.WithAuth(group.EndpointRequest)) + g.GET("/:deploymentName", auth.WithAuth(group.EndpointRequest)) + g.GET("/:deploymentName/latest", auth.WithAuth(group.EndpointRequest)) + g.GET("/:deploymentName/v:version", auth.WithAuth(group.EndpointRequest)) + g.GET("/public/:stubId", auth.WithAssumedStubAuth(group.EndpointRequest, group.es.IsPublic)) - g.POST("/id/:stubId/warmup", auth.WithAuth(group.warmUpEndpoint)) - g.POST("/:deploymentName/warmup", auth.WithAuth(group.warmUpEndpoint)) - g.POST("/:deploymentName/latest/warmup", auth.WithAuth(group.warmUpEndpoint)) - g.POST("/:deploymentName/v:version/warmup", auth.WithAuth(group.warmUpEndpoint)) + g.POST("/id/:stubId/warmup", auth.WithAuth(group.WarmUpEndpoint)) + g.POST("/:deploymentName/warmup", auth.WithAuth(group.WarmUpEndpoint)) + g.POST("/:deploymentName/latest/warmup", auth.WithAuth(group.WarmUpEndpoint)) + g.POST("/:deploymentName/v:version/warmup", auth.WithAuth(group.WarmUpEndpoint)) return group } @@ -48,49 +46,31 @@ func registerASGIRoutes(g *echo.Group, es *HttpEndpointService) *endpointGroup { g.Any("/:deploymentName/latest/:subPath", auth.WithAuth(group.ASGIRequest)) g.Any("/:deploymentName/v:version", auth.WithAuth(group.ASGIRequest)) g.Any("/:deploymentName/v:version/:subPath", auth.WithAuth(group.ASGIRequest)) - g.Any("/public/:stubId", auth.WithAssumedStubAuth(group.ASGIRequest, group.es.isPublic)) - g.Any("/public/:stubId/:subPath", auth.WithAssumedStubAuth(group.ASGIRequest, group.es.isPublic)) + g.Any("/public/:stubId", auth.WithAssumedStubAuth(group.ASGIRequest, group.es.IsPublic)) + g.Any("/public/:stubId/:subPath", auth.WithAssumedStubAuth(group.ASGIRequest, group.es.IsPublic)) + + g.POST("/id/:stubId/warmup", auth.WithAuth(group.WarmupASGI)) + g.POST("/:deploymentName/warmup", auth.WithAuth(group.WarmupASGI)) + g.POST("/:deploymentName/latest/warmup", auth.WithAuth(group.WarmupASGI)) + g.POST("/:deploymentName/v:version/warmup", auth.WithAuth(group.WarmupASGI)) return group } -func (g *endpointGroup) endpointRequest(ctx echo.Context) error { +func (g *endpointGroup) EndpointRequest(ctx echo.Context) error { cc, _ := ctx.(*auth.HttpAuthContext) - stubId := ctx.Param("stubId") - deploymentName := ctx.Param("deploymentName") - version := ctx.Param("version") - - if deploymentName != "" { - var deployment *types.DeploymentWithRelated - - if version == "" { - var err error - deployment, err = g.es.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeEndpointDeployment, true) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } else { - version, err := strconv.Atoi(version) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment version") - } - - deployment, err = g.es.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeEndpointDeployment) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } - - if deployment == nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - - if !deployment.Active { - return apiv1.HTTPBadRequest("Deployment is not active") - } - - stubId = deployment.Stub.ExternalId + stubId, err := abstractions.ParseAndValidateDeploymentStubId( + ctx.Request().Context(), + cc.AuthInfo, + ctx.Param("stubId"), + ctx.Param("deploymentName"), + ctx.Param("version"), + types.StubTypeEndpointDeployment, + g.es.backendRepo, + ) + if err != nil { + return err } return g.es.forwardRequest(ctx, cc.AuthInfo, stubId) @@ -99,82 +79,48 @@ func (g *endpointGroup) endpointRequest(ctx echo.Context) error { func (g *endpointGroup) ASGIRequest(ctx echo.Context) error { cc, _ := ctx.(*auth.HttpAuthContext) - stubId := ctx.Param("stubId") - deploymentName := ctx.Param("deploymentName") - version := ctx.Param("version") - - if deploymentName != "" { - var deployment *types.DeploymentWithRelated - - if version == "" { - var err error - deployment, err = g.es.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeASGIDeployment, true) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } else { - version, err := strconv.Atoi(version) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment version") - } - - deployment, err = g.es.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeASGIDeployment) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } - - if deployment == nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - - if !deployment.Active { - return apiv1.HTTPBadRequest("Deployment is not active") - } - - stubId = deployment.Stub.ExternalId + stubId, err := abstractions.ParseAndValidateDeploymentStubId( + ctx.Request().Context(), + cc.AuthInfo, + ctx.Param("stubId"), + ctx.Param("deploymentName"), + ctx.Param("version"), + types.StubTypeASGIDeployment, + g.es.backendRepo, + ) + if err != nil { + return err } return g.es.forwardRequest(ctx, cc.AuthInfo, stubId) } -func (g *endpointGroup) warmUpEndpoint(ctx echo.Context) error { +func (g *endpointGroup) WarmUpEndpoint(ctx echo.Context) error { + return g.warmup(ctx, types.StubTypeEndpointDeployment) +} + +func (g *endpointGroup) WarmupASGI(ctx echo.Context) error { + return g.warmup(ctx, types.StubTypeASGIDeployment) +} + +func (g *endpointGroup) warmup( + ctx echo.Context, + deploymentType string, + +) error { cc, _ := ctx.(*auth.HttpAuthContext) - stubId := ctx.Param("stubId") - deploymentName := ctx.Param("deploymentName") - version := ctx.Param("version") - - if deploymentName != "" { - var deployment *types.DeploymentWithRelated - - if version == "" { - var err error - deployment, err = g.es.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeEndpointDeployment, true) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } else { - version, err := strconv.Atoi(version) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment version") - } - - deployment, err = g.es.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeEndpointDeployment) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } - - if deployment == nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - - if !deployment.Active { - return apiv1.HTTPBadRequest("Deployment is not active") - } - - stubId = deployment.Stub.ExternalId + stubId, err := abstractions.ParseAndValidateDeploymentStubId( + ctx.Request().Context(), + cc.AuthInfo, + ctx.Param("stubId"), + ctx.Param("deploymentName"), + ctx.Param("version"), + deploymentType, + g.es.backendRepo, + ) + if err != nil { + return err } return g.es.warmup( diff --git a/pkg/abstractions/taskqueue/http.go b/pkg/abstractions/taskqueue/http.go index fb898336b..10281e109 100644 --- a/pkg/abstractions/taskqueue/http.go +++ b/pkg/abstractions/taskqueue/http.go @@ -2,9 +2,8 @@ package taskqueue import ( "net/http" - "strconv" - apiv1 "github.com/beam-cloud/beta9/pkg/api/v1" + abstractions "github.com/beam-cloud/beta9/pkg/abstractions/common" "github.com/beam-cloud/beta9/pkg/auth" "github.com/beam-cloud/beta9/pkg/task" "github.com/beam-cloud/beta9/pkg/types" @@ -37,40 +36,17 @@ func registerTaskQueueRoutes(g *echo.Group, tq *RedisTaskQueue) *taskQueueGroup func (g *taskQueueGroup) TaskQueuePut(ctx echo.Context) error { cc, _ := ctx.(*auth.HttpAuthContext) - stubId := ctx.Param("stubId") - deploymentName := ctx.Param("deploymentName") - version := ctx.Param("version") - - if deploymentName != "" { - var deployment *types.DeploymentWithRelated - - if version == "" { - var err error - deployment, err = g.tq.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeTaskQueueDeployment, true) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } else { - version, err := strconv.Atoi(version) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment version") - } - - deployment, err = g.tq.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeTaskQueueDeployment) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } - - if deployment == nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - - if !deployment.Active { - return apiv1.HTTPBadRequest("Deployment is not active") - } - - stubId = deployment.Stub.ExternalId + stubId, err := abstractions.ParseAndValidateDeploymentStubId( + ctx.Request().Context(), + cc.AuthInfo, + ctx.Param("stubId"), + ctx.Param("deploymentName"), + ctx.Param("version"), + types.StubTypeTaskQueueDeployment, + g.tq.backendRepo, + ) + if err != nil { + return err } payload, err := task.SerializeHttpPayload(ctx) @@ -101,43 +77,20 @@ func (g *taskQueueGroup) TaskQueuePut(ctx echo.Context) error { func (g *taskQueueGroup) TaskQueueWarmUp(ctx echo.Context) error { cc, _ := ctx.(*auth.HttpAuthContext) - stubId := ctx.Param("stubId") - deploymentName := ctx.Param("deploymentName") - version := ctx.Param("version") - - if deploymentName != "" { - var deployment *types.DeploymentWithRelated - - if version == "" { - var err error - deployment, err = g.tq.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeTaskQueueDeployment, true) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } else { - version, err := strconv.Atoi(version) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment version") - } - - deployment, err = g.tq.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeTaskQueueDeployment) - if err != nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - } - - if deployment == nil { - return apiv1.HTTPBadRequest("Invalid deployment") - } - - if !deployment.Active { - return apiv1.HTTPBadRequest("Deployment is not active") - } - - stubId = deployment.Stub.ExternalId + stubId, err := abstractions.ParseAndValidateDeploymentStubId( + ctx.Request().Context(), + cc.AuthInfo, + ctx.Param("stubId"), + ctx.Param("deploymentName"), + ctx.Param("version"), + types.StubTypeTaskQueueDeployment, + g.tq.backendRepo, + ) + if err != nil { + return err } - err := g.tq.warmup( + err = g.tq.warmup( stubId, ) if err != nil {