Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type safe IP checks for SIP Trunks #857

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/gentle-forks-hide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"github.com/livekit/protocol": patch
---

Type safe IP checks for SIP Trunks.
18 changes: 9 additions & 9 deletions sip/sip.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,13 @@ func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo) error {
return nil
}

func matchAddrMask(addr string, mask string) bool {
func matchAddrMask(ip netip.Addr, mask string) bool {
if !strings.Contains(mask, "/") {
return addr == mask
}
ip, err := netip.ParseAddr(addr)
if err != nil {
return false
expIP, err := netip.ParseAddr(mask)
if err != nil {
return false
}
return ip == expIP
}
pref, err := netip.ParsePrefix(mask)
if err != nil {
Expand All @@ -282,8 +282,8 @@ func matchAddrMask(addr string, mask string) bool {
return pref.Contains(ip)
}

func matchAddrMasks(addr string, masks []string) bool {
if addr == "" || len(masks) == 0 {
func matchAddrMasks(addr netip.Addr, masks []string) bool {
if !addr.IsValid() || len(masks) == 0 {
return true
}
for _, mask := range masks {
Expand All @@ -309,7 +309,7 @@ func matchNumbers(num string, allowed []string) bool {

// MatchTrunk finds a SIP Trunk definition matching the request.
// Returns nil if no rules matched or an error if there are conflicting definitions.
func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP, calling, called string) (*livekit.SIPInboundTrunkInfo, error) {
func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) {
var (
selectedTrunk *livekit.SIPInboundTrunkInfo
defaultTrunk *livekit.SIPInboundTrunkInfo
Expand Down
13 changes: 11 additions & 2 deletions sip/sip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package sip

import (
"fmt"
"net/netip"
"strconv"
"testing"

Expand Down Expand Up @@ -219,7 +220,13 @@ func TestSIPMatchTrunk(t *testing.T) {
src = "1.1.1.1"
}
trunks := toInboundTrunks(c.trunks)
got, err := MatchTrunk(trunks, src, from, to)
var srcIP netip.Addr
if src != "" {
var err error
srcIP, err = netip.ParseAddr(src)
require.NoError(t, err)
}
got, err := MatchTrunk(trunks, srcIP, from, to)
if c.expErr {
require.Error(t, err)
require.Nil(t, got)
Expand Down Expand Up @@ -657,7 +664,9 @@ func TestMatchIP(t *testing.T) {
}
for _, c := range cases {
t.Run(c.mask, func(t *testing.T) {
got := matchAddrMask(c.addr, c.mask)
ip, err := netip.ParseAddr(c.addr)
require.NoError(t, err)
got := matchAddrMask(ip, c.mask)
require.Equal(t, c.exp, got)
})
}
Expand Down
Loading