Skip to content

Commit

Permalink
feat: add protocol proxy option (#16)
Browse files Browse the repository at this point in the history
* Add protocol proxy option

* docs: update config example

* chore: use Sprintf to construct proxy protocol msg
  • Loading branch information
taoky authored Aug 4, 2023
1 parent 05314ab commit 211b8dd
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ rsync-proxy 可以根据 module name 反向代理不同 host 上的 rsync daemon

```shell
mkdir /etc/rsync-proxy
cp cofig.example.toml /etc/rsync-proxy/config.toml
cp config.example.toml /etc/rsync-proxy/config.toml
vim /etc/rsync-proxy/config.toml # 根据实际情况修改配置
```

Expand Down
7 changes: 7 additions & 0 deletions dist/config.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,10 @@ modules = ["bar"]
[upstreams.u3]
address = "rsync.example.internal:1235"
modules = ["baz"]

[upstreams.u4]
address = "rsync.example.internal:1236"
modules = ["pro"]
# This option requires rsync upstream to support and enable proxy protocol
# See: https://github.com/WayneD/rsync/blob/2f9b963abaa52e44891180fe6c0d1c2219f6686d/rsyncd.conf.5.md?plain=1#L268
use_proxy_protocol = true
5 changes: 3 additions & 2 deletions pkg/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
)

type Upstream struct {
Address string `toml:"address"`
Modules []string `toml:"modules"`
Address string `toml:"address"`
Modules []string `toml:"modules"`
UseProxyProtocol bool `toml:"use_proxy_protocol"`
}

type ProxySettings struct {
Expand Down
43 changes: 35 additions & 8 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http"
"os"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -59,6 +60,8 @@ type Server struct {
bufPool sync.Pool
// name -> address
modules map[string]string
// address -> enable proxy protocol or not
proxyProtocol map[string]bool

activeConnCount atomic.Int64
connIndex atomic.Uint32
Expand Down Expand Up @@ -89,6 +92,7 @@ func (s *Server) loadConfig(c *Config) error {
}

modules := map[string]string{}
proxyProtocol := map[string]bool{}
for upstreamName, v := range c.Upstreams {
addr := v.Address
_, err := net.ResolveTCPAddr("tcp", addr)
Expand All @@ -101,6 +105,7 @@ func (s *Server) loadConfig(c *Config) error {
}
modules[moduleName] = addr
}
proxyProtocol[addr] = v.UseProxyProtocol
}

s.reloadLock.Lock()
Expand All @@ -119,6 +124,7 @@ func (s *Server) loadConfig(c *Config) error {
}
s.Motd = c.Proxy.Motd
s.modules = modules
s.proxyProtocol = proxyProtocol
return nil
}

Expand Down Expand Up @@ -159,42 +165,44 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn *net.TCPConn)
defer s.bufPool.Put(bufPtr)
buf := *bufPtr

ip := downConn.RemoteAddr().String()
addr := downConn.RemoteAddr().String()
ip := downConn.RemoteAddr().(*net.TCPAddr).IP.String()
port := downConn.RemoteAddr().(*net.TCPAddr).Port

writeTimeout := s.WriteTimeout
readTimeout := s.ReadTimeout

n, err := readLine(downConn, buf, readTimeout)
if err != nil {
return fmt.Errorf("read version from client %s: %w", ip, err)
return fmt.Errorf("read version from client %s: %w", addr, err)
}
rsyncdClientVersion := make([]byte, n)
copy(rsyncdClientVersion, buf[:n])
if !bytes.HasPrefix(rsyncdClientVersion, RsyncdVersionPrefix) {
return fmt.Errorf("unknown version from client %s: %q", ip, rsyncdClientVersion)
return fmt.Errorf("unknown version from client %s: %q", addr, rsyncdClientVersion)
}

_, err = writeWithTimeout(downConn, RsyncdServerVersion, writeTimeout)
if err != nil {
return fmt.Errorf("send version to client %s: %w", ip, err)
return fmt.Errorf("send version to client %s: %w", addr, err)
}

n, err = readLine(downConn, buf, readTimeout)
if err != nil {
return fmt.Errorf("read module from client %s: %w", ip, err)
return fmt.Errorf("read module from client %s: %w", addr, err)
}
if n == 0 {
return fmt.Errorf("empty request from client %s", ip)
return fmt.Errorf("empty request from client %s", addr)
}
data := buf[:n]
if s.Motd != "" {
_, err = writeWithTimeout(downConn, []byte(s.Motd+"\n"), writeTimeout)
if err != nil {
return fmt.Errorf("send motd to client %s: %w", ip, err)
return fmt.Errorf("send motd to client %s: %w", addr, err)
}
}
if len(data) == 1 { // single '\n'
s.accessLog.F("client %s requests listing all modules", ip)
s.accessLog.F("client %s requests listing all modules", addr)
return s.listAllModules(downConn)
}

