Skip to content

Commit

Permalink
connection tracking and rule matching for UDP
Browse files Browse the repository at this point in the history
  • Loading branch information
glaslos committed Jul 23, 2023
1 parent 10f4961 commit 1f2ceb3
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 67 deletions.
36 changes: 18 additions & 18 deletions connection/conntable.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

type CKey [2]uint64

func NewConnKey(clientAddr gopacket.Endpoint, clientPort gopacket.Endpoint) (CKey, error) {
func newConnKey(clientAddr gopacket.Endpoint, clientPort gopacket.Endpoint) (CKey, error) {
if clientAddr.EndpointType() != layers.EndpointIPv4 {
return CKey{}, errors.New("clientAddr endpoint must be of type layers.EndpointIPv4")
}
Expand All @@ -34,7 +34,7 @@ func NewConnKeyByString(host, port string) (CKey, error) {
return CKey{}, err
}
clientPort := layers.NewTCPPortEndpoint(layers.TCPPort(p))
return NewConnKey(clientAddr, clientPort)
return newConnKey(clientAddr, clientPort)
}

func NewConnKeyFromNetConn(conn net.Conn) (CKey, error) {
Expand Down Expand Up @@ -62,44 +62,45 @@ func New() *ConnTable {
}

// RegisterConn a connection in the table
func (t *ConnTable) RegisterConn(conn net.Conn, rule *rules.Rule) error {
func (t *ConnTable) RegisterConn(conn net.Conn, rule *rules.Rule) (*Metadata, error) {
srcIP, srcPort, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
return fmt.Errorf("failed to split remote address: %w", err)
return nil, fmt.Errorf("failed to split remote address: %w", err)
}

_, dstPort, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
return fmt.Errorf("failed to split local address: %w", err)
return nil, fmt.Errorf("failed to split local address: %w", err)
}

println(fmt.Sprintf("%s:%s->%s, %s", srcIP, srcPort, dstPort, rule.Target))

port, err := strconv.Atoi(dstPort)
if err != nil {
return fmt.Errorf("failed to parse dstPort: %w", err)
return nil, fmt.Errorf("failed to parse dstPort: %w", err)
}
return t.Register(srcIP, srcPort, uint16(port), rule)
}

// Register a connection in the table
func (t *ConnTable) Register(srcIP, srcPort string, targetPort uint16, rule *rules.Rule) error {
func (t *ConnTable) Register(srcIP, srcPort string, dstPort uint16, rule *rules.Rule) (*Metadata, error) {
t.mtx.Lock()
defer t.mtx.Unlock()

ck, err := NewConnKeyByString(srcIP, srcPort)
if err != nil {
return err
return nil, err
}
if _, ok := t.table[ck]; ok {
return nil
if md, ok := t.table[ck]; ok {
return md, nil
}
t.table[ck] = &Metadata{

println(fmt.Sprintf("%s:%s->%d, %s", srcIP, srcPort, dstPort, rule.Target))

md := &Metadata{
Added: time.Now(),
TargetPort: targetPort,
TargetPort: dstPort,
Rule: rule,
}
return nil
t.table[ck] = md
return md, nil
}

func (t *ConnTable) FlushOlderThan(s time.Duration) {
Expand All @@ -115,8 +116,7 @@ func (t *ConnTable) FlushOlderThan(s time.Duration) {
}
}

// TODO: what happens when I return a *Metadata and then FlushOlderThan()
// deletes it?
// TODO: what happens when I return a *Metadata and then FlushOlderThan() deletes it?
func (t *ConnTable) Get(ck CKey) *Metadata {
t.mtx.RLock()
defer t.mtx.RUnlock()
Expand Down
16 changes: 10 additions & 6 deletions connection/conntable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ func TestNewConnTable(t *testing.T) {
func TestRegister(t *testing.T) {
table := New()
targetPort := 4321
err := table.Register("127.0.0.1", "1234", uint16(targetPort), &rules.Rule{})
m1, err := table.Register("127.0.0.1", "1234", uint16(targetPort), &rules.Rule{})
require.NoError(t, err)
m := table.Get(testck)
require.NotNil(t, m)
require.Equal(t, targetPort, int(m.TargetPort))
require.NotNil(t, m1)
m2 := table.Get(testck)
require.NotNil(t, m1)
require.Equal(t, targetPort, int(m2.TargetPort))
require.Equal(t, m1, m2)
}

func TestRegisterConn(t *testing.T) {
Expand All @@ -56,8 +58,9 @@ func TestRegisterConn(t *testing.T) {
require.NotNil(t, conn)
defer conn.Close()
table := New()
err = table.RegisterConn(conn, &rules.Rule{Target: "default"})
md, err := table.RegisterConn(conn, &rules.Rule{Target: "default"})
require.NoError(t, err)
require.NotNil(t, md)
m := table.Get(testck)
require.NotNil(t, m)
require.Equal(t, "default", m.Rule.Target)
Expand All @@ -66,8 +69,9 @@ func TestRegisterConn(t *testing.T) {
func TestFlushOlderThan(t *testing.T) {
table := New()
targetPort := 4321
err := table.Register("127.0.0.1", "1234", uint16(targetPort), &rules.Rule{})
md, err := table.Register("127.0.0.1", "1234", uint16(targetPort), &rules.Rule{})
require.NoError(t, err)
require.NotNil(t, md)
table.FlushOlderThan(time.Duration(0))
m := table.Get(testck)
require.Nil(t, m)
Expand Down
31 changes: 25 additions & 6 deletions glutton.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,17 @@ func (g *Glutton) udpListen() {
if err != nil {
g.Logger.Error("failed to read UDP packet", zap.Error(err))
}

rule, err := g.applyRules("udp", srcAddr, dstAddr)
if err != nil {
g.Logger.Error("failed to apply rules", zap.Error(err))
}
md, err := g.conntable.Register(srcAddr.IP.String(), strconv.Itoa(int(srcAddr.AddrPort().Port())), dstAddr.AddrPort().Port(), rule)
if err != nil {
g.Logger.Error("failed to register UDP packet", zap.Error(err))
}
g.Logger.Info(fmt.Sprintf("UDP payload:\n%s", hex.Dump(buffer[:n%1024])))
println(srcAddr.String(), dstAddr.String())
if err := g.ProduceUDP("udp", srcAddr, dstAddr, nil, buffer[:n%1024], nil); err != nil {
if err := g.ProduceUDP("udp", srcAddr, dstAddr, md, buffer[:n%1024], nil); err != nil {
g.Logger.Error("failed to produce UDP payload", zap.Error(err))
}
}
Expand Down Expand Up @@ -183,15 +191,15 @@ func (g *Glutton) Start() error {
return err
}

rule, err := g.applyRules(conn)
rule, err := g.applyRulesOnConn(conn)
if err != nil {
return fmt.Errorf("failed to apply rules: %w", err)
}
if rule == nil {
rule = &rules.Rule{Target: "default"}
}

if err := g.conntable.RegisterConn(conn, rule); err != nil {
if _, err := g.conntable.RegisterConn(conn, rule); err != nil {
return err
}

Expand Down Expand Up @@ -389,8 +397,19 @@ func (g *Glutton) Shutdown() error {
return g.Server.Shutdown()
}

func (g *Glutton) applyRules(conn net.Conn) (*rules.Rule, error) {
match, err := g.rules.Match(conn)
func (g *Glutton) applyRulesOnConn(conn net.Conn) (*rules.Rule, error) {
match, err := g.rules.Match("tcp", conn.RemoteAddr(), conn.LocalAddr())
if err != nil {
return nil, err
}
if match != nil {
return match, err
}
return nil, nil
}

func (g *Glutton) applyRules(network string, srcAddr, dstAddr net.Addr) (*rules.Rule, error) {
match, err := g.rules.Match(network, srcAddr, dstAddr)
if err != nil {
return nil, err
}
Expand Down
2 changes: 0 additions & 2 deletions producer/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,11 @@ func makeEventUDP(handler string, srcAddr, dstAddr *net.UDPAddr, md *connection.
Timestamp: time.Now().UTC(),
SrcHost: srcAddr.IP.String(),
SrcPort: strconv.Itoa(int(srcAddr.AddrPort().Port())),
DstPort: dstAddr.AddrPort().Port(),
SensorID: sensorID,
Handler: handler,
Payload: base64.StdEncoding.EncodeToString(payload),
Scanner: scannerName,
Decoded: decoded,
Rule: "Rule: udp",
}
if md != nil {
event.DstPort = uint16(md.TargetPort)
Expand Down
75 changes: 44 additions & 31 deletions rules/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,49 +138,54 @@ func InitRule(idx int, rule *Rule) error {
return nil
}

func splitAddr(addr string) (net.IP, layers.TCPPort, error) {
func splitAddr(addr string) (string, uint16, error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, 0, err
return "", 0, err
}
sIP := net.ParseIP(ip)

dPort, err := strconv.Atoi(port)
if err != nil {
return nil, 0, err
return "", 0, err
}
return sIP, layers.TCPPort(dPort), nil
return ip, uint16(dPort), nil
}

func fakePacketBytes(conn net.Conn) ([]byte, error) {
func fakePacketBytes(network, srcIP, dstIP string, srcPort, dstPort uint16) ([]byte, error) {
buf := gopacket.NewSerializeBuffer()

sIP, sPort, err := splitAddr(conn.LocalAddr().String())
if err != nil {
return nil, err
}
dIP, dPort, err := splitAddr(conn.RemoteAddr().String())
if err != nil {
return nil, err
}

eth := &layers.Ethernet{
SrcMAC: net.HardwareAddr{0x0, 0x11, 0x22, 0x33, 0x44, 0x55},
DstMAC: net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
EthernetType: layers.EthernetTypeIPv4,
}
ipv4 := &layers.IPv4{
SrcIP: sIP,
DstIP: dIP,
Version: 4,
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: sPort,
DstPort: dPort,
}
if err := tcp.SetNetworkLayerForChecksum(ipv4); err != nil {
return nil, err
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
Version: 4,
}

var transport gopacket.SerializableLayer
switch network {
case "tcp":
ipv4.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
if err := tcp.SetNetworkLayerForChecksum(ipv4); err != nil {
return nil, err
}
transport = tcp

case "udp":
ipv4.Protocol = layers.IPProtocolUDP
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
return nil, err
}
transport = udp
}

if err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
Expand All @@ -189,7 +194,7 @@ func fakePacketBytes(conn net.Conn) ([]byte, error) {
},
eth,
ipv4,
tcp,
transport,
gopacket.Payload([]byte{})); err != nil {
return nil, err
}
Expand All @@ -198,8 +203,16 @@ func fakePacketBytes(conn net.Conn) ([]byte, error) {

type Rules []*Rule

func (rs Rules) Match(conn net.Conn) (*Rule, error) {
b, err := fakePacketBytes(conn)
func (rs Rules) Match(network string, srcAddr, dstAddr net.Addr) (*Rule, error) {
srcIP, srcPort, err := splitAddr(srcAddr.String())
if err != nil {
return nil, err
}
dstIP, dstPort, err := splitAddr(dstAddr.String())
if err != nil {
return nil, err
}
b, err := fakePacketBytes(network, srcIP, dstIP, srcPort, dstPort)
if err != nil {
return nil, fmt.Errorf("failed to fake packet: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions rules/rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func TestInitRule(t *testing.T) {
func TestSplitAddr(t *testing.T) {
ip, port, err := splitAddr("192.168.1.1:8080")
require.NoError(t, err)
require.True(t, net.ParseIP("192.168.1.1").Equal(ip))
require.Equal(t, layers.TCPPort(8080), port)
require.Equal(t, "192.168.1.1", ip)
require.Equal(t, uint16(8080), port)
}

func testConn(t *testing.T) (net.Conn, net.Listener) {
Expand All @@ -63,7 +63,7 @@ func TestFakePacketBytes(t *testing.T) {
conn.Close()
ln.Close()
}()
b, err := fakePacketBytes(conn)
b, err := fakePacketBytes("tcp", "1.1.1.1", "2.2.2.2", 12, 21)
require.NoError(t, err)
require.NotEmpty(t, b)
}
Expand All @@ -85,7 +85,7 @@ func TestRunMatch(t *testing.T) {
err error
)

match, err = rules.Match(conn)
match, err = rules.Match("tcp", conn.LocalAddr(), conn.RemoteAddr())
require.NoError(t, err)
require.NotNil(t, match)
require.Equal(t, "test", match.Target)
Expand Down

0 comments on commit 1f2ceb3

Please sign in to comment.