Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dns host provider lookup timeout #143

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti
ec := make(chan Event, eventChanSize)
conn := &Conn{
dialer: net.DialTimeout,
hostProvider: &DNSHostProvider{},
hostProvider: NewDNSHostProvider(),
conn: nil,
state: StateDisconnected,
eventChan: ec,
Expand Down
60 changes: 53 additions & 7 deletions dnshostprovider.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,57 @@
package zk

import (
"context"
"fmt"
"net"
"sync"
"time"
)

const _defaultLookupTimeout = 3 * time.Second

type lookupHostFn func(context.Context, string) ([]string, error)

// DNSHostProviderOption is an option for the DNSHostProvider.
type DNSHostProviderOption interface {
apply(*DNSHostProvider)
}

type lookupTimeoutOption struct {
timeout time.Duration
}

// WithLookupTimeout returns a DNSHostProviderOption that sets the lookup timeout.
func WithLookupTimeout(timeout time.Duration) DNSHostProviderOption {
return lookupTimeoutOption{
timeout: timeout,
}
}

func (o lookupTimeoutOption) apply(provider *DNSHostProvider) {
provider.lookupTimeout = o.timeout
}

// DNSHostProvider is the default HostProvider. It currently matches
// the Java StaticHostProvider, resolving hosts from DNS once during
// the call to Init. It could be easily extended to re-query DNS
// periodically or if there is trouble connecting.
type DNSHostProvider struct {
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
curr int
last int
lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
curr int
last int
lookupTimeout time.Duration
lookupHost lookupHostFn // Override of net.LookupHost, for testing.
}

// NewDNSHostProvider creates a new DNSHostProvider with the given options.
func NewDNSHostProvider(options ...DNSHostProviderOption) *DNSHostProvider {
var provider DNSHostProvider
for _, option := range options {
option.apply(&provider)
}
return &provider
}

// Init is called first, with the servers specified in the connection
Expand All @@ -27,16 +63,26 @@ func (hp *DNSHostProvider) Init(servers []string) error {

lookupHost := hp.lookupHost
if lookupHost == nil {
lookupHost = net.LookupHost
var resolver net.Resolver
lookupHost = resolver.LookupHost
}

timeout := hp.lookupTimeout
if timeout == 0 {
timeout = _defaultLookupTimeout
}

// TODO: consider using a context from the caller.
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

found := []string{}
for _, server := range servers {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err
}
addrs, err := lookupHost(host)
addrs, err := lookupHost(ctx, host)
if err != nil {
return err
}
Expand Down
51 changes: 40 additions & 11 deletions dnshostprovider_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
package zk

import (
"context"
"fmt"
"log"
"testing"
"time"
)

// localhostLookupHost is a test replacement for net.LookupHost that
// always returns 127.0.0.1
func localhostLookupHost(host string) ([]string, error) {
return []string{"127.0.0.1"}, nil
type lookupHostOption struct {
lookupFn lookupHostFn
}

func withLookupHost(lookupFn lookupHostFn) DNSHostProviderOption {
return lookupHostOption{
lookupFn: lookupFn,
}
}

func (o lookupHostOption) apply(provider *DNSHostProvider) {
provider.lookupHost = o.lookupFn
}

// TestDNSHostProviderCreate is just like TestCreate, but with an
Expand All @@ -24,7 +33,15 @@ func TestIntegration_DNSHostProviderCreate(t *testing.T) {

port := ts.Servers[0].Port
server := fmt.Sprintf("foo.example.com:%d", port)
hostProvider := &DNSHostProvider{lookupHost: localhostLookupHost}
hostProvider := NewDNSHostProvider(
withLookupHost(func(ctx context.Context, host string) ([]string, error) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("No lookup context deadline set")
}
return []string{"127.0.0.1"}, nil
}),
)

zk, _, err := Connect([]string{server}, time.Second*15, WithHostProvider(hostProvider))
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
Expand Down Expand Up @@ -103,9 +120,11 @@ func TestIntegration_DNSHostProviderReconnect(t *testing.T) {
}
defer ts.Stop()

innerHp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
}}
innerHp := NewDNSHostProvider(
withLookupHost(func(_ context.Context, host string) ([]string, error) {
return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
}),
)
ports := make([]int, 0, len(ts.Servers))
for _, server := range ts.Servers {
ports = append(ports, server.Port)
Expand Down Expand Up @@ -172,9 +191,11 @@ func TestIntegration_DNSHostProviderReconnect(t *testing.T) {
func TestDNSHostProviderRetryStart(t *testing.T) {
t.Parallel()

hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
}}
hp := NewDNSHostProvider(
withLookupHost(func(_ context.Context, host string) ([]string, error) {
return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
}),
)

if err := hp.Init([]string{"foo.example.com:12345"}); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -222,3 +243,11 @@ func TestDNSHostProviderRetryStart(t *testing.T) {
}
}
}

func TestNewDNSHostProvider(t *testing.T) {
want := 5 * time.Second
provider := NewDNSHostProvider(WithLookupTimeout(want))
if provider.lookupTimeout != want {
t.Fatalf("expected lookup timeout to be %v, got %v", want, provider.lookupTimeout)
}
}
Loading