Skip to content

Commit

Permalink
Make it possible for NewSession not to block
Browse files Browse the repository at this point in the history
  • Loading branch information
sylwiaszunejko committed Jan 2, 2025
1 parent ead5781 commit f366782
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 36 deletions.
21 changes: 17 additions & 4 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,14 @@ func TestInvalidKeyspace(t *testing.T) {
t.Fatalf("Expected ErrNoConnections but got %v", err)
}
} else {
session.Close() //Clean up the session
t.Fatal("expected err, got nil.")
if err = session.WaitUntilReady(); err != nil {
if err != ErrNoConnectionsStarted {
t.Fatalf("Expected ErrNoConnections but got %v", err)
}
} else {
session.Close() //Clean up the session
t.Fatal("expected err, got nil.")
}
}
}

Expand Down Expand Up @@ -907,8 +913,11 @@ func TestCreateSessionTimeout(t *testing.T) {
cluster := createCluster()
cluster.Hosts = []string{"127.0.0.1:1"}
session, err := cluster.CreateSession()
if err == nil {
session.Close()
defer session.Close()
if err != nil {
t.Fatal(err)
}
if err = session.WaitUntilReady(); err == nil {
t.Fatal("expected ErrNoConnectionsStarted, but no error was returned.")
}
}
Expand Down Expand Up @@ -2585,6 +2594,10 @@ func TestControl_DiscoverProtocol(t *testing.T) {
}
defer session.Close()

if err = session.WaitUntilReady(); err != nil {
t.Fatal(err)
}

if session.cfg.ProtoVersion == 0 {
t.Fatal("did not discovery protocol")
}
Expand Down
9 changes: 9 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
panic(err)
}
defer session.Close()
if err = session.WaitUntilReady(); err != nil {
panic(err)
}

