diff --git a/pkg/runtime/context/context.go b/pkg/runtime/context/context.go index 478f4baa..4326461b 100644 --- a/pkg/runtime/context/context.go +++ b/pkg/runtime/context/context.go @@ -32,12 +32,29 @@ const ( _flagWrite ) +// TxState Transaction status +type TxState int64 + +const ( + _ TxState = iota + TrxStarted // CompositeTx Default state + TrxPreparing // All SQL statements are executed, and before the Commit statement executes + TrxPrepared // All SQL statements are executed, and before the Commit statement executes + TrxCommitting // After preparing is completed, ready to start execution + TrxCommitted // Officially complete the Commit action + TrxRolledBacking + TrxRolledBacked + TrxAborted + TrxUnknown // Unknown transaction +) + type ( - keyFlag struct{} - keyNodeLabel struct{} - keyDefaultDBGroup struct{} - keyHints struct{} - keyTransactionID struct{} + keyFlag struct{} + keyNodeLabel struct{} + keyDefaultDBGroup struct{} + keyHints struct{} + keyTransactionID struct{} + keyTransactionStatus struct{} ) type cFlag uint8 @@ -75,7 +92,7 @@ func WithHints(ctx context.Context, hints []*hint.Hint) context.Context { // Tenant extracts the tenant. func Tenant(ctx context.Context) string { - return isString(ctx, proto.ContextKeyTenant{}) + return getString(ctx, proto.ContextKeyTenant{}) } // IsRead returns true if this is a read operation @@ -95,25 +112,29 @@ func IsDirect(ctx context.Context) bool { // SQL returns the original sql string. func SQL(ctx context.Context) string { - return isString(ctx, proto.ContextKeySQL{}) + return getString(ctx, proto.ContextKeySQL{}) } func Schema(ctx context.Context) string { - return isString(ctx, proto.ContextKeySchema{}) + return getString(ctx, proto.ContextKeySchema{}) } func Version(ctx context.Context) string { - return isString(ctx, proto.ContextKeyServerVersion{}) + return getString(ctx, proto.ContextKeyServerVersion{}) } // NodeLabel returns the label of node. func NodeLabel(ctx context.Context) string { - return isString(ctx, keyNodeLabel{}) + return getString(ctx, keyNodeLabel{}) } // TransactionID returns the transactions id func TransactionID(ctx context.Context) string { - return isString(ctx, keyTransactionID{}) + return getString(ctx, keyTransactionID{}) +} + +func TransactionStatus(ctx context.Context) TxState { + return getTxStatus(ctx, keyTransactionStatus{}) } // Hints extracts the hints. @@ -144,9 +165,18 @@ func getFlag(ctx context.Context) cFlag { return f } -func isString(ctx context.Context, v any) string { +func getString(ctx context.Context, v any) string { if data, ok := ctx.Value(v).(string); ok { return data } return "" } + +func getTxStatus(ctx context.Context, v any) TxState { + if data, ok := ctx.Value(v).(int32); ok { + if data >= int32(TrxStarted) && data <= int32(TrxAborted) { + return TxState(data) + } + } + return TrxUnknown +} diff --git a/pkg/runtime/transaction/fault_decision.go b/pkg/runtime/transaction/fault_decision.go index 8d5eafd9..f1da0a8b 100644 --- a/pkg/runtime/transaction/fault_decision.go +++ b/pkg/runtime/transaction/fault_decision.go @@ -31,15 +31,15 @@ type TxFaultDecisionExecutor struct { func (bm *TxFaultDecisionExecutor) Run() { } -func (bm *TxFaultDecisionExecutor) scanUnFinishTxLog() ([]TrxLog, error) { +func (bm *TxFaultDecisionExecutor) scanUnFinishTxLog() ([]GlobalTrxLog, error) { return nil, nil } -func (bm *TxFaultDecisionExecutor) handlePreparing(tx TrxLog) { +func (bm *TxFaultDecisionExecutor) handlePreparing(tx GlobalTrxLog) { } -func (bm *TxFaultDecisionExecutor) handleCommitting(tx TrxLog) { +func (bm *TxFaultDecisionExecutor) handleCommitting(tx GlobalTrxLog) { } -func (bm *TxFaultDecisionExecutor) handleAborting(tx TrxLog) { +func (bm *TxFaultDecisionExecutor) handleAborting(tx GlobalTrxLog) { } diff --git a/pkg/runtime/transaction/hook.go b/pkg/runtime/transaction/hook.go index c7994816..c7a3f0f0 100644 --- a/pkg/runtime/transaction/hook.go +++ b/pkg/runtime/transaction/hook.go @@ -19,6 +19,7 @@ package transaction import ( "context" + rcontext "github.com/arana-db/arana/pkg/runtime/context" ) import ( @@ -40,19 +41,19 @@ func NewXAHook(tenant string, enable bool) (*xaHook, error) { enable: enable, } - trxStateChangeFunc := map[runtime.TxState]handleFunc{ - runtime.TrxActive: xh.onActive, - runtime.TrxPreparing: xh.onPreparing, - runtime.TrxPrepared: xh.onPrepared, - runtime.TrxCommitting: xh.onCommitting, - runtime.TrxCommitted: xh.onCommitted, - runtime.TrxAborting: xh.onAborting, - runtime.TrxRollback: xh.onRollbackOnly, - runtime.TrxRolledBack: xh.onRolledBack, + trxStateChangeFunc := map[rcontext.TxState]handleFunc{ + rcontext.TrxStarted: xh.onStarted, + rcontext.TrxPreparing: xh.onPreparing, + rcontext.TrxPrepared: xh.onPrepared, + rcontext.TrxCommitting: xh.onCommitting, + rcontext.TrxCommitted: xh.onCommitted, + rcontext.TrxAborted: xh.onAborting, + rcontext.TrxRolledBacking: xh.onRollbackOnly, + rcontext.TrxRolledBacked: xh.onRolledBack, } xh.trxMgr = trxMgr - xh.trxLog = &TrxLog{} + xh.trxLog = &GlobalTrxLog{} xh.trxStateChangeFunc = trxStateChangeFunc return xh, nil @@ -63,15 +64,15 @@ func NewXAHook(tenant string, enable bool) (*xaHook, error) { type xaHook struct { enable bool trxMgr *TrxManager - trxLog *TrxLog - trxStateChangeFunc map[runtime.TxState]handleFunc + trxLog *GlobalTrxLog + trxStateChangeFunc map[rcontext.TxState]handleFunc } -func (xh *xaHook) OnTxStateChange(ctx context.Context, state runtime.TxState, tx runtime.CompositeTx) error { +func (xh *xaHook) OnTxStateChange(ctx context.Context, state rcontext.TxState, tx runtime.CompositeTx) error { if !xh.enable { return nil } - xh.trxLog.State = state + xh.trxLog.Status = state handle, ok := xh.trxStateChangeFunc[state] if ok { return handle(ctx, tx) @@ -84,18 +85,22 @@ func (xh *xaHook) OnCreateBranchTx(ctx context.Context, tx runtime.BranchTx) { if !xh.enable { return } - xh.trxLog.Participants = append(xh.trxLog.Participants, TrxParticipant{ - NodeID: "", - RemoteAddr: tx.GetConn().GetDatabaseConn().GetNetConn().RemoteAddr().String(), - Schema: tx.GetConn().DBName(), - }) + // TODO: add branch trx log + //xh.trxLog.BranchTrxLogs = append(xh.trxLog.BranchTrxLogs, BranchTrxLog{ + // NodeID: "", + // RemoteAddr: tx.GetConn().GetDatabaseConn().GetNetConn().RemoteAddr().String(), + // Schema: tx.GetConn().DBName(), + //}) } -func (xh *xaHook) onActive(ctx context.Context, tx runtime.CompositeTx) error { +func (xh *xaHook) onStarted(ctx context.Context, tx runtime.CompositeTx) error { tx.SetBeginFunc(StartXA) xh.trxLog.TrxID = tx.GetTrxID() - xh.trxLog.State = tx.GetTxState() + xh.trxLog.Status = tx.GetTxState() xh.trxLog.Tenant = tx.GetTenant() + xh.trxLog.StartTime = tx.GetStartTime() + xh.trxLog.ExpectedEndTime = tx.GetExpectedEndTime() + return nil } @@ -103,14 +108,14 @@ func (xh *xaHook) onPreparing(ctx context.Context, tx runtime.CompositeTx) error tx.Range(func(tx runtime.BranchTx) { tx.SetPrepareFunc(PrepareXA) }) - if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil { + if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil { return err } return nil } func (xh *xaHook) onPrepared(ctx context.Context, tx runtime.CompositeTx) error { - if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil { + if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil { return err } return nil @@ -120,14 +125,14 @@ func (xh *xaHook) onCommitting(ctx context.Context, tx runtime.CompositeTx) erro tx.Range(func(tx runtime.BranchTx) { tx.SetCommitFunc(CommitXA) }) - if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil { + if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil { return err } return nil } func (xh *xaHook) onCommitted(ctx context.Context, tx runtime.CompositeTx) error { - if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil { + if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil { return err } return nil @@ -137,7 +142,7 @@ func (xh *xaHook) onAborting(ctx context.Context, tx runtime.CompositeTx) error tx.Range(func(bTx runtime.BranchTx) { bTx.SetCommitFunc(RollbackXA) }) - if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil { + if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil { return err } // auto execute XA rollback action @@ -151,15 +156,15 @@ func (xh *xaHook) onRollbackOnly(ctx context.Context, tx runtime.CompositeTx) er tx.Range(func(tx runtime.BranchTx) { tx.SetCommitFunc(RollbackXA) }) - if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil { + if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil { return err } return nil } func (xh *xaHook) onRolledBack(ctx context.Context, tx runtime.CompositeTx) error { - xh.trxLog.State = runtime.TrxRolledBack - if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil { + xh.trxLog.Status = rcontext.TrxRolledBacking + if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil { return err } return nil diff --git a/pkg/runtime/transaction/trx_log.go b/pkg/runtime/transaction/trx_log.go index 705b8254..279d147a 100644 --- a/pkg/runtime/transaction/trx_log.go +++ b/pkg/runtime/transaction/trx_log.go @@ -19,8 +19,8 @@ package transaction import ( "context" - "encoding/json" "fmt" + rcontext "github.com/arana-db/arana/pkg/runtime/context" "strings" "sync" "time" @@ -28,7 +28,6 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/runtime" ) var ( @@ -48,28 +47,46 @@ var ( const ( // TODO 启用 mysql 的二级分区功能,解决清理 tx log 的问题 - _initTxLog = ` -CREATE TABLE IF NOT EXISTS __arana_trx_log -( - log_id bigint(20) auto_increment COMMENT 'primary key', - txr_id varchar(255) NOT NULL COMMENT 'transaction uniq id', - tenant varchar(255) NOT NULL COMMENT 'tenant info', - server_id int(10) UNSIGNED NOT NULL COMMENT 'arana server node id', - status int(10) NOT NULL COMMENT 'transaction status, preparing:2,prepared:3,committing:4,committed:5,aborting:6,rollback:7,finish:8,rolledBack:9', - participant varchar(500) COMMENT 'transaction participants, content is mysql node info', - start_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - PRIMARY KEY (log_id), - UNIQUE KEY (txr_id) -) ENGINE = InnoDB - CHARSET = utf8 -` - insSql = "REPLACE INTO __arana_trx_log(trx_id, tenant, server_id, status, participant, start_time, update_time) VALUES (?,?,?,?,?,sysdate(),sysdate())" - delSql = "DELETE FROM __arana_trx_log WHERE trx_id = ?" - selectSql = "SELECT trx_id, tenant, server_id, status, participant, start_time, update_time FROM __arana_trx_log WHERE 1=1 %s ORDER BY update_time LIMIT ? OFFSET ?" + _initGlobalTxLog = ` + CREATE TABLE __arana_global_trx_log ( + log_id bigint NOT NULL AUTO_INCREMENT COMMENT 'primary key', + txr_id varchar(255) NOT NULL COMMENT 'transaction uniq id', + tenant varchar(255) NOT NULL COMMENT 'tenant info', + server_id int unsigned NOT NULL COMMENT 'arana server node id', + status int NOT NULL COMMENT 'transaction status: started:1,preparing:2,prepared:3,committing:4,committed:5,rollbacking:6,rollbacked:7,failed:8', + start_time datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'transaction start time', + expected_end_time datetime NOT NULL COMMENT 'global transaction expected end time', + update_time datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (log_id), + UNIQUE KEY txr_id (txr_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb3 + ` + insertGlobalSql = "INSERT INTO __arana_global_trx_log (txr_id, tenant, server_id, status, start_time, expected_end_time) VALUES (?, ?, ?, ?, ?, ?);" + deleteGlobalSql = "DELETE FROM __arana_global_trx_log WHERE trx_id = ?" + selectGlobalSql = "SELECT log_id, txr_id, tenant, server_id, status, start_time, expected_end_time, update_time FROM __arana_trx_log WHERE 1=1 %s ORDER BY expected_end_time LIMIT ? OFFSET ?" +) + +const ( + _initBranchTxLog = ` + CREATE TABLE __arana_branch_trx_log ( + log_id bigint NOT NULL AUTO_INCREMENT COMMENT 'primary key', + txr_id varchar(255) NOT NULL COMMENT 'transaction uniq id', + branch_id varchar(255) NOT NULL COMMENT 'branch transaction key', + participant_id int unsigned NOT NULL COMMENT 'transaction participants, content is mysql node info', + status int NOT NULL COMMENT 'transaction status: started:1,preparing:2,prepared:3,committing:4,committed:5,rollbacking:6,rollbacked:7,failed:8', + start_time datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'branch transaction start time', + update_time datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (log_id), + UNIQUE KEY txr_branch_id (txr_id, branch_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb3 + ` + insertBranchSql = "INSERT INTO __arana_branch_trx_log (txr_id, branch_id, participant_id, status, start_time) VALUES (?, ?, ?, ?, ?);" + deleteBranchSql = "DELETE FROM __arana_global_trx_log WHERE trx_id = ? and branch_id=?" + selectBranchSql = "SELECT log_id, txr_id, branch_id, participant_id, status, start_time, update_time FROM __arana_branch_trx_log WHERE 1=1 %s ORDER BY expected_end_time LIMIT ? OFFSET ?" ) // TxLogManager Transaction log management +// TODO type TxLogManager struct { sysDB proto.DB } @@ -79,60 +96,75 @@ func (gm *TxLogManager) Init(delay time.Duration) error { var err error _initTxLogOnce.Do(func() { ctx := context.Background() - res, _, err := gm.sysDB.Call(ctx, _initTxLog) + res, _, err := gm.sysDB.Call(ctx, _initGlobalTxLog) if err != nil { return } _, _ = res.RowsAffected() - _txLogCleanTimer = time.AfterFunc(delay, gm.runCleanTxLogTask) + _txLogCleanTimer = time.AfterFunc(delay, gm.runCleanGlobalTxLogTask) + + res, _, err = gm.sysDB.Call(ctx, _initBranchTxLog) + if err != nil { + return + } + _, _ = res.RowsAffected() + _txLogCleanTimer = time.AfterFunc(delay, gm.runCleanBranchTxLogTask) }) return err } -// AddOrUpdateTxLog Add or update transaction log -func (gm *TxLogManager) AddOrUpdateTxLog(l TrxLog) error { - participants, err := json.Marshal(l.Participants) - if err != nil { - return err - } +// AddOrUpdateTxLog Add or update global transaction log +func (gm *TxLogManager) AddOrUpdateGlobalTxLog(l GlobalTrxLog) error { trxIdVal, _ := proto.NewValue(l.TrxID) tenantVal, _ := proto.NewValue(l.Tenant) serverIdVal, _ := proto.NewValue(l.ServerID) - stateVal, _ := proto.NewValue(int32(l.State)) - participantsVal, _ := proto.NewValue(string(participants)) + statusVal, _ := proto.NewValue(int32(l.Status)) + startTimeVal, _ := proto.NewValue(l.StartTime) + expectedEndTimeVal, _ := proto.NewValue(l.ExpectedEndTime) args := []proto.Value{ trxIdVal, tenantVal, serverIdVal, - stateVal, - participantsVal, + statusVal, + startTimeVal, + expectedEndTimeVal, } - _, _, err = gm.sysDB.Call(context.Background(), insSql, args...) + _, _, err := gm.sysDB.Call(context.Background(), insertGlobalSql, args...) return err } +// AddOrUpdateTxLog Add or update branch transaction log +func (gm *TxLogManager) AddOrUpdateBranchTxLog(l BranchTrxLog) error { + panic("implement me") +} + // DeleteTxLog Delete transaction log -func (gm *TxLogManager) DeleteTxLog(l TrxLog) error { +func (gm *TxLogManager) DeleteGlobalTxLog(l GlobalTrxLog) error { trxIdVal, _ := proto.NewValue(l.TrxID) args := []proto.Value{ trxIdVal, } - _, _, err := gm.sysDB.Call(context.Background(), delSql, args...) + _, _, err := gm.sysDB.Call(context.Background(), deleteGlobalSql, args...) return err } -// ScanTxLog Scanning transaction -func (gm *TxLogManager) ScanTxLog(pageNo, pageSize uint64, conditions []Condition) (uint32, []TrxLog, error) { +// TODO +func (gm *TxLogManager) DeleteBranchTxLog(l BranchTrxLog) error { + panic("implement me") +} + +// Global ScanTxLog Scanning transaction +func (gm *TxLogManager) ScanGlobalTxLog(pageNo, pageSize uint64, conditions []Condition) (uint32, []GlobalTrxLog, error) { var ( - whereBuilder []string - args []proto.Value - logs []TrxLog - num uint32 - dest []proto.Value - log TrxLog - participants []TrxParticipant - serverId int64 - state int64 + whereBuilder []string + args []proto.Value + logs []GlobalTrxLog + num uint32 + dest []proto.Value + serverId int64 + expectedEndTime int64 + startTime int64 + state int64 ) for i := range conditions { @@ -149,7 +181,7 @@ func (gm *TxLogManager) ScanTxLog(pageNo, pageSize uint64, conditions []Conditio offset := proto.NewValueUint64((pageNo - 1) * pageSize) args = append(args, limit, offset) - conditionSelectSql := fmt.Sprintf(selectSql, strings.Join(whereBuilder, " ")) + conditionSelectSql := fmt.Sprintf(selectGlobalSql, strings.Join(whereBuilder, " ")) rows, _, err := gm.sysDB.Call(context.Background(), conditionSelectSql, args...) if err != nil { return 0, nil, err @@ -163,7 +195,8 @@ func (gm *TxLogManager) ScanTxLog(pageNo, pageSize uint64, conditions []Conditio if row == nil { break } - if err := row.Scan(dest[:]); err != nil { + var log GlobalTrxLog + if err = row.Scan(dest[:]); err != nil { return 0, nil, err } log.TrxID = dest[0].String() @@ -171,36 +204,41 @@ func (gm *TxLogManager) ScanTxLog(pageNo, pageSize uint64, conditions []Conditio serverId, _ = dest[2].Int64() log.ServerID = int32(serverId) state, _ = dest[3].Int64() - log.State = runtime.TxState(int32(state)) - - if err := json.Unmarshal([]byte(dest[4].String()), &participants); err != nil { - return 0, nil, err - } - log.Participants = participants + log.Status = rcontext.TxState(state) + expectedEndTime, _ = dest[4].Int64() + log.ExpectedEndTime = time.UnixMilli(expectedEndTime) + startTime, _ = dest[5].Int64() + log.StartTime = time.UnixMilli(startTime) logs = append(logs, log) num++ } return num, logs, nil } +// Branch ScanTxLog Scanning transaction +// TODO +func (gm *TxLogManager) ScanBranchTxLog(pageNo, pageSize uint64, conditions []Condition) (uint32, []BranchTrxLog, error) { + panic("implement me") +} + // runCleanTxLogTask execute the transaction log cleanup action, and clean up the __arana_tx_log secondary // partition table according to the day level or hour level. // the execution of this task requires distributed task preemption based on the metadata DB -func (gm *TxLogManager) runCleanTxLogTask() { +func (gm *TxLogManager) runCleanGlobalTxLogTask() { var ( pageNo uint64 pageSize uint64 = 50 conditions = []Condition{ { FiledName: "status", - Operation: Equal, - Value: runtime.TrxFinish, + Operation: In, + Value: []int32{int32(rcontext.TrxRolledBacked), int32(rcontext.TrxCommitted), int32(rcontext.TrxAborted)}, }, } ) - var txLogs []TrxLog + var txLogs []GlobalTrxLog for { - total, logs, err := gm.ScanTxLog(pageNo, pageSize, conditions) + total, logs, err := gm.ScanGlobalTxLog(pageNo, pageSize, conditions) if err != nil { break } @@ -210,6 +248,11 @@ func (gm *TxLogManager) runCleanTxLogTask() { } } for _, l := range txLogs { - gm.DeleteTxLog(l) + gm.DeleteGlobalTxLog(l) } } + +// TODO +func (gm *TxLogManager) runCleanBranchTxLogTask() { + panic("implement me") +} diff --git a/pkg/runtime/transaction/trx_log_test.go b/pkg/runtime/transaction/trx_log_test.go index 9bfdf686..6c686a89 100644 --- a/pkg/runtime/transaction/trx_log_test.go +++ b/pkg/runtime/transaction/trx_log_test.go @@ -19,7 +19,7 @@ package transaction import ( "context" - "encoding/json" + rcontext "github.com/arana-db/arana/pkg/runtime/context" "testing" ) @@ -31,7 +31,6 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/runtime" "github.com/arana-db/arana/testdata" ) @@ -42,20 +41,19 @@ func TestDeleteTxLog(t *testing.T) { txLogManager := &TxLogManager{ sysDB: mockDB, } - testTrxLog := TrxLog{ - TrxID: "test_delete_id", - ServerID: 1, - State: runtime.TrxActive, - Participants: []TrxParticipant{{NodeID: "1", RemoteAddr: "127.0.0.1", Schema: "schema"}}, - Tenant: "test_tenant", + testTrxLog := GlobalTrxLog{ + TrxID: "test_delete_id", + ServerID: 1, + Status: rcontext.TrxStarted, + Tenant: "test_tenant", } trxIdVal, _ := proto.NewValue("test_delete_id") mockDB.EXPECT().Call( context.Background(), - "DELETE FROM __arana_trx_log WHERE trx_id = ?", + "DELETE FROM __arana_global_trx_log WHERE trx_id = ?", gomock.Eq([]proto.Value{trxIdVal}), ).Return(nil, uint16(0), nil).Times(1) - err := txLogManager.DeleteTxLog(testTrxLog) + err := txLogManager.DeleteGlobalTxLog(testTrxLog) assert.NoError(t, err) } @@ -66,33 +64,32 @@ func TestAddOrUpdateTxLog(t *testing.T) { txLogManager := &TxLogManager{ sysDB: mockDB, } - testTrxLog := TrxLog{ - TrxID: "test_add_or_update_id", - ServerID: 1, - State: runtime.TrxActive, - Participants: []TrxParticipant{{NodeID: "1", RemoteAddr: "127.0.0.1", Schema: "schema"}}, - Tenant: "test_tenant", + testTrxLog := GlobalTrxLog{ + TrxID: "test_add_or_update_id", + ServerID: 1, + Status: rcontext.TrxStarted, + Tenant: "test_tenant", } - participants, err := json.Marshal(testTrxLog.Participants) - assert.NoError(t, err) trxIdVal, _ := proto.NewValue(testTrxLog.TrxID) tenantVal, _ := proto.NewValue(testTrxLog.Tenant) serverIdVal, _ := proto.NewValue(testTrxLog.ServerID) - stateVal, _ := proto.NewValue(int32(testTrxLog.State)) - participantsVal, _ := proto.NewValue(string(participants)) + stateVal, _ := proto.NewValue(int32(testTrxLog.Status)) + startTime, _ := proto.NewValue(testTrxLog.StartTime) + exceptEndTime, _ := proto.NewValue(testTrxLog.ExpectedEndTime) args := []proto.Value{ trxIdVal, tenantVal, serverIdVal, stateVal, - participantsVal, + startTime, + exceptEndTime, } mockDB.EXPECT().Call( context.Background(), - "REPLACE INTO __arana_trx_log(trx_id, tenant, server_id, status, participant, start_time, update_time) VALUES (?,?,?,?,?,sysdate(),sysdate())", + "INSERT INTO __arana_global_trx_log (txr_id, tenant, server_id, status, start_time, expected_end_time) VALUES (?, ?, ?, ?, ?, ?);", args, ).Return(nil, uint16(0), nil).Times(1) - err = txLogManager.AddOrUpdateTxLog(testTrxLog) + err := txLogManager.AddOrUpdateGlobalTxLog(testTrxLog) assert.NoError(t, err) } diff --git a/pkg/runtime/transaction/types.go b/pkg/runtime/transaction/types.go index f8bebe29..247311d4 100644 --- a/pkg/runtime/transaction/types.go +++ b/pkg/runtime/transaction/types.go @@ -18,16 +18,30 @@ package transaction import ( - "github.com/arana-db/arana/pkg/runtime" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "time" ) -// TrxLog arana tx log -type TrxLog struct { - TrxID string - ServerID int32 - State runtime.TxState - Participants []TrxParticipant - Tenant string +// Global TrxLog arana tx log +type GlobalTrxLog struct { + TrxID string + Tenant string + ServerID int32 + Status rcontext.TxState + StartTime time.Time + ExpectedEndTime time.Time + BranchTrxLogs []BranchTrxLog +} + +// Branch TrxLog arana tx log +type BranchTrxLog struct { + TrxID string + BranchID string + ParticipantID string + Tenant string + ServerID int32 + Status rcontext.TxState + StartTime int64 } // TrxParticipant join target trx all node info @@ -41,6 +55,7 @@ type dBOperation string const ( Like dBOperation = "LIKE" + In dBOperation = "IN" Equal dBOperation = "=" NotEqual dBOperation = "<>" LessThan dBOperation = "<" diff --git a/pkg/runtime/transaction/xa.go b/pkg/runtime/transaction/xa.go index aaef7c5d..ddfb9639 100644 --- a/pkg/runtime/transaction/xa.go +++ b/pkg/runtime/transaction/xa.go @@ -30,6 +30,7 @@ import ( ) var ErrorInvalidTxId = errors.New("invalid transaction id") +var ErrorInvalidTxStatus = errors.New("invalid transaction status") // StartXA do start xa transaction action func StartXA(ctx context.Context, bc *mysql.BackendConnection) (proto.Result, error) { @@ -37,7 +38,9 @@ func StartXA(ctx context.Context, bc *mysql.BackendConnection) (proto.Result, er if len(txId) == 0 { return nil, ErrorInvalidTxId } - + if rcontext.TransactionStatus(ctx) != rcontext.TrxStarted { + return nil, ErrorInvalidTxStatus + } return bc.ExecuteWithWarningCount(fmt.Sprintf("XA START '%s'", txId), false) } diff --git a/pkg/runtime/tx.go b/pkg/runtime/tx.go index 52722caa..edcc1240 100644 --- a/pkg/runtime/tx.go +++ b/pkg/runtime/tx.go @@ -57,22 +57,6 @@ var ( _ proto.VersionSupport = (*compositeTx)(nil) ) -// TxState Transaction status -type TxState int32 - -const ( - _ TxState = iota - TrxActive // CompositeTx Default state - TrxPreparing // Start executing the first SQL statement - TrxPrepared // All SQL statements are executed, and before the Commit statement executes - TrxCommitting // After preparing is completed, ready to start execution - TrxCommitted // Officially complete the Commit action - TrxAborting // There are abnormalities during the execution of the branch, and the composite transaction is prohibited to continue to execute - TrxRollback - TrxFinish - TrxRolledBack -) - // CompositeTx distribute transaction type ( // CompositeTx distribute transaction @@ -82,7 +66,11 @@ type ( // GetTenant get cur tx owner tenant GetTenant() string // GetTxState get cur tx state - GetTxState() TxState + GetTxState() rcontext.TxState + // GetExpectedEndTime + GetStartTime() time.Time + // GetExpectedEndTime get cur tx expected end time + GetExpectedEndTime() time.Time // SetBeginFunc sets begin func SetBeginFunc(f dbFunc) // Range range branchTx map @@ -104,7 +92,7 @@ type ( // GetConn gets mysql connection GetConn() *mysql.BackendConnection // GetTxState get cur tx state - GetTxState() TxState + GetTxState() rcontext.TxState // Commit commit tx Commit(ctx context.Context) (res proto.Result, warn uint16, err error) // Rollback rollback tx @@ -114,7 +102,7 @@ type ( // TxHook transaction hook TxHook interface { // OnTxStateChange Fired when CompositeTx TrxState change - OnTxStateChange(ctx context.Context, state TxState, tx CompositeTx) error + OnTxStateChange(ctx context.Context, state rcontext.TxState, tx CompositeTx) error // OnCreateBranchTx Fired when BranchTx create OnCreateBranchTx(ctx context.Context, tx BranchTx) } @@ -131,18 +119,22 @@ type ( ) func newCompositeTx(ctx context.Context, pi *defaultRuntime, hooks ...TxHook) *compositeTx { + now := time.Now() tx := &compositeTx{ - tenant: rcontext.Tenant(ctx), - id: gtid.NewID(), - rt: pi, - txs: make(map[string]*branchTx), - hooks: hooks, + tenant: rcontext.Tenant(ctx), + id: gtid.NewID(), + rt: pi, + txs: make(map[string]*branchTx), + beginTime: now, + // TODO: set expected end time from config, it is assumed here that the timeout of a global transaction is 30 seconds + expectedEndTime: now.Add(time.Second * 30), + hooks: hooks, beginFunc: func(ctx context.Context, bc *mysql.BackendConnection) (proto.Result, error) { return bc.ExecuteWithWarningCount("begin", true) }, } - tx.setTxState(ctx, TrxActive) + tx.setTxState(ctx, rcontext.TrxStarted) return tx } @@ -152,11 +144,12 @@ type compositeTx struct { closed atomic.Bool id gtid.ID - beginTime time.Time - endTime time.Time + beginTime time.Time + expectedEndTime time.Time + endTime time.Time isoLevel sql.IsolationLevel - txState TxState + txState rcontext.TxState beginFunc dbFunc @@ -174,6 +167,14 @@ func (tx *compositeTx) GetTenant() string { return tx.tenant } +func (tx *compositeTx) GetExpectedEndTime() time.Time { + return tx.expectedEndTime +} + +func (tx *compositeTx) GetStartTime() time.Time { + return tx.beginTime +} + func (tx *compositeTx) Version(ctx context.Context) (string, error) { return tx.rt.Version(ctx) } @@ -317,7 +318,7 @@ func (tx *compositeTx) Commit(ctx context.Context) (proto.Result, uint16, error) } func (tx *compositeTx) doPrepareCommit(ctx context.Context) error { - tx.setTxState(ctx, TrxPreparing) + tx.setTxState(ctx, rcontext.TrxPreparing) var g errgroup.Group for k, v := range tx.txs { @@ -331,16 +332,16 @@ func (tx *compositeTx) doPrepareCommit(ctx context.Context) error { }) } if err := g.Wait(); err != nil { - tx.setTxState(ctx, TrxAborting) + tx.setTxState(ctx, rcontext.TrxAborted) return err } - tx.setTxState(ctx, TrxPrepared) + tx.setTxState(ctx, rcontext.TrxPrepared) return nil } func (tx *compositeTx) doCommit(ctx context.Context) error { - tx.setTxState(ctx, TrxCommitting) + tx.setTxState(ctx, rcontext.TrxCommitting) var g errgroup.Group for k, v := range tx.txs { @@ -358,7 +359,7 @@ func (tx *compositeTx) doCommit(ctx context.Context) error { return err } - tx.setTxState(ctx, TrxCommitted) + tx.setTxState(ctx, rcontext.TrxCommitted) return nil } @@ -386,7 +387,7 @@ func (tx *compositeTx) Rollback(ctx context.Context) (proto.Result, uint16, erro } func (tx *compositeTx) doPrepareRollback(ctx context.Context) error { - tx.setTxState(ctx, TrxPreparing) + tx.setTxState(ctx, rcontext.TrxPreparing) var g errgroup.Group for k, v := range tx.txs { @@ -401,15 +402,15 @@ func (tx *compositeTx) doPrepareRollback(ctx context.Context) error { } if err := g.Wait(); err != nil { - tx.setTxState(ctx, TrxAborting) + tx.setTxState(ctx, rcontext.TrxAborted) return err } - tx.setTxState(ctx, TrxPrepared) + tx.setTxState(ctx, rcontext.TrxPrepared) return nil } func (tx *compositeTx) doRollback(ctx context.Context) error { - tx.setTxState(ctx, TrxRollback) + tx.setTxState(ctx, rcontext.TrxRolledBacking) var g errgroup.Group for k, v := range tx.txs { @@ -426,7 +427,7 @@ func (tx *compositeTx) doRollback(ctx context.Context) error { if err := g.Wait(); err != nil { return err } - tx.setTxState(ctx, TrxRolledBack) + tx.setTxState(ctx, rcontext.TrxRolledBacked) return nil } @@ -437,11 +438,11 @@ func (tx *compositeTx) Range(f func(tx BranchTx)) { } } -func (tx *compositeTx) GetTxState() TxState { +func (tx *compositeTx) GetTxState() rcontext.TxState { return tx.txState } -func (tx *compositeTx) setTxState(ctx context.Context, state TxState) { +func (tx *compositeTx) setTxState(ctx context.Context, state rcontext.TxState) { tx.txState = state for i := range tx.hooks { if err := tx.hooks[i].OnTxStateChange(ctx, state, tx); err != nil { @@ -456,7 +457,7 @@ type branchTx struct { closed atomic.Bool parent *AtomDB - state TxState + state rcontext.TxState prepare dbFunc commit dbFunc @@ -481,12 +482,12 @@ func newBranchTx(parent *AtomDB, bc *mysql.BackendConnection) *branchTx { } // GetTxState get cur tx state -func (tx *branchTx) GetTxState() TxState { +func (tx *branchTx) GetTxState() rcontext.TxState { return tx.state } func (tx *branchTx) Commit(ctx context.Context) (res proto.Result, warn uint16, err error) { - tx.state = TrxCommitting + tx.state = rcontext.TrxCommitting _ = ctx if !tx.closed.CAS(false, true) { err = errTxClosed @@ -494,7 +495,7 @@ func (tx *branchTx) Commit(ctx context.Context) (res proto.Result, warn uint16, } defer tx.dispose() if res, err = tx.commit(ctx, tx.bc); err != nil { - tx.state = TrxAborting + tx.state = rcontext.TrxAborted return } @@ -508,14 +509,14 @@ func (tx *branchTx) Commit(ctx context.Context) (res proto.Result, warn uint16, } res = resultx.New(resultx.WithRowsAffected(affected), resultx.WithLastInsertID(lastInsertId)) - tx.state = TrxCommitted + tx.state = rcontext.TrxCommitted return } func (tx *branchTx) Prepare(ctx context.Context) error { - tx.state = TrxPreparing + tx.state = rcontext.TrxPreparing _, err := tx.prepare(ctx, tx.bc) - tx.state = TrxPrepared + tx.state = rcontext.TrxPrepared return err } @@ -525,9 +526,9 @@ func (tx *branchTx) Rollback(ctx context.Context) (res proto.Result, warn uint16 return } defer tx.dispose() - tx.state = TrxRollback + tx.state = rcontext.TrxRolledBacking res, err = tx.rollback(ctx, tx.bc) - tx.state = TrxRolledBack + tx.state = rcontext.TrxRolledBacked return } @@ -581,7 +582,7 @@ func (tx *branchTx) GetConn() *mysql.BackendConnection { return tx.bc } -func NumOfStateBranchTx(state TxState, tx CompositeTx) int32 { +func NumOfStateBranchTx(state rcontext.TxState, tx CompositeTx) int32 { cnt := int32(0) tx.Range(func(bTx BranchTx) { if bTx.GetTxState() == state { diff --git a/pkg/runtime/tx_test.go b/pkg/runtime/tx_test.go index 19832e1e..0050cde2 100644 --- a/pkg/runtime/tx_test.go +++ b/pkg/runtime/tx_test.go @@ -21,6 +21,7 @@ import ( "context" "database/sql" "fmt" + rcontext "github.com/arana-db/arana/pkg/runtime/context" "testing" "time" ) @@ -41,7 +42,7 @@ func Test_branchTx_CallFieldList(t *testing.T) { type fields struct { closed atomic.Bool parent *AtomDB - state TxState + state rcontext.TxState prepare dbFunc commit dbFunc rollback dbFunc @@ -89,7 +90,7 @@ func Test_compositeTx_Rollback(t *testing.T) { beginTime time.Time endTime time.Time isoLevel sql.IsolationLevel - txState TxState + txState rcontext.TxState beginFunc dbFunc rt *defaultRuntime txs map[string]*branchTx