diff --git a/.changeset/gentle-forks-hide.md b/.changeset/gentle-forks-hide.md new file mode 100644 index 00000000..ea337391 --- /dev/null +++ b/.changeset/gentle-forks-hide.md @@ -0,0 +1,5 @@ +--- +"github.com/livekit/protocol": patch +--- + +Type safe IP checks for SIP Trunks. diff --git a/sip/sip.go b/sip/sip.go index 8d9e3b16..a812c5bb 100644 --- a/sip/sip.go +++ b/sip/sip.go @@ -267,13 +267,13 @@ func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo) error { return nil } -func matchAddrMask(addr string, mask string) bool { +func matchAddrMask(ip netip.Addr, mask string) bool { if !strings.Contains(mask, "/") { - return addr == mask - } - ip, err := netip.ParseAddr(addr) - if err != nil { - return false + expIP, err := netip.ParseAddr(mask) + if err != nil { + return false + } + return ip == expIP } pref, err := netip.ParsePrefix(mask) if err != nil { @@ -282,8 +282,8 @@ func matchAddrMask(addr string, mask string) bool { return pref.Contains(ip) } -func matchAddrMasks(addr string, masks []string) bool { - if addr == "" || len(masks) == 0 { +func matchAddrMasks(addr netip.Addr, masks []string) bool { + if !addr.IsValid() || len(masks) == 0 { return true } for _, mask := range masks { @@ -309,7 +309,7 @@ func matchNumbers(num string, allowed []string) bool { // MatchTrunk finds a SIP Trunk definition matching the request. // Returns nil if no rules matched or an error if there are conflicting definitions. -func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP, calling, called string) (*livekit.SIPInboundTrunkInfo, error) { +func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) { var ( selectedTrunk *livekit.SIPInboundTrunkInfo defaultTrunk *livekit.SIPInboundTrunkInfo diff --git a/sip/sip_test.go b/sip/sip_test.go index 4c233b42..fa14174e 100644 --- a/sip/sip_test.go +++ b/sip/sip_test.go @@ -16,6 +16,7 @@ package sip import ( "fmt" + "net/netip" "strconv" "testing" @@ -219,7 +220,13 @@ func TestSIPMatchTrunk(t *testing.T) { src = "1.1.1.1" } trunks := toInboundTrunks(c.trunks) - got, err := MatchTrunk(trunks, src, from, to) + var srcIP netip.Addr + if src != "" { + var err error + srcIP, err = netip.ParseAddr(src) + require.NoError(t, err) + } + got, err := MatchTrunk(trunks, srcIP, from, to) if c.expErr { require.Error(t, err) require.Nil(t, got) @@ -657,7 +664,9 @@ func TestMatchIP(t *testing.T) { } for _, c := range cases { t.Run(c.mask, func(t *testing.T) { - got := matchAddrMask(c.addr, c.mask) + ip, err := netip.ParseAddr(c.addr) + require.NoError(t, err) + got := matchAddrMask(ip, c.mask) require.Equal(t, c.exp, got) }) }