From 9ec9ad200ab62726de20a02e9625ecba38986a73 Mon Sep 17 00:00:00 2001 From: Claudio Costa Date: Wed, 21 Jun 2023 11:05:50 -0600 Subject: [PATCH] [MM-52268] Add initial support for connecting through TCP candidates (#103) * Add support for TCP candidates generation * Use single TCP mux * Improve NAT pairs mapping * Fix shutdown order * Update test * Fix sample config * Improve config documentation * Initial IPv6 support (#106) * Update env variables --- .github/workflows/ci.yml | 2 + Makefile | 5 ++ config/config.sample.toml | 27 ++++++- docs/env_config.md | 3 + go.mod | 16 ++-- go.sum | 20 +++++ service/config.go | 1 + service/helper_test.go | 3 +- service/rtc/config.go | 14 ++++ service/rtc/config_test.go | 15 ++++ service/rtc/net.go | 46 +++++++---- service/rtc/net_test.go | 79 ++++++++++++++++--- service/rtc/server.go | 151 ++++++++++++++++++++++++++----------- service/rtc/server_test.go | 97 +++++++++++++++++++++++- service/rtc/sfu.go | 12 ++- service/rtc/stun.go | 6 +- service/rtc/utils.go | 74 +++++++++++++----- service/rtc/utils_test.go | 88 +++++++++++++-------- service/service.go | 7 +- 19 files changed, 526 insertions(+), 140 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e3d06dd..ef1e908 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,3 +29,5 @@ jobs: go mod download go mod verify make test + env: + CI: true diff --git a/Makefile b/Makefile index f9928f5..a184c02 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,10 @@ # Variables ## General Variables + +# CI +CI ?= false + # Branch Variables PROTECTED_BRANCH := master CURRENT_BRANCH := $(shell git rev-parse --abbrev-ref HEAD) @@ -290,6 +294,7 @@ go-test: ## to run tests $(AT)$(DOCKER) run ${DOCKER_OPTS} \ -v $(PWD):/app -w /app \ -e GOCACHE="/tmp" \ + -e CI=${CI} \ $(DOCKER_IMAGE_GO) \ /bin/sh -c \ "cd /app && \ diff --git a/config/config.sample.toml b/config/config.sample.toml index f1bc089..c736b36 100644 --- a/config/config.sample.toml +++ b/config/config.sample.toml @@ -23,15 +23,34 @@ security.admin_secret_key = "" security.session_cache.expiration_minutes = 1440 [rtc] -# The IP address used to listen for UDP packets. If left empty it has the same -# effect as using the 0.0.0.0 address, causing the server to listen on all available network interfaces. +# The IP address used to listen for UDP packets and generate UDP candidates. +# +# If left empty it has the same effect as using the 0.0.0.0 catch-all address, +# causing the server to listen on all available network interfaces. ice_address_udp = "" -# The UDP port used to route all the media (audio/screen/video tracks). +# The UDP port used to route media (audio/screen/video tracks). ice_port_udp = 8443 +# The IP address used to listen for TCP connections and generate TCP candidates. This is used to generate +# TCP candidates which may be used by client in case UDP connectivity is not available. +# +# If left empty it has the same effect as using the 0.0.0.0 catch-all address, +# causing the server to listen on all available network interfaces. +ice_address_tcp = "" +# The TCP port used to route media (audio/screen/video tracks). This is used to +# generate TCP candidates. +ice_port_tcp = 8443 +# Enables experimental IPv6 support. When this setting is true the RTC service +# will work in dual-stack mode, listening for IPv6 connections and generating +# candidates in addition to IPv4 ones. +enable_ipv6 = false # An optional hostname used to override the default value. By default, the # service will try to guess its own public IP through STUN (if configured). -# Depending on the network setup, it may be necessary to set an override. +# +# Depending on the network setup, it may be necessary to set an override. # This is the host that gets advertised to clients and that will be used to connect them to calls. +# +# For more advanced usage this value can also be a comma separated list of +# NAT mappings in the form of "external IP / internal IP" pairs, e.g. "8.8.8.8/10.0.2.2,8.8.4.4/10.0.2.1". ice_host_override = "" # A list of ICE servers (STUN/TURN) to be used by the service. It supports # advanced configurations. diff --git a/docs/env_config.md b/docs/env_config.md index 9691e10..d20372e 100644 --- a/docs/env_config.md +++ b/docs/env_config.md @@ -12,10 +12,13 @@ RTCD_API_SECURITY_ALLOWSELFREGISTRATION True or False RTCD_API_SECURITY_SESSIONCACHE_EXPIRATIONMINUTES Integer RTCD_RTC_ICEADDRESSUDP String RTCD_RTC_ICEPORTUDP Integer +RTCD_RTC_ICEADDRESSTCP String +RTCD_RTC_ICEPORTTCP Integer RTCD_RTC_ICEHOSTOVERRIDE String RTCD_RTC_ICESERVERS Comma-separated list of RTCD_RTC_TURNCONFIG_STATICAUTHSECRET String RTCD_RTC_TURNCONFIG_CREDENTIALSEXPIRATIONMINUTES Integer +RTCD_RTC_ENABLEIPV6 True or False RTCD_STORE_DATASOURCE String RTCD_LOGGER_ENABLECONSOLE True or False RTCD_LOGGER_CONSOLEJSON True or False diff --git a/go.mod b/go.mod index 44edf9c..e9602e2 100644 --- a/go.mod +++ b/go.mod @@ -9,19 +9,19 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 github.com/mattermost/mattermost/server/public v0.0.0-20230613002302-62a3ee8adcb5 github.com/pborman/uuid v1.2.1 - github.com/pion/ice/v2 v2.3.2 + github.com/pion/ice/v2 v2.3.3 github.com/pion/interceptor v0.1.12 github.com/pion/logging v0.2.2 github.com/pion/rtcp v1.2.10 github.com/pion/rtp v1.7.13 - github.com/pion/stun v0.4.0 + github.com/pion/stun v0.6.0 github.com/pion/webrtc/v3 v3.1.60 github.com/prometheus/client_golang v1.15.0 github.com/pyroscope-io/godeltaprof v0.1.1 - github.com/stretchr/testify v1.8.2 + github.com/stretchr/testify v1.8.3 github.com/vmihailenco/msgpack/v5 v5.3.5 - golang.org/x/crypto v0.8.0 - golang.org/x/sys v0.7.0 + golang.org/x/crypto v0.9.0 + golang.org/x/sys v0.8.0 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 ) @@ -39,13 +39,13 @@ require ( github.com/mattermost/logr/v2 v2.0.16 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/pion/datachannel v1.5.5 // indirect - github.com/pion/dtls/v2 v2.2.6 // indirect + github.com/pion/dtls/v2 v2.2.7 // indirect github.com/pion/mdns v0.0.7 // indirect github.com/pion/randutil v0.1.0 // indirect github.com/pion/sctp v1.8.6 // indirect github.com/pion/sdp/v3 v3.0.6 // indirect github.com/pion/srtp/v2 v2.0.12 // indirect - github.com/pion/transport/v2 v2.2.0 // indirect + github.com/pion/transport/v2 v2.2.1 // indirect github.com/pion/turn/v2 v2.1.0 // indirect github.com/pion/udp/v2 v2.0.1 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -60,7 +60,7 @@ require ( github.com/wiggin77/merror v1.0.4 // indirect github.com/wiggin77/srslog v1.0.1 // indirect golang.org/x/exp v0.0.0-20200908183739-ae8ad444f925 // indirect - golang.org/x/net v0.9.0 // indirect + golang.org/x/net v0.10.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 2a99054..47a85c0 100644 --- a/go.sum +++ b/go.sum @@ -308,8 +308,12 @@ github.com/pion/datachannel v1.5.5 h1:10ef4kwdjije+M9d7Xm9im2Y3O6A6ccQb0zcqZcJew github.com/pion/datachannel v1.5.5/go.mod h1:iMz+lECmfdCMqFRhXhcA/219B0SQlbpoR2V118yimL0= github.com/pion/dtls/v2 v2.2.6 h1:yXMxKr0Skd+Ub6A8UqXTRLSywskx93ooMRHsQUtd+Z4= github.com/pion/dtls/v2 v2.2.6/go.mod h1:t8fWJCIquY5rlQZwA2yWxUS1+OCrAdXrhVKXB5oD/wY= +github.com/pion/dtls/v2 v2.2.7 h1:cSUBsETxepsCSFSxC3mc/aDo14qQLMSL+O6IjG28yV8= +github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/ice/v2 v2.3.2 h1:vh+fi4RkZ8H5fB4brZ/jm3j4BqFgMmNs+aB3X52Hu7M= github.com/pion/ice/v2 v2.3.2/go.mod h1:AMIpuJqcpe+UwloocNebmTSWhCZM1TUCo9v7nW50jX0= +github.com/pion/ice/v2 v2.3.3 h1:uGrUwn0DanTmXgFiKDTot7iQzo0J9NjTjHPG+Kt+kNE= +github.com/pion/ice/v2 v2.3.3/go.mod h1:jVbxqPWQDK5+/V/YqpinUcP0YtDGYqd24n2lusVdX80= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.7 h1:P0UB4Sr6xDWEox0kTVxF0LmQihtCbSAdW0H2nEgkA3U= @@ -329,12 +333,17 @@ github.com/pion/srtp/v2 v2.0.12 h1:WrmiVCubGMOAObBU1vwWjG0H3VSyQHawKeer2PVA5rY= github.com/pion/srtp/v2 v2.0.12/go.mod h1:C3Ep44hlOo2qEYaq4ddsmK5dL63eLehXFbHaZ9F5V9Y= github.com/pion/stun v0.4.0 h1:vgRrbBE2htWHy7l3Zsxckk7rkjnjOsSM7PHZnBwo8rk= github.com/pion/stun v0.4.0/go.mod h1:QPsh1/SbXASntw3zkkrIk3ZJVKz4saBY2G7S10P3wCw= +github.com/pion/stun v0.5.2/go.mod h1:TNo1HjyjaFVpMZsvowqPeV8TfwRytympQC0//neaksA= +github.com/pion/stun v0.6.0 h1:JHT/2iyGDPrFWE8NNC15wnddBN8KifsEDw8swQmrEmU= +github.com/pion/stun v0.6.0/go.mod h1:HPqcfoeqQn9cuaet7AOmB5e5xkObu9DwBdurwLKO9oA= github.com/pion/transport v0.14.1 h1:XSM6olwW+o8J4SCmOBb/BpwZypkHeyM0PGFCxNQBr40= github.com/pion/transport v0.14.1/go.mod h1:4tGmbk00NeYA3rUa9+n+dzCCoKkcy3YlYb99Jn2fNnI= github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc= github.com/pion/transport/v2 v2.0.2/go.mod h1:vrz6bUbFr/cjdwbnxq8OdDDzHf7JJfGsIRkxfpZoTA0= github.com/pion/transport/v2 v2.2.0 h1:u5lFqFHkXLMXMzai8tixZDfVjb8eOjH35yCunhPeb1c= github.com/pion/transport/v2 v2.2.0/go.mod h1:AdSw4YBZVDkZm8fpoz+fclXyQwANWmZAlDuQdctTThQ= +github.com/pion/transport/v2 v2.2.1 h1:7qYnCBlpgSJNYMbLCKuSY9KbQdBFoETvPNETv0y4N7c= +github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/turn/v2 v2.1.0 h1:5wGHSgGhJhP/RpabkUb/T9PdsAjkGLS6toYz5HNzoSI= github.com/pion/turn/v2 v2.1.0/go.mod h1:yrT5XbXSGX1VFSF31A3c1kCNB5bBZgk/uu5LET162qs= github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54= @@ -451,6 +460,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tidwall/btree v0.4.2/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8= @@ -511,6 +522,8 @@ golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -606,6 +619,8 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -697,12 +712,16 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -716,6 +735,7 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/service/config.go b/service/config.go index 08b2b88..5fd2e7a 100644 --- a/service/config.go +++ b/service/config.go @@ -80,6 +80,7 @@ func (c *Config) SetDefaults() { c.API.HTTP.ListenAddress = ":8045" c.API.Security.SessionCache.ExpirationMinutes = 1440 c.RTC.ICEPortUDP = 8443 + c.RTC.ICEPortTCP = 8443 c.RTC.TURNConfig.CredentialsExpirationMinutes = 1440 c.Store.DataSource = "/tmp/rtcd_db" c.Logger.EnableConsole = true diff --git a/service/helper_test.go b/service/helper_test.go index 8601bed..81babdf 100644 --- a/service/helper_test.go +++ b/service/helper_test.go @@ -80,7 +80,8 @@ func MakeDefaultCfg(tb testing.TB) *Config { }, }, RTC: rtc.ServerConfig{ - ICEPortUDP: 30443, + ICEPortUDP: 30444, + ICEPortTCP: 30444, }, Store: StoreConfig{ DataSource: dbDir, diff --git a/service/rtc/config.go b/service/rtc/config.go index 1b6049a..ca3be09 100644 --- a/service/rtc/config.go +++ b/service/rtc/config.go @@ -15,12 +15,18 @@ type ServerConfig struct { ICEAddressUDP string `toml:"ice_address_udp"` // ICEPortUDP specifies the UDP port the RTC service should listen to. ICEPortUDP int `toml:"ice_port_udp"` + // ICEAddressTCP specifies the TCP address the RTC service should listen on. + ICEAddressTCP string `toml:"ice_address_tcp"` + // ICEPortTCP specifies the TCP port the RTC service should listen to. + ICEPortTCP int `toml:"ice_port_tcp"` // ICEHostOverride optionally specifies an IP address (or hostname) // to be used as the main host ICE candidate. ICEHostOverride string `toml:"ice_host_override"` // A list of ICE server (STUN/TURN) configurations to use. ICEServers ICEServers `toml:"ice_servers"` TURNConfig TURNConfig `toml:"turn"` + // EnableIPv6 specifies whether or not IPv6 should be used. + EnableIPv6 bool `toml:"enable_ipv6"` } func (c ServerConfig) IsValid() error { @@ -28,10 +34,18 @@ func (c ServerConfig) IsValid() error { return fmt.Errorf("invalid ICEAddressUDP value: not a valid address") } + if c.ICEAddressTCP != "" && net.ParseIP(c.ICEAddressTCP) == nil { + return fmt.Errorf("invalid ICEAddressTCP value: not a valid address") + } + if c.ICEPortUDP < 80 || c.ICEPortUDP > 49151 { return fmt.Errorf("invalid ICEPortUDP value: %d is not in allowed range [80, 49151]", c.ICEPortUDP) } + if c.ICEPortTCP < 80 || c.ICEPortTCP > 49151 { + return fmt.Errorf("invalid ICEPortTCP value: %d is not in allowed range [80, 49151]", c.ICEPortTCP) + } + if err := c.ICEServers.IsValid(); err != nil { return fmt.Errorf("invalid ICEServers value: %w", err) } diff --git a/service/rtc/config_test.go b/service/rtc/config_test.go index f7efabb..5f47a20 100644 --- a/service/rtc/config_test.go +++ b/service/rtc/config_test.go @@ -41,9 +41,23 @@ func TestServerConfigIsValid(t *testing.T) { require.Equal(t, "invalid ICEPortUDP value: 65000 is not in allowed range [80, 49151]", err.Error()) }) + t.Run("invalid ICEPortTCP", func(t *testing.T) { + var cfg ServerConfig + cfg.ICEPortUDP = 8443 + cfg.ICEPortTCP = 22 + err := cfg.IsValid() + require.Error(t, err) + require.Equal(t, "invalid ICEPortTCP value: 22 is not in allowed range [80, 49151]", err.Error()) + cfg.ICEPortTCP = 65000 + err = cfg.IsValid() + require.Error(t, err) + require.Equal(t, "invalid ICEPortTCP value: 65000 is not in allowed range [80, 49151]", err.Error()) + }) + t.Run("invalid TURNCredentialsExpirationMinutes", func(t *testing.T) { var cfg ServerConfig cfg.ICEPortUDP = 8443 + cfg.ICEPortTCP = 8443 cfg.TURNConfig.StaticAuthSecret = "secret" err := cfg.IsValid() require.Error(t, err) @@ -59,6 +73,7 @@ func TestServerConfigIsValid(t *testing.T) { var cfg ServerConfig cfg.ICEAddressUDP = "127.0.0.1" cfg.ICEPortUDP = 8443 + cfg.ICEPortTCP = 8443 cfg.TURNConfig.CredentialsExpirationMinutes = 1440 err := cfg.IsValid() require.NoError(t, err) diff --git a/service/rtc/net.go b/service/rtc/net.go index 3f4c586..0a7c110 100644 --- a/service/rtc/net.go +++ b/service/rtc/net.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "net" + "net/netip" "runtime" "syscall" "time" @@ -16,9 +17,15 @@ import ( "github.com/mattermost/mattermost/server/public/shared/mlog" ) -// getSystemIPs returns a list of all the available IPv4 addresses. -func getSystemIPs(log mlog.LoggerIFace) ([]string, error) { - var ips []string +const ( + udpSocketBufferSize = 1024 * 1024 * 16 // 16MB + tcpConnReadBufferLength = 64 + tcpSocketWriteBufferSize = 1024 * 1024 * 4 // 4MB +) + +// getSystemIPs returns a list of all the available local addresses. +func getSystemIPs(log mlog.LoggerIFace, dualStack bool) ([]netip.Addr, error) { + var ips []netip.Addr interfaces, err := net.Interfaces() if err != nil { @@ -28,36 +35,43 @@ func getSystemIPs(log mlog.LoggerIFace) ([]string, error) { for _, iface := range interfaces { // filter out inactive interfaces if iface.Flags&net.FlagUp == 0 { - log.Debug("skipping inactive interface", mlog.String("interface", iface.Name)) + log.Info("skipping inactive interface", mlog.String("interface", iface.Name)) continue } addrs, err := iface.Addrs() if err != nil { - log.Debug("failed to get addresses for interface", mlog.String("interface", iface.Name)) + log.Warn("failed to get addresses for interface", mlog.String("interface", iface.Name)) continue } for _, addr := range addrs { - ip, _, err := net.ParseCIDR(addr.String()) + prefix, err := netip.ParsePrefix(addr.String()) if err != nil { - log.Debug("failed to parse address", mlog.Err(err), mlog.String("addr", addr.String())) + log.Warn("failed to parse prefix", mlog.Err(err), mlog.String("prefix", prefix.String())) + continue + } + + ip := prefix.Addr() + + if !dualStack && ip.Is6() { + log.Debug("ignoring IPv6 address: dual stack support is disabled by config", mlog.String("addr", ip.String())) continue } - // IPv4 only (for the time being at least, see MM-50294) - if ip.To4() == nil { + if ip.Is6() && !ip.IsGlobalUnicast() { + log.Debug("ignoring non global IPv6 address", mlog.String("addr", ip.String())) continue } - ips = append(ips, ip.String()) + ips = append(ips, ip) } } return ips, nil } -func createUDPConnsForAddr(log mlog.LoggerIFace, listenAddress string) ([]net.PacketConn, error) { +func createUDPConnsForAddr(log mlog.LoggerIFace, network, listenAddress string) ([]net.PacketConn, error) { var conns []net.PacketConn for i := 0; i < runtime.NumCPU(); i++ { @@ -78,7 +92,7 @@ func createUDPConnsForAddr(log mlog.LoggerIFace, listenAddress string) ([]net.Pa }, } - udpConn, err := listenConfig.ListenPacket(context.Background(), "udp4", listenAddress) + udpConn, err := listenConfig.ListenPacket(context.Background(), network, listenAddress) if err != nil { return nil, fmt.Errorf("failed to listen on udp: %w", err) } @@ -126,12 +140,12 @@ func createUDPConnsForAddr(log mlog.LoggerIFace, listenAddress string) ([]net.Pa return conns, nil } -func resolveHost(host string, timeout time.Duration) (string, error) { +func resolveHost(host, network string, timeout time.Duration) (string, error) { var ip string r := net.Resolver{} ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - addrs, err := r.LookupIP(ctx, "ip4", host) + addrs, err := r.LookupIP(ctx, network, host) if err != nil { return ip, fmt.Errorf("failed to resolve host %q: %w", host, err) } @@ -140,3 +154,7 @@ func resolveHost(host string, timeout time.Duration) (string, error) { } return ip, err } + +func areAddressesSameStack(addrA, addrB netip.Addr) bool { + return (addrA.Is4() && addrB.Is4()) || (addrA.Is6() && addrB.Is6()) +} diff --git a/service/rtc/net_test.go b/service/rtc/net_test.go index 3ce3ff3..8ec1d98 100644 --- a/service/rtc/net_test.go +++ b/service/rtc/net_test.go @@ -4,6 +4,8 @@ package rtc import ( + "net/netip" + "os" "runtime" "testing" @@ -20,9 +22,40 @@ func TestGetSystemIPs(t *testing.T) { require.NoError(t, err) }() - ips, err := getSystemIPs(log) - require.NoError(t, err) - require.NotEmpty(t, ips) + t.Run("ipv4", func(t *testing.T) { + ips, err := getSystemIPs(log, false) + require.NoError(t, err) + require.NotEmpty(t, ips) + + for _, ip := range ips { + require.True(t, ip.Is4()) + } + }) + + t.Run("dual stack", func(t *testing.T) { + // Skipping this test in CI since IPv6 is not yet supported by Github actions. + if os.Getenv("CI") != "" { + t.Skip() + } + + ips, err := getSystemIPs(log, true) + require.NoError(t, err) + require.NotEmpty(t, ips) + + var hasIPv4 bool + var hasIPv6 bool + for _, ip := range ips { + if ip.Is4() { + hasIPv4 = true + } + if ip.Is6() { + hasIPv6 = true + } + } + + require.True(t, hasIPv4) + require.True(t, hasIPv6) + }) } func TestCreateUDPConnsForAddr(t *testing.T) { @@ -33,16 +66,38 @@ func TestCreateUDPConnsForAddr(t *testing.T) { require.NoError(t, err) }() - ips, err := getSystemIPs(log) - require.NoError(t, err) - require.NotEmpty(t, ips) + t.Run("IPv4", func(t *testing.T) { + ips, err := getSystemIPs(log, false) + require.NoError(t, err) + require.NotEmpty(t, ips) - for _, ip := range ips { - conns, err := createUDPConnsForAddr(log, ip+":30443") + for _, ip := range ips { + conns, err := createUDPConnsForAddr(log, "udp4", netip.AddrPortFrom(ip, 30443).String()) + require.NoError(t, err) + require.Len(t, conns, runtime.NumCPU()) + for _, conn := range conns { + require.NoError(t, conn.Close()) + } + } + }) + + t.Run("dual stack", func(t *testing.T) { + // Skipping this test in CI since IPv6 is not yet supported by Github actions. + if os.Getenv("CI") != "" { + t.Skip() + } + + ips, err := getSystemIPs(log, false) require.NoError(t, err) - require.Len(t, conns, runtime.NumCPU()) - for _, conn := range conns { - require.NoError(t, conn.Close()) + require.NotEmpty(t, ips) + + for _, ip := range ips { + conns, err := createUDPConnsForAddr(log, "udp", netip.AddrPortFrom(ip, 30443).String()) + require.NoError(t, err) + require.Len(t, conns, runtime.NumCPU()) + for _, conn := range conns { + require.NoError(t, conn.Close()) + } } - } + }) } diff --git a/service/rtc/server.go b/service/rtc/server.go index 889a1c2..d297727 100644 --- a/service/rtc/server.go +++ b/service/rtc/server.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "net" + "net/netip" "sync" "time" @@ -17,9 +18,9 @@ import ( ) const ( - udpSocketBufferSize = 1024 * 1024 * 16 // 16MB - msgChSize = 256 - signalingTimeout = 10 * time.Second + msgChSize = 256 + signalingTimeout = 10 * time.Second + catchAllIP = "0.0.0.0" ) type Server struct { @@ -31,8 +32,9 @@ type Server struct { sessions map[string]SessionConfig udpMux ice.UDPMux - publicAddrsMap map[string]string - localIPs []string + tcpMux ice.TCPMux + publicAddrsMap map[netip.Addr]string + localIPs []netip.Addr sendCh chan Message receiveCh chan Message @@ -62,7 +64,7 @@ func NewServer(cfg ServerConfig, log mlog.LoggerIFace, metrics Metrics) (*Server sendCh: make(chan Message, msgChSize), receiveCh: make(chan Message, msgChSize), bufPool: &sync.Pool{New: func() interface{} { return make([]byte, receiveMTU) }}, - publicAddrsMap: make(map[string]string), + publicAddrsMap: make(map[netip.Addr]string), } return s, nil @@ -82,62 +84,56 @@ func (s *Server) ReceiveCh() <-chan Message { } func (s *Server) Start() error { - var err error - var muxes []ice.UDPMux - if s.cfg.ICEAddressUDP == "" || s.cfg.ICEAddressUDP == "0.0.0.0" { - s.log.Debug("going to listen on all supported interfaces") - s.localIPs, err = getSystemIPs(s.log) - if err != nil { - return fmt.Errorf("failed to get system IPs: %w", err) - } - if len(s.localIPs) == 0 { - return fmt.Errorf("no valid address to listen on was found") - } - } else { - s.localIPs = append(s.localIPs, s.cfg.ICEAddressUDP) + udpNetwork := "udp4" + tcpNetwork := "tcp4" + + if s.cfg.EnableIPv6 { + s.log.Info("rtc: experimental IPv6 support enabled") + udpNetwork = "udp" + tcpNetwork = "tcp" + } + + localIPs, err := getSystemIPs(s.log, s.cfg.EnableIPv6) + if err != nil { + return fmt.Errorf("failed to get system IPs: %w", err) + } + if len(localIPs) == 0 { + return fmt.Errorf("no valid address to listen on was found") } - for _, ip := range s.localIPs { - listenAddress := fmt.Sprintf("%s:%d", ip, s.cfg.ICEPortUDP) + s.localIPs = localIPs - if s.cfg.ICEHostOverride == "" && len(s.cfg.ICEServers) > 0 { - udpAddr, err := net.ResolveUDPAddr("udp4", listenAddress) + s.log.Debug("rtc: found local IPs", mlog.Any("ips", s.localIPs)) + + // Populate public IP addresses map if override is not set and STUN is provided. + if s.cfg.ICEHostOverride == "" && len(s.cfg.ICEServers) > 0 { + for _, ip := range localIPs { + udpListenAddr := netip.AddrPortFrom(ip, uint16(s.cfg.ICEPortUDP)).String() + udpAddr, err := net.ResolveUDPAddr(udpNetwork, udpListenAddr) if err != nil { - return fmt.Errorf("failed to resolve UDP address: %w", err) + s.log.Error("failed to resolve UDP address", mlog.Err(err)) + continue } // TODO: consider making this logic concurrent to lower total time taken // in case of multiple interfaces. - addr, err := getPublicIP(udpAddr, s.cfg.ICEServers.getSTUN()) + addr, err := getPublicIP(udpAddr, udpNetwork, s.cfg.ICEServers.getSTUN()) if err != nil { - s.log.Warn("failed to get public IP address for local interface", mlog.String("localAddr", ip), mlog.Err(err)) + s.log.Warn("failed to get public IP address for local interface", mlog.String("localAddr", ip.String()), mlog.Err(err)) } else { - s.log.Info("got public IP address for local interface", mlog.String("localAddr", ip), mlog.String("remoteAddr", addr)) + s.log.Info("got public IP address for local interface", mlog.String("localAddr", ip.String()), mlog.String("remoteAddr", addr)) } s.publicAddrsMap[ip] = addr } + } - conns, err := createUDPConnsForAddr(s.log, listenAddress) - if err != nil { - return fmt.Errorf("failed to create UDP connections: %w", err) - } - - udpConn, err := newMultiConn(conns) - if err != nil { - return fmt.Errorf("failed to create multiconn: %w", err) - } - - muxes = append(muxes, ice.NewUDPMuxDefault(ice.UDPMuxParams{ - Logger: newPionLeveledLogger(s.log), - UDPConn: udpConn, - })) + if err := s.initUDP(localIPs, udpNetwork); err != nil { + return err } - if len(muxes) == 1 { - s.udpMux = muxes[0] - } else { - s.udpMux = ice.NewMultiUDPMuxDefault(muxes...) + if err := s.initTCP(tcpNetwork); err != nil { + return err } go s.msgReader() @@ -167,6 +163,12 @@ func (s *Server) Stop() error { } } + if s.tcpMux != nil { + if err := s.tcpMux.Close(); err != nil { + return fmt.Errorf("failed to close udp mux: %w", err) + } + } + close(s.receiveCh) close(s.sendCh) @@ -298,3 +300,62 @@ func (s *Server) msgReader() { } } } + +func (s *Server) initUDP(localIPs []netip.Addr, network string) error { + var udpMuxes []ice.UDPMux + + initUDPMux := func(addr string) error { + conns, err := createUDPConnsForAddr(s.log, network, addr) + if err != nil { + return fmt.Errorf("failed to create UDP connections: %w", err) + } + + udpConn, err := newMultiConn(conns) + if err != nil { + return fmt.Errorf("failed to create multiconn: %w", err) + } + + udpMuxes = append(udpMuxes, ice.NewUDPMuxDefault(ice.UDPMuxParams{ + Logger: newPionLeveledLogger(s.log), + UDPConn: udpConn, + })) + + return nil + } + + // If an address is specified we create a single udp mux. + if s.cfg.ICEAddressUDP != "" { + if err := initUDPMux(net.JoinHostPort(s.cfg.ICEAddressUDP, fmt.Sprintf("%d", s.cfg.ICEPortUDP))); err != nil { + return err + } + s.udpMux = udpMuxes[0] + return nil + } + + // If no address is specified we create a mux for each interface we find. + for _, ip := range localIPs { + if err := initUDPMux(netip.AddrPortFrom(ip, uint16(s.cfg.ICEPortUDP)).String()); err != nil { + return err + } + } + + s.udpMux = ice.NewMultiUDPMuxDefault(udpMuxes...) + + return nil +} + +func (s *Server) initTCP(network string) error { + tcpListener, err := net.Listen(network, net.JoinHostPort(s.cfg.ICEAddressTCP, fmt.Sprintf("%d", s.cfg.ICEPortTCP))) + if err != nil { + return fmt.Errorf("failed to create TCP listener: %w", err) + } + + s.tcpMux = ice.NewTCPMuxDefault(ice.TCPMuxParams{ + Logger: newPionLeveledLogger(s.log), + Listener: tcpListener, + ReadBufferSize: tcpConnReadBufferLength, + WriteBufferSize: tcpSocketWriteBufferSize, + }) + + return nil +} diff --git a/service/rtc/server_test.go b/service/rtc/server_test.go index 5afbcf3..fc1d64b 100644 --- a/service/rtc/server_test.go +++ b/service/rtc/server_test.go @@ -16,6 +16,7 @@ import ( "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/rtcd/logger" + "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" ) @@ -31,6 +32,7 @@ func setupServer(t *testing.T) (*Server, func()) { cfg := ServerConfig{ ICEPortUDP: 30433, + ICEPortTCP: 30433, } s, err := NewServer(cfg, log, metrics) @@ -64,6 +66,7 @@ func TestNewServer(t *testing.T) { t.Run("missing logger", func(t *testing.T) { cfg := ServerConfig{ ICEPortUDP: 30433, + ICEPortTCP: 30433, } s, err := NewServer(cfg, nil, metrics) require.Error(t, err) @@ -73,6 +76,7 @@ func TestNewServer(t *testing.T) { t.Run("missing metrics", func(t *testing.T) { cfg := ServerConfig{ ICEPortUDP: 30433, + ICEPortTCP: 30433, } s, err := NewServer(cfg, log, nil) require.Error(t, err) @@ -82,6 +86,7 @@ func TestNewServer(t *testing.T) { t.Run("valid", func(t *testing.T) { cfg := ServerConfig{ ICEPortUDP: 30433, + ICEPortTCP: 30433, } s, err := NewServer(cfg, log, metrics) require.NoError(t, err) @@ -102,6 +107,7 @@ func TestStartServer(t *testing.T) { cfg := ServerConfig{ ICEPortUDP: 30433, + ICEPortTCP: 30433, } t.Run("port unavailable", func(t *testing.T) { @@ -115,7 +121,7 @@ func TestStartServer(t *testing.T) { require.NoError(t, err) defer udpConn.Close() - ips, err := getSystemIPs(log) + ips, err := getSystemIPs(log, false) require.NoError(t, err) require.NotEmpty(t, ips) @@ -153,6 +159,7 @@ func TestDraining(t *testing.T) { cfg := ServerConfig{ ICEPortUDP: 30433, + ICEPortTCP: 30433, } metrics := perf.NewMetrics("rtcd", nil) @@ -214,6 +221,7 @@ func TestInitSession(t *testing.T) { cfg := ServerConfig{ ICEPortUDP: 30433, + ICEPortTCP: 30433, } s, err := NewServer(cfg, log, metrics) @@ -372,6 +380,7 @@ func TestCalls(t *testing.T) { cfg := ServerConfig{ ICEPortUDP: 30433, + ICEPortTCP: 30433, } s, err := NewServer(cfg, log, metrics) @@ -433,3 +442,89 @@ func TestCalls(t *testing.T) { err = s.Stop() require.NoError(t, err) } + +func TestTCPCandidates(t *testing.T) { + log, err := logger.New(logger.Config{ + EnableConsole: true, + ConsoleLevel: "INFO", + }) + require.NoError(t, err) + defer func() { + err := log.Shutdown() + require.NoError(t, err) + }() + + metrics := perf.NewMetrics("rtcd", nil) + require.NotNil(t, metrics) + + serverCfg := ServerConfig{ + ICEPortUDP: 30433, + ICEPortTCP: 30433, + } + + s, err := NewServer(serverCfg, log, metrics) + require.NoError(t, err) + require.NotNil(t, s) + + err = s.Start() + require.NoError(t, err) + + cfg := SessionConfig{ + GroupID: random.NewID(), + CallID: random.NewID(), + UserID: random.NewID(), + SessionID: random.NewID(), + } + err = s.InitSession(cfg, nil) + require.NoError(t, err) + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() + + dc, err := pc.CreateDataChannel("calls-dc", nil) + require.NoError(t, err) + require.NotNil(t, dc) + + offer, err := pc.CreateOffer(nil) + require.NoError(t, err) + + err = pc.SetLocalDescription(offer) + require.NoError(t, err) + + offerData, err := json.Marshal(&offer) + require.NoError(t, err) + + err = s.Send(Message{ + GroupID: cfg.GroupID, + CallID: cfg.CallID, + UserID: cfg.UserID, + SessionID: cfg.SessionID, + Type: SDPMessage, + Data: offerData, + }) + require.NoError(t, err) + + for msg := range s.ReceiveCh() { + if msg.Type == ICEMessage { + data := make(map[string]any) + err := json.Unmarshal(msg.Data, &data) + require.NoError(t, err) + + iceString := data["candidate"].(map[string]interface{})["candidate"].(string) + + candidate, err := ice.UnmarshalCandidate(iceString) + require.NoError(t, err) + + require.Equal(t, ice.CandidateTypeHost, candidate.Type()) + require.Equal(t, serverCfg.ICEPortTCP, candidate.Port()) + + if candidate.NetworkType() == ice.NetworkTypeTCP4 { + break + } + } + } + + err = s.CloseSession(cfg.SessionID) + require.NoError(t, err) +} diff --git a/service/rtc/sfu.go b/service/rtc/sfu.go index f47ccfa..010cea0 100644 --- a/service/rtc/sfu.go +++ b/service/rtc/sfu.go @@ -187,9 +187,19 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { sEngine := webrtc.SettingEngine{} sEngine.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + networkTypes := []webrtc.NetworkType{ + webrtc.NetworkTypeUDP4, + webrtc.NetworkTypeTCP4, + } + if s.cfg.EnableIPv6 { + networkTypes = append(networkTypes, webrtc.NetworkTypeUDP6, webrtc.NetworkTypeTCP6) + } + sEngine.SetNetworkTypes(networkTypes) sEngine.SetICEUDPMux(s.udpMux) + sEngine.SetICETCPMux(s.tcpMux) + sEngine.SetIncludeLoopbackCandidate(true) - pairs, err := generateAddrsPairs(s.localIPs, s.publicAddrsMap, s.cfg.ICEHostOverride) + pairs, err := generateAddrsPairs(s.localIPs, s.publicAddrsMap, s.cfg.ICEHostOverride, s.cfg.EnableIPv6) if err != nil { return fmt.Errorf("failed to generate addresses pairs: %w", err) } else if len(pairs) > 0 { diff --git a/service/rtc/stun.go b/service/rtc/stun.go index ffc0558..635f9a3 100644 --- a/service/rtc/stun.go +++ b/service/rtc/stun.go @@ -12,19 +12,19 @@ import ( "github.com/pion/stun" ) -func getPublicIP(addr *net.UDPAddr, stunURL string) (string, error) { +func getPublicIP(addr *net.UDPAddr, network, stunURL string) (string, error) { if stunURL == "" { return "", fmt.Errorf("no STUN server URL was provided") } - conn, err := net.ListenUDP("udp4", addr) + conn, err := net.ListenUDP(network, addr) if err != nil { return "", err } defer conn.Close() serverURL := stunURL[strings.Index(stunURL, ":")+1:] - serverAddr, err := net.ResolveUDPAddr("udp", serverURL) + serverAddr, err := net.ResolveUDPAddr(network, serverURL) if err != nil { return "", fmt.Errorf("failed to resolve stun host: %w", err) } diff --git a/service/rtc/utils.go b/service/rtc/utils.go index 68d3439..e868555 100644 --- a/service/rtc/utils.go +++ b/service/rtc/utils.go @@ -5,6 +5,8 @@ package rtc import ( "fmt" + "net/netip" + "strings" "time" "github.com/mattermost/rtcd/service/random" @@ -28,43 +30,77 @@ func getTrackType(kind webrtc.RTPCodecType) string { return "unknown" } -func generateAddrsPairs(localIPs []string, publicAddrsMap map[string]string, hostOverride string) ([]string, error) { +func generateAddrsPairs(localIPs []netip.Addr, publicAddrsMap map[netip.Addr]string, hostOverride string, dualStack bool) ([]string, error) { var err error var pairs []string var hostOverrideIP string + // If the override is in full NAT mapping format (e.g. "EA/IA,EB/IB") we return + // that directly. + if strings.Contains(hostOverride, "/") { + return strings.Split(hostOverride, ","), nil + } + + ipNetwork := "ip4" + if dualStack { + ipNetwork = "ip" + } + + // If the override is set we resolve it in case it's a hostname. if hostOverride != "" { - hostOverrideIP, err = resolveHost(hostOverride, time.Second) + hostOverrideIP, err = resolveHost(hostOverride, ipNetwork, time.Second) if err != nil { return pairs, fmt.Errorf("failed to resolve host: %w", err) } } - usedPublicAddrs := map[string]bool{} - for _, localAddr := range localIPs { - publicAddr := publicAddrsMap[localAddr] + // Nothing to do at this point if no local IP was found. + if len(localIPs) == 0 { + return nil, nil + } - // If an override was explicitly provided we enforce that. - if hostOverrideIP != "" { - publicAddr = hostOverrideIP + // If the override is set but no explicit mapping is given, we try to + // generate one. + if hostOverrideIP != "" { + hostOverrideAddr, err := netip.ParseAddr(hostOverrideIP) + if err != nil { + return nil, fmt.Errorf("failed to parse hostOverrideIP: %w", err) } - if publicAddr != "" && !usedPublicAddrs[publicAddr] { - // if a public IP has not been used yet we map it to - // the first matching local ip. - pairs = append(pairs, fmt.Sprintf("%s/%s", publicAddr, localAddr)) - usedPublicAddrs[publicAddr] = true - } else { - // if a public IP has been used already we map - // any successive matching local ips to themselves. - pairs = append(pairs, fmt.Sprintf("%s/%s", localAddr, localAddr)) + // If only one local interface is found, we map that to the given public ip + // override. + if len(localIPs) == 1 && areAddressesSameStack(hostOverrideAddr, localIPs[0]) { + return []string{ + fmt.Sprintf("%s/%s", hostOverrideAddr.String(), localIPs[0].String()), + }, nil } + + // Otherwise we map the override to any non-loopback IP. + for _, localAddr := range localIPs { + if localAddr.IsLoopback() { + pairs = append(pairs, fmt.Sprintf("%s/%s", localAddr.String(), localAddr.String())) + } else if areAddressesSameStack(hostOverrideAddr, localAddr) { + pairs = append(pairs, fmt.Sprintf("%s/%s", hostOverrideAddr.String(), localAddr.String())) + } + } + + return pairs, nil } - // If no public address was found/set there's no point in generating pairs. - if len(usedPublicAddrs) == 0 { + // Nothing to do if no public address was found. + if len(publicAddrsMap) == 0 { return nil, nil } + // We finally try to generate a mapping from any public IP we have + // found through STUN. + for _, localAddr := range localIPs { + publicAddr := publicAddrsMap[localAddr] + if publicAddr == "" { + publicAddr = localAddr.String() + } + pairs = append(pairs, fmt.Sprintf("%s/%s", publicAddr, localAddr.String())) + } + return pairs, nil } diff --git a/service/rtc/utils_test.go b/service/rtc/utils_test.go index f879767..d4fe045 100644 --- a/service/rtc/utils_test.go +++ b/service/rtc/utils_test.go @@ -4,6 +4,7 @@ package rtc import ( + "net/netip" "testing" "github.com/stretchr/testify/require" @@ -11,56 +12,80 @@ import ( func TestGenerateAddrsPairs(t *testing.T) { t.Run("nil/empty inputs", func(t *testing.T) { - pairs, err := generateAddrsPairs(nil, nil, "") + pairs, err := generateAddrsPairs(nil, nil, "", false) require.NoError(t, err) require.Empty(t, pairs) - pairs, err = generateAddrsPairs([]string{}, map[string]string{}, "") + pairs, err = generateAddrsPairs([]netip.Addr{}, map[netip.Addr]string{}, "", false) require.NoError(t, err) require.Empty(t, pairs) }) t.Run("no public addresses", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "", - "10.1.1.1": "", - }, "") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "", + netip.MustParseAddr("10.1.1.1"): "", + }, "", false) require.NoError(t, err) - require.Empty(t, pairs) + require.Equal(t, []string{"127.0.0.1/127.0.0.1", "10.1.1.1/10.1.1.1"}, pairs) + }) + + t.Run("full NAT mapping", func(t *testing.T) { + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{}, "1.1.1.1/127.0.0.1,1.1.1.1/10.1.1.1", false) + require.NoError(t, err) + require.Equal(t, []string{"1.1.1.1/127.0.0.1", "1.1.1.1/10.1.1.1"}, pairs) }) t.Run("no public addresses with override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "", - "10.1.1.1": "", - }, "1.1.1.1") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "", + netip.MustParseAddr("10.1.1.1"): "", + }, "1.1.1.1", false) require.NoError(t, err) - require.Equal(t, []string{"1.1.1.1/127.0.0.1", "10.1.1.1/10.1.1.1"}, pairs) + require.Equal(t, []string{"127.0.0.1/127.0.0.1", "1.1.1.1/10.1.1.1"}, pairs) }) t.Run("single public address for multiple local addrs, no override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "1.1.1.1", - "10.1.1.1": "1.1.1.1", - }, "") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "", + netip.MustParseAddr("10.1.1.1"): "1.1.1.1", + }, "", false) require.NoError(t, err) - require.Equal(t, []string{"1.1.1.1/127.0.0.1", "10.1.1.1/10.1.1.1"}, pairs) + require.Equal(t, []string{"127.0.0.1/127.0.0.1", "1.1.1.1/10.1.1.1"}, pairs) }) t.Run("single local/public address map, no override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "", - "10.1.1.1": "1.1.1.1", - }, "") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "", + netip.MustParseAddr("10.1.1.1"): "1.1.1.1", + }, "", false) require.NoError(t, err) require.Equal(t, []string{"127.0.0.1/127.0.0.1", "1.1.1.1/10.1.1.1"}, pairs) }) t.Run("multiple public addresses, no override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "1.1.1.1", - "10.1.1.1": "1.1.1.2", - }, "") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "1.1.1.1", + netip.MustParseAddr("10.1.1.1"): "1.1.1.2", + }, "", false) require.NoError(t, err) require.Equal(t, []string{"1.1.1.1/127.0.0.1", "1.1.1.2/10.1.1.1"}, pairs) }) @@ -68,11 +93,14 @@ func TestGenerateAddrsPairs(t *testing.T) { // This is not a case that would happen in the application because the // override would prevent us from finding public IPs. t.Run("multiple public addresses, with override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "1.1.1.1", - "10.1.1.1": "1.1.1.2", - }, "8.8.8.8") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "1.1.1.1", + netip.MustParseAddr("10.1.1.1"): "1.1.1.2", + }, "8.8.8.8", false) require.NoError(t, err) - require.Equal(t, []string{"8.8.8.8/127.0.0.1", "10.1.1.1/10.1.1.1"}, pairs) + require.Equal(t, []string{"127.0.0.1/127.0.0.1", "8.8.8.8/10.1.1.1"}, pairs) }) } diff --git a/service/service.go b/service/service.go index 5e430c6..022e195 100644 --- a/service/service.go +++ b/service/service.go @@ -122,6 +122,8 @@ func New(cfg Config) (*Service, error) { } func (s *Service) Start() error { + defer s.log.Flush() + if err := s.apiServer.Start(); err != nil { return fmt.Errorf("failed to start api server: %w", err) } @@ -185,18 +187,19 @@ func (s *Service) Start() error { } func (s *Service) Stop() error { + defer s.log.Flush() s.log.Info("rtcd: shutting down") if err := s.rtcServer.Stop(); err != nil { return fmt.Errorf("failed to stop rtc server: %w", err) } + s.wsServer.Close() + if err := s.apiServer.Stop(); err != nil { return fmt.Errorf("failed to stop api server: %w", err) } - s.wsServer.Close() - if err := s.store.Close(); err != nil { return fmt.Errorf("failed to close store: %w", err) }