Skip to content

Commit

Permalink
fix: permission system blockers (#207)
Browse files Browse the repository at this point in the history
* fix: permission system blockers

* chore(api): format code

* fix(api): bugs in groups and organizations services

* refactor(api): consistent method naming GetIDsBy*()

* fixup! refactor(api): consistent method naming GetIDsBy*()
  • Loading branch information
bouassaba authored Jul 23, 2024
1 parent b110ef5 commit 03794b8
Show file tree
Hide file tree
Showing 19 changed files with 217 additions and 174 deletions.
7 changes: 7 additions & 0 deletions api/cache/group_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,10 @@ func (c *GroupCache) Refresh(id string) (model.Group, error) {
}
return res, nil
}

func (c *GroupCache) Delete(id string) error {
if err := c.redis.Delete(c.keyPrefix + id); err != nil {
return err
}
return nil
}
15 changes: 13 additions & 2 deletions api/errorpkg/error_creators.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,22 @@ func NewOrganizationPermissionError(userID string, org model.Organization, permi
)
}

func NewCannotRemoveLastRemainingOwnerOfOrganizationError(id string) *ErrorResponse {
func NewCannotRemoveLastRemainingOwnerOfOrganizationError(org model.Organization) *ErrorResponse {
return NewErrorResponse(
"cannot_remove_last_owner_of_organization",
http.StatusBadRequest,
fmt.Sprintf("Cannot remove the last remaining owner of organization '%s'.", id), MsgInvalidRequest,
fmt.Sprintf("Cannot remove the last remaining owner of organization '%s'.", org.GetID()),
fmt.Sprintf("Cannot remove the last remaining owner of organization '%s'.", org.GetName()),
nil,
)
}

func NewCannotRemoveLastRemainingOwnerOfGroupError(group model.Group) *ErrorResponse {
return NewErrorResponse(
"cannot_remove_last_owner_of_group",
http.StatusBadRequest,
fmt.Sprintf("Cannot remove the last remaining owner of group '%s'.", group.GetID()),
fmt.Sprintf("Cannot remove the last remaining owner of group '%s'.", group.GetName()),
nil,
)
}
Expand Down
65 changes: 24 additions & 41 deletions api/repo/group_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ type GroupRepo interface {
Insert(opts GroupInsertOptions) (model.Group, error)
Find(id string) (model.Group, error)
Count() (int64, error)
GetIDsForFile(fileID string) ([]string, error)
GetIDsForUser(userID string) ([]string, error)
GetIDsForOrganization(id string) ([]string, error)
GetIDs() ([]string, error)
GetIDsByFile(fileID string) ([]string, error)
GetIDsByOrganization(id string) ([]string, error)
Save(group model.Group) error
Delete(id string) error
AddUser(id string, userID string) error
RemoveMember(id string, userID string) error
GetIDs() ([]string, error)
GetMembers(id string) ([]model.User, error)
GetOwnerCount(id string) (int64, error)
GrantUserPermission(id string, userID string, permission string) error
RevokeUserPermission(id string, userID string) error
}
Expand Down Expand Up @@ -193,7 +191,7 @@ func (repo *groupRepo) Count() (int64, error) {
return res.Result, nil
}

func (repo *groupRepo) GetIDsForFile(fileID string) ([]string, error) {
func (repo *groupRepo) GetIDsByFile(fileID string) ([]string, error) {
type Value struct {
Result string
}
Expand All @@ -214,23 +212,7 @@ func (repo *groupRepo) GetIDsForFile(fileID string) ([]string, error) {
return res, nil
}

func (repo *groupRepo) GetIDsForUser(userID string) ([]string, error) {
type Value struct {
Result string
}
var values []Value
db := repo.db.Raw(`SELECT group_id from group_user WHERE user_id = ?`, userID).Scan(&values)
if db.Error != nil {
return []string{}, db.Error
}
res := []string{}
for _, v := range values {
res = append(res, v.Result)
}
return res, nil
}

func (repo *groupRepo) GetIDsForOrganization(id string) ([]string, error) {
func (repo *groupRepo) GetIDsByOrganization(id string) ([]string, error) {
type Value struct {
Result string
}
Expand Down Expand Up @@ -270,22 +252,6 @@ func (repo *groupRepo) Delete(id string) error {
return nil
}

func (repo *groupRepo) AddUser(id string, userID string) error {
db := repo.db.Exec("INSERT INTO group_user (group_id, user_id, create_time) VALUES (?, ?, ?)", id, userID, helper.NewTimestamp())
if db.Error != nil {
return db.Error
}
return nil
}

func (repo *groupRepo) RemoveMember(id string, userID string) error {
db := repo.db.Exec("DELETE FROM group_user WHERE group_id = ? AND user_id = ?", id, userID)
if db.Error != nil {
return db.Error
}
return nil
}

func (repo *groupRepo) GetIDs() ([]string, error) {
type Value struct {
Result string
Expand All @@ -305,7 +271,8 @@ func (repo *groupRepo) GetIDs() ([]string, error) {
func (repo *groupRepo) GetMembers(id string) ([]model.User, error) {
var entities []*userEntity
db := repo.db.
Raw(`SELECT DISTINCT u.* FROM "user" u INNER JOIN group_user gu ON u.id = gu.user_id WHERE gu.group_id = ?`, id).
Raw(`SELECT u.* FROM "user" u INNER JOIN userpermission up on
u.id = up.user_id AND up.resource_id = ?`, id).
Scan(&entities)
if db.Error != nil {
return nil, db.Error
Expand All @@ -317,6 +284,22 @@ func (repo *groupRepo) GetMembers(id string) ([]model.User, error) {
return res, nil
}

func (repo *groupRepo) GetOwnerCount(id string) (int64, error) {
type Result struct {
Result int64
}
var res Result
db := repo.db.
Raw(`SELECT count(*) as result FROM userpermission
WHERE resource_id = ? and permission = ?`,
id, model.PermissionOwner).
Scan(&res)
if db.Error != nil {
return 0, db.Error
}
return res.Result, nil
}

func (repo *groupRepo) GrantUserPermission(id string, userID string, permission string) error {
db := repo.db.
Exec(`INSERT INTO userpermission (id, user_id, resource_id, permission, create_time)
Expand Down
23 changes: 2 additions & 21 deletions api/repo/organization_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ type OrganizationRepo interface {
Save(org model.Organization) error
Delete(id string) error
GetIDs() ([]string, error)
AddUser(id string, userID string) error
RemoveMember(id string, userID string) error
GetMembers(id string) ([]model.User, error)
GetGroups(id string) ([]model.Group, error)
GetOwnerCount(id string) (int64, error)
Expand Down Expand Up @@ -222,28 +220,11 @@ func (repo *organizationRepo) GetIDs() ([]string, error) {
return res, nil
}

func (repo *organizationRepo) AddUser(id string, userID string) error {
db := repo.db.Exec("INSERT INTO organization_user (organization_id, user_id, create_time) VALUES (?, ?, ?)", id, userID, helper.NewTimestamp())
if db.Error != nil {
return db.Error
}
return nil
}

func (repo *organizationRepo) RemoveMember(id string, userID string) error {
db := repo.db.Exec("DELETE FROM organization_user WHERE organization_id = ? AND user_id = ?", id, userID)
if db.Error != nil {
return db.Error
}
return nil
}

func (repo *organizationRepo) GetMembers(id string) ([]model.User, error) {
var entities []*userEntity
db := repo.db.
Raw(`SELECT DISTINCT u.* FROM "user" u
INNER JOIN organization_user ou ON u.id = ou.user_id
WHERE ou.organization_id = ? ORDER BY u.full_name`,
Raw(`SELECT u.* FROM "user" u INNER JOIN userpermission up on
u.id = up.user_id AND up.resource_id = ?`,
id).
Scan(&entities)
if db.Error != nil {
Expand Down
4 changes: 2 additions & 2 deletions api/repo/snapshot_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type SnapshotRepo interface {
FindAllForFile(fileID string) ([]model.Snapshot, error)
FindAllDangling() ([]model.Snapshot, error)
FindAllPrevious(fileID string, version int64) ([]model.Snapshot, error)
GetIDsForFile(fileID string) ([]string, error)
GetIDsByFile(fileID string) ([]string, error)
Insert(snapshot model.Snapshot) error
Save(snapshot model.Snapshot) error
Delete(id string) error
Expand Down Expand Up @@ -561,7 +561,7 @@ func (repo *snapshotRepo) FindAllPrevious(fileID string, version int64) ([]model
return res, nil
}

func (repo *snapshotRepo) GetIDsForFile(fileID string) ([]string, error) {
func (repo *snapshotRepo) GetIDsByFile(fileID string) ([]string, error) {
type Value struct {
Result string
}
Expand Down
4 changes: 2 additions & 2 deletions api/repo/task_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ type TaskRepo interface {
Find(id string) (model.Task, error)
Count() (int64, error)
GetIDs(userID string) ([]string, error)
GetCount(email string) (int64, error)
GetCountByEmail(email string) (int64, error)
Save(task model.Task) error
Delete(id string) error
}
Expand Down Expand Up @@ -258,7 +258,7 @@ func (repo *taskRepo) GetIDs(userID string) ([]string, error) {
return res, nil
}

func (repo *taskRepo) GetCount(userID string) (int64, error) {
func (repo *taskRepo) GetCountByEmail(userID string) (int64, error) {
var count int64
db := repo.db.
Model(&taskEntity{}).
Expand Down
9 changes: 9 additions & 0 deletions api/repo/workspace_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type WorkspaceRepo interface {
GetIDs() ([]string, error)
GetIDsByOrganization(orgID string) ([]string, error)
GrantUserPermission(id string, userID string, permission string) error
RevokeUserPermission(id string, userID string) error
}

func NewWorkspaceRepo() WorkspaceRepo {
Expand Down Expand Up @@ -313,6 +314,14 @@ func (repo *workspaceRepo) GrantUserPermission(id string, userID string, permiss
return nil
}

func (repo *workspaceRepo) RevokeUserPermission(id string, userID string) error {
db := repo.db.Exec("DELETE FROM userpermission WHERE user_id = ? AND resource_id = ?", userID, id)
if db.Error != nil {
return db.Error
}
return nil
}

func (repo *workspaceRepo) populateModelFields(workspaces []*workspaceEntity) error {
for _, w := range workspaces {
w.UserPermissions = make([]*UserPermissionValue, 0)
Expand Down
8 changes: 4 additions & 4 deletions api/router/user_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ func (r *UserRouter) List(c *fiber.Ctx) error {
return errorpkg.NewInvalidQueryParamError("sort_order")
}
userID := GetUserID(c)
var nonGroupMembersOnly bool
if c.Query("non_group_members_only") != "" {
nonGroupMembersOnly, err = strconv.ParseBool(c.Query("non_group_members_only"))
var excludeGroupMembers bool
if c.Query("exclude_group_members") != "" {
excludeGroupMembers, err = strconv.ParseBool(c.Query("exclude_group_members"))
if err != nil {
return err
}
Expand All @@ -96,7 +96,7 @@ func (r *UserRouter) List(c *fiber.Ctx) error {
Query: query,
OrganizationID: c.Query("organization_id"),
GroupID: c.Query("group_id"),
NonGroupMembersOnly: nonGroupMembersOnly,
ExcludeGroupMembers: excludeGroupMembers,
SortBy: sortBy,
SortOrder: sortOrder,
Page: uint(page),
Expand Down
77 changes: 28 additions & 49 deletions api/service/group_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/kouprlabs/voltaserve/api/cache"
"github.com/kouprlabs/voltaserve/api/config"
"github.com/kouprlabs/voltaserve/api/errorpkg"
"github.com/kouprlabs/voltaserve/api/guard"
"github.com/kouprlabs/voltaserve/api/helper"
"github.com/kouprlabs/voltaserve/api/infra"
Expand Down Expand Up @@ -251,7 +252,7 @@ func (svc *GroupService) Delete(id string, userID string) error {
if err := svc.groupSearch.Delete([]string{group.GetID()}); err != nil {
return err
}
if err := svc.refreshCacheForOrganization(group.GetOrganizationID()); err != nil {
if err := svc.groupCache.Delete(group.GetID()); err != nil {
return err
}
return nil
Expand All @@ -262,24 +263,28 @@ func (svc *GroupService) AddMember(id string, memberID string, userID string) er
if err != nil {
return nil
}
if err := svc.groupGuard.Authorize(userID, group, model.PermissionOwner); err != nil {
return err
}

/* Ensure the member exists before proceeding. */
if _, err := svc.userRepo.Find(memberID); err != nil {
return err
}
if err := svc.groupRepo.AddUser(id, memberID); err != nil {
return err
}
if err := svc.groupRepo.GrantUserPermission(group.GetID(), memberID, model.PermissionViewer); err != nil {
return err
}
if _, err := svc.groupCache.Refresh(group.GetID()); err != nil {

if err := svc.groupGuard.Authorize(userID, group, model.PermissionOwner); err != nil {
return err
}
if err := svc.refreshCacheForOrganization(group.GetOrganizationID()); err != nil {
return err

/* Ensure that the member doesn't already have a higher permission on the group.
If we don't check that, we risk downgrading the existing permission.*/
if !svc.groupGuard.IsAuthorized(memberID, group, model.PermissionViewer) &&
!svc.groupGuard.IsAuthorized(memberID, group, model.PermissionEditor) {
if err := svc.groupRepo.GrantUserPermission(group.GetID(), memberID, model.PermissionViewer); err != nil {
return err
}
if _, err := svc.groupCache.Refresh(group.GetID()); err != nil {
return err
}
}

return nil
}

Expand All @@ -288,57 +293,31 @@ func (svc *GroupService) RemoveMember(id string, memberID string, userID string)
if err != nil {
return nil
}
if err := svc.groupGuard.Authorize(userID, group, model.PermissionOwner); err != nil {

/* Ensure the member exists before proceeding. */
if _, err := svc.userRepo.Find(memberID); err != nil {
return err
}
if err := svc.RemoveMemberUnauthorized(id, memberID); err != nil {

if err := svc.groupGuard.Authorize(userID, group, model.PermissionOwner); err != nil {
return err
}
return nil
}

func (svc *GroupService) RemoveMemberUnauthorized(id string, memberID string) error {
group, err := svc.groupCache.Get(id)
/* Make sure member is not the last remaining owner of the group */
ownerCount, err := svc.groupRepo.GetOwnerCount(group.GetID())
if err != nil {
return nil
}
if _, err := svc.userRepo.Find(memberID); err != nil {
return err
}
if err := svc.groupRepo.RemoveMember(id, memberID); err != nil {
return err
if svc.groupGuard.IsAuthorized(memberID, group, model.PermissionOwner) && ownerCount == 1 {
return errorpkg.NewCannotRemoveLastRemainingOwnerOfGroupError(group)
}

if err := svc.groupRepo.RevokeUserPermission(id, memberID); err != nil {
return err
}
if _, err := svc.groupCache.Refresh(group.GetID()); err != nil {
return err
}
if err := svc.refreshCacheForOrganization(group.GetOrganizationID()); err != nil {
return err
}
return nil
}

func (svc *GroupService) refreshCacheForOrganization(orgID string) error {
workspaceIDs, err := svc.workspaceRepo.GetIDsByOrganization(orgID)
if err != nil {
return err
}
for _, workspaceID := range workspaceIDs {
if _, err := svc.workspaceCache.Refresh(workspaceID); err != nil {
return err
}
filesIDs, err := svc.fileRepo.GetIDsByWorkspace(workspaceID)
if err != nil {
return err
}
for _, id := range filesIDs {
if _, err := svc.fileCache.Refresh(id); err != nil {
return err
}
}
}
return nil
}

Expand Down
Loading

0 comments on commit 03794b8

Please sign in to comment.