-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdial.go
104 lines (93 loc) · 2.39 KB
/
dial.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/*
Package dialsrv provides a net.Dialer implementation that can reference SRV
records to DNS servers.
*/
package dialsrv
import (
"context"
"fmt"
"net"
"strconv"
"strings"
)
// Dialer wraps net.Dialer with SRV lookup.
type Dialer struct {
drv driver
}
// New creates a new Dialer with base *net.Dialer.
func New(d *net.Dialer) *Dialer {
if d == nil {
d = &net.Dialer{}
}
return &Dialer{
drv: &netDialerDriver{d},
}
}
// Dial connects to the address on the named network.
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
// DialContext connects to the address on the named network using
// the provided context.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if fa := parseAddr(network, address); fa != nil {
return d.dialSRV(ctx, fa)
}
return d.drv.DialContext(ctx, network, address)
}
func (d Dialer) dialSRV(ctx context.Context, fa *FlavoredAddr) (net.Conn, error) {
host, err := splitHost(fa.Name)
if err != nil {
return nil, err
}
_, addrs, err := d.drv.LookupSRV(ctx, fa.Service, fa.Proto, host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, fmt.Errorf("no SRV records for %s", fa.String())
}
// TODO: consider the case of len(addrs) >= 2. Use with rotation or random?
return d.drv.DialContext(ctx, fa.Network, address(addrs[0]))
}
func splitHost(s string) (string, error) {
if strings.IndexByte(s, ':') < 0 {
return s, nil
}
h, _, err := net.SplitHostPort(s)
return h, err
}
// FlavoredAddr represents SRV flavored address.
type FlavoredAddr struct {
Network string
Service string
Proto string
Name string
}
func parseAddr(network, address string) *FlavoredAddr {
const prefix = "srv+"
if !strings.HasPrefix(address, prefix) {
return nil
}
address = address[len(prefix):]
n := strings.Index(address, "+")
if n < 0 {
return &FlavoredAddr{Network: network, Name: address}
}
return &FlavoredAddr{
Network: network,
Service: address[:n],
Proto: network,
Name: address[n+1:],
}
}
// String returns FlavoredAddr's string representation.
func (fa *FlavoredAddr) String() string {
if fa.Service == "" && fa.Proto == "" {
return fa.Name
}
return "_" + fa.Service + "._" + fa.Proto + "." + fa.Name
}
func address(srv *net.SRV) string {
return srv.Target + ":" + strconv.FormatUint(uint64(srv.Port), 10)
}