diff --git a/rpc/sip.go b/rpc/sip.go index c260afc7b..516da473e 100644 --- a/rpc/sip.go +++ b/rpc/sip.go @@ -1,6 +1,10 @@ package rpc -import "github.com/livekit/protocol/livekit" +import ( + "strings" + + "github.com/livekit/protocol/livekit" +) // NewCreateSIPParticipantRequest fills InternalCreateSIPParticipantRequest from // livekit.CreateSIPParticipantRequest and livekit.SIPTrunkInfo. @@ -9,11 +13,23 @@ func NewCreateSIPParticipantRequest( req *livekit.CreateSIPParticipantRequest, trunk *livekit.SIPTrunkInfo, ) *InternalCreateSIPParticipantRequest { + // A sanity check for the number format for well-known providers. + outboundNumber := trunk.OutboundNumber + switch { + case strings.HasSuffix(trunk.OutboundAddress, "telnyx.com"): + // Telnyx omits leading '+' by default. + outboundNumber = strings.TrimPrefix(outboundNumber, "+") + case strings.HasSuffix(trunk.OutboundAddress, "twilio.com"): + // Twilio requires leading '+'. + if !strings.HasPrefix(outboundNumber, "+") { + outboundNumber = "+" + outboundNumber + } + } return &InternalCreateSIPParticipantRequest{ SipCallId: callID, Address: trunk.OutboundAddress, Transport: trunk.Transport, - Number: trunk.OutboundNumber, + Number: outboundNumber, Username: trunk.OutboundUsername, Password: trunk.OutboundPassword, CallTo: req.SipCallTo, diff --git a/sip/sip.go b/sip/sip.go index bd2c7d24f..6ca46f841 100644 --- a/sip/sip.go +++ b/sip/sip.go @@ -139,7 +139,7 @@ func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo) error { } for _, trunk := range trunks { for _, number := range numbers { - key := ruleKey{Pin: pin, Trunk: trunk, Number: number} + key := ruleKey{Pin: pin, Trunk: trunk, Number: normalizeNumber(number)} r2 := byRuleKey[key] if r2 != nil { return fmt.Errorf("Conflicting SIP Dispatch Rules: same Trunk+Number+PIN combination for for %q and %q", @@ -190,6 +190,18 @@ func printNumber(s string) string { return strconv.Quote(s) } +func normalizeNumber(num string) string { + if num == "" { + return "" + } + // TODO: Always keep "number" as-is if it's not E.164. + // This will only matter for native SIP clients which have '+' in the username. + if !strings.HasPrefix(num, `+`) { + num = "+" + num + } + return num +} + // ValidateTrunks checks a set of trunks for conflicts. func ValidateTrunks(trunks []*livekit.SIPTrunkInfo) error { if len(trunks) == 0 { @@ -200,10 +212,11 @@ func ValidateTrunks(trunks []*livekit.SIPTrunkInfo) error { if len(t.InboundNumbersRegex) != 0 { continue // can't effectively validate these } - byInbound := byOutboundAndInbound[t.OutboundNumber] + outboundKey := normalizeNumber(t.OutboundNumber) + byInbound := byOutboundAndInbound[outboundKey] if byInbound == nil { byInbound = make(map[string]*livekit.SIPTrunkInfo) - byOutboundAndInbound[t.OutboundNumber] = byInbound + byOutboundAndInbound[outboundKey] = byInbound } if len(t.InboundNumbers) == 0 { if t2 := byInbound[""]; t2 != nil { @@ -213,19 +226,20 @@ func ValidateTrunks(trunks []*livekit.SIPTrunkInfo) error { byInbound[""] = t } else { for _, num := range t.InboundNumbers { - t2 := byInbound[num] + inboundKey := normalizeNumber(num) + t2 := byInbound[inboundKey] if t2 != nil { return fmt.Errorf("Conflicting SIP Trunks: %q and %q, using the same OutboundNumber %s and InboundNumber %q", printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumber(t.OutboundNumber), num) } - byInbound[num] = t + byInbound[inboundKey] = t } } } return nil } -func matchAddrs(addr string, mask string) bool { +func matchAddrMask(addr string, mask string) bool { if !strings.Contains(mask, "/") { return addr == mask } @@ -240,16 +254,29 @@ func matchAddrs(addr string, mask string) bool { return pref.Contains(ip) } -func matchAddr(addr string, masks []string) bool { - if addr == "" { +func matchAddrMasks(addr string, masks []string) bool { + if addr == "" || len(masks) == 0 { return true } for _, mask := range masks { - if !matchAddrs(addr, mask) { - return false + if matchAddrMask(addr, mask) { + return true + } + } + return false +} + +func matchNumbers(num string, allowed []string) bool { + if len(allowed) == 0 { + return true + } + num = normalizeNumber(num) + for _, allow := range allowed { + if num == normalizeNumber(allow) { + return true } } - return true + return false } // MatchTrunk finds a SIP Trunk definition matching the request. @@ -260,12 +287,13 @@ func MatchTrunk(trunks []*livekit.SIPTrunkInfo, srcIP, calling, called string) ( defaultTrunk *livekit.SIPTrunkInfo defaultTrunkCnt int // to error in case there are multiple ones ) + calledNorm := normalizeNumber(called) for _, tr := range trunks { // Do not consider it if number doesn't match. - if len(tr.InboundNumbers) != 0 && !slices.Contains(tr.InboundNumbers, calling) { + if !matchNumbers(calling, tr.InboundNumbers) { continue } - if !matchAddr(srcIP, tr.InboundAddresses) { + if !matchAddrMasks(srcIP, tr.InboundAddresses) { continue } // Deprecated, but we still check it for backward compatibility. @@ -289,7 +317,7 @@ func MatchTrunk(trunks []*livekit.SIPTrunkInfo, srcIP, calling, called string) ( // Default/wildcard trunk. defaultTrunk = tr defaultTrunkCnt++ - } else if tr.OutboundNumber == called { + } else if normalizeNumber(tr.OutboundNumber) == calledNorm { // Trunk specific to the number. if selectedTrunk != nil { return nil, fmt.Errorf("Multiple SIP Trunks matched for %q", called) diff --git a/sip/sip_test.go b/sip/sip_test.go index a2fa6fffc..93dc51dce 100644 --- a/sip/sip_test.go +++ b/sip/sip_test.go @@ -39,6 +39,9 @@ var trunkCases = []struct { exp int expErr bool invalid bool + from string + to string + src string }{ { name: "empty", @@ -151,13 +154,79 @@ var trunkCases = []struct { expErr: true, invalid: true, }, + { + name: "inbound with ip exact", + trunks: []*livekit.SIPTrunkInfo{ + {SipTrunkId: "bbb", OutboundNumber: sipNumber2, InboundAddresses: []string{ + "10.10.10.10", + "1.1.1.1", + }}, + }, + exp: 0, + }, + { + name: "inbound with ip exact miss", + trunks: []*livekit.SIPTrunkInfo{ + {SipTrunkId: "bbb", OutboundNumber: sipNumber2, InboundAddresses: []string{ + "10.10.10.10", + }}, + }, + exp: -1, + }, + { + name: "inbound with ip mask", + trunks: []*livekit.SIPTrunkInfo{ + {SipTrunkId: "bbb", OutboundNumber: sipNumber2, InboundAddresses: []string{ + "10.10.10.0/24", + "1.1.1.0/24", + }}, + }, + exp: 0, + }, + { + name: "inbound with ip mask miss", + trunks: []*livekit.SIPTrunkInfo{ + {SipTrunkId: "bbb", OutboundNumber: sipNumber2, InboundAddresses: []string{ + "10.10.10.0/24", + }}, + }, + exp: -1, + }, + { + name: "inbound with plus", + trunks: []*livekit.SIPTrunkInfo{ + {SipTrunkId: "aaa", OutboundNumber: "+" + sipNumber3}, + {SipTrunkId: "bbb", OutboundNumber: "+" + sipNumber2}, + }, + exp: 1, + }, + { + name: "inbound without plus", + trunks: []*livekit.SIPTrunkInfo{ + {SipTrunkId: "aaa", OutboundNumber: sipNumber3}, + {SipTrunkId: "bbb", OutboundNumber: sipNumber2}, + }, + from: "+" + sipNumber1, + to: "+" + sipNumber2, + exp: 1, + }, } func TestSIPMatchTrunk(t *testing.T) { for _, c := range trunkCases { c := c t.Run(c.name, func(t *testing.T) { - got, err := MatchTrunk(c.trunks, "", sipNumber1, sipNumber2) + from, to, src := c.from, c.to, c.src + if from == "" { + from = sipNumber1 + } + if to == "" { + to = sipNumber2 + } + if src == "" { + src = "1.1.1.1" + } + got, err := MatchTrunk(c.trunks, src, from, to) if c.expErr { require.Error(t, err) require.Nil(t, got) @@ -530,7 +599,7 @@ func TestMatchIP(t *testing.T) { } for _, c := range cases { t.Run(c.mask, func(t *testing.T) { - got := matchAddrs(c.addr, c.mask) + got := matchAddrMask(c.addr, c.mask) require.Equal(t, c.exp, got) }) }