diff --git a/cluster.go b/cluster.go index e0abd2b6d..3d795322b 100644 --- a/cluster.go +++ b/cluster.go @@ -340,6 +340,10 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) { return NewSession(*cfg) } +func (cfg *ClusterConfig) CreateSessionNonBlocking() (*Session, error) { + return NewSessionNonBlocking(*cfg) +} + // translateAddressPort is a helper method that will use the given AddressTranslator // if defined, to translate the given address and port into a possibly new address // and port, If no AddressTranslator or if an error occurs, the given address and diff --git a/scylla_shard_aware_port_common_test.go b/scylla_shard_aware_port_common_test.go index 538937162..f15ea0aed 100644 --- a/scylla_shard_aware_port_common_test.go +++ b/scylla_shard_aware_port_common_test.go @@ -46,6 +46,11 @@ func testShardAwarePortNoReconnections(t *testing.T, makeCluster makeClusterTest } defer sess.Close() + if err = sess.WaitUntilReady(); err != nil { + cancel() + return + } + if err := waitUntilPoolsStopFilling(ctx, sess, 10*time.Second); err != nil { cancel() return @@ -118,6 +123,10 @@ func testShardAwarePortMaliciousNAT(t *testing.T, makeCluster makeClusterTestFun } defer sess.Close() + if err = sess.WaitUntilReady(); err != nil { + t.Fatalf("an error occurred while initializing a session: %s", err) + } + // In this situation we are guaranteed that the connection will miss one // shard at this point. The first connection receives a random shard, // then we establish N-1 connections, targeting remaining shards. @@ -150,6 +159,10 @@ func testShardAwarePortUnreachable(t *testing.T, makeCluster makeClusterTestFunc } defer sess.Close() + if err = sess.WaitUntilReady(); err != nil { + t.Fatalf("an error occurred while initializing a session: %s", err) + } + // In this situation, the connecting to the shard-aware port will fail, // but connections to the non-shard-aware port will succeed. This test // checks that we detect that the shard-aware-port is unreachable and @@ -187,6 +200,10 @@ func testShardAwarePortUnusedIfNotEnabled(t *testing.T, makeCluster makeClusterT } defer sess.Close() + if err = sess.WaitUntilReady(); err != nil { + t.Fatalf("an error occurred while initializing a session: %s", err) + } + if err := waitUntilPoolsStopFilling(context.Background(), sess, 10*time.Second); err != nil { t.Fatal(err) } diff --git a/session.go b/session.go index ad73f27e8..326f0ed0f 100644 --- a/session.go +++ b/session.go @@ -79,6 +79,8 @@ type Session struct { // isInitialized is true once Session.init succeeds. // you can use initialized() to read the value. isInitialized bool + initErr error + readyCh chan error logger StdLogger @@ -114,8 +116,7 @@ func addrsToHosts(addrs []string, defaultPort int, logger StdLogger) ([]*HostInf return hosts, nil } -// NewSession wraps an existing Node. -func NewSession(cfg ClusterConfig) (*Session, error) { +func newSessionCommon(cfg ClusterConfig) (*Session, error) { if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("gocql: unable to create session: cluster config validation failed: %v", err) } @@ -132,6 +133,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) { ctx: ctx, cancel: cancel, logger: cfg.logger(), + readyCh: make(chan error, 1), } // Close created resources on error otherwise they'll leak @@ -181,6 +183,16 @@ func NewSession(cfg ClusterConfig) (*Session, error) { } s.connCfg = connCfg + return s, nil +} + +// NewSession wraps an existing Node. +func NewSession(cfg ClusterConfig) (*Session, error) { + s, err := newSessionCommon(cfg) + if err != nil { + return nil, err + } + if err = s.init(); err != nil { if err == ErrNoConnectionsStarted { //This error used to be generated inside NewSession & returned directly @@ -192,6 +204,29 @@ func NewSession(cfg ClusterConfig) (*Session, error) { } } + s.readyCh <- nil + close(s.readyCh) + + return s, nil +} + +func NewSessionNonBlocking(cfg ClusterConfig) (*Session, error) { + s, err := newSessionCommon(cfg) + if err != nil { + return nil, err + } + + go func() { + if initErr := s.init(); initErr != nil { + s.sessionStateMu.Lock() + s.initErr = fmt.Errorf("gocql: unable to create session: %v", initErr) + s.sessionStateMu.Unlock() + } + + s.readyCh <- s.initErr + close(s.readyCh) + }() + return s, nil } @@ -404,6 +439,9 @@ func (s *Session) AwaitSchemaAgreement(ctx context.Context) error { if s.cfg.disableControlConn { return errNoControl } + if err := s.Ready(); err != nil { + return err + } ch := s.control.getConn() return (&Iter{err: ch.conn.awaitSchemaAgreement(ctx)}).err } @@ -570,11 +608,32 @@ func (s *Session) initialized() bool { return initialized } +func (s *Session) Ready() error { + s.sessionStateMu.RLock() + err := ErrSessionNotReady + if s.isInitialized || s.initErr != nil { + err = s.initErr + } + s.sessionStateMu.RUnlock() + return err +} + +func (s *Session) WaitUntilReady() error { + err, ok := <-s.readyCh + if !ok { + return nil + } + return err +} + func (s *Session) executeQuery(qry *Query) (it *Iter) { // fail fast if s.Closed() { return &Iter{err: ErrSessionClosed} } + if err := s.Ready(); err != nil { + return &Iter{err: err} + } iter, err := s.executor.executeQuery(qry) if err != nil { @@ -599,6 +658,8 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) { // fail fast if s.Closed() { return nil, ErrSessionClosed + } else if err := s.Ready(); err != nil { + return nil, err } else if keyspace == "" { return nil, ErrNoKeyspace } @@ -611,6 +672,8 @@ func (s *Session) TabletsMetadata() (TabletInfoList, error) { // fail fast if s.Closed() { return nil, ErrSessionClosed + } else if err := s.Ready(); err != nil { + return nil, err } else if !s.tabletsRoutingV1 { return nil, ErrTabletsNotUsed } @@ -798,6 +861,9 @@ func (s *Session) executeBatch(batch *Batch) *Iter { if s.Closed() { return &Iter{err: ErrSessionClosed} } + if err := s.Ready(); err != nil { + return &Iter{err: err} + } // Prevent the execution of the batch if greater than the limit // Currently batches have a limit of 65536 queries. @@ -2364,6 +2430,7 @@ var ( ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist") ErrNoMetadata = errors.New("no metadata available") ErrTabletsNotUsed = errors.New("tablets not used") + ErrSessionNotReady = errors.New("session is not ready yet") ) type ErrProtocol struct{ error }