From ab19e3da81c6dd5e52d3cdb33902254d6f46b544 Mon Sep 17 00:00:00 2001 From: Christian Carlsson Date: Mon, 11 Nov 2024 12:24:32 +0000 Subject: [PATCH] feat: store refresh token (#7) --- db/migrations/002_base.up.sql | 9 +-- go/pkg/orm/auth.go | 118 +++++++++++++++++++++++++--------- go/pkg/orm/exercises.go | 50 -------------- go/pkg/repos/auth.go | 16 ++++- go/rpc/auth/auth.go | 22 ++++++- js/src/views/Login.vue | 5 +- js/src/views/Signup.vue | 3 - 7 files changed, 131 insertions(+), 92 deletions(-) diff --git a/db/migrations/002_base.up.sql b/db/migrations/002_base.up.sql index d80a2111..68e50368 100644 --- a/db/migrations/002_base.up.sql +++ b/db/migrations/002_base.up.sql @@ -1,9 +1,10 @@ CREATE TABLE getstronger.auth ( - id UUID PRIMARY KEY NOT NULL DEFAULT uuid_generate_v4(), - email VARCHAR(255) NOT NULL UNIQUE, - password BYTEA NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT (NOW() AT TIME ZONE 'UTC') + id UUID PRIMARY KEY NOT NULL DEFAULT uuid_generate_v4(), + email VARCHAR(128) NOT NULL UNIQUE, + password BYTEA NOT NULL, + refresh_token VARCHAR(256) NULL, + created_at TIMESTAMP NOT NULL DEFAULT (NOW() AT TIME ZONE 'UTC') ); CREATE TABLE getstronger.users diff --git a/go/pkg/orm/auth.go b/go/pkg/orm/auth.go index 4c0fc0ec..037dea6b 100644 --- a/go/pkg/orm/auth.go +++ b/go/pkg/orm/auth.go @@ -14,6 +14,7 @@ import ( "time" "github.com/friendsofgo/errors" + "github.com/volatiletech/null/v8" "github.com/volatiletech/sqlboiler/v4/boil" "github.com/volatiletech/sqlboiler/v4/queries" "github.com/volatiletech/sqlboiler/v4/queries/qm" @@ -23,37 +24,42 @@ import ( // Auth is an object representing the database table. type Auth struct { - ID string `boil:"id" json:"id" toml:"id" yaml:"id"` - Email string `boil:"email" json:"email" toml:"email" yaml:"email"` - Password []byte `boil:"password" json:"password" toml:"password" yaml:"password"` - CreatedAt time.Time `boil:"created_at" json:"created_at" toml:"created_at" yaml:"created_at"` + ID string `boil:"id" json:"id" toml:"id" yaml:"id"` + Email string `boil:"email" json:"email" toml:"email" yaml:"email"` + Password []byte `boil:"password" json:"password" toml:"password" yaml:"password"` + RefreshToken null.String `boil:"refresh_token" json:"refresh_token,omitempty" toml:"refresh_token" yaml:"refresh_token,omitempty"` + CreatedAt time.Time `boil:"created_at" json:"created_at" toml:"created_at" yaml:"created_at"` R *authR `boil:"-" json:"-" toml:"-" yaml:"-"` L authL `boil:"-" json:"-" toml:"-" yaml:"-"` } var AuthColumns = struct { - ID string - Email string - Password string - CreatedAt string + ID string + Email string + Password string + RefreshToken string + CreatedAt string }{ - ID: "id", - Email: "email", - Password: "password", - CreatedAt: "created_at", + ID: "id", + Email: "email", + Password: "password", + RefreshToken: "refresh_token", + CreatedAt: "created_at", } var AuthTableColumns = struct { - ID string - Email string - Password string - CreatedAt string + ID string + Email string + Password string + RefreshToken string + CreatedAt string }{ - ID: "auth.id", - Email: "auth.email", - Password: "auth.password", - CreatedAt: "auth.created_at", + ID: "auth.id", + Email: "auth.email", + Password: "auth.password", + RefreshToken: "auth.refresh_token", + CreatedAt: "auth.created_at", } // Generated where @@ -94,6 +100,56 @@ func (w whereHelper__byte) LTE(x []byte) qm.QueryMod { return qmhelper.Where(w.f func (w whereHelper__byte) GT(x []byte) qm.QueryMod { return qmhelper.Where(w.field, qmhelper.GT, x) } func (w whereHelper__byte) GTE(x []byte) qm.QueryMod { return qmhelper.Where(w.field, qmhelper.GTE, x) } +type whereHelpernull_String struct{ field string } + +func (w whereHelpernull_String) EQ(x null.String) qm.QueryMod { + return qmhelper.WhereNullEQ(w.field, false, x) +} +func (w whereHelpernull_String) NEQ(x null.String) qm.QueryMod { + return qmhelper.WhereNullEQ(w.field, true, x) +} +func (w whereHelpernull_String) LT(x null.String) qm.QueryMod { + return qmhelper.Where(w.field, qmhelper.LT, x) +} +func (w whereHelpernull_String) LTE(x null.String) qm.QueryMod { + return qmhelper.Where(w.field, qmhelper.LTE, x) +} +func (w whereHelpernull_String) GT(x null.String) qm.QueryMod { + return qmhelper.Where(w.field, qmhelper.GT, x) +} +func (w whereHelpernull_String) GTE(x null.String) qm.QueryMod { + return qmhelper.Where(w.field, qmhelper.GTE, x) +} +func (w whereHelpernull_String) LIKE(x null.String) qm.QueryMod { + return qm.Where(w.field+" LIKE ?", x) +} +func (w whereHelpernull_String) NLIKE(x null.String) qm.QueryMod { + return qm.Where(w.field+" NOT LIKE ?", x) +} +func (w whereHelpernull_String) ILIKE(x null.String) qm.QueryMod { + return qm.Where(w.field+" ILIKE ?", x) +} +func (w whereHelpernull_String) NILIKE(x null.String) qm.QueryMod { + return qm.Where(w.field+" NOT ILIKE ?", x) +} +func (w whereHelpernull_String) IN(slice []string) qm.QueryMod { + values := make([]interface{}, 0, len(slice)) + for _, value := range slice { + values = append(values, value) + } + return qm.WhereIn(fmt.Sprintf("%s IN ?", w.field), values...) +} +func (w whereHelpernull_String) NIN(slice []string) qm.QueryMod { + values := make([]interface{}, 0, len(slice)) + for _, value := range slice { + values = append(values, value) + } + return qm.WhereNotIn(fmt.Sprintf("%s NOT IN ?", w.field), values...) +} + +func (w whereHelpernull_String) IsNull() qm.QueryMod { return qmhelper.WhereIsNull(w.field) } +func (w whereHelpernull_String) IsNotNull() qm.QueryMod { return qmhelper.WhereIsNotNull(w.field) } + type whereHelpertime_Time struct{ field string } func (w whereHelpertime_Time) EQ(x time.Time) qm.QueryMod { @@ -116,15 +172,17 @@ func (w whereHelpertime_Time) GTE(x time.Time) qm.QueryMod { } var AuthWhere = struct { - ID whereHelperstring - Email whereHelperstring - Password whereHelper__byte - CreatedAt whereHelpertime_Time + ID whereHelperstring + Email whereHelperstring + Password whereHelper__byte + RefreshToken whereHelpernull_String + CreatedAt whereHelpertime_Time }{ - ID: whereHelperstring{field: "\"getstronger\".\"auth\".\"id\""}, - Email: whereHelperstring{field: "\"getstronger\".\"auth\".\"email\""}, - Password: whereHelper__byte{field: "\"getstronger\".\"auth\".\"password\""}, - CreatedAt: whereHelpertime_Time{field: "\"getstronger\".\"auth\".\"created_at\""}, + ID: whereHelperstring{field: "\"getstronger\".\"auth\".\"id\""}, + Email: whereHelperstring{field: "\"getstronger\".\"auth\".\"email\""}, + Password: whereHelper__byte{field: "\"getstronger\".\"auth\".\"password\""}, + RefreshToken: whereHelpernull_String{field: "\"getstronger\".\"auth\".\"refresh_token\""}, + CreatedAt: whereHelpertime_Time{field: "\"getstronger\".\"auth\".\"created_at\""}, } // AuthRels is where relationship names are stored. @@ -155,9 +213,9 @@ func (r *authR) GetUsers() UserSlice { type authL struct{} var ( - authAllColumns = []string{"id", "email", "password", "created_at"} + authAllColumns = []string{"id", "email", "password", "refresh_token", "created_at"} authColumnsWithoutDefault = []string{"email", "password"} - authColumnsWithDefault = []string{"id", "created_at"} + authColumnsWithDefault = []string{"id", "refresh_token", "created_at"} authPrimaryKeyColumns = []string{"id"} authGeneratedColumns = []string{} ) diff --git a/go/pkg/orm/exercises.go b/go/pkg/orm/exercises.go index 63033e04..da642f86 100644 --- a/go/pkg/orm/exercises.go +++ b/go/pkg/orm/exercises.go @@ -64,56 +64,6 @@ var ExerciseTableColumns = struct { // Generated where -type whereHelpernull_String struct{ field string } - -func (w whereHelpernull_String) EQ(x null.String) qm.QueryMod { - return qmhelper.WhereNullEQ(w.field, false, x) -} -func (w whereHelpernull_String) NEQ(x null.String) qm.QueryMod { - return qmhelper.WhereNullEQ(w.field, true, x) -} -func (w whereHelpernull_String) LT(x null.String) qm.QueryMod { - return qmhelper.Where(w.field, qmhelper.LT, x) -} -func (w whereHelpernull_String) LTE(x null.String) qm.QueryMod { - return qmhelper.Where(w.field, qmhelper.LTE, x) -} -func (w whereHelpernull_String) GT(x null.String) qm.QueryMod { - return qmhelper.Where(w.field, qmhelper.GT, x) -} -func (w whereHelpernull_String) GTE(x null.String) qm.QueryMod { - return qmhelper.Where(w.field, qmhelper.GTE, x) -} -func (w whereHelpernull_String) LIKE(x null.String) qm.QueryMod { - return qm.Where(w.field+" LIKE ?", x) -} -func (w whereHelpernull_String) NLIKE(x null.String) qm.QueryMod { - return qm.Where(w.field+" NOT LIKE ?", x) -} -func (w whereHelpernull_String) ILIKE(x null.String) qm.QueryMod { - return qm.Where(w.field+" ILIKE ?", x) -} -func (w whereHelpernull_String) NILIKE(x null.String) qm.QueryMod { - return qm.Where(w.field+" NOT ILIKE ?", x) -} -func (w whereHelpernull_String) IN(slice []string) qm.QueryMod { - values := make([]interface{}, 0, len(slice)) - for _, value := range slice { - values = append(values, value) - } - return qm.WhereIn(fmt.Sprintf("%s IN ?", w.field), values...) -} -func (w whereHelpernull_String) NIN(slice []string) qm.QueryMod { - values := make([]interface{}, 0, len(slice)) - for _, value := range slice { - values = append(values, value) - } - return qm.WhereNotIn(fmt.Sprintf("%s NOT IN ?", w.field), values...) -} - -func (w whereHelpernull_String) IsNull() qm.QueryMod { return qmhelper.WhereIsNull(w.field) } -func (w whereHelpernull_String) IsNotNull() qm.QueryMod { return qmhelper.WhereIsNotNull(w.field) } - var ExerciseWhere = struct { ID whereHelperstring UserID whereHelperstring diff --git a/go/pkg/repos/auth.go b/go/pkg/repos/auth.go index 94b1290f..487d4e52 100644 --- a/go/pkg/repos/auth.go +++ b/go/pkg/repos/auth.go @@ -4,9 +4,12 @@ import ( "context" "database/sql" "fmt" - "github.com/crlssn/getstronger/go/pkg/orm" + + "github.com/volatiletech/null/v8" "github.com/volatiletech/sqlboiler/v4/boil" "golang.org/x/crypto/bcrypt" + + "github.com/crlssn/getstronger/go/pkg/orm" ) type Auth struct { @@ -59,3 +62,14 @@ func (a *Auth) CompareEmailAndPassword(ctx context.Context, email, password stri func (a *Auth) FromEmail(ctx context.Context, email string) (*orm.Auth, error) { return orm.Auths(orm.AuthWhere.Email.EQ(email)).One(ctx, a.db) } + +func (a *Auth) UpdateRefreshToken(ctx context.Context, authID string, refreshToken string) error { + auth := &orm.Auth{ID: authID, RefreshToken: null.StringFrom(refreshToken)} + _, err := auth.Update(ctx, a.db, boil.Whitelist(orm.AuthColumns.RefreshToken)) + return err +} + +func (a *Auth) DeleteRefreshToken(ctx context.Context, refreshToken string) error { + _, err := orm.Auths(orm.AuthWhere.RefreshToken.EQ(null.StringFrom(refreshToken))).UpdateAll(ctx, a.db, orm.M{orm.AuthColumns.RefreshToken: nil}) + return err +} diff --git a/go/rpc/auth/auth.go b/go/rpc/auth/auth.go index bd196a7c..e66c53b1 100644 --- a/go/rpc/auth/auth.go +++ b/go/rpc/auth/auth.go @@ -82,6 +82,11 @@ func (h *handler) Login(ctx context.Context, req *connect.Request[v1.LoginReques return nil, connect.NewError(connect.CodeInternal, errors.New("")) } + if err = h.repo.UpdateRefreshToken(ctx, auth.ID, refreshToken); err != nil { + log.Error("refresh token upsert failed", zap.Error(err)) + return nil, connect.NewError(connect.CodeInternal, errors.New("")) + } + res := connect.NewResponse(&v1.LoginResponse{ AccessToken: accessToken, }) @@ -92,7 +97,7 @@ func (h *handler) Login(ctx context.Context, req *connect.Request[v1.LoginReques HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, // TODO: Set to http.SameSiteStrictMode. - Path: apiv1connect.AuthServiceRefreshTokenProcedure, + Path: "/api.v1.AuthService", MaxAge: int(jwt.ExpiryTimeRefresh), } res.Header().Set("Set-Cookie", cookie.String()) @@ -133,9 +138,20 @@ func (h *handler) RefreshToken(ctx context.Context, _ *connect.Request[v1.Refres }), nil } -func (h *handler) Logout(_ context.Context, _ *connect.Request[v1.LogoutRequest]) (*connect.Response[v1.LogoutResponse], error) { +func (h *handler) Logout(ctx context.Context, _ *connect.Request[v1.LogoutRequest]) (*connect.Response[v1.LogoutResponse], error) { log := h.log.With(xzap.FieldRPC(apiv1connect.AuthServiceLogoutProcedure)) + refreshToken, ok := ctx.Value(jwt.ContextKeyRefreshToken).(string) + if !ok { + log.Warn("refresh token not found") + return nil, connect.NewError(connect.CodeUnauthenticated, http.ErrNoCookie) + } + + if err := h.repo.DeleteRefreshToken(ctx, refreshToken); err != nil { + log.Error("refresh token deletion failed", zap.Error(err)) + return nil, connect.NewError(connect.CodeInternal, errors.New("")) + } + res := connect.NewResponse(&v1.LogoutResponse{}) cookie := &http.Cookie{ Name: "refreshToken", @@ -143,7 +159,7 @@ func (h *handler) Logout(_ context.Context, _ *connect.Request[v1.LogoutRequest] HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, // TODO: Set to http.SameSiteStrictMode. - Path: apiv1connect.AuthServiceRefreshTokenProcedure, + Path: "/api.v1.AuthService", MaxAge: -1, } res.Header().Set("Set-Cookie", cookie.String()) diff --git a/js/src/views/Login.vue b/js/src/views/Login.vue index aa214d97..7efe44d1 100644 --- a/js/src/views/Login.vue +++ b/js/src/views/Login.vue @@ -2,7 +2,7 @@ import {LoginRequest} from "@/pb/api/v1/auth_pb"; import {Auth} from "@/clients/clients"; import {ref} from 'vue' -import {RouterLink} from 'vue-router' +import {RouterLink, useRoute} from 'vue-router' import {ConnectError} from "@connectrpc/connect"; import {useAuthStore} from "@/stores/auth"; import router from "@/router/router"; @@ -42,6 +42,9 @@ const login = async () => {
+