diff --git a/rpcclient/infrastructure.go b/rpcclient/infrastructure.go index 4fe1d894df..ca039023f1 100644 --- a/rpcclient/infrastructure.go +++ b/rpcclient/infrastructure.go @@ -759,18 +759,13 @@ out: // result, unmarshalling it, and delivering the unmarshalled result to the // provided response channel. func (c *Client) handleSendPostMessage(jReq *jsonRequest) { - protocol := "http" - if !c.config.DisableTLS { - protocol = "https" - } - var ( - err, lastErr error + lastErr error backoff time.Duration httpResponse *http.Response ) - parsedAddr, err := ParseAddressString(c.config.Host) + httpURL, err := c.config.httpURL() if err != nil { jReq.responseChan <- &Response{ err: fmt.Errorf("failed to parse address %v", err), @@ -778,22 +773,12 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { return } - var url string - switch parsedAddr.Network() { - case "unix", "unixpacket": - // Using a placeholder URL because a non-empty URL is required. - // The Unix domain socket is specified in the DialContext. - url = protocol + "://unix" - default: - url = protocol + "://" + c.config.Host - } - tries := 10 for i := 0; i < tries; i++ { var httpReq *http.Request bodyReader := bytes.NewReader(jReq.marshalledJSON) - httpReq, err = http.NewRequest("POST", url, bodyReader) + httpReq, err = http.NewRequest("POST", httpURL, bodyReader) if err != nil { jReq.responseChan <- &Response{result: nil, err: err} return @@ -1355,7 +1340,7 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) { } } - parsedAddr, err := ParseAddressString(config.Host) + parsedDialAddr, err := ParseAddressString(config.Host) if err != nil { return nil, err } @@ -1363,8 +1348,13 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) { Transport: &http.Transport{ Proxy: proxyFunc, TLSClientConfig: tlsConfig, - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial(parsedAddr.Network(), parsedAddr.String()) + DialContext: func(_ context.Context, _, + _ string) (net.Conn, error) { + + return net.Dial( + parsedDialAddr.Network(), + parsedDialAddr.String(), + ) }, }, Timeout: defaultHTTPTimeout, @@ -1373,6 +1363,32 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) { return &client, nil } +// httpURL returns the URL to use for HTTP POST requests. +func (config *ConnConfig) httpURL() (string, error) { + protocol := "http" + if !config.DisableTLS { + protocol = "https" + } + + parsedAddr, err := ParseAddressString(config.Host) + if err != nil { + return "", fmt.Errorf("error parsing host '%v': %v", + config.Host, err) + } + + var httpURL string + switch parsedAddr.Network() { + case "unix", "unixpacket": + // Using a placeholder URL because a non-empty URL is required. + // The Unix domain socket is specified in the DialContext. + httpURL = protocol + "://unix" + default: + httpURL = protocol + "://" + config.Host + } + + return httpURL, nil +} + // dial opens a websocket connection using the passed connection configuration // details. func dial(config *ConnConfig) (*websocket.Conn, error) { @@ -1733,53 +1749,48 @@ func (c *Client) Send() error { return nil } +// cutPrefix returns s without the provided leading prefix string +// and reports whether it found the prefix. +// If s doesn't start with prefix, cutPrefix returns s, false. +// If prefix is the empty string, cutPrefix returns s, true. +// Copied from go1.20 version. +func cutPrefix(s, prefix string) (after string, found bool) { + if !strings.HasPrefix(s, prefix) { + return s, false + } + return s[len(prefix):], true +} + // ParseAddressString converts an address in string format to a net.Addr that is // compatible with btcd. UDP is not supported because btcd needs reliable -// connections. We accept a custom function to resolve any TCP addresses so -// that caller is able control exactly how resolution is performed. +// connections. func ParseAddressString(strAddress string) (net.Addr, error) { - var parsedNetwork, parsedAddr string + // Addresses can either be in unix://address, unixpacket://address URL + // format, or just address:port host format for tcp. + if after, ok := cutPrefix(strAddress, "unix://"); ok { + return net.ResolveUnixAddr("unix", after) + } + if after, ok := cutPrefix(strAddress, "unixpacket://"); ok { + return net.ResolveUnixAddr("unixpacket", after) + } - // Addresses can either be in network://address:port format, - // network:address:port, address:port, or just port. We want to support - // all possible types. if strings.Contains(strAddress, "://") { - parts := strings.Split(strAddress, "://") - parsedNetwork, parsedAddr = parts[0], parts[1] - } else if strings.Contains(strAddress, ":") { - parts := strings.Split(strAddress, ":") - parsedNetwork = parts[0] - parsedAddr = strings.Join(parts[1:], ":") - } else { - parsedAddr = strAddress + // Not supporting :// anywhere in the host or path. + return nil, fmt.Errorf("unsupported protocol in address: %s", + strAddress) } - // Only TCP and Unix socket addresses are valid. We can't use IP or - // UDP only connections for anything we do in lnd. - switch parsedNetwork { - case "unix", "unixpacket": - return net.ResolveUnixAddr(parsedNetwork, parsedAddr) - - case "tcp", "tcp4", "tcp6": - return net.ResolveTCPAddr(parsedNetwork, verifyPort(parsedAddr)) - - case "ip", "ip4", "ip6", "udp", "udp4", "udp6", "unixgram": - return nil, fmt.Errorf("only TCP or unix socket "+ - "addresses are supported: %s", parsedAddr) - - default: - // We'll now possibly use the local host short circuit - // or parse out an all interfaces listen. - addrWithPort := verifyPort(strAddress) - - // Otherwise, we'll attempt to resolve the host. - return net.ResolveTCPAddr("tcp", addrWithPort) + // Parse it as a dummy URL to get the host and port. + u, err := url.Parse("dummy://" + strAddress) + if err != nil { + return nil, err } + return net.ResolveTCPAddr("tcp", verifyPort(u.Host)) } // verifyPort makes sure that an address string has both a host and a port. // If the address is just a port, then we'll assume that the user is using the -// short cut to specify a localhost:port address. +// shortcut to specify a localhost:port address. func verifyPort(address string) string { host, port, err := net.SplitHostPort(address) if err != nil { @@ -1801,8 +1812,8 @@ func verifyPort(address string) string { return net.JoinHostPort(address, "") } - // In the case that both the host and port are empty, we'll use the - // an empty port. + // In the case that both the host and port are empty, we'll use an empty + // port. if host == "" && port == "" { return ":" } diff --git a/rpcclient/infrastructure_test.go b/rpcclient/infrastructure_test.go new file mode 100644 index 0000000000..8416b7ad3c --- /dev/null +++ b/rpcclient/infrastructure_test.go @@ -0,0 +1,110 @@ +package rpcclient + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestParseAddressString checks different variation of supported and +// unsupported addresses. +func TestParseAddressString(t *testing.T) { + t.Parallel() + + // Using localhost only to avoid network calls. + testCases := []struct { + name string + addressString string + expNetwork string + expAddress string + expErrStr string + }{ + { + name: "localhost", + addressString: "localhost", + expNetwork: "tcp", + expAddress: "127.0.0.1:0", + }, + { + name: "localhost ip", + addressString: "127.0.0.1", + expNetwork: "tcp", + expAddress: "127.0.0.1:0", + }, + { + name: "localhost ipv6", + addressString: "::1", + expNetwork: "tcp", + expAddress: "[::1]:0", + }, + { + name: "localhost and port", + addressString: "localhost:80", + expNetwork: "tcp", + expAddress: "127.0.0.1:80", + }, + { + name: "localhost ipv6 and port", + addressString: "[::1]:80", + expNetwork: "tcp", + expAddress: "[::1]:80", + }, + { + name: "colon and port", + addressString: ":80", + expNetwork: "tcp", + expAddress: ":80", + }, + { + name: "colon only", + addressString: ":", + expNetwork: "tcp", + expAddress: ":0", + }, + { + name: "localhost and path", + addressString: "localhost/path", + expNetwork: "tcp", + expAddress: "127.0.0.1:0", + }, + { + name: "localhost port and path", + addressString: "localhost:80/path", + expNetwork: "tcp", + expAddress: "127.0.0.1:80", + }, + { + name: "unix prefix", + addressString: "unix://the/rest/of/the/path", + expNetwork: "unix", + expAddress: "the/rest/of/the/path", + }, + { + name: "unix prefix", + addressString: "unixpacket://the/rest/of/the/path", + expNetwork: "unixpacket", + expAddress: "the/rest/of/the/path", + }, + { + name: "error http prefix", + addressString: "http://localhost:1010", + expErrStr: "unsupported protocol in address", + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + addr, err := ParseAddressString(tc.addressString) + if tc.expErrStr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expErrStr) + return + } + require.NoError(t, err) + require.Equal(t, tc.expNetwork, addr.Network()) + require.Equal(t, tc.expAddress, addr.String()) + }) + } +}