err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace)
if err != nil {
Expand Down Expand Up @@ -175,6 +178,9 @@ func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session {
if err != nil {
tb.Fatal("createSession:", err)
}
if err = session.WaitUntilReady(); err != nil {
tb.Fatal(err)
}

if err := session.control.awaitSchemaAgreement(); err != nil {
tb.Fatal(err)
Expand All @@ -190,6 +196,9 @@ func createSessionFromMultiNodeCluster(cluster *ClusterConfig, tb testing.TB) *S
if err != nil {
tb.Fatal("createSession:", err)
}
if err = session.WaitUntilReady(); err != nil {
tb.Fatal(err)
}

initOnce.Do(func() {
if err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace); err != nil {
Expand Down
91 changes: 82 additions & 9 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ func TestSimple(t *testing.T) {
t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
}

if err = db.WaitUntilReady(); err != nil {
t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
}

if err := db.Query("void").Exec(); err != nil {
t.Fatalf("0x%x: %v", defaultProto, err)
}
Expand All @@ -102,6 +106,10 @@ func TestSSLSimple(t *testing.T) {
t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
}

if err = db.WaitUntilReady(); err != nil {
t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
}

if err := db.Query("void").Exec(); err != nil {
t.Fatalf("0x%x: %v", defaultProto, err)
}
Expand All @@ -116,6 +124,10 @@ func TestSSLSimpleNoClientCert(t *testing.T) {
t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
}

if err = db.WaitUntilReady(); err != nil {
t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
}

if err := db.Query("void").Exec(); err != nil {
t.Fatalf("0x%x: %v", defaultProto, err)
}
Expand Down Expand Up @@ -176,11 +188,15 @@ func TestDNSLookupConnected(t *testing.T) {

// CreateSession() should attempt to resolve the DNS name "cassandraX.invalid"
// and fail, but continue to connect via srv.Address
_, err := cluster.CreateSession()
s, err := cluster.CreateSession()
if err != nil {
t.Fatal("CreateSession() should have connected")
}

if err = s.WaitUntilReady(); err != nil {
t.Fatalf("WaitUntilReady() returned an error: %v", err)
}

if !strings.Contains(log.String(), "gocql: dns error") {
t.Fatalf("Expected to receive dns error log message - got '%s' instead", log.String())
}
Expand All @@ -200,9 +216,13 @@ func TestDNSLookupError(t *testing.T) {

// CreateSession() should attempt to resolve each DNS name "cassandraX.invalid"
// and fail since it could not resolve any dns entries
_, err := cluster.CreateSession()
if err == nil {
t.Fatal("CreateSession() should have returned an error")
s, err := cluster.CreateSession()
if err != nil {
t.Fatalf("CreateSession() have returned an error: %v", err)
}

if err = s.WaitUntilReady(); err == nil {
t.Fatalf("WaitUntilReady() should have returned an error: %v", err)
}

if !strings.Contains(log.String(), "gocql: dns error") {
Expand Down Expand Up @@ -235,9 +255,13 @@ func TestStartupTimeout(t *testing.T) {
cluster.ConnectTimeout = 600 * time.Millisecond

// Create session should timeout during connect attempt
_, err := cluster.CreateSession()
if err == nil {
t.Fatal("CreateSession() should have returned a timeout error")
s, err := cluster.CreateSession()
if err != nil {
t.Fatal("CreatedSession() returned an error: ", err)
}

if err = s.WaitUntilReady(); err == nil {
t.Fatal("Session initialization should have returned a timeout error")
}

elapsed := time.Since(startTime)
Expand Down Expand Up @@ -304,6 +328,10 @@ func TestCancel(t *testing.T) {
}
defer db.Close()

if err = db.WaitUntilReady(); err != nil {
t.Fatalf("WaitUntilReady: %v", err)
}

qry := db.Query("timeout").WithContext(ctx)

// Make sure we finish the query without leftovers
Expand Down Expand Up @@ -356,6 +384,10 @@ func TestQueryRetry(t *testing.T) {
}
defer db.Close()

if err = db.WaitUntilReady(); err != nil {
t.Fatalf("WaitUntilReady: %v", err)
}

go func() {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -411,6 +443,10 @@ func TestQueryMultinodeWithMetrics(t *testing.T) {
}
defer db.Close()

if err = db.WaitUntilReady(); err != nil {
t.Fatalf("WaitUntilReady: %v", err)
}

// 1 retry per host
rt := &SimpleRetryPolicy{NumRetries: 3}
observer := &testQueryObserver{metrics: make(map[string]*hostMetrics), verbose: false, logger: log}
Expand Down Expand Up @@ -490,6 +526,10 @@ func TestSpeculativeExecution(t *testing.T) {
}
defer db.Close()

if err = db.WaitUntilReady(); err != nil {
t.Fatal(err)
}

// Create a test retry policy, 6 retries will cover 2 executions
rt := &testRetryPolicy{NumRetries: 8}
// test Speculative policy with 1 additional execution
Expand Down Expand Up @@ -536,6 +576,10 @@ func TestPolicyConnPoolSSL(t *testing.T) {
t.Fatalf("failed to create new session: %v", err)
}

if err = db.WaitUntilReady(); err != nil {
t.Fatal(err)
}

if err := db.Query("void").Exec(); err != nil {
t.Fatalf("query failed due to error: %v", err)
}
Expand Down Expand Up @@ -564,6 +608,10 @@ func TestQueryTimeout(t *testing.T) {
}
defer db.Close()

if err = db.WaitUntilReady(); err != nil {
t.Fatal(err)
}

ch := make(chan error, 1)

go func() {
Expand Down Expand Up @@ -600,6 +648,9 @@ func BenchmarkSingleConn(b *testing.B) {
b.Fatalf("NewCluster: %v", err)
}
defer db.Close()
if err = db.WaitUntilReady(); err != nil {
b.Fatal(err)
}

b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
Expand Down Expand Up @@ -632,6 +683,9 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
t.Fatalf("NewCluster: %v", err)
}
defer db.Close()
if err = db.WaitUntilReady(); err != nil {
t.Fatal(err)
}

db.Query("slow").Exec()

Expand All @@ -656,6 +710,10 @@ func TestQueryTimeoutClose(t *testing.T) {
t.Fatalf("NewCluster: %v", err)
}

if err = db.WaitUntilReady(); err != nil {
t.Fatalf("WaitUntilReady: %v", err)
}

ch := make(chan error)
go func() {
err := db.Query("timeout").Exec()
Expand Down Expand Up @@ -721,6 +779,10 @@ func TestContext_Timeout(t *testing.T) {
}
defer db.Close()

if err = db.WaitUntilReady(); err != nil {
t.Fatal(err)
}

ctx, cancel = context.WithCancel(ctx)
cancel()

Expand Down Expand Up @@ -794,8 +856,11 @@ func TestInitialRetryPolicy(t *testing.T) {
policy := &TestReconnectionPolicy{NumRetries: tc.NumRetries}
cluster.InitialReconnectionPolicy = policy
cluster.ProtoVersion = tc.ProtoVersion
_, err := cluster.CreateSession()
if err == nil {
s, err := cluster.CreateSession()
if err != nil {
t.Fatal(err)
}
if err = s.WaitUntilReady(); err == nil {
t.Fatal("expected to get an error")
}
if !strings.Contains(err.Error(), tc.ExpectedErr) {
Expand Down Expand Up @@ -836,6 +901,10 @@ func TestContext_CanceledBeforeExec(t *testing.T) {
}
defer db.Close()

if err = db.WaitUntilReady(); err != nil {
t.Fatal(err)
}

startupRequestCount := atomic.LoadUint64(&reqCount)

ctx, cancel = context.WithCancel(ctx)
Expand Down Expand Up @@ -1045,6 +1114,10 @@ func TestFrameHeaderObserver(t *testing.T) {
t.Fatal(err)
}

if err = db.WaitUntilReady(); err != nil {
t.Fatal(err)
}

if err := db.Query("void").Exec(); err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 4 additions & 0 deletions control_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ func TestUnixSockets(t *testing.T) {

defer sess.Close()

if err = sess.WaitUntilReady(); err != nil {
panic(fmt.Sprintf("unable to initialize the session: %v", err))
}

keyspace := "test1"

err = createTable(sess, `DROP KEYSPACE IF EXISTS `+keyspace)
Expand Down
4 changes: 4 additions & 0 deletions keyspace_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ func TestKeyspaceTable(t *testing.T) {
t.Fatal("createSession:", err)
}

if err = session.WaitUntilReady(); err != nil {
t.Fatal("waitUntilReady: ", err)
}

cluster.Keyspace = "wrong_keyspace"

keyspace := "test1"
Expand Down
21 changes: 15 additions & 6 deletions policies_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ func TestDCValidationTokenAware(t *testing.T) {
fallback := DCAwareRoundRobinPolicy("WRONG_DC")
cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback)

_, err := cluster.CreateSession()
if err == nil {
session, err := cluster.CreateSession()
if err != nil {
t.Fatal(err)
}
if err = session.WaitUntilReady(); err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}
Expand All @@ -24,8 +27,11 @@ func TestDCValidationDCAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = DCAwareRoundRobinPolicy("WRONG_DC")

_, err := cluster.CreateSession()
if err == nil {
session, err := cluster.CreateSession()
if err != nil {
t.Fatal(err)
}
if err = session.WaitUntilReady(); err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}
Expand All @@ -34,8 +40,11 @@ func TestDCValidationRackAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = RackAwareRoundRobinPolicy("WRONG_DC", "RACK")

_, err := cluster.CreateSession()
if err == nil {
session, err := cluster.CreateSession()
if err != nil {
t.Fatal(err)
}
if err = session.WaitUntilReady(); err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}
Loading

0 comments on commit f366782

Please sign in to comment.