Skip to content

Commit

Permalink
Merge pull request #6 from percipia/close_fix
Browse files Browse the repository at this point in the history
Fix connections not properly stopping for invalid inbound authentication
  • Loading branch information
winsock authored Feb 16, 2021
2 parents 4b56517 + f1cee8d commit 6e56fc1
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 36 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ eslgo was written from the ground up in idiomatic Go for use in our production p
go get github.com/percipia/eslgo
```
```
github.com/percipia/eslgo v1.3.2
github.com/percipia/eslgo v1.3.3
```

## Overview
Expand Down Expand Up @@ -90,6 +90,6 @@ func main() {

// Close the connection after sleeping for a bit
time.Sleep(60 * time.Second)
conn.Close()
conn.ExitAndClose()
}
```
73 changes: 51 additions & 22 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func newConnection(c net.Conn, outbound bool) *Conn {
TypeEventPlain: make(chan *RawResponse),
TypeEventXML: make(chan *RawResponse),
TypeEventJSON: make(chan *RawResponse),
TypeAuthRequest: make(chan *RawResponse),
TypeAuthRequest: make(chan *RawResponse, 1), // Buffered to ensure we do not lose the initial auth request before we are setup to respond
TypeDisconnect: make(chan *RawResponse),
},
runningContext: runningContext,
Expand Down Expand Up @@ -124,19 +124,32 @@ func (c *Conn) SendCommand(ctx context.Context, command command.Command) (*RawRe
}
}

func (c *Conn) ExitAndClose() {
c.closeOnce.Do(func() {
// Attempt a graceful closing of the connection with FreeSWITCH
ctx, cancel := context.WithTimeout(c.runningContext, time.Second)
_, _ = c.SendCommand(ctx, command.Exit{})
cancel()
c.close()
})
}

func (c *Conn) Close() {
c.closeOnce.Do(c.close)
}

func (c *Conn) close() {
// Allow users to do anything they need to do before we tear everything down
c.stopFunc()
_ = c.conn.Close()
c.responseChanMutex.Lock()
defer c.responseChanMutex.Unlock()
for key, responseChan := range c.responseChannels {
close(responseChan)
delete(c.responseChannels, key)
}

// Close the connection only after we have the response channel lock and we have deleted all response channels to ensure we don't receive on a closed channel
_ = c.conn.Close()
}

func (c *Conn) callEventListener(event *Event) {
Expand Down Expand Up @@ -224,30 +237,46 @@ func (c *Conn) eventLoop() {
}

func (c *Conn) receiveLoop() {
for {
response, err := c.readResponse()
for c.runningContext.Err() == nil {
err := c.doMessage()
if err != nil {
log.Println("Error receiving message", err)
break
}
}
}

c.responseChanMutex.RLock()
responseChan, ok := c.responseChannels[response.GetHeader("Content-Type")]
if !ok && len(c.responseChannels) <= 0 {
// We must have shutdown!
break
}
c.responseChanMutex.RUnlock()
if ok {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
select {
case responseChan <- response:
case <-c.runningContext.Done():
cancel()
return
case <-ctx.Done():
log.Printf("No one to handle response %v\n", response)
}
cancel()
func (c *Conn) doMessage() error {
response, err := c.readResponse()
if err != nil {
return err
}

c.responseChanMutex.RLock()
defer c.responseChanMutex.RUnlock()
responseChan, ok := c.responseChannels[response.GetHeader("Content-Type")]
if !ok && len(c.responseChannels) <= 0 {
// We must have shutdown!
return errors.New("no response channels")
}

// We have a handler
if ok {
// Only allow 5 seconds to allow the handler to receive hte message on the channel
ctx, cancel := context.WithTimeout(c.runningContext, 5*time.Second)
defer cancel()

select {
case responseChan <- response:
case <-c.runningContext.Done():
// Parent connection context has stopped we most likely shutdown in the middle of waiting for a handler to handle the message
return c.runningContext.Err()
case <-ctx.Done():
// Do not return an error since this is not fatal but log since it could be a indication of problems
log.Printf("No one to handle response\nIs the connection overloaded or stopping?\n%v\n\n", response)
}
} else {
return errors.New("no response channel for Content-Type: " + response.GetHeader("Content-Type"))
}
return nil
}
6 changes: 3 additions & 3 deletions example/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func main() {
})

// Ensure all events are enabled
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = conn.EnableEvents(ctx)
cancel()

Expand All @@ -48,7 +48,7 @@ func main() {
}
}

// Remove the listener and close the connection
// Remove the listener and close the connection gracefully
conn.RemoveEventListener(eslgo.EventListenAll, listenerID)
conn.Close()
conn.ExitAndClose()
}
2 changes: 1 addition & 1 deletion example/inbound/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ func main() {

// Close the connection after sleeping for a bit
time.Sleep(60 * time.Second)
conn.Close()
conn.ExitAndClose()
}
18 changes: 12 additions & 6 deletions inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ func Dial(address, password string, onDisconnect func()) (*Conn, error) {
<-connection.responseChannels[TypeAuthRequest]
err = connection.doAuth(connection.runningContext, command.Auth{Password: password})
if err != nil {
// Try to gracefully disconnect
log.Printf("Failed to auth %e\n", err)
_, _ = connection.SendCommand(connection.runningContext, command.Exit{})
// Try to gracefully disconnect, we have the wrong password.
connection.ExitAndClose()
if onDisconnect != nil {
go onDisconnect()
}
return nil, err
} else {
log.Printf("Sucessfully authenticated %s\n", connection.conn.RemoteAddr())
}
Expand All @@ -47,7 +50,9 @@ func (c *Conn) disconnectLoop(onDisconnect func()) {
select {
case <-c.responseChannels[TypeDisconnect]:
c.Close()
defer onDisconnect()
if onDisconnect != nil {
onDisconnect()
}
return
case <-c.runningContext.Done():
return
Expand All @@ -60,9 +65,10 @@ func (c *Conn) authLoop(auth command.Auth) {
case <-c.responseChannels[TypeAuthRequest]:
err := c.doAuth(c.runningContext, auth)
if err != nil {
// Try to gracefully disconnect
log.Printf("Failed to auth %e\n", err)
_, _ = c.SendCommand(c.runningContext, command.Exit{})
// Close the connection, we have the wrong password
c.ExitAndClose()
return
} else {
log.Printf("Sucessfully authenticated %s\n", c.conn.RemoteAddr())
}
Expand Down
4 changes: 2 additions & 2 deletions outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (c *Conn) outboundHandle(handler OutboundHandler) {
if err != nil {
log.Printf("Error connecting to %s error %s", c.conn.RemoteAddr().String(), err.Error())
// Try closing cleanly first
c.Close()
c.Close() // Not ExitAndClose since this error connection is most likely from communication failure
return
}
handler(c.runningContext, c, response)
Expand All @@ -67,7 +67,7 @@ func (c *Conn) outboundHandle(handler OutboundHandler) {
ctx, cancel = context.WithTimeout(c.runningContext, 5*time.Second)
_, _ = c.SendCommand(ctx, command.Exit{})
cancel()
c.Close()
c.ExitAndClose()
}

func (c *Conn) dummyLoop() {
Expand Down

0 comments on commit 6e56fc1

Please sign in to comment.