From 8a47c8d92fbcc0353147b31be5ac481faeb2a16e Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Thu, 16 Nov 2023 15:02:52 +0100 Subject: [PATCH] feat(GODT-2567): Simulate Answered/Forwarded behavior in GPA server --- server/backend/api.go | 24 ++++++++++- server/backend/message.go | 44 +++++++++++++------- server/messages.go | 2 +- server/server_test.go | 88 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 137 insertions(+), 21 deletions(-) diff --git a/server/backend/api.go b/server/backend/api.go index 323f3cf..d064cae 100644 --- a/server/backend/api.go +++ b/server/backend/api.go @@ -636,19 +636,21 @@ func (b *Backend) DeleteMessage(userID, messageID string) error { }) } -func (b *Backend) CreateDraft(userID, addrID string, draft proton.DraftTemplate, parentID string) (proton.Message, error) { +func (b *Backend) CreateDraft(userID, addrID string, draft proton.DraftTemplate, parentID string, action proton.CreateDraftAction) (proton.Message, error) { return withAcc(b, userID, func(acc *account) (proton.Message, error) { return withMessages(b, func(messages map[string]*message) (proton.Message, error) { return withLabels(b, func(labels map[string]*label) (proton.Message, error) { // Convert the parentID into externalRef.\ var parentRef string + var internalParentID string if parentID != "" { parentMsg, ok := messages[parentID] if ok { parentRef = "<" + parentMsg.externalID + ">" + internalParentID = parentID } } - msg := newMessageFromTemplate(addrID, draft, parentRef) + msg := newMessageFromTemplate(addrID, draft, parentRef, internalParentID, action) // Drafts automatically get the sysLabel "Drafts". msg.addLabel(proton.DraftsLabel, labels) @@ -712,6 +714,24 @@ func (b *Backend) SendMessage(userID, messageID string, packages []*proton.Messa msg.flags |= proton.MessageFlagSent msg.addLabel(proton.SentLabel, labels) + if parent, ok := messages[msg.internalParentID]; ok { + switch msg.draftAction { + case proton.ReplyAction: + parent.flags |= proton.MessageFlagReplied + case proton.ReplyAllAction: + parent.flags |= proton.MessageFlagRepliedAll + case proton.ForwardAction: + parent.flags |= proton.MessageFlagForwarded + } + + updateID, err := b.newUpdate(&messageUpdated{messageID: msg.internalParentID}) + if err != nil { + return proton.Message{}, err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + } + updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) if err != nil { return proton.Message{}, err diff --git a/server/backend/message.go b/server/backend/message.go index a3f7ce5..59c5873 100644 --- a/server/backend/message.go +++ b/server/backend/message.go @@ -13,12 +13,13 @@ import ( ) type message struct { - messageID string - externalID string - addrID string - labelIDs []string - attIDs []string - inReplyTo string + messageID string + externalID string + addrID string + labelIDs []string + attIDs []string + inReplyTo string + internalParentID string // sysLabel is the system label for the message. // If nil, the message's flags are used to determine the system label (inbox, sent, drafts). @@ -34,6 +35,8 @@ type message struct { replytos []*mail.Address date time.Time + draftAction proton.CreateDraftAction + armBody string mimeType rfc822.MIMEType @@ -92,13 +95,20 @@ func newMessageFromSent(addrID, armBody string, msg *message) *message { } } -func newMessageFromTemplate(addrID string, template proton.DraftTemplate, parentRef string) *message { +func newMessageFromTemplate( + addrID string, + template proton.DraftTemplate, + parentRef string, + internalParentID string, + action proton.CreateDraftAction, +) *message { return &message{ - messageID: uuid.NewString(), - externalID: template.ExternalID, - addrID: addrID, - sysLabel: pointer(""), - inReplyTo: parentRef, + messageID: uuid.NewString(), + externalID: template.ExternalID, + addrID: addrID, + sysLabel: pointer(""), + inReplyTo: parentRef, + internalParentID: internalParentID, subject: template.Subject, sender: template.Sender, @@ -107,6 +117,8 @@ func newMessageFromTemplate(addrID string, template proton.DraftTemplate, parent bccList: template.BCCList, unread: bool(template.Unread), + draftAction: action, + armBody: template.Body, mimeType: template.MIMEType, } @@ -186,9 +198,11 @@ func (msg *message) toMetadata(attData map[string][]byte, att map[string]*attach ReplyTos: msg.replytos, Size: messageSize, - Flags: msg.flags, - Unread: proton.Bool(msg.unread), - IsForwarded: msg.flags&proton.MessageFlagForwarded != 0, + Flags: msg.flags, + Unread: proton.Bool(msg.unread), + IsForwarded: msg.flags&proton.MessageFlagForwarded != 0, + IsReplied: msg.flags&proton.MessageFlagReplied != 0, + IsRepliedAll: msg.flags&proton.MessageFlagRepliedAll != 0, NumAttachments: len(attData), } diff --git a/server/messages.go b/server/messages.go index ad283ba..8aa5ef2 100644 --- a/server/messages.go +++ b/server/messages.go @@ -102,7 +102,7 @@ func (s *Server) postMailMessages(c *gin.Context) { return } - message, err := s.b.CreateDraft(c.GetString("UserID"), addrID, req.Message, req.ParentID) + message, err := s.b.CreateDraft(c.GetString("UserID"), addrID, req.Message, req.ParentID, req.Action) if err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return diff --git a/server/server_test.go b/server/server_test.go index da240e7..50c2cd0 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/ProtonMail/go-proton-api/server/backend" "net/http" "net/mail" "net/url" @@ -17,14 +16,14 @@ import ( "testing" "time" - "github.com/bradenaw/juniper/parallel" - "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/server/backend" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/iterator" + "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" @@ -2232,6 +2231,89 @@ func TestServer_GetMessageGroupCount(t *testing.T) { }) } +func TestServer_TestDraftActions(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) + require.NoError(t, err) + + type testData struct { + action proton.CreateDraftAction + flag proton.MessageFlag + } + + tests := []testData{ + { + action: proton.ReplyAction, + flag: proton.MessageFlagReplied, + }, + { + action: proton.ReplyAllAction, + flag: proton.MessageFlagRepliedAll, + }, + { + action: proton.ForwardAction, + flag: proton.MessageFlagForwarded, + }, + } + + importedMessages := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, 0, len(tests)) + + for i := 0; i < len(tests); i++ { + importedMessageID := importedMessages[i].MessageID + + msg, err := c.GetMessage(ctx, importedMessageID) + require.NoError(t, err) + + { + kr := addrKRs[addr[0].ID] + msg, err := c.CreateDraft(ctx, kr, proton.CreateDraftReq{ + Message: proton.DraftTemplate{ + Subject: "Foo", + Sender: &mail.Address{Address: addr[0].Email}, + ToList: []*mail.Address{{Address: "foo@bar"}}, + CCList: nil, + BCCList: nil, + }, + AttachmentKeyPackets: nil, + ParentID: msg.ID, + Action: tests[i].action, + }) + + require.NoError(t, err) + + var sreq proton.SendDraftReq + + require.NoError(t, sreq.AddTextPackage(kr, "Hello", "text/plain", map[string]proton.SendPreferences{}, map[string]*crypto.SessionKey{})) + + _, err = c.SendDraft(ctx, msg.ID, sreq) + require.NoError(t, err) + + msg, err = c.GetMessage(ctx, importedMessageID) + require.NoError(t, err) + require.True(t, msg.Flags&tests[i].flag != 0) + } + } + + }) + }) +} + func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()