diff --git a/CHANGELOG.md b/CHANGELOG.md index db391755b..d89762420 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Fix ISM Allocation field types ([#609](https://github.com/opensearch-project/opensearch-go/pull/609)) - Fix ISM Error Notification types ([#612](https://github.com/opensearch-project/opensearch-go/pull/612)) - Fix signer receiving drained body on retries ([#620](https://github.com/opensearch-project/opensearch-go/pull/620)) +- Fix Bulk Index Items not executing failure callbacks on bulk request failure ([#626](https://github.com/opensearch-project/opensearch-go/issues/626)) ### Security diff --git a/opensearchutil/bulk_indexer.go b/opensearchutil/bulk_indexer.go index 380ed70e4..816098bf9 100644 --- a/opensearchutil/bulk_indexer.go +++ b/opensearchutil/bulk_indexer.go @@ -503,11 +503,7 @@ func (w *worker) flush(ctx context.Context) error { blk, err = w.bi.config.Client.Bulk(ctx, req) if err != nil { - atomic.AddUint64(&w.bi.stats.numFailed, uint64(len(w.items))) - if w.bi.config.OnError != nil { - w.bi.config.OnError(ctx, fmt.Errorf("flush: %w", err)) - } - return fmt.Errorf("flush: %w", err) + return w.handleBulkError(ctx, fmt.Errorf("flush: %w", err)) } for i, blkItem := range blk.Items { @@ -520,7 +516,7 @@ func (w *worker) flush(ctx context.Context) error { item = w.items[i] // The OpenSearch bulk response contains an array of maps like this: // [ { "index": { ... } }, { "create": { ... } }, ... ] - // We range over the map, to set the first key and value as "op" and "info". + // We range over the map, to set the last key and value as "op" and "info". for k, v := range blkItem { op = k info = v @@ -552,3 +548,21 @@ func (w *worker) flush(ctx context.Context) error { return err } + +func (w *worker) handleBulkError(ctx context.Context, err error) error { + atomic.AddUint64(&w.bi.stats.numFailed, uint64(len(w.items))) + + // info (the response item) will be empty since the bulk request failed + var info opensearchapi.BulkRespItem + for i := range w.items { + if item := w.items[i]; item.OnFailure != nil { + item.OnFailure(ctx, item, info, err) + } + } + + if w.bi.config.OnError != nil { + w.bi.config.OnError(ctx, err) + } + + return err +} diff --git a/opensearchutil/bulk_indexer_internal_test.go b/opensearchutil/bulk_indexer_internal_test.go index ed8827fb4..134b0a195 100644 --- a/opensearchutil/bulk_indexer_internal_test.go +++ b/opensearchutil/bulk_indexer_internal_test.go @@ -31,6 +31,7 @@ package opensearchutil import ( "bytes" "context" + "errors" "fmt" "io" "log" @@ -765,6 +766,78 @@ func TestBulkIndexer(t *testing.T) { }) } }) + + t.Run("Items Failure Callbacks Executed On Bulk Failure", func(t *testing.T) { + client, _ := opensearchapi.NewClient(opensearchapi.Config{ + Client: opensearch.Config{ + Transport: &mockTransport{ + RoundTripFunc: func(request *http.Request) (*http.Response, error) { + if request.URL.Path == "/" { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(strings.NewReader(infoBody)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + } + + return nil, errors.New("simulated bulk request error") + }, + }, + }, + }) + + bi, _ := NewBulkIndexer(BulkIndexerConfig{ + NumWorkers: 1, + FlushBytes: 1, + Client: client, + OnError: func(ctx context.Context, err error) { + if err.Error() != "flush: simulated bulk request error" { + t.Errorf("Unexpected error: %v", err) + } + }, + }) + + var ( + numItems = 5 + wg sync.WaitGroup + failedItems int + ) + + wg.Add(numItems) + for i := 0; i < numItems; i++ { + err := bi.Add(context.Background(), BulkIndexerItem{ + Action: "index", + DocumentID: fmt.Sprintf("id_%d", i), + Body: strings.NewReader(fmt.Sprintf(`{"title":"doc_%d"}`, i)), + OnFailure: func(ctx context.Context, item BulkIndexerItem, resp opensearchapi.BulkRespItem, err error) { + if err.Error() != "flush: simulated bulk request error" { + t.Errorf("Unexpected error in OnFailure: %v", err) + } + + failedItems++ + wg.Done() + }, + }) + if err != nil { + t.Fatalf("Unexpected error adding item: %v", err) + } + } + + if err := bi.Close(context.Background()); err != nil { + t.Errorf("Unexpected error: %s", err) + } + + wg.Wait() + + if failedItems != numItems { + t.Errorf("Expected %d failed items, got %d", numItems, failedItems) + } + + if stats := bi.Stats(); stats.NumFailed != uint64(numItems) { + t.Errorf("Expected NumFailed to be %d, got %d", numItems, stats.NumFailed) + } + }) } func strPointer(s string) *string {