diff --git a/e2e/gorm_test.go b/e2e/gorm_test.go index bcfe94c..bc106eb 100644 --- a/e2e/gorm_test.go +++ b/e2e/gorm_test.go @@ -19,7 +19,6 @@ package e2e import ( "context" "fmt" - "log" "sync" "testing" "time" @@ -40,38 +39,58 @@ type GormMQSuite struct { db *gorm.DB } -func (g *GormMQSuite) SetupTest() { - db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{ - //Logger: logger.Default.LogMode(logger.Info), - }) - require.NoError(g.T(), err) - g.db = db -} -func (g *GormMQSuite) TearDownTest() { - sqlDB, err := g.db.DB() - require.NoError(g.T(), err) - _, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`") - require.NoError(g.T(), err) - _, err = sqlDB.Exec("CREATE DATABASE `test`") - require.NoError(g.T(), err) -} +//func (g *GormMQSuite) SetupTest() { +// db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{ +// //Logger: logger.Default.LogMode(logger.Info), +// }) +// require.NoError(g.T(), err) +// g.db = db +//} + +//func (g *GormMQSuite) TearDownTest() { +// sqlDB, err := g.db.DB() +// require.NoError(g.T(), err) +// _, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`") +// require.NoError(g.T(), err) +// _, err = sqlDB.Exec("CREATE DATABASE `test`") +// require.NoError(g.T(), err) +// +//} func (g *GormMQSuite) TestTopic() { testcases := []struct { - name string - topic string - input int - wantVal []string + name string + topic string + input int + wantVal []string + beforeFunc func() + afterFunc func() }{ { name: "建立含有4个分区的topic", topic: "test_topic", input: 4, wantVal: []string{"test_topic_0", "test_topic_1", "test_topic_2", "test_topic_3"}, + beforeFunc: func() { + db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{ + //Logger: logger.Default.LogMode(logger.Info), + }) + require.NoError(g.T(), err) + g.db = db + }, + afterFunc: func() { + sqlDB, err := g.db.DB() + require.NoError(g.T(), err) + _, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`") + require.NoError(g.T(), err) + _, err = sqlDB.Exec("CREATE DATABASE `test`") + require.NoError(g.T(), err) + }, }, } for _, tc := range testcases { g.T().Run(tc.name, func(t *testing.T) { + tc.beforeFunc() mq, err := gorm_mq.NewMq(g.db) require.NoError(t, err) err = mq.Topic(tc.topic, 4) @@ -79,6 +98,7 @@ func (g *GormMQSuite) TestTopic() { tables, err := g.getTables(tc.topic) require.NoError(t, err) assert.Equal(t, tc.wantVal, tables) + tc.afterFunc() }) } @@ -94,6 +114,8 @@ func (g *GormMQSuite) TestConsumer() { // 处理消息 consumerFunc func(c mq.Consumer) []*mq.Message wantVal []*mq.Message + beforeFunc func() + afterFunc func() }{ { name: "一个消费组内多个消费者", @@ -170,6 +192,21 @@ func (g *GormMQSuite) TestConsumer() { Topic: "test_topic", }, }, + beforeFunc: func() { + db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{ + //Logger: logger.Default.LogMode(logger.Info), + }) + require.NoError(g.T(), err) + g.db = db + }, + afterFunc: func() { + sqlDB, err := g.db.DB() + require.NoError(g.T(), err) + _, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`") + require.NoError(g.T(), err) + _, err = sqlDB.Exec("CREATE DATABASE `test`") + require.NoError(g.T(), err) + }, }, { name: "多个消费组,多个消费者", @@ -280,10 +317,26 @@ func (g *GormMQSuite) TestConsumer() { Topic: "test_topic", }, }, + beforeFunc: func() { + db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{ + //Logger: logger.Default.LogMode(logger.Info), + }) + require.NoError(g.T(), err) + g.db = db + }, + afterFunc: func() { + sqlDB, err := g.db.DB() + require.NoError(g.T(), err) + _, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`") + require.NoError(g.T(), err) + _, err = sqlDB.Exec("CREATE DATABASE `test`") + require.NoError(g.T(), err) + }, }, } for _, tc := range testcases { g.T().Run(tc.name, func(t *testing.T) { + tc.beforeFunc() gormMq, err := gorm_mq.NewMq(g.db) require.NoError(t, err) err = gormMq.Topic(tc.topic, int(tc.partitions)) @@ -293,26 +346,29 @@ func (g *GormMQSuite) TestConsumer() { consumers := tc.consumers(gormMq) ans := make([]*mq.Message, 0, len(tc.wantVal)) var wg sync.WaitGroup + locker := sync.RWMutex{} for _, c := range consumers { newc := c wg.Add(1) go func() { defer wg.Done() msgs := tc.consumerFunc(newc) + locker.Lock() ans = append(ans, msgs...) + locker.Unlock() }() } for _, msg := range tc.input { _, err := p.Produce(context.Background(), msg) require.NoError(t, err) } - time.Sleep(10 * time.Second) + time.Sleep(5 * time.Second) err = gormMq.Close() require.NoError(t, err) wg.Wait() assert.ElementsMatch(t, tc.wantVal, ans) // 清理测试环境 - g.TearDownTest() + tc.afterFunc() }) } } @@ -328,6 +384,8 @@ func (g *GormMQSuite) TestConsumer_Sort() { // 处理消息 consumerFunc func(c mq.Consumer) []*mq.Message wantVal []*mq.Message + beforeFunc func() + afterFunc func() }{ { name: "消息有序", @@ -431,10 +489,26 @@ func (g *GormMQSuite) TestConsumer_Sort() { Topic: "test_topic", }, }, + beforeFunc: func() { + db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{ + //Logger: logger.Default.LogMode(logger.Info), + }) + require.NoError(g.T(), err) + g.db = db + }, + afterFunc: func() { + sqlDB, err := g.db.DB() + require.NoError(g.T(), err) + _, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`") + require.NoError(g.T(), err) + _, err = sqlDB.Exec("CREATE DATABASE `test`") + require.NoError(g.T(), err) + }, }, } for _, tc := range testcases { g.T().Run(tc.name, func(t *testing.T) { + tc.beforeFunc() gormMq, err := gorm_mq.NewMq(g.db) require.NoError(t, err) err = gormMq.Topic(tc.topic, int(tc.partitions)) @@ -444,20 +518,23 @@ func (g *GormMQSuite) TestConsumer_Sort() { consumers := tc.consumers(gormMq) ans := make([]*mq.Message, 0, len(tc.wantVal)) var wg sync.WaitGroup + locker := sync.RWMutex{} for _, c := range consumers { newc := c wg.Add(1) go func() { defer wg.Done() msgs := tc.consumerFunc(newc) + locker.Lock() ans = append(ans, msgs...) + locker.Unlock() }() } for _, msg := range tc.input { _, err := p.Produce(context.Background(), msg) require.NoError(t, err) } - time.Sleep(10 * time.Second) + time.Sleep(5 * time.Second) err = gormMq.Close() require.NoError(t, err) wg.Wait() @@ -465,7 +542,7 @@ func (g *GormMQSuite) TestConsumer_Sort() { actualMap := getMsgMap(ans) assert.Equal(t, wantMap, actualMap) // 清理测试环境 - g.TearDownTest() + tc.afterFunc() }) } } @@ -520,9 +597,3 @@ func TestMq(t *testing.T) { dsn: "root:root@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local", }) } - -func msgLog(msgs []*mq.Message) { - for i := 0; i < len(msgs); i++ { - log.Println(string(msgs[i].Key), string(msgs[i].Value)) - } -} diff --git a/gorm_mq/consumer.go b/gorm_mq/consumer.go index 19c845f..33d3753 100644 --- a/gorm_mq/consumer.go +++ b/gorm_mq/consumer.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "sync" "time" "github.com/ecodeclub/mq-api" @@ -35,6 +36,8 @@ type MqConsumer struct { msgCh chan *mq.Message name string groupId string + locker sync.RWMutex + corsor int64 // 每次至多消费多少 limit int // 抢占超时时间 @@ -43,6 +46,11 @@ type MqConsumer struct { interval time.Duration } +type msgRes struct { + msgs []*domain.Partition + err error +} + func (m *MqConsumer) Consume(ctx context.Context) (*mq.Message, error) { select { case <-ctx.Done(): @@ -58,7 +66,8 @@ func (m *MqConsumer) ConsumeMsgCh(ctx context.Context) (<-chan *mq.Message, erro func (m *MqConsumer) getMsgFromDB(ctx context.Context) ([]*mq.Message, error) { ans := make([]*mq.Message, 0, 64) - for _, p := range m.partitions { + partitions := m.getPartition() + for _, p := range partitions { // 抢占获取表的游标,如果当前 tableName := fmt.Sprintf("%s_%d", m.topic.Name, p) cursor, err := m.occupyCursor(tableName, m.groupId) @@ -245,7 +254,16 @@ func (m *MqConsumer) releaseCursor(tableName string, id string, cursor int64) er } -type msgRes struct { - msgs []*domain.Partition - err error +// 获取分区 +func (m *MqConsumer) getPartition() []int { + m.locker.Lock() + defer m.locker.Unlock() + return m.partitions +} + +// 设置分区 +func (m *MqConsumer) setPartition(partitions []int) { + m.locker.Lock() + defer m.locker.Unlock() + m.partitions = partitions } diff --git a/gorm_mq/mq.go b/gorm_mq/mq.go index aab9a62..412dfcf 100644 --- a/gorm_mq/mq.go +++ b/gorm_mq/mq.go @@ -18,18 +18,56 @@ import ( "context" "errors" "fmt" + "github.com/ecodeclub/mq-sql/gorm_mq/balancer/equal_divide" "log" "sync" "time" "github.com/ecodeclub/ekit/syncx" "github.com/ecodeclub/mq-api" - "github.com/ecodeclub/mq-sql/gorm_mq/balancer/equal_divide" "github.com/ecodeclub/mq-sql/gorm_mq/domain" "github.com/google/uuid" "gorm.io/gorm" ) +func NewMq(Db *gorm.DB, opts ...MqOption) (mq.MQ, error) { + err := Db.AutoMigrate(&domain.Cursors{}) + if err != nil { + return nil, err + } + m := &Mq{ + Db: Db, + topics: syncx.Map[string, *Topic]{}, + consumerBalancer: equal_divide.NewBalancer(), + producerGetter: NewGetter, + limit: 20, + timeout: 10 * time.Second, + interval: 2 * time.Second, + } + for _, opt := range opts { + opt(m) + } + return m, nil +} + +func WithLimit(limit int) MqOption { + return func(m *Mq) { + m.limit = limit + } +} + +func WithTimeout(timeout time.Duration) MqOption { + return func(m *Mq) { + m.timeout = timeout + } +} + +func WithInterval(interval time.Duration) MqOption { + return func(m *Mq) { + m.interval = interval + } +} + type Mq struct { Db *gorm.DB producerGetter NewProducerGetter @@ -82,16 +120,17 @@ func (m *Mq) Topic(name string, partition int) error { } func (tp *Topic) Close() error { + tp.lock.Lock() + defer tp.lock.Unlock() tp.once.Do(func() { - tp.lock.Lock() for _, ch := range tp.msgCh { close(ch) } for _, ch := range tp.closeChs { close(ch) } - tp.lock.Unlock() }) + return nil } @@ -149,9 +188,9 @@ func (m *Mq) Consumer(topic string, id string) (mq.Consumer, error) { } res := m.consumerBalancer.Balance(tp.partitionNum, len(consumers)+1) for i := 0; i < len(consumers); i++ { - consumers[i].partitions = res[i] + consumers[i].setPartition(res[i]) } - mqConsumer.partitions = res[len(consumers)] + mqConsumer.setPartition(res[len(consumers)]) consumers = append(consumers, mqConsumer) tp.consumerGroups[id] = consumers tp.closeChs = append(tp.closeChs, closeCh) @@ -181,41 +220,3 @@ func (m *Mq) Consumer(topic string, id string) (mq.Consumer, error) { }() return mqConsumer, nil } - -func NewMq(Db *gorm.DB, opts ...MqOption) (mq.MQ, error) { - err := Db.AutoMigrate(&domain.Cursors{}) - if err != nil { - return nil, err - } - m := &Mq{ - Db: Db, - topics: syncx.Map[string, *Topic]{}, - consumerBalancer: equal_divide.NewBalancer(), - producerGetter: NewGetter, - limit: 20, - timeout: 10 * time.Second, - interval: 2 * time.Second, - } - for _, opt := range opts { - opt(m) - } - return m, nil -} - -func WithLimit(limit int) MqOption { - return func(m *Mq) { - m.limit = limit - } -} - -func WithTimeout(timeout time.Duration) MqOption { - return func(m *Mq) { - m.timeout = timeout - } -} - -func WithInterval(interval time.Duration) MqOption { - return func(m *Mq) { - m.interval = interval - } -}