diff --git a/auth.go b/auth.go index bff94b6..b60aad1 100644 --- a/auth.go +++ b/auth.go @@ -3,14 +3,9 @@ package main import ( "fmt" "net/http" - "strconv" ) -type loginData struct { - Email string - Password string -} - +// move together with nginxauth.go func auth(authStore AuthStorer, w http.ResponseWriter, r *http.Request) { session, err := authStore.GetSession() if err != nil { @@ -72,5 +67,5 @@ func run(method func() error, w http.ResponseWriter) { } func addUserHeader(session *UserLoginSession, w http.ResponseWriter) { - w.Header().Add("X-User-Id", strconv.Itoa(session.UserID)) + w.Header().Add("X-User", session.Email) } diff --git a/authStore.go b/authStore.go index 5f5e672..fda7cf8 100644 --- a/authStore.go +++ b/authStore.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "regexp" + "strconv" "strings" "time" ) @@ -75,11 +76,11 @@ func (s *authStore) Login() error { } func (s *authStore) login(email, password string, rememberMe bool) (*UserLoginSession, error) { - login, err := s.loginStore.Login(email, password, rememberMe) + _, err := s.loginStore.Login(email, password, rememberMe) if err != nil { return nil, err } - return s.sessionStore.CreateSession(login.LoginID, login.UserID, rememberMe) + return s.sessionStore.CreateSession(email, rememberMe) } type sendVerifyParams struct { @@ -101,12 +102,17 @@ func (s *authStore) register(email string) error { return newAuthError("Invalid email", nil) } - emailConfirmCode, err := s.addUser(email) + user, err := s.backend.GetUser(email) + if user != nil { + return newAuthError("User already registered", err) + } + + verifyCode, err := s.addEmailSession(email) if err != nil { return newLoggedError("Unable to save user", err) } - code := emailConfirmCode[:len(emailConfirmCode)-1] // drop the "=" at the end of the code since it makes it look like a querystring + code := verifyCode[:len(verifyCode)-1] // drop the "=" at the end of the code since it makes it look like a querystring if err := s.mailer.SendVerify(email, &sendVerifyParams{code, email, getBaseURL(s.r.Referer())}); err != nil { return newLoggedError("Unable to send verification email", err) } @@ -126,17 +132,18 @@ func getBaseURL(url string) string { return url[:protoIndex+3+firstSlash] } -func (s *authStore) addUser(email string) (string, error) { - emailConfirmCode, emailConfimHash, err := generateStringAndHash() +func (s *authStore) addEmailSession(email string) (string, error) { + verifyCode, verifyHash, err := generateStringAndHash() if err != nil { return "", newLoggedError("Problem generating email confirmation code", err) } - err = s.backend.AddUser(email, emailConfimHash) + err = s.backend.CreateEmailSession(email, verifyHash) if err != nil { return "", newLoggedError("Problem adding user to database", err) } - return emailConfirmCode, nil + + return verifyCode, nil } func (s *authStore) CreateProfile() error { @@ -144,10 +151,10 @@ func (s *authStore) CreateProfile() error { if err != nil { return newAuthError("Unable to get profile information from form", err) } - return s.createProfile(profile.FullName, profile.Organization, profile.Password, profile.PicturePath) + return s.createProfile(profile.FullName, profile.Organization, profile.Password, profile.PicturePath, profile.MailQuota, profile.FileQuota) } -func (s *authStore) createProfile(fullName, organization, password, picturePath string) error { +func (s *authStore) createProfile(fullName, organization, password, picturePath string, mailQuota, fileQuota int) error { emailCookie, err := s.getEmailCookie() if err != nil || emailCookie.EmailVerificationCode == "" { return newLoggedError("Unable to get email verification cookie", err) @@ -158,17 +165,27 @@ func (s *authStore) createProfile(fullName, organization, password, picturePath return newLoggedError("Invalid email verification cookie", err) } - email, err := s.backend.UpdateUser(emailVerifyHash, fullName, organization, picturePath) + session, err := s.backend.GetEmailSession(emailVerifyHash) + if err != nil { + return newLoggedError("Invalid email verification", err) + } + + err = s.backend.UpdateUser(session.Email, fullName, organization, picturePath) if err != nil { return newLoggedError("Unable to update user", err) } - login, err := s.loginStore.CreateLogin(email, fullName, password) + err = s.backend.DeleteEmailSession(session.EmailVerifyHash) + if err != nil { + return newLoggedError("Error verifying email", err) + } + + _, err = s.loginStore.CreateLogin(session.Email, fullName, password, mailQuota, fileQuota) if err != nil { return newLoggedError("Unable to create login", err) } - _, err = s.sessionStore.CreateSession(login.LoginID, login.UserID, false) + _, err = s.sessionStore.CreateSession(session.Email, false) if err != nil { return err } @@ -177,6 +194,7 @@ func (s *authStore) createProfile(fullName, organization, password, picturePath return nil } +// move to sessionStore func (s *authStore) VerifyEmail() error { verify, err := getVerificationCode(s.r) if err != nil { @@ -194,17 +212,22 @@ func (s *authStore) verifyEmail(emailVerificationCode string) error { return newLoggedError("Invalid verification code", err) } - email, err := s.backend.VerifyEmail(emailVerifyHash) + session, err := s.backend.GetEmailSession(emailVerifyHash) if err != nil { return newLoggedError("Failed to verify email", err) } + err = s.backend.AddUser(session.Email) + if err != nil { + return newLoggedError("Failed to create new user in database", err) + } + err = s.saveEmailCookie(emailVerificationCode, time.Now().UTC().Add(emailExpireDuration)) if err != nil { return newLoggedError("Failed to save email cookie", err) } - err = s.mailer.SendWelcome(email, nil) + err = s.mailer.SendWelcome(session.Email, nil) if err != nil { return newLoggedError("Failed to send welcome email", err) } @@ -286,6 +309,8 @@ type profile struct { Organization string Password string PicturePath string + MailQuota int + FileQuota int } func getProfile(r *http.Request) (*profile, error) { @@ -304,10 +329,17 @@ func getProfile(r *http.Request) (*profile, error) { defer f.Close() io.Copy(f, file) + // ************** TODO: change to generic way to get other parameters ******************* + + // get quota. will be zero if not found + mailQuota, _ := strconv.Atoi(r.FormValue("mailQuota")) + fileQuota, _ := strconv.Atoi(r.FormValue("fileQuota")) profile.FullName = r.FormValue("fullName") profile.Organization = r.FormValue("Organization") profile.Password = r.FormValue("password") profile.PicturePath = handler.Filename + profile.MailQuota = mailQuota + profile.FileQuota = fileQuota return profile, nil } diff --git a/authStore_test.go b/authStore_test.go index 5845cd3..59aa00d 100644 --- a/authStore_test.go +++ b/authStore_test.go @@ -64,9 +64,6 @@ func TestAuthGetBasicAuth(t *testing.T) { } func TestAuthStoreEndToEnd(t *testing.T) { - if testing.Short() { - t.SkipNow() - } w := httptest.NewRecorder() r := &http.Request{Header: http.Header{}} b := NewBackendMemory().(*backendMemory) @@ -76,8 +73,8 @@ func TestAuthStoreEndToEnd(t *testing.T) { // register new user // adds to users, logins and sessions err := s.register("test@test.com") - if err != nil || len(b.Users) != 1 || b.Users[0].EmailVerified || len(b.Sessions) != 0 { - t.Fatal("expected to be able to add user") + if err != nil || len(b.EmailSessions) != 1 || b.EmailSessions[0].Email != "test@test.com" || len(b.Sessions) != 0 { + t.Fatal("expected to be able to add user", err, len(b.EmailSessions), b.EmailSessions[0], len(b.Sessions)) } // get code from "email" @@ -91,17 +88,17 @@ func TestAuthStoreEndToEnd(t *testing.T) { emailCookie := emailCookie{} cookieStoreInstance.Decode("prefixEmail", value, &emailCookie) emailVerifyHash, _ := decodeStringToHash(emailCookie.EmailVerificationCode) - if err != nil || len(b.Users) != 1 || !b.Users[0].EmailVerified || emailVerifyHash != b.Users[0].EmailVerifyHash { - t.Fatal("expected email to be verified", err, data.Email, b.Users) + if len(b.EmailSessions) != 1 || b.EmailSessions[0].EmailVerifyHash != emailVerifyHash { + t.Fatal("expected emailVerifyHash to be saved", b.EmailSessions[0], emailVerifyHash) } // add email cookie to the next request r.AddCookie(newCookie("prefixEmail", value, false, emailExpireMins)) // create profile - err = s.createProfile("fullName", "company", "password", "picturePath") + err = s.createProfile("fullName", "company", "password", "picturePath", 1, 1) hashErr := cryptoHashEquals("password", b.Logins[0].ProviderKey) - if err != nil || len(b.Users) != 1 || len(b.Sessions) != 1 || len(b.Logins) != 1 || b.Logins[0].LoginID != 1 || b.Logins[0].UserID != 1 || hashErr != nil { + if err != nil || len(b.Users) != 1 || len(b.Sessions) != 1 || len(b.Logins) != 1 || b.Logins[0].Email != "test@test.com" || len(b.EmailSessions) != 0 || hashErr != nil { t.Fatal("expected valid user, login and session", b.Logins[0], b.Logins[0].ProviderKey, hashErr) } @@ -114,14 +111,14 @@ func TestAuthStoreEndToEnd(t *testing.T) { // add session cookie to the next request r.AddCookie(newCookie("prefixSession", value, false, emailExpireMins)) - if err != nil || len(b.Sessions) != 1 || b.Sessions[0].SessionHash != sessionHash || len(b.Logins) != 1 || b.Logins[0].UserID != 1 || + if err != nil || len(b.Sessions) != 1 || b.Sessions[0].SessionHash != sessionHash || len(b.Logins) != 1 || b.Logins[0].Email != "test@test.com" || b.Users[0].FullName != "fullName" || b.Users[0].PrimaryEmail != "test@test.com" { - t.Fatal("expected profile to be created", err, b.Sessions[0].SessionHash, b.Logins[0].UserID, b.Users[0].FullName, b.Users[0].PrimaryEmail) + t.Fatal("expected profile to be created", err, len(b.Sessions), b.Sessions[0].SessionHash != sessionHash, len(b.Logins) != 1, b.Logins[0].Email, b.Users[0].FullName, b.Users[0].PrimaryEmail) } // login on same browser with same existing session session, err := s.login("test@test.com", "password", true) - if err != nil || len(b.Logins) != 1 || len(b.Sessions) != 1 || len(b.Users) != 1 || session.SessionHash != b.Sessions[0].SessionHash || session.UserID != 1 { + if err != nil || len(b.Logins) != 1 || len(b.Sessions) != 1 || len(b.Users) != 1 || session.SessionHash != b.Sessions[0].SessionHash || session.Email != "test@test.com" { t.Fatal("expected to login to existing session", err, len(b.Logins), len(b.Sessions), len(b.Users), session, b.Sessions[0].SessionHash) } @@ -133,11 +130,12 @@ func TestAuthStoreEndToEnd(t *testing.T) { } var registerTests = []struct { - Scenario string - Email string - AddUserReturn error - MethodsCalled []string - ExpectedErr string + Scenario string + Email string + CreateEmailSessionReturn error + GetUserReturn *GetUserReturn + MethodsCalled []string + ExpectedErr string }{ { Scenario: "Invalid email", @@ -145,22 +143,31 @@ var registerTests = []struct { ExpectedErr: "Invalid email", }, { - Scenario: "Add User error", + Scenario: "User Already Exists", Email: "validemail@test.com", - AddUserReturn: errors.New("failed"), - MethodsCalled: []string{"AddUser"}, - ExpectedErr: "Unable to save user", + GetUserReturn: getUserSuccess(), + MethodsCalled: []string{"GetUser"}, + ExpectedErr: "User already registered", + }, + { + Scenario: "Add User error", + Email: "validemail@test.com", + CreateEmailSessionReturn: errors.New("failed"), + GetUserReturn: getUserErr(), + MethodsCalled: []string{"GetUser", "CreateEmailSession"}, + ExpectedErr: "Unable to save user", }, { Scenario: "Send verify email", + GetUserReturn: getUserErr(), Email: "validemail@test.com", - MethodsCalled: []string{"AddUser"}, + MethodsCalled: []string{"GetUser", "CreateEmailSession"}, }, } func TestAuthRegister(t *testing.T) { for i, test := range registerTests { - backend := &MockBackend{AddUserReturn: test.AddUserReturn} + backend := &MockBackend{ErrReturn: test.CreateEmailSessionReturn, GetUserReturn: test.GetUserReturn} store := getAuthStore(nil, nil, nil, false, false, nil, backend) err := store.register(test.Email) methods := store.backend.(*MockBackend).MethodsCalled @@ -172,15 +179,16 @@ func TestAuthRegister(t *testing.T) { } var createProfileTests = []struct { - Scenario string - HasCookieGetError bool - HasCookiePutError bool - EmailCookie *emailCookie - LoginReturn *LoginReturn - UpdateUserReturn error - CreateSessionReturn *SessionReturn - MethodsCalled []string - ExpectedErr string + Scenario string + HasCookieGetError bool + HasCookiePutError bool + GetEmailSessionReturn *GetEmailSessionReturn + EmailCookie *emailCookie + LoginReturn *LoginReturn + UpdateUserReturn error + CreateSessionReturn *SessionReturn + MethodsCalled []string + ExpectedErr string }{ { Scenario: "Error Getting email cookie", @@ -193,42 +201,53 @@ var createProfileTests = []struct { ExpectedErr: "Invalid email verification cookie", }, { - Scenario: "Error Updating user", - EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, - UpdateUserReturn: errors.New("failed"), - LoginReturn: loginErr(), - MethodsCalled: []string{"UpdateUser"}, - ExpectedErr: "Unable to update user", + Scenario: "Can't get EmailSession", + EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, + GetEmailSessionReturn: getEmailSessionErr(), + MethodsCalled: []string{"GetEmailSession"}, + ExpectedErr: "Invalid email verification", + }, + { + Scenario: "Error Updating user", + EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, + GetEmailSessionReturn: getEmailSessionSuccess(), + UpdateUserReturn: errors.New("failed"), + LoginReturn: loginErr(), + MethodsCalled: []string{"GetEmailSession", "UpdateUser"}, + ExpectedErr: "Unable to update user", }, { - Scenario: "Error Creating login", - EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, - LoginReturn: loginErr(), - MethodsCalled: []string{"UpdateUser"}, - ExpectedErr: "Unable to create login", + Scenario: "Error Creating login", + EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, + GetEmailSessionReturn: getEmailSessionSuccess(), + LoginReturn: loginErr(), + MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession"}, + ExpectedErr: "Unable to create login", }, { - Scenario: "Error creating session", - EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, - LoginReturn: loginSuccess(), - CreateSessionReturn: sessionErr(), - MethodsCalled: []string{"UpdateUser"}, - ExpectedErr: "failed", + Scenario: "Error creating session", + EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, + GetEmailSessionReturn: getEmailSessionSuccess(), + LoginReturn: loginSuccess(), + CreateSessionReturn: sessionErr(), + MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession"}, + ExpectedErr: "failed", }, { - Scenario: "Success", - EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, - LoginReturn: loginSuccess(), - CreateSessionReturn: sessionSuccess(futureTime, futureTime), - MethodsCalled: []string{"UpdateUser"}, + Scenario: "Success", + EmailCookie: &emailCookie{EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0=", ExpireTimeUTC: time.Now()}, + GetEmailSessionReturn: getEmailSessionSuccess(), + LoginReturn: loginSuccess(), + CreateSessionReturn: sessionSuccess(futureTime, futureTime), + MethodsCalled: []string{"GetEmailSession", "UpdateUser", "DeleteEmailSession"}, }, } func TestAuthCreateProfile(t *testing.T) { for i, test := range createProfileTests { - backend := &MockBackend{ErrReturn: test.UpdateUserReturn} + backend := &MockBackend{ErrReturn: test.UpdateUserReturn, GetEmailSessionReturn: test.GetEmailSessionReturn} store := getAuthStore(test.CreateSessionReturn, test.LoginReturn, test.EmailCookie, test.HasCookieGetError, test.HasCookiePutError, nil, backend) - err := store.createProfile("name", "organization", "password", "path") + err := store.createProfile("name", "organization", "password", "path", 1, 1) methods := store.backend.(*MockBackend).MethodsCalled if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) || !collectionEqual(test.MethodsCalled, methods) { @@ -241,7 +260,7 @@ var verifyEmailTests = []struct { Scenario string EmailVerificationCode string HasCookiePutError bool - VerifyEmailReturn *VerifyEmailReturn + GetEmailSessionReturn *GetEmailSessionReturn MailErr error MethodsCalled []string ExpectedErr string @@ -249,43 +268,43 @@ var verifyEmailTests = []struct { { Scenario: "Decode error", EmailVerificationCode: "code", - VerifyEmailReturn: verifyEmailErr(), + GetEmailSessionReturn: getEmailSessionErr(), ExpectedErr: "Invalid verification code", }, { Scenario: "Verify Email Error", EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0", - VerifyEmailReturn: verifyEmailErr(), - MethodsCalled: []string{"VerifyEmail"}, + GetEmailSessionReturn: getEmailSessionErr(), + MethodsCalled: []string{"GetEmailSession"}, ExpectedErr: "Failed to verify email", }, { Scenario: "Cookie Save Error", EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0", - VerifyEmailReturn: verifyEmailSuccess(), + GetEmailSessionReturn: getEmailSessionSuccess(), HasCookiePutError: true, - MethodsCalled: []string{"VerifyEmail"}, + MethodsCalled: []string{"GetEmailSession", "AddUser"}, ExpectedErr: "Failed to save email cookie", }, { Scenario: "Mail Error", EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0", - VerifyEmailReturn: verifyEmailSuccess(), - MethodsCalled: []string{"VerifyEmail"}, + GetEmailSessionReturn: getEmailSessionSuccess(), + MethodsCalled: []string{"GetEmailSession", "AddUser"}, MailErr: errors.New("test"), ExpectedErr: "Failed to send welcome email", }, { Scenario: "Email sent", EmailVerificationCode: "nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0", - VerifyEmailReturn: verifyEmailSuccess(), - MethodsCalled: []string{"VerifyEmail"}, + GetEmailSessionReturn: getEmailSessionSuccess(), + MethodsCalled: []string{"GetEmailSession", "AddUser"}, }, } func TestAuthVerifyEmail(t *testing.T) { for i, test := range verifyEmailTests { - backend := &MockBackend{VerifyEmailReturn: test.VerifyEmailReturn} + backend := &MockBackend{GetEmailSessionReturn: test.GetEmailSessionReturn} store := getAuthStore(nil, nil, nil, false, test.HasCookiePutError, test.MailErr, backend) err := store.verifyEmail(test.EmailVerificationCode) methods := store.backend.(*MockBackend).MethodsCalled @@ -339,7 +358,7 @@ func TestVerifyEmailPub(t *testing.T) { var buf bytes.Buffer buf.WriteString(`{"EmailVerificationCode":"nfwRDzfxxJj2_HY-_mLz6jWyWU7bF0zUlIUUVkQgbZ0"}`) // random valid base64 encoded data r := &http.Request{Body: ioutil.NopCloser(&buf)} - backend := &MockBackend{VerifyEmailReturn: verifyEmailErr()} + backend := &MockBackend{GetEmailSessionReturn: getEmailSessionErr()} store := getAuthStore(nil, nil, nil, true, false, nil, backend) store.r = r err := store.VerifyEmail() @@ -390,6 +409,8 @@ func TestGetProfile(t *testing.T) { w.WriteField("fullName", "name") w.WriteField("Organization", "org") w.WriteField("password", "pass") + w.WriteField("mailQuota", "1") + w.WriteField("fileQuota", "1") file, _ := os.Open("cover.out") data, _ := ioutil.ReadAll(file) tmpFile, _ := ioutil.TempFile("", "profile") @@ -399,9 +420,9 @@ func TestGetProfile(t *testing.T) { r, _ := http.NewRequest("PUT", "url", &buf) r.Header.Add("Content-Type", w.FormDataContentType()) - profile, _ := getProfile(r) - if profile.FullName != "name" || profile.Organization != "org" || profile.Password != "pass" { - t.Error("expected correct profile", profile) + profile, err := getProfile(r) + if err != nil || profile == nil || profile.FullName != "name" || profile.Organization != "org" || profile.Password != "pass" || profile.MailQuota != 1 || profile.FileQuota != 1 { + t.Error("expected correct profile", profile, err) } } diff --git a/auth_test.go b/auth_test.go index 3207adb..0c93fe4 100644 --- a/auth_test.go +++ b/auth_test.go @@ -15,10 +15,10 @@ func TestAuth(t *testing.T) { } w = httptest.NewRecorder() - storer = &MockAuthStorer{SessionReturn: &UserLoginSession{UserID: 12}} + storer = &MockAuthStorer{SessionReturn: &UserLoginSession{Email: "test@test.com"}} auth(storer, w, nil) - if w.Header().Get("X-User-Id") != "12" || storer.LastRun != "GetSession" { - t.Error("expected UserId header to be set", w.Header().Get("X-User-Id"), storer.LastRun) + if w.Header().Get("X-User") != "test@test.com" || storer.LastRun != "GetSession" { + t.Error("expected User header to be set", w.Header().Get("X-User"), storer.LastRun) } } @@ -31,10 +31,10 @@ func TestAuthBasic(t *testing.T) { } w = httptest.NewRecorder() - storer = &MockAuthStorer{SessionReturn: &UserLoginSession{UserID: 12}} + storer = &MockAuthStorer{SessionReturn: &UserLoginSession{Email: "test@test.com"}} authBasic(storer, w, nil) - if w.Header().Get("X-User-Id") != "12" || storer.LastRun != "GetBasicAuth" { - t.Error("expected UserId header to be set", w.Header().Get("X-User-Id"), storer.LastRun) + if w.Header().Get("X-User") != "test@test.com" || storer.LastRun != "GetBasicAuth" { + t.Error("expected User header to be set", w.Header().Get("X-User"), storer.LastRun) } } @@ -101,8 +101,8 @@ func TestVerifyEmail(t *testing.T) { func TestAddUserHeader(t *testing.T) { w := httptest.NewRecorder() - addUserHeader(&UserLoginSession{UserID: 42}, w) - if w.Header().Get("X-User-Id") != "42" { + addUserHeader(&UserLoginSession{Email: "test@test.com"}, w) + if w.Header().Get("X-User") != "test@test.com" { t.Error("expected halfauth header", w.Header()) } } diff --git a/backend.go b/backend.go index 063e53f..e419554 100644 --- a/backend.go +++ b/backend.go @@ -7,9 +7,9 @@ import ( type Backender interface { // UserBackender. Write out since it contains duplicate BackendCloser - AddUser(email, emailVerifyHash string) error - VerifyEmail(emailVerifyHash string) (string, error) - UpdateUser(emailVerifyHash, fullname string, company string, pictureURL string) (string, error) + AddUser(email string) error + GetUser(email string) (*User, error) + UpdateUser(email, fullname string, company string, pictureURL string) error // LoginBackender. Write out since it contains duplicate BackendCloser CreateLogin(email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*UserLogin, error) @@ -25,9 +25,9 @@ type BackendCloser interface { } type UserBackender interface { - AddUser(email, emailVerifyHash string) error - VerifyEmail(emailVerifyHash string) (string, error) - UpdateUser(emailVerifyHash, fullname string, company string, pictureURL string) (string, error) + AddUser(email string) error + GetUser(email string) (*User, error) + UpdateUser(email, fullname string, company string, pictureURL string) error BackendCloser } @@ -40,7 +40,11 @@ type LoginBackender interface { } type SessionBackender interface { - CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) + CreateEmailSession(email, emailVerifyHash string) error + GetEmailSession(verifyHash string) (*emailSession, error) + DeleteEmailSession(verifyHash string) error + + CreateSession(email string, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) GetSession(sessionHash string) (*UserLoginSession, error) RenewSession(sessionHash string, renewTimeUTC time.Time) (*UserLoginSession, error) InvalidateSession(sessionHash string) error @@ -60,38 +64,39 @@ var errRememberMeSelectorExists = errors.New("DB: RememberMe selector already ex var errUserNotFound = errors.New("DB: User not found") var errLoginNotFound = errors.New("DB: Login not found") var errSessionNotFound = errors.New("DB: Session not found") +var errSessionAlreadyExists = errors.New("DB: Session already exists") var errRememberMeNotFound = errors.New("DB: RememberMe not found") var errRememberMeNeedsRenew = errors.New("DB: RememberMe needs to be renewed") var errRememberMeExpired = errors.New("DB: RememberMe is expired") var errUserAlreadyExists = errors.New("DB: User already exists") +type emailSession struct { + Email string + EmailVerifyHash string +} + type User struct { - UserID int FullName string PrimaryEmail string - EmailVerifyHash string - EmailVerified bool LockoutEndTimeUTC *time.Time AccessFailedCount int } type UserLogin struct { - LoginID int - UserID int + Email string LoginProviderID int ProviderKey string } type UserLoginSession struct { - LoginID int + Email string SessionHash string - UserID int RenewTimeUTC time.Time ExpireTimeUTC time.Time } type UserLoginRememberMe struct { - LoginID int + Email string Selector string TokenHash string RenewTimeUTC time.Time @@ -152,8 +157,8 @@ func (b *backend) GetLogin(email, loginProvider string) (*UserLogin, error) { return b.l.GetLogin(email, loginProvider) } -func (b *backend) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) { - return b.s.CreateSession(loginID, userID, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC, rememberMe, rememberMeSelector, rememberMeTokenHash, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC) +func (b *backend) CreateSession(email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) { + return b.s.CreateSession(email, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC, rememberMe, rememberMeSelector, rememberMeTokenHash, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC) } func (b *backend) GetSession(sessionHash string) (*UserLoginSession, error) { @@ -172,16 +177,28 @@ func (b *backend) RenewRememberMe(selector string, renewTimeUTC time.Time) (*Use return b.s.RenewRememberMe(selector, renewTimeUTC) } -func (b *backend) AddUser(email, emailVerifyHash string) error { - return b.u.AddUser(email, emailVerifyHash) +func (b *backend) CreateEmailSession(email, emailVerifyHash string) error { + return b.s.CreateEmailSession(email, emailVerifyHash) +} + +func (b *backend) GetEmailSession(emailVerifyHash string) (*emailSession, error) { + return b.s.GetEmailSession(emailVerifyHash) +} + +func (b *backend) DeleteEmailSession(emailVerifyHash string) error { + return b.s.DeleteEmailSession(emailVerifyHash) +} + +func (b *backend) AddUser(email string) error { + return b.u.AddUser(email) } -func (b *backend) VerifyEmail(emailVerifyHash string) (string, error) { - return b.u.VerifyEmail(emailVerifyHash) +func (b *backend) GetUser(email string) (*User, error) { + return b.u.GetUser(email) } -func (b *backend) UpdateUser(emailVerifyHash, fullname string, company string, pictureURL string) (string, error) { - return b.u.UpdateUser(emailVerifyHash, fullname, company, pictureURL) +func (b *backend) UpdateUser(email, fullname string, company string, pictureURL string) error { + return b.u.UpdateUser(email, fullname, company, pictureURL) } func (b *backend) CreateLogin(email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*UserLogin, error) { diff --git a/backendDbUser.go b/backendDbUser.go index 8c63809..cfa8870 100644 --- a/backendDbUser.go +++ b/backendDbUser.go @@ -32,21 +32,21 @@ func (u *backendDbUser) GetLogin(email, loginProvider string) (*UserLogin, error return login, u.Db.QueryStruct(onedb.NewSqlQuery(u.GetUserLoginQuery, email, loginProvider), login) } -func (u *backendDbUser) AddUser(email, emailVerifyHash string) error { - return u.Db.Execute(onedb.NewSqlQuery(u.AddUserQuery, email, emailVerifyHash)) +func (u *backendDbUser) AddUser(email string) error { + return u.Db.Execute(onedb.NewSqlQuery(u.AddUserQuery, email)) } -func (u *backendDbUser) VerifyEmail(emailVerifyHash string) (string, error) { +func (u *backendDbUser) GetUser(email string) (*User, error) { var user *User - err := u.Db.QueryStructRow(onedb.NewSqlQuery(u.VerifyEmailQuery, emailVerifyHash), user) + err := u.Db.QueryStructRow(onedb.NewSqlQuery(u.VerifyEmailQuery, email), user) if err != nil || user == nil { - return "", errors.New("Unable to verify email: " + err.Error()) + return nil, errors.New("Unable to get user: " + err.Error()) } - return user.PrimaryEmail, err + return user, err } -func (u *backendDbUser) UpdateUser(emailVerifyHash, fullname string, company string, pictureURL string) (string, error) { - return "email", nil +func (u *backendDbUser) UpdateUser(email, fullname string, company string, pictureURL string) error { + return nil } func (u *backendDbUser) Close() error { diff --git a/backendLDAPLogin.go b/backendLDAPLogin.go index 1bb9a4c..fc5fa8e 100644 --- a/backendLDAPLogin.go +++ b/backendLDAPLogin.go @@ -43,9 +43,10 @@ func (l *backendLDAPLogin) GetLogin(email, loginProvider string) (*UserLogin, er return &UserLogin{ProviderKey: password}, nil } +/**************** TODO: create different type of user if not using file and mail quotas **********************/ func (l *backendLDAPLogin) CreateLogin(email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*UserLogin, error) { req := ldap.NewAddRequest("uid=" + email + ",ou=Users,dc=endfirst,dc=com") - req.Attribute("objectClass", []string{"posixAccount", "account"}) + req.Attribute("objectClass", []string{"posixAccount", "account", "ownCloud", "systemQuotas"}) req.Attribute("uid", []string{email}) req.Attribute("cn", []string{fullName}) req.Attribute("userPassword", []string{passwordHash}) diff --git a/backendMemory.go b/backendMemory.go index f85faa7..e996f9a 100644 --- a/backendMemory.go +++ b/backendMemory.go @@ -8,6 +8,7 @@ import ( type backendMemory struct { Backender + EmailSessions []*emailSession Users []*User Logins []*UserLogin Sessions []*UserLoginSession @@ -24,45 +25,30 @@ func NewBackendMemory() Backender { } func (m *backendMemory) GetLogin(email, loginProvider string) (*UserLogin, error) { - user := m.getUserByEmail(email) - if user == nil { - return nil, errUserNotFound - } - login := m.getLoginByUser(user.UserID, loginProvider) + login := m.getLoginByEmail(email) if login == nil { return nil, errLoginNotFound } return login, nil } -func (m *backendMemory) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) { - login := m.getLoginByLoginID(loginID) - if login == nil { - return nil, nil, errLoginNotFound - } +func (m *backendMemory) CreateSession(email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) { session := m.getSessionByHash(sessionHash) if session != nil { - session.SessionHash = sessionHash // update sessionHash - session.ExpireTimeUTC = sessionExpireTimeUTC - session.RenewTimeUTC = sessionRenewTimeUTC - } else { - session = &UserLoginSession{loginID, sessionHash, login.UserID, sessionRenewTimeUTC, sessionExpireTimeUTC} - m.Sessions = append(m.Sessions, session) + return nil, nil, errSessionAlreadyExists } + + session = &UserLoginSession{email, sessionHash, sessionRenewTimeUTC, sessionExpireTimeUTC} + m.Sessions = append(m.Sessions, session) var rememberItem *UserLoginRememberMe if rememberMe { rememberItem = m.getRememberMe(rememberMeSelector) - if rememberItem != nil && rememberItem.LoginID != login.LoginID { // existing is for different login, so can't continue + if rememberItem != nil { return nil, nil, errRememberMeSelectorExists - } else if rememberItem != nil { // update the existing rememberMe - rememberItem.Selector = rememberMeSelector - rememberItem.TokenHash = rememberMeTokenHash - rememberItem.ExpireTimeUTC = rememberMeExpireTimeUTC - rememberItem.RenewTimeUTC = rememberMeRenewTimeUTC - } else { - rememberItem = &UserLoginRememberMe{login.LoginID, rememberMeSelector, rememberMeTokenHash, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC} - m.RememberMes = append(m.RememberMes, rememberItem) } + + rememberItem = &UserLoginRememberMe{email, rememberMeSelector, rememberMeTokenHash, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC} + m.RememberMes = append(m.RememberMes, rememberItem) } return session, rememberItem, nil } @@ -105,50 +91,64 @@ func (m *backendMemory) RenewRememberMe(selector string, renewTimeUTC time.Time) return rememberMe, nil } -func (m *backendMemory) AddUser(email, emailVerifyHash string) error { +func (m *backendMemory) CreateEmailSession(email, emailVerifyHash string) error { if m.getUserByEmail(email) != nil { return errUserAlreadyExists } - if m.getUserByEmailVerifyHash(emailVerifyHash) != nil { + if m.getEmailSessionByEmailVerifyHash(emailVerifyHash) != nil { return errEmailVerifyHashExists } - m.LastUserID = m.LastUserID + 1 - user := &User{m.LastUserID, "", email, emailVerifyHash, false, nil, 0} - m.Users = append(m.Users, user) + + m.EmailSessions = append(m.EmailSessions, &emailSession{Email: email, EmailVerifyHash: emailVerifyHash}) return nil } -func (m *backendMemory) VerifyEmail(emailVerifyHash string) (string, error) { - user := m.getUserByEmailVerifyHash(emailVerifyHash) - if user == nil { - return "", errInvalidEmailVerifyHash +func (m *backendMemory) GetEmailSession(emailVerifyHash string) (*emailSession, error) { + session := m.getEmailSessionByEmailVerifyHash(emailVerifyHash) + if session == nil { + return nil, errInvalidEmailVerifyHash } - user.EmailVerified = true - return user.PrimaryEmail, nil + return session, nil } -func (m *backendMemory) UpdateUser(emailVerifyHash, fullname string, company string, pictureURL string) (string, error) { - user := m.getUserByEmailVerifyHash(emailVerifyHash) - if user == nil { - return "", errUserNotFound +// ***************** TODO: need to come up with a way to clean up all sessions for this email ************** +func (m *backendMemory) DeleteEmailSession(emailVerifyHash string) error { + m.removeEmailSession(emailVerifyHash) + return nil +} + +func (m *backendMemory) AddUser(email string) error { + user := m.getUserByEmail(email) + if user != nil { + return errUserAlreadyExists } - user.FullName = fullname - // need to be able to create company and set pictureURL - return user.PrimaryEmail, nil + m.Users = append(m.Users, &User{"", email, nil, 0}) + return nil } -// This method needs to be fixed to work with the new data model using LDAP -func (m *backendMemory) CreateLogin(email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*UserLogin, error) { +func (m *backendMemory) GetUser(email string) (*User, error) { user := m.getUserByEmail(email) if user == nil { return nil, errUserNotFound } - user.FullName = fullName + return user, nil +} + +func (m *backendMemory) UpdateUser(email, fullname string, company string, pictureURL string) error { + user := m.getUserByEmail(email) + if user == nil { + return errUserNotFound + } + user.FullName = fullname + // need to be able to create company and set pictureURL + return nil +} - m.LastLoginID = m.LastLoginID + 1 - login := UserLogin{m.LastLoginID, user.UserID, 1, passwordHash} +// This method needs to be fixed to work with the new data model using LDAP +func (m *backendMemory) CreateLogin(email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*UserLogin, error) { + login := UserLogin{email, 1, passwordHash} m.Logins = append(m.Logins, &login) return &login, nil @@ -201,6 +201,16 @@ func (m *backendMemory) Close() error { return nil } +func (m *backendMemory) removeEmailSession(emailVerifyHash string) { + for i := 0; i < len(m.EmailSessions); i++ { + session := m.EmailSessions[i] + if session.EmailVerifyHash == emailVerifyHash { + m.EmailSessions = append(m.EmailSessions[:i], m.EmailSessions[i+1:]...) // remove item + break + } + } +} + func (m *backendMemory) removeRememberMe(selector string) { for i := 0; i < len(m.RememberMes); i++ { rememberMe := m.RememberMes[i] @@ -230,22 +240,22 @@ func (m *backendMemory) getLoginProvider(name string) *UserLoginProvider { return nil } -func (m *backendMemory) getLoginByUser(userID int, loginProvider string) *UserLogin { +func (m *backendMemory) getLoginByUser(email, loginProvider string) *UserLogin { provider := m.getLoginProvider(loginProvider) if provider == nil { return nil } for _, login := range m.Logins { - if login.UserID == userID && login.LoginProviderID == provider.LoginProviderID { + if login.Email == email && login.LoginProviderID == provider.LoginProviderID { return login } } return nil } -func (m *backendMemory) getLoginByLoginID(loginID int) *UserLogin { +func (m *backendMemory) getLoginByEmail(email string) *UserLogin { for _, login := range m.Logins { - if login.LoginID == loginID { + if login.Email == email { return login } } @@ -261,10 +271,10 @@ func (m *backendMemory) getUserByEmail(email string) *User { return nil } -func (m *backendMemory) getUserByEmailVerifyHash(hash string) *User { - for _, user := range m.Users { - if user.EmailVerifyHash == hash { - return user +func (m *backendMemory) getEmailSessionByEmailVerifyHash(hash string) *emailSession { + for _, session := range m.EmailSessions { + if session.EmailVerifyHash == hash { + return session } } return nil diff --git a/backendMemory_test.go b/backendMemory_test.go index fc2b276..4dd28d5 100644 --- a/backendMemory_test.go +++ b/backendMemory_test.go @@ -10,14 +10,10 @@ var in1Hour = time.Now().UTC().Add(time.Hour) func TestMemoryGetLogin(t *testing.T) { backend := NewBackendMemory().(*backendMemory) - if _, err := backend.GetLogin("email", loginProviderDefaultName); err != errUserNotFound { - t.Error("expected no login since nothing added yet", err) - } - backend.Users = append(backend.Users, &User{PrimaryEmail: "email", UserID: 1}) if _, err := backend.GetLogin("email", loginProviderDefaultName); err != errLoginNotFound { t.Error("expected no login since login not added yet", err) } - expected := &UserLogin{UserID: 1, LoginProviderID: 1} + expected := &UserLogin{Email: "email", LoginProviderID: 1} backend.Logins = append(backend.Logins, expected) if actual, _ := backend.GetLogin("email", loginProviderDefaultName); expected != actual { t.Error("expected no login since login not added yet") @@ -26,38 +22,29 @@ func TestMemoryGetLogin(t *testing.T) { func TestMemoryCreateSession(t *testing.T) { backend := NewBackendMemory().(*backendMemory) - if _, _, err := backend.CreateSession(1, 1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); err != errLoginNotFound { - t.Error("expected error since login doesn't exist") - } - backend.Logins = append(backend.Logins, &UserLogin{UserID: 1, LoginID: 1}) - if session, _, _ := backend.CreateSession(1, 1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginID != 1 || session.UserID != 1 { + if session, _, _ := backend.CreateSession("test@test.com", "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.Email != "test@test.com" { t.Error("expected matching session", session) } - // create again, shouldn't create new Session, just update - if session, _, _ := backend.CreateSession(1, 1, "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginID != 1 || session.UserID != 1 || len(backend.Sessions) != 1 { - t.Error("expected matching session", session) + // create again, should error + if _, _, err := backend.CreateSession("test@test.com", "sessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); err == nil { + t.Error("expected error since session exists", err) } - // new session ID since it was generated when no cookie was found - if session, _, _ := backend.CreateSession(1, 1, "newSessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "newSessionHash" || len(backend.Sessions) != 2 { + // new session ID since it was generated when no cookie was found (e.g. on another computer or browser) + if session, _, _ := backend.CreateSession("test@test.com", "newSessionHash", in5Minutes, in1Hour, false, "", "", time.Time{}, time.Time{}); session.SessionHash != "newSessionHash" || len(backend.Sessions) != 2 { t.Error("expected matching session", session) } - // existing remember already exists - backend.RememberMes = append(backend.RememberMes, &UserLoginRememberMe{LoginID: 1, Selector: "selector"}) - if session, rememberMe, err := backend.CreateSession(1, 1, "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginID != 1 || session.UserID != 1 || - rememberMe.LoginID != 1 || rememberMe.Selector != "selector" || rememberMe.TokenHash != "hash" { - t.Error("expected RememberMe to be created", session, rememberMe, err) - } - - // create new rememberMe - if session, rememberMe, err := backend.CreateSession(1, 1, "sessionHash", in5Minutes, in1Hour, true, "newselector", "hash", time.Time{}, time.Time{}); session.SessionHash != "sessionHash" || session.LoginID != 1 || session.UserID != 1 || - rememberMe.LoginID != 1 || rememberMe.Selector != "newselector" || rememberMe.TokenHash != "hash" { + // new rememberMe + backend.Sessions = nil + backend.RememberMes = nil + if session, rememberMe, err := backend.CreateSession("test@test.com", "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); session == nil || session.SessionHash != "sessionHash" || session.Email != "test@test.com" || + rememberMe == nil || rememberMe.Selector != "selector" || rememberMe.TokenHash != "hash" { t.Error("expected RememberMe to be created", session, rememberMe, err) } - // existing remember is for different login... error - backend.RememberMes = append(backend.RememberMes, &UserLoginRememberMe{LoginID: 2, Selector: "otherselector"}) - if _, _, err := backend.CreateSession(1, 1, "sessionHash", in5Minutes, in1Hour, true, "otherselector", "hash", time.Time{}, time.Time{}); err != errRememberMeSelectorExists { + // existing rememberMe. Error + backend.Sessions = nil + if _, _, err := backend.CreateSession("test@test.com", "sessionHash", in5Minutes, in1Hour, true, "selector", "hash", time.Time{}, time.Time{}); err != errRememberMeSelectorExists { t.Error("expected error", err) } } @@ -128,55 +115,46 @@ func TestMemoryRenewRememberMe(t *testing.T) { func TestMemoryAddUser(t *testing.T) { backend := NewBackendMemory().(*backendMemory) - if err := backend.AddUser("email", "emailVerifyHash"); err != nil || len(backend.Users) != 1 { + if err := backend.AddUser("email"); err != nil || len(backend.Users) != 1 { t.Error("expected valid session", err, backend.Users) } - if err := backend.AddUser("email", "emailVerifyHash"); err != errUserAlreadyExists { + if err := backend.AddUser("email"); err != errUserAlreadyExists { t.Error("expected user to already exist", err) } - - if err := backend.AddUser("email1", "emailVerifyHash"); err != errEmailVerifyHashExists { - t.Error("expected failure due to existing email verify code", err) - } } -func TestMemoryVerifyEmail(t *testing.T) { +func TestMemoryGetEmailSession(t *testing.T) { backend := NewBackendMemory().(*backendMemory) - if _, err := backend.VerifyEmail("verifyHash"); err != errInvalidEmailVerifyHash { + if _, err := backend.GetEmailSession("verifyHash"); err != errInvalidEmailVerifyHash { t.Error("expected login not found err", err) } // success - backend.Users = append(backend.Users, &User{EmailVerifyHash: "verifyHash", UserID: 1, PrimaryEmail: "email"}) - if email, _ := backend.VerifyEmail("verifyHash"); email != "email" { + backend.EmailSessions = append(backend.EmailSessions, &emailSession{Email: "email", EmailVerifyHash: "verifyHash"}) + if email, _ := backend.GetEmailSession("verifyHash"); email.Email != "email" { t.Error("expected valid session", email) } } func TestMemoryUpdateUser(t *testing.T) { backend := NewBackendMemory().(*backendMemory) - _, err := backend.UpdateUser("emailHash", "fullname", "company", "pictureUrl") + err := backend.UpdateUser("email", "fullname", "company", "pictureUrl") if err != errUserNotFound { t.Error("expected to be unable to update non-existant user") } backend = NewBackendMemory().(*backendMemory) - backend.Users = append(backend.Users, &User{EmailVerifyHash: "verifyHash", UserID: 1, PrimaryEmail: "email"}) - email, err := backend.UpdateUser("verifyHash", "fullname", "company", "pictureUrl") - if email != "email" || err != nil { - t.Error("expected success", email, err) + backend.Users = append(backend.Users, &User{PrimaryEmail: "email"}) + err = backend.UpdateUser("email", "fullname", "company", "pictureUrl") + if err != nil { + t.Error("expected success", err) } } func TestMemoryCreateLogin(t *testing.T) { backend := NewBackendMemory().(*backendMemory) - if _, err := backend.CreateLogin("email", "passwordHash", "fullName", "homeDirectory", 1, 1, "mailQuota", "fileQuota"); err != errUserNotFound { - t.Error("expected login not found err", err) - } - - backend.Users = append(backend.Users, &User{EmailVerifyHash: "emailVerifyHash", UserID: 1, PrimaryEmail: "email"}) - if login, err := backend.CreateLogin("email", "passwordHash", "fullName", "homeDirectory", 1, 1, "mailQuota", "fileQuota"); err != nil || login.LoginID != 1 || login.UserID != 1 { + if login, err := backend.CreateLogin("email", "passwordHash", "fullName", "homeDirectory", 1, 1, "mailQuota", "fileQuota"); err != nil || login.Email != "email" { t.Error("expected valid login", login) } } @@ -227,7 +205,7 @@ func TestToString(t *testing.T) { backend.RememberMes = append(backend.RememberMes, &UserLoginRememberMe{}) actual := backend.ToString() - expected := "Users:\n {0 false 0}\nLogins:\n {0 0 0 }\nSessions:\n {0 0 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\nRememberMe:\n {0 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\n" + expected := "Users:\n { 0}\nLogins:\n { 0 }\nSessions:\n { 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\nRememberMe:\n { 0001-01-01 00:00:00 +0000 UTC 0001-01-01 00:00:00 +0000 UTC}\n" if actual != expected { t.Error("expected different value", actual) } @@ -242,7 +220,7 @@ func TestGetLoginProvider(t *testing.T) { func TestGetLoginByUser(t *testing.T) { backend := NewBackendMemory().(*backendMemory) - if backend.getLoginByUser(1, "bogus") != nil { + if backend.getLoginByUser("email", "bogus") != nil { t.Error("expected no login") } } diff --git a/backendRedisSession.go b/backendRedisSession.go index 0f783b3..279d55c 100644 --- a/backendRedisSession.go +++ b/backendRedisSession.go @@ -17,9 +17,23 @@ func NewBackendRedisSession(server string, port int, password string, maxIdle, m return &backendRedisSession{db: r, prefix: keyPrefix} } -func (r *backendRedisSession) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, +// need to first check that this emailVerifyHash isn't being used, otherwise we'll clobber existing +func (r *backendRedisSession) CreateEmailSession(email, emailVerifyHash string) error { + return r.saveEmailSession(&emailSession{Email: email, EmailVerifyHash: emailVerifyHash}) +} + +func (r *backendRedisSession) GetEmailSession(emailVerifyHash string) (*emailSession, error) { + session := &emailSession{} + return session, r.db.QueryStructRow(onedb.NewRedisGetCommand(r.getEmailSessionUrl(emailVerifyHash)), session) +} + +func (r *backendRedisSession) DeleteEmailSession(emailVerifyHash string) error { + return nil +} + +func (r *backendRedisSession) CreateSession(email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, includeRememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) { - session := UserLoginSession{LoginID: loginID, UserID: userID, SessionHash: sessionHash, RenewTimeUTC: sessionRenewTimeUTC, ExpireTimeUTC: sessionExpireTimeUTC} + session := UserLoginSession{Email: email, SessionHash: sessionHash, RenewTimeUTC: sessionRenewTimeUTC, ExpireTimeUTC: sessionExpireTimeUTC} err := r.saveSession(&session) if err != nil { return nil, nil, err @@ -27,7 +41,7 @@ func (r *backendRedisSession) CreateSession(loginID, userID int, sessionHash str var rememberMe UserLoginRememberMe if includeRememberMe { - rememberMe = UserLoginRememberMe{LoginID: loginID, Selector: rememberMeSelector, TokenHash: rememberMeTokenHash, RenewTimeUTC: rememberMeRenewTimeUTC, ExpireTimeUTC: rememberMeExpireTimeUTC} + rememberMe = UserLoginRememberMe{Email: email, Selector: rememberMeSelector, TokenHash: rememberMeTokenHash, RenewTimeUTC: rememberMeRenewTimeUTC, ExpireTimeUTC: rememberMeExpireTimeUTC} err = r.saveRememberMe(&rememberMe) if err != nil { return nil, nil, err @@ -88,6 +102,10 @@ func (r *backendRedisSession) Close() error { return r.db.Close() } +func (r *backendRedisSession) saveEmailSession(session *emailSession) error { + return r.save(r.getEmailSessionUrl(session.EmailVerifyHash), session, emailExpireMins*60) +} + func (r *backendRedisSession) saveSession(session *UserLoginSession) error { if time.Since(session.ExpireTimeUTC).Seconds() >= 0 { return errors.New("Unable to save expired session") @@ -102,6 +120,10 @@ func (r *backendRedisSession) saveRememberMe(rememberMe *UserLoginRememberMe) er return r.save(r.getRememberMeUrl(rememberMe.Selector), rememberMe, round(rememberMeExpireDuration.Seconds())) } +func (r *backendRedisSession) getEmailSessionUrl(emailVerifyHash string) string { + return r.prefix + "/email/" + emailVerifyHash +} + func (r *backendRedisSession) getSessionUrl(sessionHash string) string { return r.prefix + "/session/" + sessionHash } diff --git a/backendRedisSession_test.go b/backendRedisSession_test.go index eb74c07..59ffce1 100644 --- a/backendRedisSession_test.go +++ b/backendRedisSession_test.go @@ -14,13 +14,13 @@ func TestRedisCreateSession(t *testing.T) { // expired session error m := onedb.NewMock(nil, nil, nil) r := backendRedisSession{db: m, prefix: "test"} - _, _, err := r.CreateSession(1, 1, "hash", time.Now(), time.Now(), false, "selector", "token", time.Now(), time.Now()) + _, _, err := r.CreateSession("test@test.com", "hash", time.Now(), time.Now(), false, "selector", "token", time.Now(), time.Now()) if err == nil || len(m.QueriesRun()) != 0 { t.Error("expected error") } // expired rememberMe, but session should save. - _, _, err = r.CreateSession(1, 1, "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now()) + _, _, err = r.CreateSession("test@test.com", "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now()) if q := m.QueriesRun(); err == nil || len(q) != 1 || q[0].(*onedb.RedisCommand).Command != "SETEX" || len(q[0].(*onedb.RedisCommand).Args) != 3 || q[0].(*onedb.RedisCommand).Args[0] != "test/session/hash" { t.Error("expected error") } @@ -28,7 +28,7 @@ func TestRedisCreateSession(t *testing.T) { // success m = onedb.NewMock(nil, nil, nil) r = backendRedisSession{db: m, prefix: "test"} - session, rememberMe, err := r.CreateSession(1, 1, "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now().AddDate(1, 0, 0)) + session, rememberMe, err := r.CreateSession("test@test.com", "hash", time.Now(), time.Now().AddDate(1, 0, 0), true, "selector", "token", time.Now(), time.Now().AddDate(1, 0, 0)) if q := m.QueriesRun(); err != nil || len(q) != 2 || q[1].(*onedb.RedisCommand).Command != "SETEX" || len(q[1].(*onedb.RedisCommand).Args) != 3 || q[1].(*onedb.RedisCommand).Args[0] != "test/rememberMe/selector" { t.Error("expected success") } @@ -38,18 +38,18 @@ func TestRedisCreateSession(t *testing.T) { } func TestRedisGetSession(t *testing.T) { - data := UserLoginSession{LoginID: 1, SessionHash: "hash"} + data := UserLoginSession{Email: "test@test.com", SessionHash: "hash"} m := onedb.NewMock(nil, nil, data) r := backendRedisSession{db: m, prefix: "test"} s, err := r.GetSession("hash") - if err != nil || s.LoginID != 1 || s.SessionHash != "hash" { + if err != nil || s.Email != "test@test.com" || s.SessionHash != "hash" { t.Error("expected error") } } func TestRedisRenewSession(t *testing.T) { // success - data := UserLoginSession{LoginID: 1, SessionHash: "hash", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)} + data := UserLoginSession{Email: "test@test.com", SessionHash: "hash", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)} m := onedb.NewMock(nil, nil, data) r := backendRedisSession{db: m, prefix: "test"} s, err := r.RenewSession("hash", time.Now().AddDate(1, 0, 0)) @@ -68,7 +68,7 @@ func TestRedisRenewSession(t *testing.T) { func TestRedisInvalidateSession(t *testing.T) { // success - data := UserLoginSession{LoginID: 1, SessionHash: "hash", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)} + data := UserLoginSession{Email: "test@test.com", SessionHash: "hash", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)} m := onedb.NewMock(nil, nil, data) r := backendRedisSession{db: m, prefix: "test"} if err := r.InvalidateSession("hash"); err != nil { @@ -124,3 +124,12 @@ func TestRedisRenewRememberMe(t *testing.T) { t.Error("expected error", remember, err) } } + +func TestRedisInvalidateRememberMe(t *testing.T) { + data := UserLoginRememberMe{Selector: "selector", ExpireTimeUTC: time.Now().AddDate(1, 0, 0)} + m := onedb.NewMock(nil, nil, data) + r := backendRedisSession{db: m, prefix: "test"} + if err := r.InvalidateRememberMe("selector"); err != nil { + t.Error("expected success") + } +} diff --git a/backend_test.go b/backend_test.go index f8585f4..5cb69ea 100644 --- a/backend_test.go +++ b/backend_test.go @@ -30,7 +30,7 @@ func TestGetLogin(t *testing.T) { func TestBackendCreateSession(t *testing.T) { m := &MockBackend{CreateSessionReturn: sessionRemember(time.Now(), time.Now())} b := backend{u: m, l: m, s: m} - b.CreateSession(1, 1, "hash", time.Now(), time.Now(), false, "", "", time.Now(), time.Now()) + b.CreateSession("test@test.com", "hash", time.Now(), time.Now(), false, "", "", time.Now(), time.Now()) if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "CreateSession" { t.Error("Expected it would call backend", m.MethodsCalled) } @@ -75,17 +75,17 @@ func TestBackendRenewRememberMe(t *testing.T) { func TestBackendAddUser(t *testing.T) { m := &MockBackend{AddUserReturn: nil} b := backend{u: m, l: m, s: m} - b.AddUser("mail", "hash") + b.AddUser("mail") if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "AddUser" { t.Error("Expected it would call backend", m.MethodsCalled) } } -func TestBackendVerifyEmail(t *testing.T) { - m := &MockBackend{VerifyEmailReturn: verifyEmailErr()} +func TestBackendGetEmailSession(t *testing.T) { + m := &MockBackend{GetEmailSessionReturn: getEmailSessionErr()} b := backend{u: m, l: m, s: m} - b.VerifyEmail("hash") - if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "VerifyEmail" { + b.GetEmailSession("hash") + if len(m.MethodsCalled) != 1 || m.MethodsCalled[0] != "GetEmailSession" { t.Error("Expected it would call backend", m.MethodsCalled) } } @@ -213,9 +213,14 @@ type RememberMeReturn struct { Err error } -type VerifyEmailReturn struct { - Email string - Err error +type GetUserReturn struct { + User *User + Err error +} + +type GetEmailSessionReturn struct { + Session *emailSession + Err error } type MockBackend struct { @@ -226,7 +231,8 @@ type MockBackend struct { CreateSessionReturn *SessionRememberReturn RenewSessionReturn *SessionReturn AddUserReturn error - VerifyEmailReturn *VerifyEmailReturn + GetUserReturn *GetUserReturn + GetEmailSessionReturn *GetEmailSessionReturn CreateLoginReturn *LoginReturn UpdateEmailReturn *SessionReturn UpdatePasswordReturn *SessionReturn @@ -247,7 +253,7 @@ func (b *MockBackend) GetSession(sessionHash string) (*UserLoginSession, error) return b.GetSessionReturn.Session, b.GetSessionReturn.Err } -func (b *MockBackend) CreateSession(loginID, userID int, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) { +func (b *MockBackend) CreateSession(email, sessionHash string, sessionRenewTimeUTC, sessionExpireTimeUTC time.Time, rememberMe bool, rememberMeSelector, rememberMeTokenHash string, rememberMeRenewTimeUTC, rememberMeExpireTimeUTC time.Time) (*UserLoginSession, *UserLoginRememberMe, error) { b.MethodsCalled = append(b.MethodsCalled, "CreateSession") return b.CreateSessionReturn.Session, b.CreateSessionReturn.RememberMe, b.CreateSessionReturn.Err } @@ -264,19 +270,34 @@ func (b *MockBackend) RenewRememberMe(selector string, renewTimeUTC time.Time) ( b.MethodsCalled = append(b.MethodsCalled, "RenewRememberMe") return b.RenewRememberMeReturn.RememberMe, b.RenewRememberMeReturn.Err } -func (b *MockBackend) AddUser(email, emailVerifyHash string) error { +func (b *MockBackend) AddUser(email string) error { b.MethodsCalled = append(b.MethodsCalled, "AddUser") return b.AddUserReturn } -func (b *MockBackend) VerifyEmail(emailVerifyHash string) (string, error) { - b.MethodsCalled = append(b.MethodsCalled, "VerifyEmail") - return b.VerifyEmailReturn.Email, b.VerifyEmailReturn.Err +func (b *MockBackend) CreateEmailSession(email, emailVerifyHash string) error { + b.MethodsCalled = append(b.MethodsCalled, "CreateEmailSession") + return b.ErrReturn } -func (b *MockBackend) UpdateUser(emailVerifyHash, fullname string, company string, pictureURL string) (string, error) { +func (b *MockBackend) GetEmailSession(emailVerifyHash string) (*emailSession, error) { + b.MethodsCalled = append(b.MethodsCalled, "GetEmailSession") + return b.GetEmailSessionReturn.Session, b.GetEmailSessionReturn.Err +} + +func (b *MockBackend) DeleteEmailSession(emailVerifyHash string) error { + b.MethodsCalled = append(b.MethodsCalled, "DeleteEmailSession") + return b.ErrReturn +} + +func (b *MockBackend) GetUser(email string) (*User, error) { + b.MethodsCalled = append(b.MethodsCalled, "GetUser") + return b.GetUserReturn.User, b.GetUserReturn.Err +} + +func (b *MockBackend) UpdateUser(email, fullname, company, pictureURL string) error { b.MethodsCalled = append(b.MethodsCalled, "UpdateUser") - return "test@test.com", b.ErrReturn + return b.ErrReturn } func (b *MockBackend) CreateLogin(email, passwordHash, fullName, homeDirectory string, uidNumber, gidNumber int, mailQuota, fileQuota string) (*UserLogin, error) { @@ -315,7 +336,7 @@ func (b *MockBackend) Close() error { } func loginSuccess() *LoginReturn { - return &LoginReturn{&UserLogin{LoginID: 1, ProviderKey: "$6$rounds=200000$pYt48w3PgDcRoCMx$sxbuADDhNI9nNe35HcrFYW7vpWLLMNiPBKcbqOgaRxTBYE8hePJWvmuN9dp.783JmDZBhDJRG956Wc/fzghhh."}, nil} // cryptoHash of "correctPassword" + return &LoginReturn{&UserLogin{Email: "test@test.com", ProviderKey: "$6$rounds=200000$pYt48w3PgDcRoCMx$sxbuADDhNI9nNe35HcrFYW7vpWLLMNiPBKcbqOgaRxTBYE8hePJWvmuN9dp.783JmDZBhDJRG956Wc/fzghhh."}, nil} // cryptoHash of "correctPassword" } func loginErr() *LoginReturn { @@ -323,7 +344,7 @@ func loginErr() *LoginReturn { } func sessionSuccess(renewTimeUTC, expireTimeUTC time.Time) *SessionReturn { - return &SessionReturn{&UserLoginSession{1, "sessionHash", 2, renewTimeUTC, expireTimeUTC}, nil} + return &SessionReturn{&UserLoginSession{"test@test.com", "sessionHash", renewTimeUTC, expireTimeUTC}, nil} } func sessionErr() *SessionReturn { @@ -339,16 +360,24 @@ func rememberErr() *RememberMeReturn { } func sessionRemember(renewTimeUTC, expireTimeUTC time.Time) *SessionRememberReturn { - return &SessionRememberReturn{&UserLoginSession{1, "sessionHash", 2, renewTimeUTC, expireTimeUTC}, &UserLoginRememberMe{TokenHash: "PEaenWxYddN6Q_NT1PiOYfz4EsZu7jRXRlpAsNpBU-A=", ExpireTimeUTC: expireTimeUTC, RenewTimeUTC: renewTimeUTC}, nil} + return &SessionRememberReturn{&UserLoginSession{"test@test.com", "sessionHash", renewTimeUTC, expireTimeUTC}, &UserLoginRememberMe{TokenHash: "PEaenWxYddN6Q_NT1PiOYfz4EsZu7jRXRlpAsNpBU-A=", ExpireTimeUTC: expireTimeUTC, RenewTimeUTC: renewTimeUTC}, nil} } func sessionRememberErr() *SessionRememberReturn { return &SessionRememberReturn{nil, nil, errors.New("failed")} } -func verifyEmailSuccess() *VerifyEmailReturn { - return &VerifyEmailReturn{"email", nil} +func getEmailSessionSuccess() *GetEmailSessionReturn { + return &GetEmailSessionReturn{&emailSession{Email: "email", EmailVerifyHash: "hash"}, nil} } -func verifyEmailErr() *VerifyEmailReturn { - return &VerifyEmailReturn{"", errors.New("failed")} +func getEmailSessionErr() *GetEmailSessionReturn { + return &GetEmailSessionReturn{nil, errors.New("failed")} +} + +func getUserSuccess() *GetUserReturn { + return &GetUserReturn{&User{FullName: "name", PrimaryEmail: "test@test.com"}, nil} +} + +func getUserErr() *GetUserReturn { + return &GetUserReturn{nil, errors.New("failed")} } diff --git a/cryptoStore.go b/cryptoStore.go index 48df169..96010c6 100644 --- a/cryptoStore.go +++ b/cryptoStore.go @@ -95,6 +95,9 @@ func getRandomSalt(length, iterations int) (string, error) { } func cryptoHashEquals(in string, hash string) error { + if hash[0:7] == "{CRYPT}" { + hash = hash[7:] + } hashed, err := cryptoHashWSalt(in, hash) // sha512_crypt will strip out salt from hash if err != nil { return err diff --git a/cryptoStore_test.go b/cryptoStore_test.go new file mode 100644 index 0000000..2daaa57 --- /dev/null +++ b/cryptoStore_test.go @@ -0,0 +1,16 @@ +package main + +import ( + "testing" +) + +func TestGetHash(t *testing.T) { + /*for i := 0; i < 13; i++ { + code, hash, err := generateStringAndHash() + if err != nil { + i-- + continue + } + t.Logf(", '%s'), -- code: %s\n", hash, code) + }*/ +} diff --git a/loginStore.go b/loginStore.go index a64166c..3a8d70d 100644 --- a/loginStore.go +++ b/loginStore.go @@ -1,9 +1,13 @@ package main +import ( + "fmt" +) + type LoginStorer interface { Login(email, password string, rememberMe bool) (*UserLogin, error) - CreateLogin(email, fullName, password string) (*UserLogin, error) + CreateLogin(email, fullName, password string, mailQuota, fileQuota int) (*UserLogin, error) UpdateEmail() error UpdatePassword() error } @@ -33,23 +37,24 @@ func (s *loginStore) Login(email, password string, rememberMe bool) (*UserLogin, } if err := cryptoHashEquals(password, login.ProviderKey); err != nil { - return nil, newLoggedError("Invalid username or password", err) + return nil, newLoggedError("Invalid username or password +crypto"+login.ProviderKey, err) } return login, nil } -func (s *loginStore) CreateLogin(email, fullName, password string) (*UserLogin, error) { +/**************** TODO: send 0 for UID and GID numbers and empty quotas if mailQuota and fileQuota are 0 **********************/ +func (s *loginStore) CreateLogin(email, fullName, password string, mailQuota, fileQuota int) (*UserLogin, error) { passwordHash, err := cryptoHash(password) if err != nil { return nil, newLoggedError("Unable to create login", err) } - uidNumber := 0 - gidNumber := 0 + uidNumber := 10000 // vmail user + gidNumber := 10000 // vmail user homeDirectory := "/home" - mailQuota := "10 GB" - fileQuota := "10 GB" - login, err := s.backend.CreateLogin(email, passwordHash, fullName, homeDirectory, uidNumber, gidNumber, mailQuota, fileQuota) + mQuota := fmt.Sprintf("%dGB", mailQuota) + fQuota := fmt.Sprintf("%dGB", fileQuota) + login, err := s.backend.CreateLogin(email, passwordHash, fullName, homeDirectory, uidNumber, gidNumber, mQuota, fQuota) if err != nil { return nil, newLoggedError("Unable to create login", err) } diff --git a/loginStore_test.go b/loginStore_test.go index c3a287a..8d4015e 100644 --- a/loginStore_test.go +++ b/loginStore_test.go @@ -52,7 +52,7 @@ var loginTests = []struct { Scenario: "Incorrect password", Email: "email@example.com", Password: "wrongPassword", - GetUserLoginReturn: &LoginReturn{Login: &UserLogin{LoginID: 1, UserID: 1, ProviderKey: "1234"}}, + GetUserLoginReturn: &LoginReturn{Login: &UserLogin{Email: "test@test.com", ProviderKey: "1234"}}, MethodsCalled: []string{"GetUserLogin"}, ExpectedErr: "Invalid username or password", }, @@ -99,7 +99,7 @@ func (s *MockLoginStore) LoginBasic() (*UserLogin, error) { return s.LoginReturn.Login, s.LoginReturn.Err } -func (s *MockLoginStore) CreateLogin(email, fullName, password string) (*UserLogin, error) { +func (s *MockLoginStore) CreateLogin(email, fullName, password string, cloudQuota, fileQuota int) (*UserLogin, error) { return s.LoginReturn.Login, s.LoginReturn.Err } diff --git a/mailer_test.go b/mailer_test.go index 0bbaeae..2ec2955 100644 --- a/mailer_test.go +++ b/mailer_test.go @@ -38,35 +38,35 @@ func TestSends(t *testing.T) { } m.templateCache = template.Must(template.ParseFiles(m.VerifyEmailTemplate, m.WelcomeTemplate, m.NewLoginTemplate, m.LockedOutTemplate, m.EmailChangedTemplate, m.PasswordChangedTemplate)) - data := &VerifyEmailReturn{Email: "myemail@here.com"} - m.SendVerify("to", data) - if sender.LastBody != "verifyEmail:myemail@here.com" || sender.LastTo != "to" || sender.LastSubject != "verifyEmailSubject" { - t.Error("expected valid values", sender) + data := &emailSession{Email: "myemail@here.com"} + err := m.SendVerify("to", data) + if err != nil || sender.LastBody != "verifyEmail:myemail@here.com" || sender.LastTo != "to" || sender.LastSubject != "verifyEmailSubject" { + t.Error("expected valid values", sender, err) } - m.SendWelcome("to1", data) - if sender.LastBody != "welcomeEmail:myemail@here.com" || sender.LastTo != "to1" || sender.LastSubject != "welcomeSubject" { - t.Error("expected valid values", sender) + err = m.SendWelcome("to1", data) + if err != nil || sender.LastBody != "welcomeEmail:myemail@here.com" || sender.LastTo != "to1" || sender.LastSubject != "welcomeSubject" { + t.Error("expected valid values", sender, err) } - m.SendNewLogin("to2", data) - if sender.LastBody != "newLogin:myemail@here.com" || sender.LastTo != "to2" || sender.LastSubject != "newLoginSubject" { - t.Error("expected valid values", sender) + err = m.SendNewLogin("to2", data) + if err != nil || sender.LastBody != "newLogin:myemail@here.com" || sender.LastTo != "to2" || sender.LastSubject != "newLoginSubject" { + t.Error("expected valid values", sender, err) } - m.SendLockedOut("to3", data) - if sender.LastBody != "lockedOut:myemail@here.com" || sender.LastTo != "to3" || sender.LastSubject != "lockedOutSubject" { - t.Error("expected valid values", sender) + err = m.SendLockedOut("to3", data) + if err != nil || sender.LastBody != "lockedOut:myemail@here.com" || sender.LastTo != "to3" || sender.LastSubject != "lockedOutSubject" { + t.Error("expected valid values", sender, err) } - m.SendEmailChanged("to4", data) - if sender.LastBody != "emailChanged:myemail@here.com" || sender.LastTo != "to4" || sender.LastSubject != "emailChangedSubject" { - t.Error("expected valid values", sender) + err = m.SendEmailChanged("to4", data) + if err != nil || sender.LastBody != "emailChanged:myemail@here.com" || sender.LastTo != "to4" || sender.LastSubject != "emailChangedSubject" { + t.Error("expected valid values", sender, err) } - m.SendPasswordChanged("to5", data) - if sender.LastBody != "passwordChanged:myemail@here.com" || sender.LastTo != "to5" || sender.LastSubject != "passwordChangedSubject" { - t.Error("expected valid values", sender) + err = m.SendPasswordChanged("to5", data) + if err != nil || sender.LastBody != "passwordChanged:myemail@here.com" || sender.LastTo != "to5" || sender.LastSubject != "passwordChangedSubject" { + t.Error("expected valid values", sender, err) } } diff --git a/nginxauth.go b/nginxauth.go index 10ef456..6f7d89a 100644 --- a/nginxauth.go +++ b/nginxauth.go @@ -96,10 +96,11 @@ func newNginxAuth() (*nginxauth, error) { if err != nil { return nil, err } - u, err := newBackendDbUser(config.DbServer, config.DbPort, config.DbUser, config.DbPassword, config.DbDatabase, config.GetUserLoginQuery, config.AddUserQuery, config.VerifyEmailQuery, config.UpdateUserQuery) + u := NewBackendMemory() + /*u, err := newBackendDbUser(config.DbServer, config.DbPort, config.DbUser, config.DbPassword, config.DbDatabase, config.GetUserLoginQuery, config.AddUserQuery, config.VerifyEmailQuery, config.UpdateUserQuery) if err != nil { return nil, err - } + }*/ b := &backend{u: u, l: l, s: s} mailer, err := config.NewEmailer() diff --git a/sessionStore.go b/sessionStore.go index 0041ceb..48b24c2 100644 --- a/sessionStore.go +++ b/sessionStore.go @@ -7,7 +7,7 @@ import ( type SessionStorer interface { GetSession() (*UserLoginSession, error) - CreateSession(loginID, userID int, rememberMe bool) (*UserLoginSession, error) + CreateSession(email string, rememberMe bool) (*UserLoginSession, error) } type sessionCookie struct { @@ -136,7 +136,7 @@ func (s *sessionStore) renewSession(sessionID, sessionHash string, renewTimeUTC, return session, nil } -func (s *sessionStore) CreateSession(loginID, userID int, rememberMe bool) (*UserLoginSession, error) { +func (s *sessionStore) CreateSession(email string, rememberMe bool) (*UserLoginSession, error) { var err error var selector, token, tokenHash string if rememberMe { @@ -150,7 +150,7 @@ func (s *sessionStore) CreateSession(loginID, userID int, rememberMe bool) (*Use return nil, newLoggedError("Problem generating sessionId", nil) } - session, remember, err := s.b.CreateSession(loginID, userID, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration)) + session, remember, err := s.b.CreateSession(email, sessionHash, time.Now().UTC().Add(sessionRenewDuration), time.Now().UTC().Add(sessionExpireDuration), rememberMe, selector, tokenHash, time.Now().UTC().Add(rememberMeRenewDuration), time.Now().UTC().Add(rememberMeExpireDuration)) if err != nil { return nil, newLoggedError("Unable to create new session", err) } diff --git a/sessionStore_test.go b/sessionStore_test.go index 9d652e4..d5210f8 100644 --- a/sessionStore_test.go +++ b/sessionStore_test.go @@ -278,7 +278,7 @@ func TestCreateSession(t *testing.T) { for i, test := range createSessionTests { backend := &MockBackend{CreateSessionReturn: test.CreateSessionReturn} store := getSessionStore(nil, test.SessionCookie, test.RememberMeCookie, test.HasCookieGetError, test.HasCookiePutError, backend) - val, err := store.CreateSession(1, 1, test.RememberMe) + val, err := store.CreateSession("test@test.com", test.RememberMe) methods := store.b.(*MockBackend).MethodsCalled if (err == nil && test.ExpectedErr != "" || err != nil && test.ExpectedErr != err.Error()) || !collectionEqual(test.MethodsCalled, methods) { @@ -296,6 +296,6 @@ func (m *MockSessionStore) GetSession() (*UserLoginSession, error) { return m.SessionReturn.Session, m.SessionReturn.Err } -func (m *MockSessionStore) CreateSession(loginID, userID int, rememberMe bool) (*UserLoginSession, error) { +func (m *MockSessionStore) CreateSession(email string, rememberMe bool) (*UserLoginSession, error) { return m.SessionReturn.Session, m.SessionReturn.Err }