diff --git a/internal/rpc/chat/password.go b/internal/rpc/chat/password.go index 7c9657a0..73d34a36 100644 --- a/internal/rpc/chat/password.go +++ b/internal/rpc/chat/password.go @@ -16,7 +16,6 @@ package chat import ( "context" - "github.com/openimsdk/tools/errs" "github.com/openimsdk/chat/pkg/common/constant" @@ -28,6 +27,11 @@ func (o *chatSvr) ResetPassword(ctx context.Context, req *chat.ResetPasswordReq) if req.Password == "" { return nil, errs.ErrArgs.WrapMsg("password must be set") } + if req.AreaCode == "" || req.PhoneNumber == "" { + if !(req.AreaCode == "" && req.PhoneNumber == "") { + return nil, errs.ErrArgs.WrapMsg("area code and phone number must set together") + } + } var verifyCodeID string var err error if req.Email == "" { @@ -39,22 +43,19 @@ func (o *chatSvr) ResetPassword(ctx context.Context, req *chat.ResetPasswordReq) if err != nil { return nil, err } - + var account string if req.Email == "" { - attribute, err := o.Database.TakeAttributeByPhone(ctx, req.AreaCode, req.PhoneNumber) - if err != nil { - return nil, err - } - err = o.Database.UpdatePasswordAndDeleteVerifyCode(ctx, attribute.UserID, req.Password, verifyCodeID) + account = BuildCredentialPhone(req.AreaCode, req.PhoneNumber) } else { - attribute, err := o.Database.TakeAttributeByEmail(ctx, req.Email) - if err != nil { - return nil, err - } - err = o.Database.UpdatePasswordAndDeleteVerifyCode(ctx, attribute.UserID, req.Password, verifyCodeID) - if err != nil { - return nil, err - } + account = req.Email + } + cred, err := o.Database.TakeCredentialByAccount(ctx, account) + if err != nil { + return nil, err + } + err = o.Database.UpdatePasswordAndDeleteVerifyCode(ctx, cred.UserID, req.Password, verifyCodeID) + if err != nil { + return nil, err } return &chat.ResetPasswordResp{}, nil } diff --git a/pkg/common/db/database/chat.go b/pkg/common/db/database/chat.go index 55acd558..5898fd73 100644 --- a/pkg/common/db/database/chat.go +++ b/pkg/common/db/database/chat.go @@ -255,6 +255,9 @@ func (o *ChatDatabase) UpdatePasswordAndDeleteVerifyCode(ctx context.Context, us if err := o.account.UpdatePassword(ctx, userID, password); err != nil { return err } + if codeID == "" { + return nil + } if err := o.verifyCode.Delete(ctx, codeID); err != nil { return err }