Skip to content

Commit

Permalink
feat: set deadlines on AP connection
Browse files Browse the repository at this point in the history
  • Loading branch information
devgianlu committed Nov 18, 2024
1 parent 5e60acb commit 355191f
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
21 changes: 13 additions & 8 deletions ap/ap.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials)
return err
}

if deadline, ok := ctx.Deadline(); ok {
_ = ap.conn.SetDeadline(deadline)
defer func() { _ = ap.conn.SetDeadline(time.Time{}) }()
}

// perform key exchange with diffiehellman
exchangeData, err := ap.performKeyExchange()
if err != nil {
Expand All @@ -198,7 +203,7 @@ func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials)
}

// do authentication with credentials
if err := ap.authenticate(creds); err != nil {
if err := ap.authenticate(ctx, creds); err != nil {
return fmt.Errorf("failed authenticating: %w", err)
}

Expand All @@ -220,10 +225,10 @@ func (ap *Accesspoint) Close() {
_ = ap.conn.Close()
}

func (ap *Accesspoint) Send(pktType PacketType, payload []byte) error {
func (ap *Accesspoint) Send(ctx context.Context, pktType PacketType, payload []byte) error {
ap.connMu.RLock()
defer ap.connMu.RUnlock()
return ap.encConn.sendPacket(pktType, payload)
return ap.encConn.sendPacket(ctx, pktType, payload)
}

func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet {
Expand Down Expand Up @@ -261,7 +266,7 @@ loop:
break loop
default:
// no need to hold the connMu since reconnection happens in this routine
pkt, payload, err := ap.encConn.receivePacket()
pkt, payload, err := ap.encConn.receivePacket(context.TODO())
if err != nil {
log.WithError(err).Errorf("failed receiving packet")
break loop
Expand All @@ -270,7 +275,7 @@ loop:
switch pkt {
case PacketTypePing:
log.Tracef("received accesspoint ping")
if err := ap.encConn.sendPacket(PacketTypePong, payload); err != nil {
if err := ap.encConn.sendPacket(context.TODO(), PacketTypePong, payload); err != nil {
log.WithError(err).Errorf("failed sending Pong packet")
break loop
}
Expand Down Expand Up @@ -467,7 +472,7 @@ func (ap *Accesspoint) solveChallenge(exchangeData []byte) error {
return fmt.Errorf("failed login: %s", resp.LoginFailed.ErrorCode.String())
}

func (ap *Accesspoint) authenticate(credentials *pb.LoginCredentials) error {
func (ap *Accesspoint) authenticate(ctx context.Context, credentials *pb.LoginCredentials) error {
if ap.encConn == nil {
panic("accesspoint not connected")
}
Expand All @@ -488,12 +493,12 @@ func (ap *Accesspoint) authenticate(credentials *pb.LoginCredentials) error {
}

// send Login packet
if err := ap.encConn.sendPacket(PacketTypeLogin, payload); err != nil {
if err := ap.encConn.sendPacket(ctx, PacketTypeLogin, payload); err != nil {
return fmt.Errorf("failed sending Login packet: %w", err)
}

// receive APWelcome or AuthFailure
recvPkt, recvPayload, err := ap.encConn.receivePacket()
recvPkt, recvPayload, err := ap.encConn.receivePacket(ctx)
if err != nil {
return fmt.Errorf("failed recevining Login response packet: %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions ap/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ package ap
import (
"encoding/binary"
"fmt"
"io"

"google.golang.org/protobuf/proto"
"io"
)

func writeMessage(w io.Writer, withHello bool, m proto.Message) error {
Expand Down
16 changes: 14 additions & 2 deletions ap/shannon.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package ap

import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"sync"
"time"

"github.com/devgianlu/shannon"
)
Expand Down Expand Up @@ -33,7 +35,7 @@ func newShannonConn(conn net.Conn, sendKey []byte, recvKey []byte) *shannonConn
}
}

func (c *shannonConn) sendPacket(pktType PacketType, payload []byte) error {
func (c *shannonConn) sendPacket(ctx context.Context, pktType PacketType, payload []byte) error {
if len(payload) > 65535 {
return fmt.Errorf("payload too big: %d", len(payload))
}
Expand All @@ -58,6 +60,11 @@ func (c *shannonConn) sendPacket(pktType PacketType, payload []byte) error {
mac := make([]byte, 4)
c.sendCipher.Finish(mac)

if deadline, ok := ctx.Deadline(); ok {
_ = c.conn.SetDeadline(deadline)
defer func() { _ = c.conn.SetDeadline(time.Time{}) }()
}

// write it all out
if _, err := c.conn.Write(packet); err != nil {
return fmt.Errorf("failed writing packet: %w", err)
Expand All @@ -68,10 +75,15 @@ func (c *shannonConn) sendPacket(pktType PacketType, payload []byte) error {
return nil
}

func (c *shannonConn) receivePacket() (PacketType, []byte, error) {
func (c *shannonConn) receivePacket(ctx context.Context) (PacketType, []byte, error) {
c.recvLock.Lock()
defer c.recvLock.Unlock()

if deadline, ok := ctx.Deadline(); ok {
_ = c.conn.SetDeadline(deadline)
defer func() { _ = c.conn.SetDeadline(time.Time{}) }()
}

// set nonce on cipher and increment
c.recvCipher.NonceU32(c.recvNonce)
c.recvNonce++
Expand Down
2 changes: 1 addition & 1 deletion audio/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (p *KeyProvider) recvLoop() {

reqs[reqSeq] = req

if err := p.ap.Send(ap.PacketTypeRequestKey, buf.Bytes()); err != nil {
if err := p.ap.Send(context.TODO(), ap.PacketTypeRequestKey, buf.Bytes()); err != nil {
delete(reqs, reqSeq)
req.resp <- keyResponse{err: fmt.Errorf("failed sending key request for file %s, gid: %s: %w",
hex.EncodeToString(req.fileId), librespot.GidToBase62(req.gid), err)}
Expand Down

0 comments on commit 355191f

Please sign in to comment.