diff --git a/backendRedisSession.go b/backendRedisSession.go index 8482af0..0f783b3 100644 --- a/backendRedisSession.go +++ b/backendRedisSession.go @@ -7,17 +7,17 @@ import ( "time" ) -type BackendRedisSession struct { +type backendRedisSession struct { db onedb.DBer prefix string } func NewBackendRedisSession(server string, port int, password string, maxIdle, maxConnections int, keyPrefix string) SessionBackender { r := onedb.NewRedis(server, port, password, maxIdle, maxConnections) - return &BackendRedisSession{db: r, prefix: keyPrefix} + return &backendRedisSession{db: r, prefix: keyPrefix} } -func (r *BackendRedisSession) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, +func (r *backendRedisSession) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, includeRememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) { session := UserLoginSession{LoginID: loginID, UserID: userID, SessionHash: sessionHash, RenewTimeUTC: sessionRenewTimeUTC, ExpireTimeUTC: sessionExpireTimeUTC} err := r.saveSession(&session) @@ -37,15 +37,15 @@ func (r *BackendRedisSession) CreateSession(loginID, userID int, sessionHash str return &session, &rememberMe, nil } -func (r *BackendRedisSession) GetSession(sessionHash string) (*UserLoginSession, error) { +func (r *backendRedisSession) GetSession(sessionHash string) (*UserLoginSession, error) { session := &UserLoginSession{} - return session, r.db.QueryStruct(onedb.NewRedisGetCommand(r.getSessionUrl(sessionHash)), session) + return session, r.db.QueryStructRow(onedb.NewRedisGetCommand(r.getSessionUrl(sessionHash)), session) } -func (r *BackendRedisSession) RenewSession(sessionHash string, renewTimeUTC time.Time) (*UserLoginSession, error) { +func (r *backendRedisSession) RenewSession(sessionHash string, renewTimeUTC time.Time) (*UserLoginSession, error) { session := &UserLoginSession{} key := r.getSessionUrl(sessionHash) - err := r.db.QueryStruct(onedb.NewRedisGetCommand(key), session) + err := r.db.QueryStructRow(onedb.NewRedisGetCommand(key), session) if err != nil { return nil, err } @@ -53,24 +53,24 @@ func (r *BackendRedisSession) RenewSession(sessionHash string, renewTimeUTC time return session, r.saveSession(session) } -func (r *BackendRedisSession) InvalidateSession(sessionHash string) error { +func (r *backendRedisSession) InvalidateSession(sessionHash string) error { return r.db.Execute(onedb.NewRedisDelCommand(r.getSessionUrl(sessionHash))) } -func (r *BackendRedisSession) InvalidateSessions(email string) error { +func (r *backendRedisSession) InvalidateSessions(email string) error { return nil } -func (r *BackendRedisSession) GetRememberMe(selector string) (*UserLoginRememberMe, error) { +func (r *backendRedisSession) GetRememberMe(selector string) (*UserLoginRememberMe, error) { rememberMe := &UserLoginRememberMe{} - return rememberMe, r.db.QueryStruct(onedb.NewRedisGetCommand(r.getRememberMeUrl(selector)), rememberMe) + return rememberMe, r.db.QueryStructRow(onedb.NewRedisGetCommand(r.getRememberMeUrl(selector)), rememberMe) } -func (r *BackendRedisSession) RenewRememberMe(selector string, renewTimeUTC time.Time) (*UserLoginRememberMe, error) { +func (r *backendRedisSession) RenewRememberMe(selector string, renewTimeUTC time.Time) (*UserLoginRememberMe, error) { rememberMe := &UserLoginRememberMe{} - err := r.db.QueryStruct(onedb.NewRedisGetCommand(r.getRememberMeUrl(selector)), rememberMe) + err := r.db.QueryStructRow(onedb.NewRedisGetCommand(r.getRememberMeUrl(selector)), rememberMe) if err != nil { - return nil, errRememberMeNotFound + return nil, err } else if rememberMe.ExpireTimeUTC.Before(time.Now().UTC()) { return nil, errRememberMeExpired } else if rememberMe.ExpireTimeUTC.Before(renewTimeUTC) || renewTimeUTC.Before(time.Now().UTC()) { @@ -80,33 +80,33 @@ func (r *BackendRedisSession) RenewRememberMe(selector string, renewTimeUTC time return rememberMe, nil } -func (r *BackendRedisSession) InvalidateRememberMe(selector string) error { +func (r *backendRedisSession) InvalidateRememberMe(selector string) error { return r.db.Execute(onedb.NewRedisDelCommand(r.getRememberMeUrl(selector))) } -func (r *BackendRedisSession) Close() error { +func (r *backendRedisSession) Close() error { return r.db.Close() } -func (r *BackendRedisSession) saveSession(session *UserLoginSession) error { +func (r *backendRedisSession) saveSession(session *UserLoginSession) error { if time.Since(session.ExpireTimeUTC).Seconds() >= 0 { return errors.New("Unable to save expired session") } return r.save(r.getSessionUrl(session.SessionHash), session, round(rememberMeExpireDuration.Seconds())) } -func (r *BackendRedisSession) saveRememberMe(rememberMe *UserLoginRememberMe) error { +func (r *backendRedisSession) saveRememberMe(rememberMe *UserLoginRememberMe) error { if time.Since(rememberMe.ExpireTimeUTC).Seconds() >= 0 { return errors.New("Unable to save expired rememberMe") } return r.save(r.getRememberMeUrl(rememberMe.Selector), rememberMe, round(rememberMeExpireDuration.Seconds())) } -func (r *BackendRedisSession) getSessionUrl(sessionHash string) string { +func (r *backendRedisSession) getSessionUrl(sessionHash string) string { return r.prefix + "/session/" + sessionHash } -func (r *BackendRedisSession) getRememberMeUrl(selector string) string { +func (r *backendRedisSession) getRememberMeUrl(selector string) string { return r.prefix + "/rememberMe/" + selector } @@ -114,7 +114,7 @@ func round(num float64) int { return int(math.Floor(0.5 + num)) } -func (r *BackendRedisSession) save(key string, value interface{}, expireSeconds int) error { +func (r *backendRedisSession) save(key string, value interface{}, expireSeconds int) error { cmd, err := onedb.NewRedisSetCommand(key, value, expireSeconds) if err != nil { return err diff --git a/backendRedisSession_test.go b/backendRedisSession_test.go new file mode 100644 index 0000000..eb74c07 --- /dev/null +++ b/backendRedisSession_test.go @@ -0,0 +1,126 @@ +package main + +import ( + "github.com/robarchibald/onedb" + "testing" + "time" +) + +func TestNewBackendRedisSession(t *testing.T) { + +} + +func TestRedisCreateSession(t *testing.T) { + // expired session error + m := onedb.NewMock(nil, nil, nil) + r := backendRedisSession{db: m, prefix: "test"} + _, _, err := r.CreateSession(1, 1, "hash", time.Now(), time.Now(), false, "selector", "token", time.Now(), time.Now()) + if err == nil || len(m.QueriesRun()) != 0 { + t.Error("expected error") + } + + // expired rememberMe, but session should save. + _, _, err = r.CreateSession(1, 1, "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now()) + if q := m.QueriesRun(); err == nil || len(q) != 1 || q[0].(*onedb.RedisCommand).Command != "SETEX" || len(q[0].(*onedb.RedisCommand).Args) != 3 || q[0].(*onedb.RedisCommand).Args[0] != "test/session/hash" { + t.Error("expected error") + } + + // success + m = onedb.NewMock(nil, nil, nil) + r = backendRedisSession{db: m, prefix: "test"} + session, rememberMe, err := r.CreateSession(1, 1, "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now().AddDate(1, 0, 0)) + if q := m.QueriesRun(); err != nil || len(q) != 2 || q[1].(*onedb.RedisCommand).Command != "SETEX" || len(q[1].(*onedb.RedisCommand).Args) != 3 || q[1].(*onedb.RedisCommand).Args[0] != "test/rememberMe/selector" { + t.Error("expected success") + } + if session.SessionHash != "hash" || rememberMe.Selector != "selector" || rememberMe.TokenHash != "token" { + t.Error("expected valid session and rememberMe") + } +} + +func TestRedisGetSession(t *testing.T) { + data := UserLoginSession{LoginID: 1, SessionHash: "hash"} + m := onedb.NewMock(nil, nil, data) + r := backendRedisSession{db: m, prefix: "test"} + s, err := r.GetSession("hash") + if err != nil || s.LoginID != 1 || s.SessionHash != "hash" { + t.Error("expected error") + } +} + +func TestRedisRenewSession(t *testing.T) { + // success + data := UserLoginSession{LoginID: 1, SessionHash: "hash", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)} + m := onedb.NewMock(nil, nil, data) + r := backendRedisSession{db: m, prefix: "test"} + s, err := r.RenewSession("hash", time.Now().AddDate(1, 0, 0)) + if err != nil || s == nil { + t.Error("expected success") + } + + // error. No data + m = onedb.NewMock(nil, nil, nil) + r = backendRedisSession{db: m, prefix: "test"} + s, err = r.RenewSession("hash", time.Now().AddDate(1, 0, 0)) + if err == nil || s != nil { + t.Error("expected success") + } +} + +func TestRedisInvalidateSession(t *testing.T) { + // success + data := UserLoginSession{LoginID: 1, SessionHash: "hash", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)} + m := onedb.NewMock(nil, nil, data) + r := backendRedisSession{db: m, prefix: "test"} + if err := r.InvalidateSession("hash"); err != nil { + t.Error("expected success") + } +} + +func TestRedisGetRememberMe(t *testing.T) { + // success + data := UserLoginRememberMe{Selector: "selector"} + m := onedb.NewMock(nil, nil, data) + r := backendRedisSession{db: m, prefix: "test"} + rememberMe, err := r.GetRememberMe("selector") + if err != nil || rememberMe.Selector != "selector" { + t.Error("expected to find rememberMe", err, rememberMe) + } +} + +func TestRedisRenewRememberMe(t *testing.T) { + // success + data := UserLoginRememberMe{Selector: "selector", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)} + m := onedb.NewMock(nil, nil, data) + r := backendRedisSession{db: m, prefix: "test"} + renew := time.Now().AddDate(0, 1, 0) + remember, err := r.RenewRememberMe("selector", renew) + if err != nil || remember == nil || remember.RenewTimeUTC != renew { + t.Error("expected success", remember, err) + } + + // nothing to renew + m = onedb.NewMock(nil, nil, nil) + r = backendRedisSession{db: m, prefix: "test"} + remember, err = r.RenewRememberMe("selector", time.Now()) + if err == nil || remember != nil { + t.Error("expected error", remember, err) + } + + // expired + data = UserLoginRememberMe{Selector: "selector", ExpireTimeUTC: time.Now().AddDate(0, 0, -1)} + m = onedb.NewMock(nil, nil, data) + r = backendRedisSession{db: m, prefix: "test"} + remember, err = r.RenewRememberMe("selector", time.Now()) + if err != errRememberMeExpired || remember != nil { + t.Error("expected error", remember, err) + } + + // invalid renew time + data = UserLoginRememberMe{Selector: "selector", ExpireTimeUTC: time.Now().AddDate(0, 0, 1)} + m = onedb.NewMock(nil, nil, data) + r = backendRedisSession{db: m, prefix: "test"} + remember, err = r.RenewRememberMe("selector", time.Now().AddDate(0, 0, -1)) + if err != errInvalidRenewTimeUTC || remember != nil { + t.Error("expected error", remember, err) + } +}