Skip to content

Commit

Permalink
refactor(api): map function + reliability on racing conditions (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
bouassaba authored Oct 5, 2024
1 parent 088fcda commit 5baf835
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 75 deletions.
3 changes: 1 addition & 2 deletions api/errorpkg/error_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ import (
func ErrorHandler(c *fiber.Ctx, err error) error {
var e *ErrorResponse
if errors.As(err, &e) {
v := err.(*ErrorResponse)
return c.Status(v.Status).JSON(v)
return c.Status(e.Status).JSON(e)
} else {
log.GetLogger().Error(err)
return c.Status(http.StatusInternalServerError).JSON(NewInternalServerError(err))
Expand Down
19 changes: 15 additions & 4 deletions api/service/file_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package service

import (
"bytes"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -533,7 +534,7 @@ func (svc *FileService) ListByPath(path string, userID string) ([]*File, error)
if err != nil {
return nil, err
}
result := []*File{}
result := make([]*File, 0)
for _, w := range workspaces {
result = append(result, &File{
ID: w.RootID,
Expand Down Expand Up @@ -767,7 +768,7 @@ func (svc *FileService) GetPath(id string, userID string) ([]*File, error) {
if err != nil {
return nil, err
}
res := []*File{}
res := make([]*File, 0)
for _, file := range path {
f, err := svc.fileMapper.mapOne(file, userID)
if err != nil {
Expand Down Expand Up @@ -1687,7 +1688,12 @@ func (svc *FileService) doAuthorizationByIDs(ids []string, userID string) ([]mod
var f model.File
f, err := svc.fileCache.Get(id)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewFileNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
if svc.fileGuard.IsAuthorized(userID, f, model.PermissionViewer) {
res = append(res, f)
Expand Down Expand Up @@ -2105,7 +2111,12 @@ func (mp *FileMapper) mapMany(data []model.File, userID string) ([]*File, error)
for _, file := range data {
f, err := mp.mapOne(file, userID)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewFileNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
res = append(res, f)
}
Expand Down
17 changes: 14 additions & 3 deletions api/service/group_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package service

import (
"errors"
"sort"
"time"

Expand Down Expand Up @@ -337,7 +338,12 @@ func (svc *GroupService) doAuthorizationByIDs(ids []string, userID string) ([]mo
var o model.Group
o, err := svc.groupCache.Get(id)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewGroupNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
if svc.groupGuard.IsAuthorized(userID, o, model.PermissionViewer) {
res = append(res, o)
Expand Down Expand Up @@ -461,11 +467,16 @@ func (mp *groupMapper) mapOne(m model.Group, userID string) (*Group, error) {
}

func (mp *groupMapper) mapMany(groups []model.Group, userID string) ([]*Group, error) {
res := []*Group{}
res := make([]*Group, 0)
for _, group := range groups {
g, err := mp.mapOne(group, userID)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewGroupNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
res = append(res, g)
}
Expand Down
15 changes: 13 additions & 2 deletions api/service/organization_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package service

import (
"errors"
"sort"
"time"

Expand Down Expand Up @@ -300,7 +301,12 @@ func (svc *OrganizationService) doAuthorizationByIDs(ids []string, userID string
var o model.Organization
o, err := svc.orgCache.Get(id)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewOrganizationNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
if svc.orgGuard.IsAuthorized(userID, o, model.PermissionViewer) {
res = append(res, o)
Expand Down Expand Up @@ -415,7 +421,12 @@ func (mp *organizationMapper) mapMany(orgs []model.Organization, userID string)
for _, org := range orgs {
o, err := mp.mapOne(org, userID)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewOrganizationNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
res = append(res, o)
}
Expand Down
133 changes: 72 additions & 61 deletions api/service/task_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package service

import (
"errors"
"slices"
"sort"
"time"
Expand Down Expand Up @@ -226,6 +227,63 @@ func (svc *TaskService) List(opts TaskListOptions, userID string) (*TaskList, er
}, nil
}

func (svc *TaskService) GetCount(userID string) (*int64, error) {
var res int64
var err error
if res, err = svc.taskRepo.GetCountByEmail(userID); err != nil {
return nil, err
}
return &res, nil
}

func (svc *TaskService) Dismiss(id string, userID string) error {
task, err := svc.taskCache.Get(id)
if err != nil {
return err
}
if task.GetUserID() != userID {
return errorpkg.NewTaskBelongsToAnotherUserError(nil)
}
if !task.HasError() {
return errorpkg.NewTaskIsRunningError(nil)
}
return svc.deleteAndSync(id)
}

type TaskDismissAllResult struct {
Succeeded []string `json:"succeeded"`
Failed []string `json:"failed"`
}

func (svc *TaskService) DismissAll(userID string) (*TaskDismissAllResult, error) {
ids, err := svc.taskRepo.GetIDs(userID)
if err != nil {
return nil, err
}
authorized, err := svc.doAuthorizationByIDs(ids, userID)
if err != nil {
return nil, err
}
res := TaskDismissAllResult{
Succeeded: make([]string, 0),
Failed: make([]string, 0),
}
for _, t := range authorized {
if t.HasError() {
if err := svc.deleteAndSync(t.GetID()); err != nil {
res.Failed = append(res.Failed, t.GetID())
} else {
res.Succeeded = append(res.Succeeded, t.GetID())
}
}
}
return &res, nil
}

func (svc *TaskService) Delete(id string) error {
return svc.deleteAndSync(id)
}

func (svc *TaskService) doAuthorization(data []model.Task, userID string) ([]model.Task, error) {
var res []model.Task
for _, t := range data {
Expand All @@ -242,7 +300,12 @@ func (svc *TaskService) doAuthorizationByIDs(ids []string, userID string) ([]mod
var t model.Task
t, err := svc.taskCache.Get(id)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewTaskNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
if t.GetUserID() == userID {
res = append(res, t)
Expand Down Expand Up @@ -291,15 +354,6 @@ func (svc *TaskService) doSorting(data []model.Task, sortBy string, sortOrder st
return data
}

func (svc *TaskService) GetCount(userID string) (*int64, error) {
var res int64
var err error
if res, err = svc.taskRepo.GetCountByEmail(userID); err != nil {
return nil, err
}
return &res, nil
}

func (svc *TaskService) doPagination(data []model.Task, page, size uint) (pageData []model.Task, totalElements uint, totalPages uint) {
totalElements = uint(len(data))
totalPages = (totalElements + size - 1) / size
Expand All @@ -314,54 +368,6 @@ func (svc *TaskService) doPagination(data []model.Task, page, size uint) (pageDa
return data[startIndex:endIndex], totalElements, totalPages
}

func (svc *TaskService) Dismiss(id string, userID string) error {
task, err := svc.taskCache.Get(id)
if err != nil {
return err
}
if task.GetUserID() != userID {
return errorpkg.NewTaskBelongsToAnotherUserError(nil)
}
if !task.HasError() {
return errorpkg.NewTaskIsRunningError(nil)
}
return svc.deleteAndSync(id)
}

type TaskDismissAllResult struct {
Succeeded []string `json:"succeeded"`
Failed []string `json:"failed"`
}

func (svc *TaskService) DismissAll(userID string) (*TaskDismissAllResult, error) {
ids, err := svc.taskRepo.GetIDs(userID)
if err != nil {
return nil, err
}
authorized, err := svc.doAuthorizationByIDs(ids, userID)
if err != nil {
return nil, err
}
res := TaskDismissAllResult{
Succeeded: make([]string, 0),
Failed: make([]string, 0),
}
for _, t := range authorized {
if t.HasError() {
if err := svc.deleteAndSync(t.GetID()); err != nil {
res.Failed = append(res.Failed, t.GetID())
} else {
res.Succeeded = append(res.Succeeded, t.GetID())
}
}
}
return &res, nil
}

func (svc *TaskService) Delete(id string) error {
return svc.deleteAndSync(id)
}

func (svc *TaskService) insertAndSync(opts repo.TaskInsertOptions) (model.Task, error) {
task, err := svc.taskRepo.Insert(opts)
if err != nil {
Expand Down Expand Up @@ -453,12 +459,17 @@ func (mp *taskMapper) mapOne(m model.Task) (*Task, error) {
}, nil
}

func (mp *taskMapper) mapMany(orgs []model.Task) ([]*Task, error) {
func (mp *taskMapper) mapMany(tasks []model.Task) ([]*Task, error) {
res := make([]*Task, 0)
for _, task := range orgs {
for _, task := range tasks {
t, err := mp.mapOne(task)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewTaskNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
res = append(res, t)
}
Expand Down
2 changes: 1 addition & 1 deletion api/service/user_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (mp *userMapper) mapOne(user model.User) *User {
}

func (mp *userMapper) mapMany(users []model.User) ([]*User, error) {
res := []*User{}
res := make([]*User, 0)
for _, user := range users {
res = append(res, mp.mapOne(user))
}
Expand Down
15 changes: 13 additions & 2 deletions api/service/workspace_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package service

import (
"errors"
"sort"
"strings"
"time"
Expand Down Expand Up @@ -336,7 +337,12 @@ func (svc *WorkspaceService) doAuthorizationByIDs(ids []string, userID string) (
var w model.Workspace
w, err := svc.workspaceCache.Get(id)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewWorkspaceNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
if svc.workspaceGuard.IsAuthorized(userID, w, model.PermissionViewer) {
res = append(res, w)
Expand Down Expand Up @@ -467,7 +473,12 @@ func (mp *workspaceMapper) mapMany(workspaces []model.Workspace, userID string)
for _, workspace := range workspaces {
w, err := mp.mapOne(workspace, userID)
if err != nil {
return nil, err
var e *errorpkg.ErrorResponse
if errors.As(err, &e) && e.Code == errorpkg.NewWorkspaceNotFoundError(nil).Code {
continue
} else {
return nil, err
}
}
res = append(res, w)
}
Expand Down

0 comments on commit 5baf835

Please sign in to comment.