Skip to content

Commit

Permalink
Support proxy RTMP to backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
winlinvip committed Aug 28, 2024
1 parent 1bd1e48 commit 301405d
Show file tree
Hide file tree
Showing 11 changed files with 357 additions and 25 deletions.
3 changes: 2 additions & 1 deletion proxy/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,11 @@ func (v *systemAPI) Run(ctx context.Context) error {
srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid
srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api
srs.SRT, srs.RTC = srt, rtc
srs.UpdatedAt = time.Now()
})
srsLoadBalancer.Update(server)

logger.Df(ctx, "Register SRS media server, %v", server)
logger.Df(ctx, "Register SRS media server, %+v", server)
return nil
}(); err != nil {
apiError(ctx, w, r, err)
Expand Down
9 changes: 7 additions & 2 deletions proxy/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,18 @@ func setupDefaultEnv(ctx context.Context) {
// The API server of proxy itself.
setEnvDefault("PROXY_SYSTEM_API", "2025")

// Default backend server IP.
//setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1")
// Default backend server port.
//setEnvDefault("PROXY_DEFAULT_BACKEND_PORT", "1935")

logger.Df(ctx, "load .env as GO_PPROF=%v, "+
"PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+
"PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+
"PROXY_SYSTEM_API=%v",
"PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_PORT=%v",
envGoPprof(),
envForceQuitTimeout(), envGraceQuitTimeout(),
envHttpAPI(), envHttpServer(), envRtmpServer(),
envSystemAPI(),
envSystemAPI(), envDefaultBackendIP(), envDefaultBackendPort(),
)
}
3 changes: 2 additions & 1 deletion proxy/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (
"fmt"
"net/http"
"os"
"srs-proxy/logger"
"strings"
"sync"
"time"

"srs-proxy/logger"
)

type httpServer struct {
Expand Down
4 changes: 2 additions & 2 deletions proxy/logger/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type key string
var cidKey key = "cid.proxy.ossrs.org"

// generateContextID generates a random context id in string.
func generateContextID() string {
func GenerateContextID() string {
randomBytes := make([]byte, 32)
_, _ = rand.Read(randomBytes)
hash := sha256.Sum256(randomBytes)
Expand All @@ -26,7 +26,7 @@ func generateContextID() string {

// WithContext creates a new context with cid, which will be used for log.
func WithContext(ctx context.Context) context.Context {
return context.WithValue(ctx, cidKey, generateContextID())
return context.WithValue(ctx, cidKey, GenerateContextID())
}

// ContextID returns the cid in context, or empty string if not set.
Expand Down
4 changes: 4 additions & 0 deletions proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package main
import (
"context"
"os"

"srs-proxy/errors"
"srs-proxy/logger"
)
Expand Down Expand Up @@ -46,6 +47,9 @@ func doMain(ctx context.Context) error {
// Start the Go pprof if enabled.
handleGoPprof(ctx)

// Initialize SRS load balancers.
srsLoadBalancer.Initialize(ctx)

// Parse the gracefully quit timeout.
gracefulQuitTimeout, err := parseGracefullyQuitTimeout()
if err != nil {
Expand Down
233 changes: 226 additions & 7 deletions proxy/rtmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ package main

import (
"context"
"fmt"
"io"
"math/rand"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -147,7 +149,9 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error {
if err := client.WritePacket(ctx, connectRes, 0); err != nil {
return errors.Wrapf(err, "write connect res")
}
logger.Df(ctx, "RTMP connect app %v", connectReq.TcUrl())

tcUrl := connectReq.TcUrl()
logger.Df(ctx, "RTMP connect app %v", tcUrl)

// Expect RTMP command to identify the client, a publisher or viewer.
var currentStreamID int
Expand All @@ -166,8 +170,8 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error {
identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID)
response = identifyRes

identifyRes.StreamID = 1
currentStreamID = int(identifyRes.StreamID)
currentStreamID = 1
identifyRes.StreamID = *rtmp.NewAmf0Number(float64(currentStreamID))
} else {
// For releaseStream, FCPublish, etc.
identifyRes := rtmp.NewCallPacket()
Expand Down Expand Up @@ -201,7 +205,20 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error {
}
}
}
logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v",
tcUrl, streamName, currentStreamID, clientType)

// Find a backend SRS server to proxy the RTMP stream.
backend := NewRTMPClient(func(client *RTMPClient) {
client.rd = v.rd
})
defer backend.Close()

if err := backend.Connect(ctx, tcUrl, streamName); err != nil {
return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName)
}

// Start the streaming.
if clientType == RTMPClientTypePublisher {
identifyRes := rtmp.NewCallPacket()

Expand All @@ -219,17 +236,32 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error {
return errors.Wrapf(err, "start publish")
}
}
logger.Df(ctx, "RTMP identify stream=%v, id=%v, type=%v",
streamName, currentStreamID, clientType)
logger.Df(ctx, "RTMP start streaming")

// Proxy all message from backend to client.
go func() {
for {
m, err := backend.client.ReadMessage(ctx)
if err != nil {
return
}

if err := client.WriteMessage(ctx, m); err != nil {
return
}
}
}()

// Proxy all messages from client to backend.
for {
m, err := client.ReadMessage(ctx)
if err != nil {
return errors.Wrapf(err, "read message")
}

_ = m
logger.Df(ctx, "Got message %v, %v bytes", m.MessageType, len(m.Payload))
if err := backend.client.WriteMessage(ctx, m); err != nil {
return errors.Wrapf(err, "write message")
}
}

return nil
Expand All @@ -240,3 +272,190 @@ type RTMPClientType string
const (
RTMPClientTypePublisher RTMPClientType = "publisher"
)

type RTMPClient struct {
// The random number generator.
rd *rand.Rand
// The underlayer tcp client.
tcpConn *net.TCPConn
// The RTMP protocol client.
client *rtmp.Protocol
}

func NewRTMPClient(opts ...func(*RTMPClient)) *RTMPClient {
v := &RTMPClient{}
for _, opt := range opts {
opt(v)
}
return v
}

func (v *RTMPClient) Close() error {
if v.tcpConn != nil {
v.tcpConn.Close()
}
return nil
}

func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) error {
// Pick a backend SRS server to proxy the RTMP stream.
streamURL := fmt.Sprintf("%v/%v", tcUrl, streamName)
backend, err := srsLoadBalancer.Pick(streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}

// Parse RTMP port from backend.
if len(backend.RTMP) == 0 {
return errors.Errorf("no rtmp server for %v", streamURL)
}

var rtmpPort int
if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse backend %v rtmp port %v", backend, backend.RTMP[0])
} else {
rtmpPort = int(iv)
}

// Connect to backend SRS server via TCP client.
addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort}
c, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend)
}
v.tcpConn = c

hs := rtmp.NewHandshake(v.rd)
client := rtmp.NewProtocol(c)
v.client = client

// Simple RTMP handshake with server.
if err := hs.WriteC0S0(c); err != nil {
return errors.Wrapf(err, "write c0")
}
if err := hs.WriteC1S1(c); err != nil {
return errors.Wrapf(err, "write c1")
}

if _, err = hs.ReadC0S0(c); err != nil {
return errors.Wrapf(err, "read s0")
}
if _, err := hs.ReadC1S1(c); err != nil {
return errors.Wrapf(err, "read s1")
}
if _, err = hs.ReadC2S2(c); err != nil {
return errors.Wrapf(err, "read c2")
}
logger.Df(ctx, "backend simple handshake done, server=%v", addr)

if err := hs.WriteC2S2(c, hs.C1S1()); err != nil {
return errors.Wrapf(err, "write c2")
}

// Connect RTMP app on tcUrl with server.
if true {
connectApp := rtmp.NewConnectAppPacket()
connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl))
if err := client.WritePacket(ctx, connectApp, 1); err != nil {
return errors.Wrapf(err, "write connect app")
}
}

