diff --git a/CHANGELOG.md b/CHANGELOG.md index db391755b..d89762420 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Fix ISM Allocation field types ([#609](https://github.com/opensearch-project/opensearch-go/pull/609)) - Fix ISM Error Notification types ([#612](https://github.com/opensearch-project/opensearch-go/pull/612)) - Fix signer receiving drained body on retries ([#620](https://github.com/opensearch-project/opensearch-go/pull/620)) +- Fix Bulk Index Items not executing failure callbacks on bulk request failure ([#626](https://github.com/opensearch-project/opensearch-go/issues/626)) ### Security diff --git a/opensearchutil/bulk_indexer.go b/opensearchutil/bulk_indexer.go index 380ed70e4..fe93dd58d 100644 --- a/opensearchutil/bulk_indexer.go +++ b/opensearchutil/bulk_indexer.go @@ -338,7 +338,8 @@ func (w *worker) run() { w.mu.Lock() if w.bi.config.DebugLogger != nil { - w.bi.config.DebugLogger.Printf("[worker-%03d] Received item [%s:%s]\n", w.id, item.Action, item.DocumentID) + w.bi.config.DebugLogger.Printf("[worker-%03d] Received item [%s:%s]\n", w.id, item.Action, + item.DocumentID) } if err := w.writeMeta(item); err != nil { @@ -503,11 +504,7 @@ func (w *worker) flush(ctx context.Context) error { blk, err = w.bi.config.Client.Bulk(ctx, req) if err != nil { - atomic.AddUint64(&w.bi.stats.numFailed, uint64(len(w.items))) - if w.bi.config.OnError != nil { - w.bi.config.OnError(ctx, fmt.Errorf("flush: %w", err)) - } - return fmt.Errorf("flush: %w", err) + return w.handleBulkError(ctx, fmt.Errorf("flush: %w", err)) } for i, blkItem := range blk.Items { @@ -520,7 +517,7 @@ func (w *worker) flush(ctx context.Context) error { item = w.items[i] // The OpenSearch bulk response contains an array of maps like this: // [ { "index": { ... } }, { "create": { ... } }, ... ] - // We range over the map, to set the first key and value as "op" and "info". + // We range over the map, to set the last key and value as "op" and "info". for k, v := range blkItem { op = k info = v @@ -552,3 +549,17 @@ func (w *worker) flush(ctx context.Context) error { return err } + +func (w *worker) handleBulkError(ctx context.Context, err error) error { + atomic.AddUint64(&w.bi.stats.numFailed, uint64(len(w.items))) + + // info (the response item) will be empty since the bulk request failed + var info opensearchapi.BulkRespItem + for i := range w.items { + if item := w.items[i]; item.OnFailure != nil { + item.OnFailure(ctx, item, info, err) + } + } + + return err +} diff --git a/opensearchutil/bulk_indexer_internal_test.go b/opensearchutil/bulk_indexer_internal_test.go index ed8827fb4..743c6932f 100644 --- a/opensearchutil/bulk_indexer_internal_test.go +++ b/opensearchutil/bulk_indexer_internal_test.go @@ -31,6 +31,7 @@ package opensearchutil import ( "bytes" "context" + "errors" "fmt" "io" "log" @@ -248,11 +249,18 @@ func TestBulkIndexer(t *testing.T) { client, _ := opensearchapi.NewClient(config) - var indexerError error + var ( + expectedOnErrorCount = 1 + indexerError error + onErrorCount int + ) biCfg := BulkIndexerConfig{ NumWorkers: 1, Client: client, - OnError: func(ctx context.Context, err error) { indexerError = err }, + OnError: func(ctx context.Context, err error) { + onErrorCount++ + indexerError = err + }, } if os.Getenv("DEBUG") != "" { biCfg.DebugLogger = log.New(os.Stdout, "", 0) @@ -266,11 +274,17 @@ func TestBulkIndexer(t *testing.T) { t.Fatalf("Unexpected error: %s", err) } - bi.Close(context.Background()) + if err := bi.Close(context.Background()); err != nil { + t.Errorf("Unexpected error: %s", err) + } if indexerError == nil { t.Errorf("Expected indexerError to not be nil") } + + if onErrorCount != expectedOnErrorCount { + t.Errorf("Expected onErrorCount to be %d, got %d", expectedOnErrorCount, onErrorCount) + } }) t.Run("Item Callbacks", func(t *testing.T) { @@ -281,9 +295,14 @@ func TestBulkIndexer(t *testing.T) { successfulItemBodies []string failedItemBodies []string - numItems = 4 - numFailed = 2 - bodyContent, _ = os.ReadFile("testdata/bulk_response_2.json") + numItems = 4 + numFailed = 2 + bodyContent, _ = os.ReadFile("testdata/bulk_response_2.json") + bodyFailureCount = make(map[string]int) + bodiesExpectedToFail = map[string]struct{}{ + `{"title":"bar"}`: {}, + `{"title":"baz"}`: {}, + } ) client, _ := opensearchapi.NewClient( @@ -323,18 +342,21 @@ func TestBulkIndexer(t *testing.T) { } successfulItemBodies = append(successfulItemBodies, string(buf)) } + failureFunc := func(ctx context.Context, item BulkIndexerItem, res opensearchapi.BulkRespItem, err error) { if err != nil { t.Fatalf("Unexpected error: %s", err) } - atomic.AddUint64(&countFailed, 1) - failedIDs = append(failedIDs, item.DocumentID) buf, err := io.ReadAll(item.Body) if err != nil { t.Fatalf("Unexpected error: %s", err) } + + countFailed++ + failedIDs = append(failedIDs, item.DocumentID) failedItemBodies = append(failedItemBodies, string(buf)) + bodyFailureCount[string(buf)]++ } if err := bi.Add(context.Background(), BulkIndexerItem{ @@ -392,6 +414,25 @@ func TestBulkIndexer(t *testing.T) { // * Operation #2: document can't be created, because a document with the same ID already exists. // * Operation #3: document can't be deleted, because it doesn't exist. + if stats.NumFailed != uint64(len(bodyFailureCount)) { + t.Errorf("Expected %v items in the bodyFailureCount map, got %v", numFailed, len(bodyFailureCount)) + } + + for k, v := range bodyFailureCount { + if _, ok := bodiesExpectedToFail[k]; !ok { + t.Errorf("Unexpected item body failure: %v", k) + } + delete(bodiesExpectedToFail, k) + + if v != 1 { + t.Errorf("Expected 1 failure callback call for item %v, got %v", k, v) + } + } + + if len(bodiesExpectedToFail) > 0 { + t.Errorf("Expected failure callbacks for the following item bodies: %v", bodiesExpectedToFail) + } + if stats.NumFailed != uint64(numFailed) { t.Errorf("Unexpected NumFailed: %d", stats.NumFailed) } @@ -765,6 +806,108 @@ func TestBulkIndexer(t *testing.T) { }) } }) + + t.Run("Items Failure Callbacks Executed On Bulk Failure", func(t *testing.T) { + var ( + numItems = 5 + idsExpectedToFail = make(map[string]struct{}, numItems) + idsFailureCount = make(map[string]int) + + onErrorCallCount int + wg sync.WaitGroup + ) + + // if the test takes more than 5 seconds, it's a failure. Want to avoid infinitely waiting for the waitgroup in + // the edge case where a failure callback is not executed. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + client, _ := opensearchapi.NewClient(opensearchapi.Config{ + Client: opensearch.Config{ + Transport: &mockTransport{ + RoundTripFunc: func(request *http.Request) (*http.Response, error) { + if request.URL.Path == "/" { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(strings.NewReader(infoBody)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + } + + return nil, errors.New("simulated bulk request error") + }, + }, + }, + }) + + bi, _ := NewBulkIndexer(BulkIndexerConfig{ + NumWorkers: 1, + FlushBytes: 1, + Client: client, + OnError: func(ctx context.Context, err error) { + onErrorCallCount++ + if err.Error() != "flush: simulated bulk request error" { + t.Errorf("Unexpected error: %v", err) + } + }, + }) + + wg.Add(numItems) + for i := 0; i < numItems; i++ { + id := fmt.Sprintf("id_%d", i) + idsExpectedToFail[id] = struct{}{} + if err := bi.Add(ctx, BulkIndexerItem{ + Action: "index", + DocumentID: id, + Body: strings.NewReader(fmt.Sprintf(`{"title":"doc_%d"}`, i)), + OnFailure: func(ctx context.Context, item BulkIndexerItem, resp opensearchapi.BulkRespItem, err error) { + if err.Error() != "flush: simulated bulk request error" { + t.Errorf("Unexpected error in OnFailure: %v", err) + } + + idsFailureCount[item.DocumentID]++ + wg.Done() + }, + }); err != nil { + t.Fatalf("Unexpected error adding item: %v", err) + } + } + + if err := bi.Close(ctx); err != nil { + t.Errorf("Unexpected error: %s", err) + } + + wg.Wait() + + if onErrorCallCount != numItems { + t.Errorf("Expected %d calls to OnError, got %d", numItems, onErrorCallCount) + } + + stats := bi.Stats() + if stats.NumFailed != uint64(len(idsFailureCount)) { + t.Errorf("Expected %d items in the idsFailureCount map, got %d", numItems, len(idsFailureCount)) + } + + for k, v := range idsFailureCount { + if _, ok := idsExpectedToFail[k]; !ok { + t.Errorf("Unexpected item ID failure: %v", k) + } + delete(idsExpectedToFail, k) + + if v != 1 { + t.Errorf("Expected 1 failure callback call for item %v, got %v", k, v) + } + } + + if len(idsExpectedToFail) > 0 { + t.Errorf("Expected failure callbacks for the following item IDs: %v", idsExpectedToFail) + } + + if stats.NumFailed != uint64(numItems) { + t.Errorf("Expected NumFailed to be %d, got %d", numItems, stats.NumFailed) + } + }) } func strPointer(s string) *string {