Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a secret token to authenticate relay connections #186

Merged
merged 4 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ func main() {
&cli.StringFlag{
Name: "token",
},
&cli.StringFlag{
Name: "relay-token",
},
&cli.StringFlag{
Name: "ws-url",
},
Expand Down Expand Up @@ -248,7 +251,7 @@ func runHandler(c *cli.Context) error {

var handler interface {
Kill()
HandleIngress(ctx context.Context, info *livekit.IngressInfo, wsUrl, token string, extraParams any)
HandleIngress(ctx context.Context, info *livekit.IngressInfo, wsUrl, token, relayToken string, extraParams any)
}

bus := psrpc.NewRedisMessageBus(rc)
Expand Down Expand Up @@ -278,7 +281,7 @@ func runHandler(c *cli.Context) error {
wsUrl = c.String("ws-url")
}

handler.HandleIngress(ctx, info, wsUrl, token, ep)
handler.HandleIngress(ctx, info, wsUrl, token, c.String("relay-token"), ep)
return nil
}

Expand Down
1 change: 1 addition & 0 deletions pkg/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var (
ErrDuplicateTrack = psrpc.NewErrorf(psrpc.NotAcceptable, "more than 1 track with given media kind")
ErrUnableToAddPad = psrpc.NewErrorf(psrpc.Internal, "could not add pads to bin")
ErrMissingResourceId = psrpc.NewErrorf(psrpc.InvalidArgument, "missing resource ID")
ErrInvalidRelayToken = psrpc.NewErrorf(psrpc.PermissionDenied, "invalid token")
ErrIngressNotFound = psrpc.NewErrorf(psrpc.NotFound, "ingress not found")
ErrServerCapacityExceeded = psrpc.NewErrorf(psrpc.ResourceExhausted, "server capacity exceeded")
ErrServerShuttingDown = psrpc.NewErrorf(psrpc.Unavailable, "server shutting down")
Expand Down
4 changes: 3 additions & 1 deletion pkg/media/rtmp/appsrc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package rtmp

import (
"context"
"fmt"
"io"
"net/http"

Expand Down Expand Up @@ -74,7 +75,8 @@ func (s *RTMPRelaySource) Start(ctx context.Context) error {

s.result = make(chan error, 1)

resp, err := http.Get(s.params.RelayUrl)
relayUrl := fmt.Sprintf("%s?token=%s", s.params.RelayUrl, s.params.RelayToken)
resp, err := http.Get(relayUrl)
switch {
case err != nil:
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/media/whip/whipsrc.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,5 @@ func (s *WHIPSource) ValidateCaps(*gst.Caps) error {
}

func (s *WHIPSource) getRelayUrl(kind types.StreamKind) string {
return fmt.Sprintf("%s/%s", s.params.RelayUrl, kind)
return fmt.Sprintf("%s/%s?token=%s", s.params.RelayUrl, kind, s.params.RelayToken)
}
12 changes: 9 additions & 3 deletions pkg/params/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ type Params struct {
Token string

// relay info
RelayUrl string
TmpDir string
RelayUrl string
RelayToken string
TmpDir string

// Input type specific private parameters
ExtraParams any
Expand All @@ -73,7 +74,7 @@ func InitLogger(conf *config.Config, info *livekit.IngressInfo) error {

return nil
}
func GetParams(ctx context.Context, psrpcClient rpc.IOInfoClient, conf *config.Config, info *livekit.IngressInfo, wsUrl, token string, ep any) (*Params, error) {
func GetParams(ctx context.Context, psrpcClient rpc.IOInfoClient, conf *config.Config, info *livekit.IngressInfo, wsUrl, token, relayToken string, ep any) (*Params, error) {
var err error

// The state should have been created by the service, before launching the hander, but be defensive here.
Expand All @@ -89,6 +90,10 @@ func GetParams(ctx context.Context, psrpcClient rpc.IOInfoClient, conf *config.C
relayUrl = getWHIPRelayUrlPrefix(conf, info.State.ResourceId)
}

if relayToken == "" {
relayToken = utils.NewGuid("")
}

l := logger.GetLogger().WithValues(getLoggerFields(info)...)

tmpDir := path.Join(os.TempDir(), info.State.ResourceId)
Expand Down Expand Up @@ -139,6 +144,7 @@ func GetParams(ctx context.Context, psrpcClient rpc.IOInfoClient, conf *config.C
VideoEncodingOptions: videoEncodingOptions,
Token: token,
WsUrl: wsUrl,
RelayToken: relayToken,
RelayUrl: relayUrl,
TmpDir: tmpDir,
ExtraParams: ep,
Expand Down
3 changes: 2 additions & 1 deletion pkg/rtmp/relay_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func (h *RTMPRelayHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}()

resourceId := strings.TrimLeft(r.URL.Path, "/rtmp/")
token := r.URL.Query().Get("token")

log := logger.Logger(logger.GetLogger().WithValues("resourceID", resourceId))
log.Infow("relaying ingress")
Expand All @@ -63,7 +64,7 @@ func (h *RTMPRelayHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
close(done)
}()

err = h.rtmpServer.AssociateRelay(resourceId, pw)
err = h.rtmpServer.AssociateRelay(resourceId, token, pw)
if err != nil {
return
}
Expand Down
26 changes: 17 additions & 9 deletions pkg/rtmp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

"github.com/livekit/ingress/pkg/config"
"github.com/livekit/ingress/pkg/errors"
"github.com/livekit/ingress/pkg/params"
"github.com/livekit/ingress/pkg/stats"
"github.com/livekit/ingress/pkg/types"
"github.com/livekit/ingress/pkg/utils"
Expand All @@ -46,7 +47,7 @@ func NewRTMPServer() *RTMPServer {
return &RTMPServer{}
}

func (s *RTMPServer) Start(conf *config.Config, onPublish func(streamKey, resourceId string) (*stats.MediaStatsReporter, error)) error {
func (s *RTMPServer) Start(conf *config.Config, onPublish func(streamKey, resourceId string) (*params.Params, *stats.MediaStatsReporter, error)) error {
port := conf.RTMPPort

tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf(":%d", port))
Expand All @@ -70,19 +71,20 @@ func (s *RTMPServer) Start(conf *config.Config, onPublish func(streamKey, resour
lf := l.WithFields(conf.GetLoggerFields())

h := NewRTMPHandler()
h.OnPublishCallback(func(streamKey, resourceId string) (*stats.MediaStatsReporter, error) {
h.OnPublishCallback(func(streamKey, resourceId string) (*params.Params, *stats.MediaStatsReporter, error) {
var params *params.Params
var stats *stats.MediaStatsReporter
var err error
if onPublish != nil {
stats, err = onPublish(streamKey, resourceId)
params, stats, err = onPublish(streamKey, resourceId)
if err != nil {
return nil, err
return nil, nil, err
}
}

s.handlers.Store(resourceId, h)

return stats, nil
return params, stats, nil
})
h.OnCloseCallback(func(resourceId string) {
if h.stats != nil {
Expand Down Expand Up @@ -114,9 +116,13 @@ func (s *RTMPServer) Start(conf *config.Config, onPublish func(streamKey, resour
return nil
}

func (s *RTMPServer) AssociateRelay(resourceId string, w io.WriteCloser) error {
func (s *RTMPServer) AssociateRelay(resourceId string, token string, w io.WriteCloser) error {
h, ok := s.handlers.Load(resourceId)
if ok && h != nil {
if h.(*RTMPHandler).params.RelayToken != token {
return errors.ErrInvalidRelayToken
}

err := h.(*RTMPHandler).SetWriter(w)
if err != nil {
return err
Expand Down Expand Up @@ -151,6 +157,7 @@ type RTMPHandler struct {

flvEnc *flv.Encoder
stats *stats.MediaStatsReporter
params *params.Params
resourceId string
videoInit *flvtag.VideoData
audioInit *flvtag.AudioData
Expand All @@ -159,7 +166,7 @@ type RTMPHandler struct {

log logger.Logger

onPublish func(streamKey, resourceId string) (*stats.MediaStatsReporter, error)
onPublish func(streamKey, resourceId string) (*params.Params, *stats.MediaStatsReporter, error)
onClose func(resourceId string)
}

Expand All @@ -178,7 +185,7 @@ func NewRTMPHandler() *RTMPHandler {
return h
}

func (h *RTMPHandler) OnPublishCallback(cb func(streamKey, resourceId string) (*stats.MediaStatsReporter, error)) {
func (h *RTMPHandler) OnPublishCallback(cb func(streamKey, resourceId string) (*params.Params, *stats.MediaStatsReporter, error)) {
h.onPublish = cb
}

Expand All @@ -196,11 +203,12 @@ func (h *RTMPHandler) OnPublish(_ *rtmp.StreamContext, timestamp uint32, cmd *rt
h.resourceId = protoutils.NewGuid(protoutils.RTMPResourcePrefix)
h.log = logger.GetLogger().WithValues("streamKey", streamKey, "resourceID", h.resourceId)
if h.onPublish != nil {
stats, err := h.onPublish(streamKey, h.resourceId)
params, stats, err := h.onPublish(streamKey, h.resourceId)
if err != nil {
return err
}
h.stats = stats
h.params = params
}

h.log.Infow("Received a new published stream")
Expand Down
8 changes: 4 additions & 4 deletions pkg/service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ func NewHandler(conf *config.Config, rpcClient rpc.IOInfoClient) *Handler {
}
}

func (h *Handler) HandleIngress(ctx context.Context, info *livekit.IngressInfo, wsUrl, token string, extraParams any) {
func (h *Handler) HandleIngress(ctx context.Context, info *livekit.IngressInfo, wsUrl, token, relayToken string, extraParams any) {
ctx, span := tracer.Start(ctx, "Handler.HandleRequest")
defer span.End()

params.InitLogger(h.conf, info)

p, err := h.buildPipeline(ctx, info, wsUrl, token, extraParams)
p, err := h.buildPipeline(ctx, info, wsUrl, token, relayToken, extraParams)
if err != nil {
span.RecordError(err)
return
Expand Down Expand Up @@ -204,13 +204,13 @@ func (h *Handler) UpdateMediaStats(ctx context.Context, in *ipc.UpdateMediaStats
return &google_protobuf2.Empty{}, nil
}

func (h *Handler) buildPipeline(ctx context.Context, info *livekit.IngressInfo, wsUrl, token string, extraParams any) (*media.Pipeline, error) {
func (h *Handler) buildPipeline(ctx context.Context, info *livekit.IngressInfo, wsUrl, token, relayToken string, extraParams any) (*media.Pipeline, error) {
ctx, span := tracer.Start(ctx, "Handler.buildPipeline")
defer span.End()

// build/verify params
var p *media.Pipeline
params, err := params.GetParams(ctx, h.rpcClient, h.conf, info, wsUrl, token, extraParams)
params, err := params.GetParams(ctx, h.rpcClient, h.conf, info, wsUrl, token, relayToken, extraParams)
if err == nil {
// create the pipeline
p, err = media.New(ctx, h.conf, params)
Expand Down
1 change: 1 addition & 0 deletions pkg/service/process_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func (s *ProcessManager) launchHandler(ctx context.Context, p *params.Params) er
"run-handler",
"--config-body", string(confString),
"--info", string(infoString),
"--relay-token", p.RelayToken,
}

if p.WsUrl != "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/service/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (r *Relay) Start(conf *config.Config) error {

r.server = &http.Server{
Handler: mux,
Addr: fmt.Sprintf(":%d", port),
Addr: fmt.Sprintf("localhost:%d", port),
}

go func() {
Expand Down
14 changes: 7 additions & 7 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func NewService(conf *config.Config, psrpcClient rpc.IOInfoClient, bus psrpc.Mes
return s
}

func (s *Service) HandleRTMPPublishRequest(streamKey, resourceId string) (*stats.MediaStatsReporter, error) {
func (s *Service) HandleRTMPPublishRequest(streamKey, resourceId string) (*params.Params, *stats.MediaStatsReporter, error) {
ctx, span := tracer.Start(context.Background(), "Service.HandleRTMPPublishRequest")
defer span.End()

Expand All @@ -129,26 +129,26 @@ func (s *Service) HandleRTMPPublishRequest(streamKey, resourceId string) (*stats
var pRes publishResponse
select {
case <-s.shutdown.Watch():
return nil, errors.ErrServerShuttingDown
return nil, nil, errors.ErrServerShuttingDown
case s.publishRequests <- r:
pRes = <-res
if pRes.err != nil {
return nil, pRes.err
return nil, nil, pRes.err
}
}

err := s.manager.launchHandler(ctx, pRes.params)
if err != nil {
return nil, err
return nil, nil, err
}

api, err := s.sm.GetIngressSessionAPI(resourceId)
if err != nil {
return nil, err
return nil, nil, err
}
stats := stats.NewMediaStats(api)

return stats, nil
return pRes.params, stats, nil
}

func (s *Service) HandleWHIPPublishRequest(streamKey, resourceId string, ihs rpc.IngressHandlerServerImpl) (p *params.Params, ready func(mimeTypes map[types.StreamKind]string, err error) *stats.MediaStatsReporter, ended func(err error), err error) {
Expand Down Expand Up @@ -295,7 +295,7 @@ func (s *Service) handleNewPublisher(ctx context.Context, resourceId string, inp
}

// This validates the ingress info
p, err := params.GetParams(ctx, s.psrpcClient, conf, info, wsUrl, token, nil)
p, err := params.GetParams(ctx, s.psrpcClient, conf, info, wsUrl, token, "", nil)
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/whip/relay_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (h *WHIPRelayHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
resourceId := v[0]
kind := types.StreamKind(v[1])
token := r.URL.Query().Get("token")

log := logger.Logger(logger.GetLogger().WithValues("resourceId", resourceId, "kind", kind))
log.Infow("relaying whip ingress")
Expand Down Expand Up @@ -92,14 +93,14 @@ func (h *WHIPRelayHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
close(done)
}()

err = h.whipServer.AssociateRelay(resourceId, kind, pw)
if err != nil {
return
}

defer func() {
pw.Close()
}()

err = h.whipServer.AssociateRelay(resourceId, kind, token, pw)
if err != nil {
return
}

err = <-done
}
4 changes: 2 additions & 2 deletions pkg/whip/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@ func (s *WHIPServer) Stop() {
s.cancel()
}

func (s *WHIPServer) AssociateRelay(resourceId string, kind types.StreamKind, w io.WriteCloser) error {
func (s *WHIPServer) AssociateRelay(resourceId string, kind types.StreamKind, token string, w io.WriteCloser) error {
s.handlersLock.Lock()
h, ok := s.handlers[resourceId]
s.handlersLock.Unlock()
if ok && h != nil {
err := h.AssociateRelay(kind, w)
err := h.AssociateRelay(kind, token, w)
if err != nil {
return err
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/whip/whip_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,14 @@ func (h *whipHandler) WaitForSessionEnd(ctx context.Context) error {
}
}

func (h *whipHandler) AssociateRelay(kind types.StreamKind, w io.WriteCloser) error {
func (h *whipHandler) AssociateRelay(kind types.StreamKind, token string, w io.WriteCloser) error {
h.trackLock.Lock()
defer h.trackLock.Unlock()

if token != h.params.RelayToken {
return errors.ErrInvalidRelayToken
}

th := h.trackRelayMediaSink[kind]
if th == nil {
return errors.ErrIngressNotFound
Expand Down
Loading