From aacb2563d21c52218e4bb54621686e3b73fd6974 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 30 Dec 2024 15:36:47 +0100 Subject: [PATCH 1/3] Add a flag to signal that failed query could have been executed and it might be not safe to retry it --- conn.go | 55 ++++++++++++++++++++++++++++++++++++-------------- conn_test.go | 12 +++++------ errors_test.go | 7 ++++--- 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/conn.go b/conn.go index 1c1650abb..670ff59be 100644 --- a/conn.go +++ b/conn.go @@ -1101,13 +1101,13 @@ func (c *Conn) addCall(call *callReq) error { func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) { if ctxErr := ctx.Err(); ctxErr != nil { - return nil, ctxErr + return nil, &QueryError{err: ctxErr, potentiallyExecuted: false} } // TODO: move tracer onto conn stream, ok := c.streams.GetStream() if !ok { - return nil, ErrNoStreams + return nil, &QueryError{err: ErrNoStreams, potentiallyExecuted: false} } // resp is basically a waiting semaphore protecting the framer @@ -1125,7 +1125,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram } if err := c.addCall(call); err != nil { - return nil, err + return nil, &QueryError{err: err, potentiallyExecuted: false} } // After this point, we need to either read from call.resp or close(call.timeout) @@ -1157,7 +1157,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil // check above could fail. c.releaseStream(call) - return nil, err + return nil, &QueryError{err: err, potentiallyExecuted: false} } n, err := c.w.writeContext(ctx, framer.buf) @@ -1185,7 +1185,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram // send a frame on, with all the streams used up and not returned. c.closeWithError(err) } - return nil, err + return nil, &QueryError{err: err, potentiallyExecuted: true} } var timeoutCh <-chan time.Time @@ -1222,7 +1222,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram // connection to close. c.releaseStream(call) } - return nil, resp.err + return nil, &QueryError{err: resp.err, potentiallyExecuted: true} } // dont release the stream if detect a timeout as another request can reuse // that stream and get a response for the old request, which we have no @@ -1233,20 +1233,20 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram defer c.releaseStream(call) if v := resp.framer.header.version.version(); v != c.version { - return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) + return nil, &QueryError{err: NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version), potentiallyExecuted: true} } return resp.framer, nil case <-timeoutCh: close(call.timeout) c.handleTimeout() - return nil, ErrTimeoutNoResponse + return nil, &QueryError{err: ErrTimeoutNoResponse, potentiallyExecuted: true} case <-ctxDone: close(call.timeout) - return nil, ctx.Err() + return nil, &QueryError{err: ctx.Err(), potentiallyExecuted: true} case <-c.ctx.Done(): close(call.timeout) - return nil, ErrConnectionClosed + return nil, &QueryError{err: ErrConnectionClosed, potentiallyExecuted: true} } } @@ -1906,11 +1906,14 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } var ( - ErrQueryArgLength = errors.New("gocql: query argument length mismatch") - ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") - ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection") - ErrConnectionClosed = errors.New("gocql: connection closed waiting for response") - ErrNoStreams = errors.New("gocql: no streams available on connection") + ErrQueryArgLength = errors.New("gocql: query argument length mismatch") + ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") + ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection") + ErrConnectionClosed = errors.New("gocql: connection closed waiting for response") + ErrNoStreams = errors.New("gocql: no streams available on connection") + ErrHostDown = errors.New("gocql: host is nil or down") + ErrNoPool = errors.New("gocql: host does not have a pool") + ErrNoConnectionsInPool = errors.New("gocql: host pool does not have connections") ) type ErrSchemaMismatch struct { @@ -1920,3 +1923,25 @@ type ErrSchemaMismatch struct { func (e *ErrSchemaMismatch) Error() string { return fmt.Sprintf("gocql: cluster schema versions not consistent: %+v", e.schemas) } + +type QueryError struct { + err error + potentiallyExecuted bool + isIdempotent bool +} + +func (e *QueryError) IsIdempotent() bool { + return e.isIdempotent +} + +func (e *QueryError) PotentiallyExecuted() bool { + return e.potentiallyExecuted +} + +func (e *QueryError) Error() string { + return fmt.Sprintf("%s (potentially executed: %v)", e.err.Error(), e.potentiallyExecuted) +} + +func (e *QueryError) Unwrap() error { + return e.err +} diff --git a/conn_test.go b/conn_test.go index f1bdd4338..0d4a46885 100644 --- a/conn_test.go +++ b/conn_test.go @@ -311,7 +311,7 @@ func TestCancel(t *testing.T) { wg.Add(1) go func() { - if err := qry.Exec(); err != context.Canceled { + if err := qry.Exec(); !errors.Is(err, context.Canceled) { t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err) } wg.Done() @@ -573,7 +573,7 @@ func TestQueryTimeout(t *testing.T) { select { case err := <-ch: - if err != ErrTimeoutNoResponse { + if !errors.Is(err, ErrTimeoutNoResponse) { t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err) } case <-time.After(40*time.Millisecond + db.cfg.Timeout): @@ -667,8 +667,8 @@ func TestQueryTimeoutClose(t *testing.T) { t.Fatal("timedout waiting to get a response once cluster is closed") } - if err != ErrConnectionClosed { - t.Fatalf("expected to get %v got %v", ErrConnectionClosed, err) + if !errors.Is(err, ErrConnectionClosed) { + t.Fatalf("expected to get %v or an error wrapping it, got %v", ErrConnectionClosed, err) } } @@ -721,7 +721,7 @@ func TestContext_Timeout(t *testing.T) { cancel() err = db.Query("timeout").WithContext(ctx).Exec() - if err != context.Canceled { + if !errors.Is(err, context.Canceled) { t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err) } } @@ -838,7 +838,7 @@ func TestContext_CanceledBeforeExec(t *testing.T) { cancel() err = db.Query("timeout").WithContext(ctx).Exec() - if err != context.Canceled { + if !errors.Is(err, context.Canceled) { t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err) } diff --git a/errors_test.go b/errors_test.go index 85246c0e5..76c34089e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -4,6 +4,7 @@ package gocql import ( + "errors" "testing" ) @@ -18,12 +19,12 @@ func TestErrorsParse(t *testing.T) { if err := createTable(session, `CREATE TABLE gocql_test.errors_parse (id int primary key)`); err == nil { t.Fatal("Should have gotten already exists error from cassandra server.") } else { - switch e := err.(type) { - case *RequestErrAlreadyExists: + e := &RequestErrAlreadyExists{} + if errors.As(err, &e) { if e.Table != "errors_parse" { t.Fatalf("expected error table to be 'errors_parse' but was %q", e.Table) } - default: + } else { t.Fatalf("expected to get RequestErrAlreadyExists instead got %T", e) } } From 46095d0b7261ab013d2446a24afd4006ac685c77 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 30 Dec 2024 15:32:59 +0100 Subject: [PATCH 2/3] Let retry policy to decide what to do with potentially executed non idempotent queries --- conn_test.go | 4 ++++ policies.go | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/conn_test.go b/conn_test.go index 0d4a46885..65dbadc54 100644 --- a/conn_test.go +++ b/conn_test.go @@ -456,6 +456,10 @@ func (t *testRetryPolicy) Attempt(qry RetryableQuery) bool { return qry.Attempts() <= t.NumRetries } func (t *testRetryPolicy) GetRetryType(err error) RetryType { + var executedErr *QueryError + if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() { + return Rethrow + } return Retry } diff --git a/policies.go b/policies.go index ca89aecba..d5ce45677 100644 --- a/policies.go +++ b/policies.go @@ -160,6 +160,10 @@ func (s *SimpleRetryPolicy) AttemptLWT(q RetryableQuery) bool { } func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType { + var executedErr *QueryError + if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() { + return Rethrow + } return RetryNextHost } @@ -168,6 +172,10 @@ func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType { // even timeouts if other clients send statements touching the same // partition to the original node at the same time. func (s *SimpleRetryPolicy) GetRetryTypeLWT(err error) RetryType { + var executedErr *QueryError + if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() { + return Rethrow + } return Retry } @@ -208,6 +216,10 @@ func getExponentialTime(min time.Duration, max time.Duration, attempts int) time } func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType { + var executedErr *QueryError + if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() { + return Rethrow + } return RetryNextHost } @@ -216,6 +228,10 @@ func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType { // even timeouts if other clients send statements touching the same // partition to the original node at the same time. func (e *ExponentialBackoffRetryPolicy) GetRetryTypeLWT(err error) RetryType { + var executedErr *QueryError + if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() { + return Rethrow + } return Retry } @@ -250,6 +266,14 @@ func (d *DowngradingConsistencyRetryPolicy) Attempt(q RetryableQuery) bool { } func (d *DowngradingConsistencyRetryPolicy) GetRetryType(err error) RetryType { + var executedErr *QueryError + if errors.As(err, &executedErr) { + err = executedErr.err + if executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() { + return Rethrow + } + } + switch t := err.(type) { case *RequestErrUnavailable: if t.Alive > 0 { From 4039926a7e35adc0ec098fdfa34e550b2e918532 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Wed, 18 Dec 2024 13:32:31 +0100 Subject: [PATCH 3/3] Always consider retry policy when executing query and use QueryError to signal potential execution --- query_executor.go | 116 +++++++++++++++++++++++++++++----------------- 1 file changed, 74 insertions(+), 42 deletions(-) diff --git a/query_executor.go b/query_executor.go index 7bbe7f6ab..354383981 100644 --- a/query_executor.go +++ b/query_executor.go @@ -2,6 +2,7 @@ package gocql import ( "context" + "errors" "sync" "time" ) @@ -107,74 +108,107 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { } func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter { - selectedHost := hostIter() rt := qry.retryPolicy() - lwt_rt, use_lwt_rt := rt.(LWTRetryPolicy) - // We only want to apply LWT policy to LWT queries - use_lwt_rt = use_lwt_rt && qry.IsLWT() + if rt == nil { + rt = &SimpleRetryPolicy{3} + } - var lastErr error - var iter *Iter - for selectedHost != nil { + lwtRT, isRTSupportsLWT := rt.(LWTRetryPolicy) + + var getShouldRetry func(qry RetryableQuery) bool + var getRetryType func(error) RetryType + + if isRTSupportsLWT && qry.IsLWT() { + getShouldRetry = lwtRT.AttemptLWT + getRetryType = lwtRT.GetRetryTypeLWT + } else { + getShouldRetry = rt.Attempt + getRetryType = rt.GetRetryType + } + + var potentiallyExecuted bool + + execute := func(qry ExecutableQuery, selectedHost SelectedHost) (iter *Iter, retry RetryType) { host := selectedHost.Info() if host == nil || !host.IsUp() { - selectedHost = hostIter() - continue + return &Iter{ + err: &QueryError{ + err: ErrHostDown, + potentiallyExecuted: potentiallyExecuted, + }, + }, RetryNextHost } - pool, ok := q.pool.getPool(host) if !ok { - selectedHost = hostIter() - continue + return &Iter{ + err: &QueryError{ + err: ErrNoPool, + potentiallyExecuted: potentiallyExecuted, + }, + }, RetryNextHost } - conn := pool.Pick(selectedHost.Token(), qry) if conn == nil { - selectedHost = hostIter() - continue + return &Iter{ + err: &QueryError{ + err: ErrNoConnectionsInPool, + potentiallyExecuted: potentiallyExecuted, + }, + }, RetryNextHost } - iter = q.attemptQuery(ctx, qry, conn) iter.host = selectedHost.Info() // Update host - switch iter.err { - case context.Canceled, context.DeadlineExceeded, ErrNotFound: - // those errors represents logical errors, they should not count - // toward removing a node from the pool + if iter.err == nil { + return iter, RetryType(255) + } + + switch { + case errors.Is(iter.err, context.Canceled), + errors.Is(iter.err, context.DeadlineExceeded): selectedHost.Mark(nil) - return iter + potentiallyExecuted = true + retry = Rethrow default: selectedHost.Mark(iter.err) + retry = RetryType(255) // Don't enforce retry and get it from retry policy } - // Exit if the query was successful - // or no retry policy defined - if iter.err == nil || rt == nil { - return iter - } - - // or retry policy decides to not retry anymore - if use_lwt_rt { - if !lwt_rt.AttemptLWT(qry) { - return iter - } + var qErr *QueryError + if errors.As(iter.err, &qErr) { + potentiallyExecuted = potentiallyExecuted && qErr.PotentiallyExecuted() + qErr.potentiallyExecuted = potentiallyExecuted + qErr.isIdempotent = qry.IsIdempotent() + iter.err = qErr } else { - if !rt.Attempt(qry) { - return iter + iter.err = &QueryError{ + err: iter.err, + potentiallyExecuted: potentiallyExecuted, + isIdempotent: qry.IsIdempotent(), } } + return iter, retry + } + var lastErr error + selectedHost := hostIter() + for selectedHost != nil { + iter, retryType := execute(qry, selectedHost) + if iter.err == nil { + return iter + } lastErr = iter.err - var retry_type RetryType - if use_lwt_rt { - retry_type = lwt_rt.GetRetryTypeLWT(iter.err) - } else { - retry_type = rt.GetRetryType(iter.err) + // Exit if retry policy decides to not retry anymore + if retryType == RetryType(255) { + if !getShouldRetry(qry) { + return iter + } + retryType = getRetryType(iter.err) } // If query is unsuccessful, check the error with RetryPolicy to retry - switch retry_type { + switch retryType { case Retry: // retry on the same host continue @@ -189,11 +223,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne return &Iter{err: ErrUnknownRetryType} } } - if lastErr != nil { return &Iter{err: lastErr} } - return &Iter{err: ErrNoConnections} }