diff --git a/server/pkg/repo/repo.go b/server/pkg/repo/repo.go index 484af4f6..4876c9ed 100644 --- a/server/pkg/repo/repo.go +++ b/server/pkg/repo/repo.go @@ -824,6 +824,12 @@ func (r *repo) DeleteWorkout(ctx context.Context, opts ...DeleteWorkoutOpt) erro return fmt.Errorf("workout comments delete: %w", err) } + if _, err = orm.Notifications( + qm.Where("payload ->> 'workoutId' = ?", workout.ID), + ).DeleteAll(ctx, tx.GetTx()); err != nil { + return fmt.Errorf("notifications delete: %w", err) + } + if _, err = workout.Delete(ctx, tx.GetTx()); err != nil { return fmt.Errorf("workout delete: %w", err) } diff --git a/server/pkg/repo/repo_test.go b/server/pkg/repo/repo_test.go index b23e51f0..6c789d40 100644 --- a/server/pkg/repo/repo_test.go +++ b/server/pkg/repo/repo_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/suite" + "github.com/volatiletech/sqlboiler/v4/queries/qm" "github.com/crlssn/getstronger/server/pkg/orm" "github.com/crlssn/getstronger/server/pkg/repo" @@ -293,3 +294,93 @@ func (s *repoSuite) TestGetPreviousWorkoutSets() { }) } } + +func (s *repoSuite) TestDeleteWorkout() { + type expected struct { + err error + } + + type test struct { + name string + opts []repo.DeleteWorkoutOpt + init func(test) *orm.Workout + expected expected + } + + userID := uuid.NewString() + workoutID := uuid.NewString() + + tests := []test{ + { + name: "ok_with_workout_id", + opts: []repo.DeleteWorkoutOpt{ + repo.DeleteWorkoutWithID(workoutID), + }, + init: func(_ test) *orm.Workout { + workout := s.testFactory.NewWorkout(testdb.WorkoutID(workoutID)) + s.testFactory.NewSet(testdb.SetWorkoutID(workoutID)) + s.testFactory.NewWorkoutComment(testdb.WorkoutCommentWorkoutID(workoutID)) + s.testFactory.NewNotification(testdb.NotificationPayload(repo.NotificationPayload{ + WorkoutID: workoutID, + })) + + return workout + }, + expected: expected{ + err: nil, + }, + }, + { + name: "ok_with_user_id", + opts: []repo.DeleteWorkoutOpt{ + repo.DeleteWorkoutWithUserID(userID), + }, + init: func(_ test) *orm.Workout { + user := s.testFactory.NewUser(testdb.UserID(userID)) + workout := s.testFactory.NewWorkout(testdb.WorkoutUserID(user.ID)) + s.testFactory.NewSet(testdb.SetWorkoutID(workout.ID)) + s.testFactory.NewWorkoutComment(testdb.WorkoutCommentWorkoutID(workout.ID)) + s.testFactory.NewNotification(testdb.NotificationPayload(repo.NotificationPayload{ + WorkoutID: workout.ID, + })) + + return workout + }, + expected: expected{ + err: nil, + }, + }, + } + + for _, t := range tests { + s.Run(t.name, func() { + workout := t.init(t) + err := s.repo.DeleteWorkout(context.Background(), t.opts...) + s.Require().ErrorIs(err, t.expected.err) + + exists, err := orm. + Workouts(orm.WorkoutWhere.ID.EQ(workout.ID)). + Exists(context.Background(), s.testContainer.DB) + s.Require().NoError(err) + s.Require().False(exists) + + exists, err = orm. + Sets(orm.SetWhere.WorkoutID.EQ(workout.ID)). + Exists(context.Background(), s.testContainer.DB) + s.Require().NoError(err) + s.Require().False(exists) + + exists, err = orm. + WorkoutComments(orm.WorkoutCommentWhere.WorkoutID.EQ(workout.ID)). + Exists(context.Background(), s.testContainer.DB) + s.Require().NoError(err) + s.Require().False(exists) + + exists, err = orm. + Notifications(qm.Where("payload ->> 'workoutId' = ?", workout.ID)). + Exists(context.Background(), s.testContainer.DB) + s.Require().NoError(err) + s.Require().False(exists) + }) + } +}