if true {
var connectAppRes *rtmp.ConnectAppResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil {
return errors.Wrapf(err, "expect connect app res")
}
logger.Df(ctx, "backend connect RTMP app, id=%v", connectAppRes.SrsID())
}

// Publish RTMP stream with server.
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "releaseStream"
identifyReq.TransactionID = 2
identifyReq.CommandObject = rtmp.NewAmf0Null()
identifyReq.Args = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
return errors.Wrapf(err, "releaseStream")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect releaseStream res")
}
if identifyRes.CommandName == "_result" {
break
}
}

if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "FCPublish"
identifyReq.TransactionID = 3
identifyReq.CommandObject = rtmp.NewAmf0Null()
identifyReq.Args = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
return errors.Wrapf(err, "FCPublish")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect FCPublish res")
}
if identifyRes.CommandName == "_result" {
break
}
}

if true {
createStream := rtmp.NewCreateStreamPacket()
createStream.TransactionID = 4
createStream.CommandObject = rtmp.NewAmf0Null()
if err := client.WritePacket(ctx, createStream, 0); err != nil {
return errors.Wrapf(err, "createStream")
}
}
var currentStreamID int
for {
var identifyRes *rtmp.CreateStreamResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect createStream res")
}
if sid := identifyRes.StreamID; sid != 0 {
currentStreamID = int(sid)
break
}
}

if true {
publishStream := rtmp.NewPublishPacket()
publishStream.TransactionID = 5
publishStream.CommandObject = rtmp.NewAmf0Null()
publishStream.StreamName = *rtmp.NewAmf0String(streamName)
publishStream.StreamType = *rtmp.NewAmf0String("live")
if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil {
return errors.Wrapf(err, "publish")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect publish res")
}
// Ignore onFCPublish, expect onStatus(NetStream.Publish.Start).
if identifyRes.CommandName == "onStatus" {
if data := rtmp.Amf0AnyToObject(identifyRes.Args); data == nil {
return errors.Errorf("onStatus args not object")
} else if code := rtmp.Amf0AnyToString(data.Get("code")); *code != "NetStream.Publish.Start" {
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code)
}
break
}
}
logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID)

return nil
}
19 changes: 19 additions & 0 deletions proxy/rtmp/amf0.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,25 @@ type amf0Any interface {
amf0Marker() amf0Marker
}

func Amf0AnyToObject(a amf0Any) *amf0Object {
return amf0AnyTo[*amf0Object](a)
}

func Amf0AnyToString(a amf0Any) *amf0String {
return amf0AnyTo[*amf0String](a)
}

// Convert any to specified object.
func amf0AnyTo[T amf0Any](a amf0Any) T {
var to T
if a != nil {
if v, ok := a.(T); ok {
return v
}
}
return to
}

// Discovery the amf0 object from the bytes b.
func Amf0Discovery(p []byte) (a amf0Any, err error) {
if len(p) < 1 {
Expand Down
Loading

0 comments on commit 301405d

Please sign in to comment.