Skip to content

Commit

Permalink
fix race data timer close bug
Browse files Browse the repository at this point in the history
  • Loading branch information
bobohume committed Oct 18, 2022
1 parent 0dc6050 commit 1ab2073
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 49 deletions.
44 changes: 25 additions & 19 deletions actor/Actor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"strings"
"sync/atomic"
"time"
"unsafe"
)

var (
Expand All @@ -26,9 +25,9 @@ const (
ASF_STOP = iota //已经关闭
)

//********************************************************
// ********************************************************
// actor 核心actor模式
//********************************************************
// ********************************************************
type (
ActorBase struct {
actorName string
Expand All @@ -48,7 +47,8 @@ type (
mailIn [8]int64
mailChan chan bool
timerId *int64
pool IActorPool //ACTOR_TYPE_VIRTUAL,ACTOR_TYPE_POOL
pool IActorPool //ACTOR_TYPE_VIRTUAL,ACTOR_TYPE_POOL
timerMap map[uintptr]func() //成员方法转func()会是闭包函数,定时器释放会有问题
}

IActor interface {
Expand Down Expand Up @@ -151,11 +151,13 @@ func (a *Actor) Init() {
a.mailChan = make(chan bool, 1)
a.mailBox = mpsc.New()
a.acotrChan = make(chan int, 1)
a.timerMap = make(map[uintptr]func())
//trance
a.trace.Init()
if a.id == 0 {
a.id = AssignActorId()
}
a.timerId = new(int64)
}

func (a *Actor) register(ac IActor, op Op) {
Expand All @@ -164,13 +166,12 @@ func (a *Actor) register(ac IActor, op Op) {
}

func (a *Actor) RegisterTimer(duration time.Duration, fun func(), opts ...timer.OpOption) {
if a.timerId == nil {
a.timerId = new(int64)
*a.timerId = a.id
}

timer.StoreTimerId(a.timerId, a.id)
//&fun这里有问题,会产生一对闭包函数,再释放的释放有问题
ptr := uintptr(reflect.ValueOf(fun).Pointer())
a.timerMap[ptr] = fun
timer.RegisterTimer(a.timerId, duration, func() {
a.SendMsg(rpc.RpcHead{ActorName: a.actorName}, "UpdateTimer", (*int64)(unsafe.Pointer(&fun)))
a.SendMsg(rpc.RpcHead{ActorName: a.actorName}, "UpdateTimer", ptr)
}, opts...)
}

Expand All @@ -183,9 +184,12 @@ func (a *Actor) clear() {
}

func (a *Actor) Stop() {
if atomic.CompareAndSwapInt32(&a.state, ASF_RUN, ASF_STOP) {
a.acotrChan <- DESDORY_EVENT
}
timer.RegisterTimer(a.timerId, timer.TICK_INTERVAL, func() {
timer.StopTimer(a.timerId)
if atomic.CompareAndSwapInt32(&a.state, ASF_RUN, ASF_STOP) {
a.acotrChan <- DESDORY_EVENT
}
})
}

func (a *Actor) Start() {
Expand Down Expand Up @@ -225,7 +229,7 @@ func (a *Actor) call(io CallIO) {
head := io.RpcHead
funcName := rpcPacket.FuncName
m, bEx := a.rType.MethodByName(funcName)
if !bEx{
if !bEx {
log.Printf("func [%s] has no method", funcName)
return
}
Expand Down Expand Up @@ -254,11 +258,13 @@ func (a *Actor) call(io CallIO) {
}
}

func (a *Actor) UpdateTimer(ctx context.Context, p *int64) {
func1 := (*func())(unsafe.Pointer(p))
a.Trace("timer")
(*func1)()
a.Trace("")
func (a *Actor) UpdateTimer(ctx context.Context, ptr uintptr) {
fun, isEx := a.timerMap[ptr]
if isEx {
a.Trace("timer")
(fun)()
a.Trace("")
}
}

func (a *Actor) consume() {
Expand Down
48 changes: 26 additions & 22 deletions common/timer/Timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ const (
TICK_INTERVAL = 10 * time.Millisecond
)

//先搞清楚下面的单位
//1秒=1000毫秒 milliseconds
//1毫秒=1000微秒 microseconds
//1微秒=1000纳秒 nanoseconds
//整个timer中毫秒的精度都是10ms,
//也就是说毫秒的一个三个位,但是最小的位被丢弃
// 先搞清楚下面的单位
// 1秒=1000毫秒 milliseconds
// 1毫秒=1000微秒 microseconds
// 1微秒=1000纳秒 nanoseconds
// 整个timer中毫秒的精度都是10ms,
// 也就是说毫秒的一个三个位,但是最小的位被丢弃
type (
TimerHandle func()
TimerNode struct {
Expand Down Expand Up @@ -69,6 +69,10 @@ func (t *TimerNode) LoadId() int64 {
return atomic.LoadInt64(t.id)
}

func StoreTimerId(id *int64, val int64) bool {
return atomic.LoadInt64(id) == 0 && atomic.CompareAndSwapInt64(id, 0, val)
}

func (op *Op) applyOpts(opts []OpOption) {
for _, opt := range opts {
opt(op)
Expand All @@ -90,22 +94,22 @@ func uuid() int64 {
return atomic.AddInt64(&g_Id, 1)
}

//清空链表,返回链表第一个结点
// 清空链表,返回链表第一个结点
func linkClear(list *LinkList) *TimerNode {
ret := list.head.next
list.head.next = nil
list.tail = &list.head
return ret
}

//将结点放入链表
// 将结点放入链表
func link(list *LinkList, node *TimerNode) {
list.tail.next = node
list.tail = node
node.next = nil
}

//创建一个定时器
// 创建一个定时器
func (t *Timer) Init() {
for i := 0; i < TIME_NEAR; i++ {
linkClear(&t.near[i])
Expand All @@ -123,7 +127,7 @@ func (t *Timer) Init() {
go t.run()
}

//添加一个定时器结点
// 添加一个定时器结点
func (t *Timer) addNode(node *TimerNode) {
time := node.expire //去看一下它是在哪赋值的
current_time := t.time //当前计数
Expand All @@ -144,7 +148,7 @@ func (t *Timer) addNode(node *TimerNode) {
}
}

//添加一个定时器
// 添加一个定时器
func (t *Timer) Add(id *int64, time uint32, handle TimerHandle, opts ...OpOption) *TimerNode {
op := Op{}
op.applyOpts(opts)
Expand All @@ -156,12 +160,12 @@ func (t *Timer) Add(id *int64, time uint32, handle TimerHandle, opts ...OpOption
return node
}

//删除一个定时器
// 删除一个定时器
func (t *Timer) Delete(id *int64) {
atomic.StoreInt64(id, 0)
atomic.StoreInt64(id, -1)
}

//移动某个级别的链表内容
// 移动某个级别的链表内容
func (t *Timer) moveList(level int, idx int) {
current := linkClear(&t.t[level][idx])
for current != nil {
Expand All @@ -171,8 +175,8 @@ func (t *Timer) moveList(level int, idx int) {
}
}

//这是一个非常重要的函数
//定时器的移动都在这里
// 这是一个非常重要的函数
// 定时器的移动都在这里
func (t *Timer) shift() {
mask := uint32(TIME_NEAR)
t.time += 1
Expand All @@ -196,11 +200,11 @@ func (t *Timer) shift() {
}
}

//派发消息到目标服务消息队列
// 派发消息到目标服务消息队列
func (t *Timer) dispatch(current *TimerNode) {
for current != nil {
id := current.LoadId()
if id != 0 {
if id > 0 {
current.handle()
if !current.bOnce {
t.loop_node = append(t.loop_node, current)
Expand All @@ -210,7 +214,7 @@ func (t *Timer) dispatch(current *TimerNode) {
}
}

//派发消息
// 派发消息
func (t *Timer) execute() {
idx := t.time & TIME_NEAR_MASK

Expand All @@ -228,7 +232,7 @@ func (t *Timer) execute() {
}
}

//时间更新好了以后,这里检查调用各个定时器
// 时间更新好了以后,这里检查调用各个定时器
func (t *Timer) advace() {
t.lock.Lock()
// try to dispatch timeout 0 (rare condition)
Expand All @@ -239,8 +243,8 @@ func (t *Timer) advace() {
t.lock.Unlock()
}

//在线程中不断被调用
//调用时间 间隔为微秒
// 在线程中不断被调用
// 调用时间 间隔为微秒
func (t *Timer) update() {
cp := uint64(time.Now().UnixNano()) / uint64(TICK_INTERVAL)
if cp < t.current_point {
Expand Down
6 changes: 4 additions & 2 deletions network/Isocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,10 @@ func (this *Socket) Start() bool {
}

func (this *Socket) Stop() bool {
if this.conn != nil && atomic.CompareAndSwapInt32(&this.state, SSF_RUN, SSF_STOP) {
this.conn.Close()
if atomic.CompareAndSwapInt32(&this.state, SSF_RUN, SSF_STOP) {
if this.conn != nil {
this.conn.Close()
}
}
return false
}
Expand Down
18 changes: 15 additions & 3 deletions network/ServerSocketClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"hash/crc32"
"io"
"log"
"sync/atomic"
"time"
)

Expand Down Expand Up @@ -42,6 +43,7 @@ func handleError(err error) {

func (s *ServerSocketClient) Init(ip string, port int, params ...OpOption) bool {
s.Socket.Init(ip, port, params...)
s.timerId = new(int64)
return true
}

Expand All @@ -52,8 +54,7 @@ func (s *ServerSocketClient) Start() bool {

if s.connectType == CLIENT_CONNECT {
s.sendChan = make(chan []byte, MAX_SEND_CHAN)
s.timerId = new(int64)
*s.timerId = int64(s.clientId)
timer.StoreTimerId(s.timerId, int64(s.clientId)+1<<32)
timer.RegisterTimer(s.timerId, (HEART_TIME_OUT/3)*time.Second, func() {
s.Update()
})
Expand Down Expand Up @@ -119,11 +120,22 @@ func (s *ServerSocketClient) OnNetFail(error int) {
}
}

func (s *ServerSocketClient) Stop() bool {
timer.RegisterTimer(s.timerId, timer.TICK_INTERVAL, func() {
timer.StopTimer(s.timerId)
if atomic.CompareAndSwapInt32(&s.state, SSF_RUN, SSF_STOP) {
if s.conn != nil {
s.conn.Close()
}
}
})
return false
}

func (s *ServerSocketClient) Close() {
if s.connectType == CLIENT_CONNECT {
s.sendChan <- nil
//close(s.sendChan)
timer.StopTimer(s.timerId)
}
s.Socket.Close()
if s.server != nil {
Expand Down
19 changes: 16 additions & 3 deletions network/WebSocketClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"gonet/common/timer"
"gonet/rpc"
"io"
"sync/atomic"
"time"
)

Expand All @@ -22,6 +23,7 @@ type WebSocketClient struct {

func (w *WebSocketClient) Init(ip string, port int, params ...OpOption) bool {
w.Socket.Init(ip, port, params...)
w.timerId = new(int64)
return true
}

Expand All @@ -32,8 +34,7 @@ func (w *WebSocketClient) Start() bool {

if w.connectType == CLIENT_CONNECT {
w.sendChan = make(chan []byte, MAX_SEND_CHAN)
w.timerId = new(int64)
*w.timerId = int64(w.clientId)
timer.StoreTimerId(w.timerId, int64(w.clientId)+1<<32)
timer.RegisterTimer(w.timerId, (HEART_TIME_OUT/3)*time.Second, func() {
w.Update()
})
Expand All @@ -49,6 +50,18 @@ func (w *WebSocketClient) Start() bool {
return true
}

func (w *WebSocketClient) Stop() bool {
timer.RegisterTimer(w.timerId, timer.TICK_INTERVAL, func() {
timer.StopTimer(w.timerId)
if atomic.CompareAndSwapInt32(&w.state, SSF_RUN, SSF_STOP) {
if w.conn != nil {
w.conn.Close()
}
}
})
return false
}

func (w *WebSocketClient) Send(head rpc.RpcHead, packet rpc.Packet) int {
defer func() {
if err := recover(); err != nil {
Expand Down Expand Up @@ -101,7 +114,6 @@ func (w *WebSocketClient) OnNetFail(error int) {
func (w *WebSocketClient) Close() {
if w.connectType == CLIENT_CONNECT {
//close(w.sendChan)
timer.StopTimer(w.timerId)
}
w.Socket.Close()
if w.server != nil {
Expand Down Expand Up @@ -137,6 +149,7 @@ func (w *WebSocketClient) Run() bool {
if n > 0 {
w.packetParser.Read(buff[:n])
}
w.heartTime = int(time.Now().Unix()) + HEART_TIME_OUT
return true
}

Expand Down

0 comments on commit 1ab2073

Please sign in to comment.