Expand All @@ -204,6 +212,10 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn *net.TCPConn)

s.reloadLock.RLock()
upstreamAddr, ok := s.modules[moduleName]
var useProxyProtocol bool
if ok {
useProxyProtocol = s.proxyProtocol[upstreamAddr]
}
s.reloadLock.RUnlock()

if !ok {
Expand All @@ -220,6 +232,21 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn *net.TCPConn)
upConn := conn.(*net.TCPConn)
defer upConn.Close()
upIp := upConn.RemoteAddr().(*net.TCPAddr).IP.String()
upPort := upConn.RemoteAddr().(*net.TCPAddr).Port

if useProxyProtocol {
var IPVersion string
if strings.Contains(ip, ":") {
IPVersion = "TCP6"
} else {
IPVersion = "TCP4"
}
proxyHeader := fmt.Sprintf("PROXY %s %s %s %d %d\r\n", IPVersion, ip, upIp, port, upPort)
_, err = writeWithTimeout(upConn, []byte(proxyHeader), writeTimeout)
if err != nil {
return fmt.Errorf("send proxy protocol header to upstream %s: %w", upIp, err)
}
}

_, err = writeWithTimeout(upConn, rsyncdClientVersion, writeTimeout)
if err != nil {
Expand Down
28 changes: 28 additions & 0 deletions test/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,31 @@ func TestReloadConfigWithDuplicatedModules(t *testing.T) {
r.Error(err)
r.Contains(reloadOutput.String(), "Failed to reload config")
}

func TestProxyProtocol(t *testing.T) {
r := require.New(t)
dst, err := os.CreateTemp("", "rsync-proxy-e2e-*")
r.NoError(err)
r.NoError(dst.Close())

r.NoError(copyFile(getProxyConfigPath("config4.toml"), dst.Name()))

proxy := startProxy(t, func(s *server.Server) {
s.ConfigPath = dst.Name()
})

tmpFile, err := os.CreateTemp("", "rsync-proxy-e2e-*")
r.NoError(err)
r.NoError(tmpFile.Close())
defer os.Remove(tmpFile.Name())

outputBytes, err := newRsyncCommand(getRsyncPath(proxy, "/pro/v3.5/data"), tmpFile.Name()).CombinedOutput()
if err != nil {
t.Log(string(outputBytes))
r.NoError(err)
}

got, err := os.ReadFile(tmpFile.Name())
r.NoError(err)
r.Equal("3.5", string(got))
}
6 changes: 6 additions & 0 deletions test/e2e/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func testMain(m *testing.M) (int, error) {
port: 1235,
name: "bar.conf",
},
{
port: 1236,
name: "proxyprotocol.conf",
},
} {
prog, err := runRsyncd(ctx, cfg.port, filepath.Join(rsyncdConfDir, cfg.name))
if err != nil {
Expand Down Expand Up @@ -222,6 +226,8 @@ func setupDataDirs() error {
"/tmp/rsync-proxy-e2e/bar/v3.2/data": []byte("3.2"),
"/tmp/rsync-proxy-e2e/bar/v3.3/data": []byte("3.3"),
"/tmp/rsync-proxy-e2e/baz/v3.4/data": []byte("3.4"),
"/tmp/rsync-proxy-e2e/pro/v3.5/data": []byte("3.5"),
"/tmp/rsync-proxy-e2e/pro/v3.6/data": []byte("3.6"),
}
for fp, data := range files {
err := writeFile(fp, data)
Expand Down
8 changes: 8 additions & 0 deletions test/fixtures/proxy/config4.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[upstreams.u1]
address = "127.0.0.1:1236"
modules = ["pro"]
use_proxy_protocol = true

[upstreams.u2]
address = "127.0.0.1:1235"
modules = ["bar", "baz"]
8 changes: 8 additions & 0 deletions test/fixtures/rsyncd/proxyprotocol.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use chroot = false
proxy protocol = true

[pro]
path = /tmp/rsync-proxy-e2e/pro/
comment = PRO FILES
read only = true
timeout = 300

0 comments on commit 211b8dd

Please sign in to comment.