Skip to content

Commit

Permalink
cleanup: simplify feature flag checks.
Browse files Browse the repository at this point in the history
Reduces code duplication.

Signed-off-by: Hiram Chirino <[email protected]>
  • Loading branch information
chirino committed Nov 12, 2023
1 parent fb02340 commit 7c9576a
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 56 deletions.
18 changes: 12 additions & 6 deletions internal/fflags/fflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fflags

import (
"fmt"
"github.com/gin-gonic/gin"
"os"
"strconv"

Expand All @@ -24,7 +25,7 @@ type FFlag struct {

var hardCodedFlags = map[string]FFlag{
"multi-organization": {"NEXAPI_FFLAG_MULTI_ORGANIZATION", true},
"security-groups": {"NEXAPI_FFLAG_SECURITY_GROUPS", false},
"security-groups": {"NEXAPI_FFLAG_SECURITY_GROUPS", true},
}

func NewFFlags(logger *zap.SugaredLogger) *FFlags {
Expand All @@ -33,7 +34,11 @@ func NewFFlags(logger *zap.SugaredLogger) *FFlags {
}
}

func (f *FFlags) getFlagValue(fflag FFlag) bool {
func (f *FFlags) getFlagValue(c *gin.Context, name string, fflag FFlag) bool {
ctxName := fmt.Sprintf("nexodus.fflag.%s", name)
if _, found := c.Get(ctxName); found {
return c.GetBool(ctxName)
}
if envValue, err := strconv.ParseBool(os.Getenv(fflag.env)); err == nil {
return envValue
}
Expand All @@ -42,22 +47,23 @@ func (f *FFlags) getFlagValue(fflag FFlag) bool {

// ListFlags returns a map of all currently defined feature flags and
// whether those features are enabled (true) or not (false).
func (f *FFlags) ListFlags() map[string]bool {
func (f *FFlags) ListFlags(c *gin.Context) map[string]bool {
result := map[string]bool{}
for name, fflag := range hardCodedFlags {
result[name] = f.getFlagValue(fflag)
result[name] = f.getFlagValue(c, name, fflag)
}
return result
}

// GetFlag returns whether the feature named by the string parameter
// flag is enabled (true) or not (false). An error is returned if
// the flag name is invalid.
func (f *FFlags) GetFlag(flag string) (bool, error) {
func (f *FFlags) GetFlag(c *gin.Context, flag string) (bool, error) {
fflag, ok := hardCodedFlags[flag]

if !ok {
f.logger.Errorf("Invalid feature flag name: %s", flag)
return false, fmt.Errorf("Invalid feature flag name: %s", flag)
}
return f.getFlagValue(fflag), nil
return f.getFlagValue(c, flag, fflag), nil
}
13 changes: 13 additions & 0 deletions internal/handlers/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,16 @@ func (api *API) GetCurrentUserID(c *gin.Context) uuid.UUID {
}
return userId.(uuid.UUID)
}

func (api *API) FlagCheck(c *gin.Context, name string) bool {
enabled, err := api.fflags.GetFlag(c, name)
if err != nil {
api.SendInternalServerError(c, err)
return false
}
if !enabled {
c.JSON(http.StatusMethodNotAllowed, models.NewNotAllowedError(fmt.Sprintf("%s support is disabled", name)))
return false
}
return enabled
}
4 changes: 2 additions & 2 deletions internal/handlers/fflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// @Failure 500 {object} models.InternalServerError "Internal Server Error"
// @Router /api/fflags [get]
func (api *API) ListFeatureFlags(c *gin.Context) {
c.JSON(http.StatusOK, api.fflags.ListFlags())
c.JSON(http.StatusOK, api.fflags.ListFlags(c))
}

// GetFeatureFlag gets a feature flag by name
Expand All @@ -43,7 +43,7 @@ func (api *API) GetFeatureFlag(c *gin.Context) {
return
}

enabled, err := api.fflags.GetFlag(flagName)
enabled, err := api.fflags.GetFlag(c, flagName)
if err != nil {
c.JSON(http.StatusNotFound, models.NewNotFoundError("flag"))
return
Expand Down
4 changes: 2 additions & 2 deletions internal/handlers/invitations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (suite *HandlerTestSuite) TestCreateAcceptRefuseInvitation() {
http.MethodPost,
"/:id", fmt.Sprintf("/%s", inviteID.String()),
func(c *gin.Context) {
c.Set("_apex.testCreateOrganization", "true")
c.Set("nexodus.fflag.multi-organization", true)
suite.api.AcceptInvitation(c)
}, nil,
)
Expand All @@ -136,7 +136,7 @@ func (suite *HandlerTestSuite) TestCreateAcceptRefuseInvitation() {
http.MethodPost,
"/:id", fmt.Sprintf("/%s", inviteID.String()),
func(c *gin.Context) {
c.Set("_apex.testCreateOrganization", "true")
c.Set("nexodus.fflag.multi-organization", true)
suite.api.DeleteInvitation(c)
}, nil,
)
Expand Down
22 changes: 6 additions & 16 deletions internal/handlers/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,11 @@ func (e errDuplicateOrganization) Error() string {
func (api *API) CreateOrganization(c *gin.Context) {
ctx, span := tracer.Start(c.Request.Context(), "CreateOrganization")
defer span.End()
multiOrganizationEnabled, err := api.fflags.GetFlag("multi-organization")
if err != nil {
api.SendInternalServerError(c, err)
return
}
allowForTests := c.GetString("nexodus.testCreateOrganization")
if (!multiOrganizationEnabled && allowForTests != "true") || allowForTests == "false" {
c.JSON(http.StatusMethodNotAllowed, models.NewNotAllowedError("multi-organization support is disabled"))

if !api.FlagCheck(c, "multi-organization") {
return
}

userId := api.GetCurrentUserID(c)

var request models.AddOrganization
Expand All @@ -66,7 +61,7 @@ func (api *API) CreateOrganization(c *gin.Context) {
}

var org models.Organization
err = api.transaction(ctx, func(tx *gorm.DB) error {
err := api.transaction(ctx, func(tx *gorm.DB) error {
var user models.User
if res := tx.First(&user, "id = ?", userId); res.Error != nil {
return errUserNotFound
Expand Down Expand Up @@ -221,13 +216,8 @@ func (api *API) DeleteOrganization(c *gin.Context) {
attribute.String("id", c.Param("id")),
))
defer span.End()
multiOrganizationEnabled, err := api.fflags.GetFlag("multi-organization")
if err != nil {
api.SendInternalServerError(c, err)
return
}
if !multiOrganizationEnabled {
c.JSON(http.StatusMethodNotAllowed, models.NewNotAllowedError("multi-organization support is disabled"))

if !api.FlagCheck(c, "multi-organization") {
return
}

Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/reg_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func (api *API) DeleteRegKey(c *gin.Context) {

id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, models.NewBadPathParameterError("key-id"))
c.JSON(http.StatusBadRequest, models.NewBadPathParameterError("id"))
return
}

Expand Down
20 changes: 3 additions & 17 deletions internal/handlers/security_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,6 @@ func (api *API) GetSecurityGroup(c *gin.Context) {
c.JSON(http.StatusOK, securityGroup)
}

func (api *API) secGroupsEnabled(c *gin.Context) bool {
secGroupsEnabled, err := api.fflags.GetFlag("security-groups")
if err != nil {
api.SendInternalServerError(c, err)
return false
}
allowForTests := c.GetString("nexodus.secGroupsEnabled")
if (!secGroupsEnabled && allowForTests != "true") || allowForTests == "false" {
c.JSON(http.StatusMethodNotAllowed, models.NewNotAllowedError("security-groups support is disabled"))
return false
}
return true
}

// CreateSecurityGroup handles adding a new SecurityGroup
// @Summary Add SecurityGroup
// @Id CreateSecurityGroup
Expand All @@ -252,7 +238,7 @@ func (api *API) CreateSecurityGroup(c *gin.Context) {
ctx, span := tracer.Start(c.Request.Context(), "CreateSecurityGroup")
defer span.End()

if !api.secGroupsEnabled(c) {
if !api.FlagCheck(c, "security-groups") {
return
}

Expand Down Expand Up @@ -360,7 +346,7 @@ func (api *API) DeleteSecurityGroup(c *gin.Context) {

defer span.End()

if !api.secGroupsEnabled(c) {
if !api.FlagCheck(c, "security-groups") {
return
}

Expand Down Expand Up @@ -444,7 +430,7 @@ func (api *API) UpdateSecurityGroup(c *gin.Context) {
))
defer span.End()

if !api.secGroupsEnabled(c) {
if !api.FlagCheck(c, "security-groups") {
return
}

Expand Down
22 changes: 11 additions & 11 deletions internal/handlers/security_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (suite *HandlerTestSuite) TestCreateGetSecurityGroups() {
http.MethodPost,
"/security-groups", "/security-groups",
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.CreateSecurityGroup(c)
},
bytes.NewBuffer(resBody),
Expand Down Expand Up @@ -120,7 +120,7 @@ func (suite *HandlerTestSuite) TestDeleteSecurityGroup() {
http.MethodPost,
"/security-groups", "/security-groups",
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.CreateSecurityGroup(c)
},
bytes.NewBuffer(resBody),
Expand All @@ -141,7 +141,7 @@ func (suite *HandlerTestSuite) TestDeleteSecurityGroup() {
http.MethodDelete,
"/security-groups/:id", fmt.Sprintf("/security-groups/%s", actual.ID),
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.DeleteSecurityGroup(c)
},
nil,
Expand Down Expand Up @@ -188,7 +188,7 @@ func (suite *HandlerTestSuite) TestListSecurityGroups() {
http.MethodPost,
"/security-groups", "/security-groups",
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.CreateSecurityGroup(c)
},
bytes.NewBuffer(resBody),
Expand Down Expand Up @@ -237,7 +237,7 @@ func (suite *HandlerTestSuite) TestUpdateSecurityGroup() {
http.MethodPost,
"/security-groups", "/security-groups",
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.CreateSecurityGroup(c)
},
bytes.NewBuffer(resBody),
Expand Down Expand Up @@ -266,7 +266,7 @@ func (suite *HandlerTestSuite) TestUpdateSecurityGroup() {
http.MethodPatch,
"/security-groups/:id", fmt.Sprintf("/security-groups/%s", actualGroup.ID),
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.UpdateSecurityGroup(c)
},
bytes.NewBuffer(updateBody),
Expand Down Expand Up @@ -306,7 +306,7 @@ func (suite *HandlerTestSuite) TestInvalidUpdateSecurityGroup() {
http.MethodPost,
"/security-groups", "/security-groups",
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.CreateSecurityGroup(c)
},
bytes.NewBuffer(resBody),
Expand Down Expand Up @@ -337,7 +337,7 @@ func (suite *HandlerTestSuite) TestInvalidUpdateSecurityGroup() {
http.MethodPatch,
"/security-groups/:id", fmt.Sprintf("/security-groups/%s", actualGroup.ID),
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.UpdateSecurityGroup(c)
},
bytes.NewBuffer(updateBody),
Expand All @@ -363,7 +363,7 @@ func (suite *HandlerTestSuite) TestInvalidUpdateSecurityGroup() {
http.MethodPatch,
"/security-groups/:id", fmt.Sprintf("/security-groups/%s", actualGroup.ID),
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.UpdateSecurityGroup(c)
},
bytes.NewBuffer(updateBody),
Expand All @@ -389,7 +389,7 @@ func (suite *HandlerTestSuite) TestInvalidUpdateSecurityGroup() {
http.MethodPatch,
"/security-groups/:id", fmt.Sprintf("/security-groups/%s", actualGroup.ID),
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.UpdateSecurityGroup(c)
},
bytes.NewBuffer(updateBody),
Expand All @@ -415,7 +415,7 @@ func (suite *HandlerTestSuite) TestInvalidUpdateSecurityGroup() {
http.MethodPatch,
"/security-groups/:id", fmt.Sprintf("/security-groups/%s", actualGroup.ID),
func(c *gin.Context) {
c.Set("nexodus.secGroupsEnabled", "true")
c.Set("nexodus.fflag.security-groups", true)
suite.api.UpdateSecurityGroup(c)
},
bytes.NewBuffer(updateBody),
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/vpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (suite *HandlerTestSuite) TestListVPCs() {
http.MethodPost,
"/", "/",
func(c *gin.Context) {
c.Set("nexodus.testCreateVPC", "false")
c.Set("nexodus.fflag.multi-organization", false)
suite.api.CreateVPC(c)
},

Expand Down

0 comments on commit 7c9576a

Please sign in to comment.