Skip to content

Commit

Permalink
Merge pull request #376 from sylwiaszunejko/idempotent-flag-retry
Browse files Browse the repository at this point in the history
Let retry policy to decide about non idempotent queries
  • Loading branch information
dkropachev authored Dec 31, 2024
2 parents 1aab8a5 + 4039926 commit ead5781
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 66 deletions.
55 changes: 40 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1101,13 +1101,13 @@ func (c *Conn) addCall(call *callReq) error {

func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
return nil, &QueryError{err: ctxErr, potentiallyExecuted: false}
}

// TODO: move tracer onto conn
stream, ok := c.streams.GetStream()
if !ok {
return nil, ErrNoStreams
return nil, &QueryError{err: ErrNoStreams, potentiallyExecuted: false}
}

// resp is basically a waiting semaphore protecting the framer
Expand All @@ -1125,7 +1125,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
}

if err := c.addCall(call); err != nil {
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: false}
}

// After this point, we need to either read from call.resp or close(call.timeout)
Expand Down Expand Up @@ -1157,7 +1157,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil
// check above could fail.
c.releaseStream(call)
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: false}
}

n, err := c.w.writeContext(ctx, framer.buf)
Expand Down Expand Up @@ -1185,7 +1185,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// send a frame on, with all the streams used up and not returned.
c.closeWithError(err)
}
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: true}
}

var timeoutCh <-chan time.Time
Expand Down Expand Up @@ -1222,7 +1222,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// connection to close.
c.releaseStream(call)
}
return nil, resp.err
return nil, &QueryError{err: resp.err, potentiallyExecuted: true}
}
// dont release the stream if detect a timeout as another request can reuse
// that stream and get a response for the old request, which we have no
Expand All @@ -1233,20 +1233,20 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
defer c.releaseStream(call)

if v := resp.framer.header.version.version(); v != c.version {
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
return nil, &QueryError{err: NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version), potentiallyExecuted: true}
}

return resp.framer, nil
case <-timeoutCh:
close(call.timeout)
c.handleTimeout()
return nil, ErrTimeoutNoResponse
return nil, &QueryError{err: ErrTimeoutNoResponse, potentiallyExecuted: true}
case <-ctxDone:
close(call.timeout)
return nil, ctx.Err()
return nil, &QueryError{err: ctx.Err(), potentiallyExecuted: true}
case <-c.ctx.Done():
close(call.timeout)
return nil, ErrConnectionClosed
return nil, &QueryError{err: ErrConnectionClosed, potentiallyExecuted: true}
}
}

Expand Down Expand Up @@ -1906,11 +1906,14 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
}

var (
ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
ErrNoStreams = errors.New("gocql: no streams available on connection")
ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
ErrNoStreams = errors.New("gocql: no streams available on connection")
ErrHostDown = errors.New("gocql: host is nil or down")
ErrNoPool = errors.New("gocql: host does not have a pool")
ErrNoConnectionsInPool = errors.New("gocql: host pool does not have connections")
)

type ErrSchemaMismatch struct {
Expand All @@ -1920,3 +1923,25 @@ type ErrSchemaMismatch struct {
func (e *ErrSchemaMismatch) Error() string {
return fmt.Sprintf("gocql: cluster schema versions not consistent: %+v", e.schemas)
}

type QueryError struct {
err error
potentiallyExecuted bool
isIdempotent bool
}

func (e *QueryError) IsIdempotent() bool {
return e.isIdempotent
}

func (e *QueryError) PotentiallyExecuted() bool {
return e.potentiallyExecuted
}

func (e *QueryError) Error() string {
return fmt.Sprintf("%s (potentially executed: %v)", e.err.Error(), e.potentiallyExecuted)
}

func (e *QueryError) Unwrap() error {
return e.err
}
16 changes: 10 additions & 6 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func TestCancel(t *testing.T) {
wg.Add(1)

go func() {
if err := qry.Exec(); err != context.Canceled {
if err := qry.Exec(); !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err)
}
wg.Done()
Expand Down Expand Up @@ -456,6 +456,10 @@ func (t *testRetryPolicy) Attempt(qry RetryableQuery) bool {
return qry.Attempts() <= t.NumRetries
}
func (t *testRetryPolicy) GetRetryType(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return Retry
}

Expand Down Expand Up @@ -573,7 +577,7 @@ func TestQueryTimeout(t *testing.T) {

select {
case err := <-ch:
if err != ErrTimeoutNoResponse {
if !errors.Is(err, ErrTimeoutNoResponse) {
t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err)
}
case <-time.After(40*time.Millisecond + db.cfg.Timeout):
Expand Down Expand Up @@ -667,8 +671,8 @@ func TestQueryTimeoutClose(t *testing.T) {
t.Fatal("timedout waiting to get a response once cluster is closed")
}

if err != ErrConnectionClosed {
t.Fatalf("expected to get %v got %v", ErrConnectionClosed, err)
if !errors.Is(err, ErrConnectionClosed) {
t.Fatalf("expected to get %v or an error wrapping it, got %v", ErrConnectionClosed, err)
}
}

Expand Down Expand Up @@ -721,7 +725,7 @@ func TestContext_Timeout(t *testing.T) {
cancel()

err = db.Query("timeout").WithContext(ctx).Exec()
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
}
}
Expand Down Expand Up @@ -838,7 +842,7 @@ func TestContext_CanceledBeforeExec(t *testing.T) {
cancel()

err = db.Query("timeout").WithContext(ctx).Exec()
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
}

Expand Down
7 changes: 4 additions & 3 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package gocql

import (
"errors"
"testing"
)

Expand All @@ -18,12 +19,12 @@ func TestErrorsParse(t *testing.T) {
if err := createTable(session, `CREATE TABLE gocql_test.errors_parse (id int primary key)`); err == nil {
t.Fatal("Should have gotten already exists error from cassandra server.")
} else {
switch e := err.(type) {
case *RequestErrAlreadyExists:
e := &RequestErrAlreadyExists{}
if errors.As(err, &e) {
if e.Table != "errors_parse" {
t.Fatalf("expected error table to be 'errors_parse' but was %q", e.Table)
}
default:
} else {
t.Fatalf("expected to get RequestErrAlreadyExists instead got %T", e)
}
}
Expand Down
24 changes: 24 additions & 0 deletions policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ func (s *SimpleRetryPolicy) AttemptLWT(q RetryableQuery) bool {
}

func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return RetryNextHost
}

Expand All @@ -168,6 +172,10 @@ func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType {
// even timeouts if other clients send statements touching the same
// partition to the original node at the same time.
func (s *SimpleRetryPolicy) GetRetryTypeLWT(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return Retry
}

Expand Down Expand Up @@ -208,6 +216,10 @@ func getExponentialTime(min time.Duration, max time.Duration, attempts int) time
}

func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return RetryNextHost
}

Expand All @@ -216,6 +228,10 @@ func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType {
// even timeouts if other clients send statements touching the same
// partition to the original node at the same time.
func (e *ExponentialBackoffRetryPolicy) GetRetryTypeLWT(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return Retry
}

Expand Down Expand Up @@ -250,6 +266,14 @@ func (d *DowngradingConsistencyRetryPolicy) Attempt(q RetryableQuery) bool {
}

func (d *DowngradingConsistencyRetryPolicy) GetRetryType(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) {
err = executedErr.err
if executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
}

switch t := err.(type) {
case *RequestErrUnavailable:
if t.Alive > 0 {
Expand Down
Loading

0 comments on commit ead5781

Please sign in to comment.