diff --git a/client.go b/client.go index f9c056d..f30e45b 100644 --- a/client.go +++ b/client.go @@ -312,12 +312,12 @@ func (c *Client) Session(opt ...*options.SessionOptions) (*Session, error) { // - version of mongoDB server >= v4.0 // - Topology of mongoDB server is not Single // At the same time, please pay attention to the following -// - make sure all operations in callback use the sessCtx as context parameter -// - if operations in callback takes more than(include equal) 120s, the operations will not take effect, -// - if operation in callback return qmgo.ErrTransactionRetry, -// the whole transaction will retry, so this transaction must be idempotent -// - if operations in callback return qmgo.ErrTransactionNotSupported, -// - If the ctx parameter already has a Session attached to it, it will be replaced by this session. +// - make sure all operations in callback use the sessCtx as context parameter +// - if operations in callback takes more than(include equal) 120s, the operations will not take effect, +// - if operation in callback return qmgo.ErrTransactionRetry, +// the whole transaction will retry, so this transaction must be idempotent +// - if operations in callback return qmgo.ErrTransactionNotSupported, +// - If the ctx parameter already has a Session attached to it, it will be replaced by this session. func (c *Client) DoTransaction(ctx context.Context, callback func(sessCtx context.Context) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { if !c.transactionAllowed() { return nil, ErrTransactionNotSupported diff --git a/session.go b/session.go index 900cb37..f11c158 100644 --- a/session.go +++ b/session.go @@ -28,17 +28,17 @@ type Session struct { } // StartTransaction starts transaction -//precondition: -//- version of mongoDB server >= v4.0 -//- Topology of mongoDB server is not Single -//At the same time, please pay attention to the following -//- make sure all operations in callback use the sessCtx as context parameter -//- Dont forget to call EndSession if session is not used anymore -//- if operations in callback takes more than(include equal) 120s, the operations will not take effect, -//- if operation in callback return qmgo.ErrTransactionRetry, -// the whole transaction will retry, so this transaction must be idempotent -//- if operations in callback return qmgo.ErrTransactionNotSupported, -//- If the ctx parameter already has a Session attached to it, it will be replaced by this session. +// precondition: +// - version of mongoDB server >= v4.0 +// - Topology of mongoDB server is not Single +// At the same time, please pay attention to the following +// - make sure all operations in callback use the sessCtx as context parameter +// - Dont forget to call EndSession if session is not used anymore +// - if operations in callback takes more than(include equal) 120s, the operations will not take effect, +// - if operation in callback return qmgo.ErrTransactionRetry, +// the whole transaction will retry, so this transaction must be idempotent +// - if operations in callback return qmgo.ErrTransactionNotSupported, +// - If the ctx parameter already has a Session attached to it, it will be replaced by this session. func (s *Session) StartTransaction(ctx context.Context, cb func(sessCtx context.Context) (interface{}, error), opts ...*opts.TransactionOptions) (interface{}, error) { transactionOpts := options.Transaction() if len(opts) > 0 && opts[0].TransactionOptions != nil { @@ -51,6 +51,27 @@ func (s *Session) StartTransaction(ctx context.Context, cb func(sessCtx context. return result, nil } +func (s *Session) StartAsyncTransaction(ctx context.Context, opts ...*opts.TransactionOptions) (context.Context, error) { + transactionOpts := options.Transaction() + if len(opts) > 0 && opts[0].TransactionOptions != nil { + transactionOpts = opts[0].TransactionOptions + } + + sCtx := mongo.NewSessionContext(ctx, s.session) + + err := s.session.StartTransaction(transactionOpts) + + return sCtx, err +} + +func (s *Session) CommitAsyncTransaction(ctx context.Context) error { + return s.session.CommitTransaction(ctx) +} + +func (s *Session) AbortAsyncTransaction(ctx context.Context) error { + return s.session.AbortTransaction(ctx) +} + // EndSession will abort any existing transactions and close the session. func (s *Session) EndSession(ctx context.Context) { s.session.EndSession(ctx) diff --git a/session_test.go b/session_test.go index df39f5a..a5ededc 100644 --- a/session_test.go +++ b/session_test.go @@ -50,6 +50,82 @@ func initTransactionClient(coll string) *QmgoClient { return qClient } + +func TestClient_AsyncTransaction(t *testing.T) { + ast := require.New(t) + ctx := context.Background() + cli := initTransactionClient("test") + defer cli.DropDatabase(ctx) + + session, err := cli.Session() + if err != nil { + ast.NoError(err) + } + + defer session.EndSession(ctx) + + sCtx, err := session.StartAsyncTransaction(ctx) + if err != nil { + ast.NoError(err) + } + + if _, err := cli.InsertOne(sCtx, bson.D{{"abc", int32(1)}}); err != nil { + ast.NoError(err) + } + if _, err := cli.InsertOne(sCtx, bson.D{{"xyz", int32(999)}}); err != nil { + ast.NoError(err) + } + + if err := session.CommitAsyncTransaction(sCtx); err != nil { + ast.NoError(err) + } + + r := bson.M{} + cli.Find(ctx, bson.M{"abc": 1}).One(&r) + ast.Equal(r["abc"], int32(1)) + + cli.Find(ctx, bson.M{"xyz": 999}).One(&r) + ast.Equal(r["xyz"], int32(999)) +} + +func TestClient_AsyncAbortTransaction(t *testing.T) { + ast := require.New(t) + ctx := context.Background() + cli := initTransactionClient("test") + defer cli.DropDatabase(ctx) + session, err := cli.Session() + if err != nil { + ast.NoError(err) + } + + sCtx, err := session.StartAsyncTransaction(ctx) + if err != nil { + ast.NoError(err) + } + + defer session.EndSession(ctx) + defer session.AbortAsyncTransaction(sCtx) + + if _, err := cli.InsertOne(sCtx, bson.D{{"abc", int32(1)}}); err != nil { + ast.NoError(err) + } + + if _, err := cli.InsertOne(sCtx, bson.D{{"xyz", int32(999)}}); err != nil { + ast.NoError(err) + } + + if err := session.AbortAsyncTransaction(sCtx); err != nil { + ast.NoError(err) + } + + r := bson.M{} + cli.Find(ctx, bson.M{"abc": 1}).One(&r) + ast.Empty(r) + + cli.Find(ctx, bson.M{"xyz": 999}).One(&r) + ast.Empty(r) +} + func TestClient_DoTransaction(t *testing.T) { ast := require.New(t) ctx := context.Background()