Skip to content

Commit

Permalink
bug fixes. Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Rob Archibald committed Jan 17, 2017
1 parent 5129566 commit 2fd9ea5
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 21 deletions.
42 changes: 21 additions & 21 deletions backendRedisSession.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ import (
"time"
)

type BackendRedisSession struct {
type backendRedisSession struct {
db onedb.DBer
prefix string
}

func NewBackendRedisSession(server string, port int, password string, maxIdle, maxConnections int, keyPrefix string) SessionBackender {
r := onedb.NewRedis(server, port, password, maxIdle, maxConnections)
return &BackendRedisSession{db: r, prefix: keyPrefix}
return &backendRedisSession{db: r, prefix: keyPrefix}
}

func (r *BackendRedisSession) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time,
func (r *backendRedisSession) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time,
includeRememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) {
session := UserLoginSession{LoginID: loginID, UserID: userID, SessionHash: sessionHash, RenewTimeUTC: sessionRenewTimeUTC, ExpireTimeUTC: sessionExpireTimeUTC}
err := r.saveSession(&session)
Expand All @@ -37,40 +37,40 @@ func (r *BackendRedisSession) CreateSession(loginID, userID int, sessionHash str
return &session, &rememberMe, nil
}

func (r *BackendRedisSession) GetSession(sessionHash string) (*UserLoginSession, error) {
func (r *backendRedisSession) GetSession(sessionHash string) (*UserLoginSession, error) {
session := &UserLoginSession{}
return session, r.db.QueryStruct(onedb.NewRedisGetCommand(r.getSessionUrl(sessionHash)), session)
return session, r.db.QueryStructRow(onedb.NewRedisGetCommand(r.getSessionUrl(sessionHash)), session)
}

func (r *BackendRedisSession) RenewSession(sessionHash string, renewTimeUTC time.Time) (*UserLoginSession, error) {
func (r *backendRedisSession) RenewSession(sessionHash string, renewTimeUTC time.Time) (*UserLoginSession, error) {
session := &UserLoginSession{}
key := r.getSessionUrl(sessionHash)
err := r.db.QueryStruct(onedb.NewRedisGetCommand(key), session)
err := r.db.QueryStructRow(onedb.NewRedisGetCommand(key), session)
if err != nil {
return nil, err
}
session.RenewTimeUTC = renewTimeUTC
return session, r.saveSession(session)
}

func (r *BackendRedisSession) InvalidateSession(sessionHash string) error {
func (r *backendRedisSession) InvalidateSession(sessionHash string) error {
return r.db.Execute(onedb.NewRedisDelCommand(r.getSessionUrl(sessionHash)))
}

func (r *BackendRedisSession) InvalidateSessions(email string) error {
func (r *backendRedisSession) InvalidateSessions(email string) error {
return nil
}

func (r *BackendRedisSession) GetRememberMe(selector string) (*UserLoginRememberMe, error) {
func (r *backendRedisSession) GetRememberMe(selector string) (*UserLoginRememberMe, error) {
rememberMe := &UserLoginRememberMe{}
return rememberMe, r.db.QueryStruct(onedb.NewRedisGetCommand(r.getRememberMeUrl(selector)), rememberMe)
return rememberMe, r.db.QueryStructRow(onedb.NewRedisGetCommand(r.getRememberMeUrl(selector)), rememberMe)
}

func (r *BackendRedisSession) RenewRememberMe(selector string, renewTimeUTC time.Time) (*UserLoginRememberMe, error) {
func (r *backendRedisSession) RenewRememberMe(selector string, renewTimeUTC time.Time) (*UserLoginRememberMe, error) {
rememberMe := &UserLoginRememberMe{}
err := r.db.QueryStruct(onedb.NewRedisGetCommand(r.getRememberMeUrl(selector)), rememberMe)
err := r.db.QueryStructRow(onedb.NewRedisGetCommand(r.getRememberMeUrl(selector)), rememberMe)
if err != nil {
return nil, errRememberMeNotFound
return nil, err
} else if rememberMe.ExpireTimeUTC.Before(time.Now().UTC()) {
return nil, errRememberMeExpired
} else if rememberMe.ExpireTimeUTC.Before(renewTimeUTC) || renewTimeUTC.Before(time.Now().UTC()) {
Expand All @@ -80,41 +80,41 @@ func (r *BackendRedisSession) RenewRememberMe(selector string, renewTimeUTC time
return rememberMe, nil
}

func (r *BackendRedisSession) InvalidateRememberMe(selector string) error {
func (r *backendRedisSession) InvalidateRememberMe(selector string) error {
return r.db.Execute(onedb.NewRedisDelCommand(r.getRememberMeUrl(selector)))
}

func (r *BackendRedisSession) Close() error {
func (r *backendRedisSession) Close() error {
return r.db.Close()
}

func (r *BackendRedisSession) saveSession(session *UserLoginSession) error {
func (r *backendRedisSession) saveSession(session *UserLoginSession) error {
if time.Since(session.ExpireTimeUTC).Seconds() >= 0 {
return errors.New("Unable to save expired session")
}
return r.save(r.getSessionUrl(session.SessionHash), session, round(rememberMeExpireDuration.Seconds()))
}

func (r *BackendRedisSession) saveRememberMe(rememberMe *UserLoginRememberMe) error {
func (r *backendRedisSession) saveRememberMe(rememberMe *UserLoginRememberMe) error {
if time.Since(rememberMe.ExpireTimeUTC).Seconds() >= 0 {
return errors.New("Unable to save expired rememberMe")
}
return r.save(r.getRememberMeUrl(rememberMe.Selector), rememberMe, round(rememberMeExpireDuration.Seconds()))
}

func (r *BackendRedisSession) getSessionUrl(sessionHash string) string {
func (r *backendRedisSession) getSessionUrl(sessionHash string) string {
return r.prefix + "/session/" + sessionHash
}

func (r *BackendRedisSession) getRememberMeUrl(selector string) string {
func (r *backendRedisSession) getRememberMeUrl(selector string) string {
return r.prefix + "/rememberMe/" + selector
}

func round(num float64) int {
return int(math.Floor(0.5 + num))
}

func (r *BackendRedisSession) save(key string, value interface{}, expireSeconds int) error {
func (r *backendRedisSession) save(key string, value interface{}, expireSeconds int) error {
cmd, err := onedb.NewRedisSetCommand(key, value, expireSeconds)
if err != nil {
return err
Expand Down
126 changes: 126 additions & 0 deletions backendRedisSession_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package main

import (
"github.com/robarchibald/onedb"
"testing"
"time"
)

func TestNewBackendRedisSession(t *testing.T) {

}

func TestRedisCreateSession(t *testing.T) {
// expired session error
m := onedb.NewMock(nil, nil, nil)
r := backendRedisSession{db: m, prefix: "test"}
_, _, err := r.CreateSession(1, 1, "hash", time.Now(), time.Now(), false, "selector", "token", time.Now(), time.Now())
if err == nil || len(m.QueriesRun()) != 0 {
t.Error("expected error")
}

// expired rememberMe, but session should save.
_, _, err = r.CreateSession(1, 1, "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now())
if q := m.QueriesRun(); err == nil || len(q) != 1 || q[0].(*onedb.RedisCommand).Command != "SETEX" || len(q[0].(*onedb.RedisCommand).Args) != 3 || q[0].(*onedb.RedisCommand).Args[0] != "test/session/hash" {
t.Error("expected error")
}

// success
m = onedb.NewMock(nil, nil, nil)
r = backendRedisSession{db: m, prefix: "test"}
session, rememberMe, err := r.CreateSession(1, 1, "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now().AddDate(1, 0, 0))
if q := m.QueriesRun(); err != nil || len(q) != 2 || q[1].(*onedb.RedisCommand).Command != "SETEX" || len(q[1].(*onedb.RedisCommand).Args) != 3 || q[1].(*onedb.RedisCommand).Args[0] != "test/rememberMe/selector" {
t.Error("expected success")
}
if session.SessionHash != "hash" || rememberMe.Selector != "selector" || rememberMe.TokenHash != "token" {
t.Error("expected valid session and rememberMe")
}
}

func TestRedisGetSession(t *testing.T) {
data := UserLoginSession{LoginID: 1, SessionHash: "hash"}
m := onedb.NewMock(nil, nil, data)
r := backendRedisSession{db: m, prefix: "test"}
s, err := r.GetSession("hash")
if err != nil || s.LoginID != 1 || s.SessionHash != "hash" {
t.Error("expected error")
}
}

func TestRedisRenewSession(t *testing.T) {
// success
data := UserLoginSession{LoginID: 1, SessionHash: "hash", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)}
m := onedb.NewMock(nil, nil, data)
r := backendRedisSession{db: m, prefix: "test"}
s, err := r.RenewSession("hash", time.Now().AddDate(1, 0, 0))
if err != nil || s == nil {
t.Error("expected success")
}

// error. No data
m = onedb.NewMock(nil, nil, nil)
r = backendRedisSession{db: m, prefix: "test"}
s, err = r.RenewSession("hash", time.Now().AddDate(1, 0, 0))
if err == nil || s != nil {
t.Error("expected success")
}
}

func TestRedisInvalidateSession(t *testing.T) {
// success
data := UserLoginSession{LoginID: 1, SessionHash: "hash", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)}
m := onedb.NewMock(nil, nil, data)
r := backendRedisSession{db: m, prefix: "test"}
if err := r.InvalidateSession("hash"); err != nil {
t.Error("expected success")
}
}

func TestRedisGetRememberMe(t *testing.T) {
// success
data := UserLoginRememberMe{Selector: "selector"}
m := onedb.NewMock(nil, nil, data)
r := backendRedisSession{db: m, prefix: "test"}
rememberMe, err := r.GetRememberMe("selector")
if err != nil || rememberMe.Selector != "selector" {
t.Error("expected to find rememberMe", err, rememberMe)
}
}

func TestRedisRenewRememberMe(t *testing.T) {
// success
data := UserLoginRememberMe{Selector: "selector", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)}
m := onedb.NewMock(nil, nil, data)
r := backendRedisSession{db: m, prefix: "test"}
renew := time.Now().AddDate(0, 1, 0)
remember, err := r.RenewRememberMe("selector", renew)
if err != nil || remember == nil || remember.RenewTimeUTC != renew {
t.Error("expected success", remember, err)
}

// nothing to renew
m = onedb.NewMock(nil, nil, nil)
r = backendRedisSession{db: m, prefix: "test"}
remember, err = r.RenewRememberMe("selector", time.Now())
if err == nil || remember != nil {
t.Error("expected error", remember, err)
}

// expired
data = UserLoginRememberMe{Selector: "selector", ExpireTimeUTC: time.Now().AddDate(0, 0, -1)}
m = onedb.NewMock(nil, nil, data)
r = backendRedisSession{db: m, prefix: "test"}
remember, err = r.RenewRememberMe("selector", time.Now())
if err != errRememberMeExpired || remember != nil {
t.Error("expected error", remember, err)
}

// invalid renew time
data = UserLoginRememberMe{Selector: "selector", ExpireTimeUTC: time.Now().AddDate(0, 0, 1)}
m = onedb.NewMock(nil, nil, data)
r = backendRedisSession{db: m, prefix: "test"}
remember, err = r.RenewRememberMe("selector", time.Now().AddDate(0, 0, -1))
if err != errInvalidRenewTimeUTC || remember != nil {
t.Error("expected error", remember, err)
}
}

0 comments on commit 2fd9ea5

Please sign in to comment.