Skip to content

Commit

Permalink
make check
Browse files Browse the repository at this point in the history
  • Loading branch information
juniaoshaonian committed Sep 28, 2023
1 parent 29aa62b commit a894214
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 77 deletions.
131 changes: 101 additions & 30 deletions e2e/gorm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package e2e
import (
"context"
"fmt"
"log"
"sync"
"testing"
"time"
Expand All @@ -40,45 +39,66 @@ type GormMQSuite struct {
db *gorm.DB
}

func (g *GormMQSuite) SetupTest() {
db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
require.NoError(g.T(), err)
g.db = db
}
func (g *GormMQSuite) TearDownTest() {
sqlDB, err := g.db.DB()
require.NoError(g.T(), err)
_, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`")
require.NoError(g.T(), err)
_, err = sqlDB.Exec("CREATE DATABASE `test`")
require.NoError(g.T(), err)
}
//func (g *GormMQSuite) SetupTest() {
// db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{
// //Logger: logger.Default.LogMode(logger.Info),
// })
// require.NoError(g.T(), err)
// g.db = db
//}

//func (g *GormMQSuite) TearDownTest() {
// sqlDB, err := g.db.DB()
// require.NoError(g.T(), err)
// _, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`")
// require.NoError(g.T(), err)
// _, err = sqlDB.Exec("CREATE DATABASE `test`")
// require.NoError(g.T(), err)
//
//}

func (g *GormMQSuite) TestTopic() {
testcases := []struct {
name string
topic string
input int
wantVal []string
name string
topic string
input int
wantVal []string
beforeFunc func()
afterFunc func()
}{
{
name: "建立含有4个分区的topic",
topic: "test_topic",
input: 4,
wantVal: []string{"test_topic_0", "test_topic_1", "test_topic_2", "test_topic_3"},
beforeFunc: func() {
db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
require.NoError(g.T(), err)
g.db = db
},
afterFunc: func() {
sqlDB, err := g.db.DB()
require.NoError(g.T(), err)
_, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`")
require.NoError(g.T(), err)
_, err = sqlDB.Exec("CREATE DATABASE `test`")
require.NoError(g.T(), err)
},
},
}
for _, tc := range testcases {
g.T().Run(tc.name, func(t *testing.T) {
tc.beforeFunc()
mq, err := gorm_mq.NewMq(g.db)
require.NoError(t, err)
err = mq.Topic(tc.topic, 4)
require.NoError(t, err)
tables, err := g.getTables(tc.topic)
require.NoError(t, err)
assert.Equal(t, tc.wantVal, tables)
tc.afterFunc()
})
}

Expand All @@ -94,6 +114,8 @@ func (g *GormMQSuite) TestConsumer() {
// 处理消息
consumerFunc func(c mq.Consumer) []*mq.Message
wantVal []*mq.Message
beforeFunc func()
afterFunc func()
}{
{
name: "一个消费组内多个消费者",
Expand Down Expand Up @@ -170,6 +192,21 @@ func (g *GormMQSuite) TestConsumer() {
Topic: "test_topic",
},
},
beforeFunc: func() {
db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
require.NoError(g.T(), err)
g.db = db
},
afterFunc: func() {
sqlDB, err := g.db.DB()
require.NoError(g.T(), err)
_, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`")
require.NoError(g.T(), err)
_, err = sqlDB.Exec("CREATE DATABASE `test`")
require.NoError(g.T(), err)
},
},
{
name: "多个消费组,多个消费者",
Expand Down Expand Up @@ -280,10 +317,26 @@ func (g *GormMQSuite) TestConsumer() {
Topic: "test_topic",
},
},
beforeFunc: func() {
db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
require.NoError(g.T(), err)
g.db = db
},
afterFunc: func() {
sqlDB, err := g.db.DB()
require.NoError(g.T(), err)
_, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`")
require.NoError(g.T(), err)
_, err = sqlDB.Exec("CREATE DATABASE `test`")
require.NoError(g.T(), err)
},
},
}
for _, tc := range testcases {
g.T().Run(tc.name, func(t *testing.T) {
tc.beforeFunc()
gormMq, err := gorm_mq.NewMq(g.db)
require.NoError(t, err)
err = gormMq.Topic(tc.topic, int(tc.partitions))
Expand All @@ -293,26 +346,29 @@ func (g *GormMQSuite) TestConsumer() {
consumers := tc.consumers(gormMq)
ans := make([]*mq.Message, 0, len(tc.wantVal))
var wg sync.WaitGroup
locker := sync.RWMutex{}
for _, c := range consumers {
newc := c
wg.Add(1)
go func() {
defer wg.Done()
msgs := tc.consumerFunc(newc)
locker.Lock()
ans = append(ans, msgs...)
locker.Unlock()
}()
}
for _, msg := range tc.input {
_, err := p.Produce(context.Background(), msg)
require.NoError(t, err)
}
time.Sleep(10 * time.Second)
time.Sleep(5 * time.Second)
err = gormMq.Close()
require.NoError(t, err)
wg.Wait()
assert.ElementsMatch(t, tc.wantVal, ans)
// 清理测试环境
g.TearDownTest()
tc.afterFunc()
})
}
}
Expand All @@ -328,6 +384,8 @@ func (g *GormMQSuite) TestConsumer_Sort() {
// 处理消息
consumerFunc func(c mq.Consumer) []*mq.Message
wantVal []*mq.Message
beforeFunc func()
afterFunc func()
}{
{
name: "消息有序",
Expand Down Expand Up @@ -431,10 +489,26 @@ func (g *GormMQSuite) TestConsumer_Sort() {
Topic: "test_topic",
},
},
beforeFunc: func() {
db, err := gorm.Open(mysql.Open(g.dsn), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
require.NoError(g.T(), err)
g.db = db
},
afterFunc: func() {
sqlDB, err := g.db.DB()
require.NoError(g.T(), err)
_, err = sqlDB.Exec("DROP DATABASE IF EXISTS `test`")
require.NoError(g.T(), err)
_, err = sqlDB.Exec("CREATE DATABASE `test`")
require.NoError(g.T(), err)
},
},
}
for _, tc := range testcases {
g.T().Run(tc.name, func(t *testing.T) {
tc.beforeFunc()
gormMq, err := gorm_mq.NewMq(g.db)
require.NoError(t, err)
err = gormMq.Topic(tc.topic, int(tc.partitions))
Expand All @@ -444,28 +518,31 @@ func (g *GormMQSuite) TestConsumer_Sort() {
consumers := tc.consumers(gormMq)
ans := make([]*mq.Message, 0, len(tc.wantVal))
var wg sync.WaitGroup
locker := sync.RWMutex{}
for _, c := range consumers {
newc := c
wg.Add(1)
go func() {
defer wg.Done()
msgs := tc.consumerFunc(newc)
locker.Lock()
ans = append(ans, msgs...)
locker.Unlock()
}()
}
for _, msg := range tc.input {
_, err := p.Produce(context.Background(), msg)
require.NoError(t, err)
}
time.Sleep(10 * time.Second)
time.Sleep(5 * time.Second)
err = gormMq.Close()
require.NoError(t, err)
wg.Wait()
wantMap := getMsgMap(tc.wantVal)
actualMap := getMsgMap(ans)
assert.Equal(t, wantMap, actualMap)
// 清理测试环境
g.TearDownTest()
tc.afterFunc()
})
}
}
Expand Down Expand Up @@ -520,9 +597,3 @@ func TestMq(t *testing.T) {
dsn: "root:root@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local",
})
}

func msgLog(msgs []*mq.Message) {
for i := 0; i < len(msgs); i++ {
log.Println(string(msgs[i].Key), string(msgs[i].Value))
}
}
26 changes: 22 additions & 4 deletions gorm_mq/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"sync"
"time"

"github.com/ecodeclub/mq-api"
Expand All @@ -35,6 +36,8 @@ type MqConsumer struct {
msgCh chan *mq.Message
name string
groupId string
locker sync.RWMutex
corsor int64
// 每次至多消费多少
limit int
// 抢占超时时间
Expand All @@ -43,6 +46,11 @@ type MqConsumer struct {
interval time.Duration
}

type msgRes struct {
msgs []*domain.Partition
err error
}

func (m *MqConsumer) Consume(ctx context.Context) (*mq.Message, error) {
select {
case <-ctx.Done():
Expand All @@ -58,7 +66,8 @@ func (m *MqConsumer) ConsumeMsgCh(ctx context.Context) (<-chan *mq.Message, erro

func (m *MqConsumer) getMsgFromDB(ctx context.Context) ([]*mq.Message, error) {
ans := make([]*mq.Message, 0, 64)
for _, p := range m.partitions {
partitions := m.getPartition()
for _, p := range partitions {
// 抢占获取表的游标,如果当前
tableName := fmt.Sprintf("%s_%d", m.topic.Name, p)
cursor, err := m.occupyCursor(tableName, m.groupId)
Expand Down Expand Up @@ -245,7 +254,16 @@ func (m *MqConsumer) releaseCursor(tableName string, id string, cursor int64) er

}

type msgRes struct {
msgs []*domain.Partition
err error
// 获取分区
func (m *MqConsumer) getPartition() []int {
m.locker.Lock()
defer m.locker.Unlock()
return m.partitions
}

// 设置分区
func (m *MqConsumer) setPartition(partitions []int) {
m.locker.Lock()
defer m.locker.Unlock()
m.partitions = partitions
}
Loading

0 comments on commit a894214

Please sign in to comment.