Skip to content

Commit

Permalink
Simplify number matching rules. Make '+' optional.
Browse files Browse the repository at this point in the history
  • Loading branch information
dennwc committed Jun 12, 2024
1 parent 1dd95af commit b9a523c
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 18 deletions.
20 changes: 18 additions & 2 deletions rpc/sip.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package rpc

import "github.com/livekit/protocol/livekit"
import (
"strings"

"github.com/livekit/protocol/livekit"
)

// NewCreateSIPParticipantRequest fills InternalCreateSIPParticipantRequest from
// livekit.CreateSIPParticipantRequest and livekit.SIPTrunkInfo.
Expand All @@ -9,11 +13,23 @@ func NewCreateSIPParticipantRequest(
req *livekit.CreateSIPParticipantRequest,
trunk *livekit.SIPTrunkInfo,
) *InternalCreateSIPParticipantRequest {
// A sanity check for the number format for well-known providers.
outboundNumber := trunk.OutboundNumber
switch {
case strings.HasSuffix(trunk.OutboundAddress, "telnyx.com"):
// Telnyx omits leading '+' by default.
outboundNumber = strings.TrimPrefix(outboundNumber, "+")
case strings.HasSuffix(trunk.OutboundAddress, "twilio.com"):
// Twilio requires leading '+'.
if !strings.HasPrefix(outboundNumber, "+") {
outboundNumber = "+" + outboundNumber
}
}
return &InternalCreateSIPParticipantRequest{
SipCallId: callID,
Address: trunk.OutboundAddress,
Transport: trunk.Transport,
Number: trunk.OutboundNumber,
Number: outboundNumber,
Username: trunk.OutboundUsername,
Password: trunk.OutboundPassword,
CallTo: req.SipCallTo,
Expand Down
56 changes: 42 additions & 14 deletions sip/sip.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo) error {
}
for _, trunk := range trunks {
for _, number := range numbers {
key := ruleKey{Pin: pin, Trunk: trunk, Number: number}
key := ruleKey{Pin: pin, Trunk: trunk, Number: normalizeNumber(number)}
r2 := byRuleKey[key]
if r2 != nil {
return fmt.Errorf("Conflicting SIP Dispatch Rules: same Trunk+Number+PIN combination for for %q and %q",
Expand Down Expand Up @@ -190,6 +190,18 @@ func printNumber(s string) string {
return strconv.Quote(s)
}

func normalizeNumber(num string) string {
if num == "" {
return ""
}
// TODO: Always keep "number" as-is if it's not E.164.
// This will only matter for native SIP clients which have '+' in the username.
if !strings.HasPrefix(num, `+`) {
num = "+" + num
}
return num
}

// ValidateTrunks checks a set of trunks for conflicts.
func ValidateTrunks(trunks []*livekit.SIPTrunkInfo) error {
if len(trunks) == 0 {
Expand All @@ -200,10 +212,11 @@ func ValidateTrunks(trunks []*livekit.SIPTrunkInfo) error {
if len(t.InboundNumbersRegex) != 0 {
continue // can't effectively validate these
}
byInbound := byOutboundAndInbound[t.OutboundNumber]
outboundKey := normalizeNumber(t.OutboundNumber)
byInbound := byOutboundAndInbound[outboundKey]
if byInbound == nil {
byInbound = make(map[string]*livekit.SIPTrunkInfo)
byOutboundAndInbound[t.OutboundNumber] = byInbound
byOutboundAndInbound[outboundKey] = byInbound
}
if len(t.InboundNumbers) == 0 {
if t2 := byInbound[""]; t2 != nil {
Expand All @@ -213,19 +226,20 @@ func ValidateTrunks(trunks []*livekit.SIPTrunkInfo) error {
byInbound[""] = t
} else {
for _, num := range t.InboundNumbers {
t2 := byInbound[num]
inboundKey := normalizeNumber(num)
t2 := byInbound[inboundKey]
if t2 != nil {
return fmt.Errorf("Conflicting SIP Trunks: %q and %q, using the same OutboundNumber %s and InboundNumber %q",
printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumber(t.OutboundNumber), num)
}
byInbound[num] = t
byInbound[inboundKey] = t
}
}
}
return nil
}

func matchAddrs(addr string, mask string) bool {
func matchAddrMask(addr string, mask string) bool {
if !strings.Contains(mask, "/") {
return addr == mask
}
Expand All @@ -240,16 +254,29 @@ func matchAddrs(addr string, mask string) bool {
return pref.Contains(ip)
}

func matchAddr(addr string, masks []string) bool {
if addr == "" {
func matchAddrMasks(addr string, masks []string) bool {
if addr == "" || len(masks) == 0 {
return true
}
for _, mask := range masks {
if !matchAddrs(addr, mask) {
return false
if matchAddrMask(addr, mask) {
return true
}
}
return false
}

func matchNumbers(num string, allowed []string) bool {
if len(allowed) == 0 {
return true
}
num = normalizeNumber(num)
for _, allow := range allowed {
if num == normalizeNumber(allow) {
return true
}
}
return true
return false
}

// MatchTrunk finds a SIP Trunk definition matching the request.
Expand All @@ -260,12 +287,13 @@ func MatchTrunk(trunks []*livekit.SIPTrunkInfo, srcIP, calling, called string) (
defaultTrunk *livekit.SIPTrunkInfo
defaultTrunkCnt int // to error in case there are multiple ones
)
calledNorm := normalizeNumber(called)
for _, tr := range trunks {
// Do not consider it if number doesn't match.
if len(tr.InboundNumbers) != 0 && !slices.Contains(tr.InboundNumbers, calling) {
if !matchNumbers(calling, tr.InboundNumbers) {
continue
}
if !matchAddr(srcIP, tr.InboundAddresses) {
if !matchAddrMasks(srcIP, tr.InboundAddresses) {
continue
}
// Deprecated, but we still check it for backward compatibility.
Expand All @@ -289,7 +317,7 @@ func MatchTrunk(trunks []*livekit.SIPTrunkInfo, srcIP, calling, called string) (
// Default/wildcard trunk.
defaultTrunk = tr
defaultTrunkCnt++
} else if tr.OutboundNumber == called {
} else if normalizeNumber(tr.OutboundNumber) == calledNorm {
// Trunk specific to the number.
if selectedTrunk != nil {
return nil, fmt.Errorf("Multiple SIP Trunks matched for %q", called)
Expand Down
73 changes: 71 additions & 2 deletions sip/sip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ var trunkCases = []struct {
exp int
expErr bool
invalid bool
from string
to string
src string
}{
{
name: "empty",
Expand Down Expand Up @@ -151,13 +154,79 @@ var trunkCases = []struct {
expErr: true,
invalid: true,
},
{
name: "inbound with ip exact",
trunks: []*livekit.SIPTrunkInfo{
{SipTrunkId: "bbb", OutboundNumber: sipNumber2, InboundAddresses: []string{
"10.10.10.10",
"1.1.1.1",
}},
},
exp: 0,
},
{
name: "inbound with ip exact miss",
trunks: []*livekit.SIPTrunkInfo{
{SipTrunkId: "bbb", OutboundNumber: sipNumber2, InboundAddresses: []string{
"10.10.10.10",
}},
},
exp: -1,
},
{
name: "inbound with ip mask",
trunks: []*livekit.SIPTrunkInfo{
{SipTrunkId: "bbb", OutboundNumber: sipNumber2, InboundAddresses: []string{
"10.10.10.0/24",
"1.1.1.0/24",
}},
},
exp: 0,
},
{
name: "inbound with ip mask miss",
trunks: []*livekit.SIPTrunkInfo{
{SipTrunkId: "bbb", OutboundNumber: sipNumber2, InboundAddresses: []string{
"10.10.10.0/24",
}},
},
exp: -1,
},
{
name: "inbound with plus",
trunks: []*livekit.SIPTrunkInfo{
{SipTrunkId: "aaa", OutboundNumber: "+" + sipNumber3},
{SipTrunkId: "bbb", OutboundNumber: "+" + sipNumber2},
},
exp: 1,
},
{
name: "inbound without plus",
trunks: []*livekit.SIPTrunkInfo{
{SipTrunkId: "aaa", OutboundNumber: sipNumber3},
{SipTrunkId: "bbb", OutboundNumber: sipNumber2},
},
from: "+" + sipNumber1,
to: "+" + sipNumber2,
exp: 1,
},
}

func TestSIPMatchTrunk(t *testing.T) {
for _, c := range trunkCases {
c := c
t.Run(c.name, func(t *testing.T) {
got, err := MatchTrunk(c.trunks, "", sipNumber1, sipNumber2)
from, to, src := c.from, c.to, c.src
if from == "" {
from = sipNumber1
}
if to == "" {
to = sipNumber2
}
if src == "" {
src = "1.1.1.1"
}
got, err := MatchTrunk(c.trunks, src, from, to)
if c.expErr {
require.Error(t, err)
require.Nil(t, got)
Expand Down Expand Up @@ -530,7 +599,7 @@ func TestMatchIP(t *testing.T) {
}
for _, c := range cases {
t.Run(c.mask, func(t *testing.T) {
got := matchAddrs(c.addr, c.mask)
got := matchAddrMask(c.addr, c.mask)
require.Equal(t, c.exp, got)
})
}
Expand Down

0 comments on commit b9a523c

Please sign in to comment.