forked from vishvananda/go-netlink
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlistener.go
162 lines (147 loc) · 3.81 KB
/
listener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package netlink
import (
"bufio"
"errors"
"fmt"
"log"
"os"
"runtime"
"sync"
)
type Listener struct {
sock *Socket
Messagechan chan Message
sendqueue chan *Message
currSeq uint32
lock sync.Mutex
socketLock sync.Mutex
echan chan error
}
// Used as an atomic counter for sequence numbering.
// No check is made to see that sequences aren't still in use on roll-over.
func (listener *Listener) Seq() (out uint32) {
listener.lock.Lock()
listener.currSeq++
out = listener.currSeq
listener.lock.Unlock()
return
}
func (listener *Listener) Close() {
//close(listener.Messagechan)
listener.sock.Close()
}
// Send a message. If SequenceNumber is unset, Seq() will be used
// to generate one.
func (listener *Listener) Query(msg *Message) (err error) {
listener.sendqueue <- msg
return
}
func NewListener(nlfm NetlinkFamily) (listener *Listener, err error) {
mysock, err := Dial(nlfm)
listener = nil
if err != nil {
log.Panicf("Can't dial netlink socket: %v", err)
err = err
return
}
listener = &Listener{sock: mysock, Messagechan: make(chan Message, 10), sendqueue: make(chan *Message, 10), currSeq: 0}
return
}
func (listener *Listener) startListening() {
r := bufio.NewReader(listener.sock)
for listener.sock.IsOpen() {
peekedBytes, err := r.Peek(1)
if peekedBytes == nil {
log.Println("Didn't peeked any bytes")
continue
}
if err != nil && err == bufio.ErrNegativeCount {
// Most probably the socket is closed
log.Printf("Netlink socket seems to be closed, checking if socket is open")
continue
}
//listener.socketLock.Lock()
messages := make([]*Message, 0, 10)
msg, err := ReadMessage(r)
messages = append(messages, msg)
for !isLastMessage(msg) {
msg, err = ReadMessage(r)
messages = append(messages, msg)
}
if msg.Header.MessageSequence() == listener.currSeq {
// We received the response to a request, so we can send the next Query
listener.socketLock.Unlock()
}
//listener.socketLock.Unlock()
for i := range messages {
msg := messages[i]
listener.handleMessage(msg)
}
runtime.Gosched()
}
}
func (listener *Listener) sendError(err error) {
if listener.echan != nil {
listener.echan <- err
} else {
log.Fatalf("Can't parse netlink message: %v", err)
}
}
func (listener *Listener) handleMessage(msg *Message) {
if msg != nil {
if msg.Header.MessageType() == NLMSG_ERROR {
errmsg := &Error{}
err := errmsg.UnmarshalNetlink(msg.Body)
if err != nil {
log.Panicf("Can't unmarshall netlink error message: %v", err)
} else {
err = errors.New(fmt.Sprintf("Netlink Error (%d) for message with sequence (%d): %s. Header %v Body: %v", errmsg.Code(), msg.Header.MessageSequence(), errmsg.Error(), msg.Header, msg.Body))
listener.echan <- err
}
} else {
listener.Messagechan <- *msg
}
} else {
log.Fatalf("Netlink message was null")
}
}
func (listener *Listener) startWriting() {
for listener.sock.IsOpen() {
select {
case msg := <-listener.sendqueue:
listener.socketLock.Lock()
if msg.Header.MessageSequence() == 0 {
msg.Header.SetMessageSequence(listener.Seq())
}
ob, err := msg.MarshalNetlink()
if err == nil {
_, err = listener.sock.Write(ob)
runtime.Gosched()
}
default:
runtime.Gosched()
}
}
}
func (listener *Listener) Start(echan chan error) (err error) {
// ^uint32 is MAX UNIT and means that we want to listen to all multicast groups
listener.echan = echan
err = listener.sock.Bind(uint32(os.Getpid()), ^uint32(0))
if err != nil {
log.Panicf("Can't bind to netlink socket: %v", err)
err = err
return
}
go listener.startWriting()
go listener.startListening()
return
}
func isLastMessage(msg *Message) bool {
if msg == nil {
return true
}
if msg.Header.MessageType() == NLMSG_DONE {
return true
}
return msg.Header.MessageFlags()&NLM_F_MULTI == 0
}