Skip to content

Commit

Permalink
Type safe IP checks for SIP Trunks. (#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennwc authored Oct 16, 2024
1 parent c0c2e6e commit d16f740
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
5 changes: 5 additions & 0 deletions .changeset/gentle-forks-hide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"github.com/livekit/protocol": patch
---

Type safe IP checks for SIP Trunks.
18 changes: 9 additions & 9 deletions sip/sip.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,13 @@ func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo) error {
return nil
}

func matchAddrMask(addr string, mask string) bool {
func matchAddrMask(ip netip.Addr, mask string) bool {
if !strings.Contains(mask, "/") {
return addr == mask
}
ip, err := netip.ParseAddr(addr)
if err != nil {
return false
expIP, err := netip.ParseAddr(mask)
if err != nil {
return false
}
return ip == expIP
}
pref, err := netip.ParsePrefix(mask)
if err != nil {
Expand All @@ -282,8 +282,8 @@ func matchAddrMask(addr string, mask string) bool {
return pref.Contains(ip)
}

func matchAddrMasks(addr string, masks []string) bool {
if addr == "" || len(masks) == 0 {
func matchAddrMasks(addr netip.Addr, masks []string) bool {
if !addr.IsValid() || len(masks) == 0 {
return true
}
for _, mask := range masks {
Expand All @@ -309,7 +309,7 @@ func matchNumbers(num string, allowed []string) bool {

// MatchTrunk finds a SIP Trunk definition matching the request.
// Returns nil if no rules matched or an error if there are conflicting definitions.
func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP, calling, called string) (*livekit.SIPInboundTrunkInfo, error) {
func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) {
var (
selectedTrunk *livekit.SIPInboundTrunkInfo
defaultTrunk *livekit.SIPInboundTrunkInfo
Expand Down
13 changes: 11 additions & 2 deletions sip/sip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package sip

import (
"fmt"
"net/netip"
"strconv"
"testing"

Expand Down Expand Up @@ -219,7 +220,13 @@ func TestSIPMatchTrunk(t *testing.T) {
src = "1.1.1.1"
}
trunks := toInboundTrunks(c.trunks)
got, err := MatchTrunk(trunks, src, from, to)
var srcIP netip.Addr
if src != "" {
var err error
srcIP, err = netip.ParseAddr(src)
require.NoError(t, err)
}
got, err := MatchTrunk(trunks, srcIP, from, to)
if c.expErr {
require.Error(t, err)
require.Nil(t, got)
Expand Down Expand Up @@ -657,7 +664,9 @@ func TestMatchIP(t *testing.T) {
}
for _, c := range cases {
t.Run(c.mask, func(t *testing.T) {
got := matchAddrMask(c.addr, c.mask)
ip, err := netip.ParseAddr(c.addr)
require.NoError(t, err)
got := matchAddrMask(ip, c.mask)
require.Equal(t, c.exp, got)
})
}
Expand Down

0 comments on commit d16f740

Please sign in to comment.