diff --git a/backend/src/main/java/com/cloudera/cai/rag/external/RagBackendClient.java b/backend/src/main/java/com/cloudera/cai/rag/external/RagBackendClient.java index acc0e42d..ad0c5cd7 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/external/RagBackendClient.java +++ b/backend/src/main/java/com/cloudera/cai/rag/external/RagBackendClient.java @@ -50,6 +50,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.opentelemetry.instrumentation.annotations.WithSpan; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; @@ -152,11 +154,10 @@ public static RagBackendClient createNull() { return createNull(new Tracker<>()); } - public static RagBackendClient createNull( - Tracker> tracker, RuntimeException... t) { + public static RagBackendClient createNull(Tracker> tracker, List r) { return new RagBackendClient(SimpleHttpClient.createNull()) { - private final RuntimeException[] exceptions = t; - private int exceptionIndex = 0; + private final List runnables = r; + private int runnableIndex = 0; @Override public void indexFile( @@ -170,8 +171,8 @@ public void indexFile( } private void checkForException() { - if (exceptionIndex < exceptions.length) { - throw exceptions[exceptionIndex++]; + if (runnableIndex < runnables.size()) { + runnables.get(runnableIndex++).run(); } } @@ -209,6 +210,20 @@ public void deleteDocument(long dataSourceId, String documentId) { }; } + public static RagBackendClient createNull( + Tracker> tracker, RuntimeException... t) { + return RagBackendClient.createNull( + tracker, + Arrays.stream(t) + .map( + e -> + (Runnable) + () -> { + throw e; + }) + .toList()); + } + public record TrackedIndexRequest( String bucketName, String s3Path, long dataSourceId, IndexConfiguration configuration) {} diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileIndexReconcilerTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileIndexReconcilerTest.java index 4ec78aa4..ec1d58d7 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileIndexReconcilerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileIndexReconcilerTest.java @@ -54,7 +54,10 @@ import com.cloudera.cai.util.reconcilers.ReconcilerConfig; import io.opentelemetry.api.OpenTelemetry; import java.time.Instant; +import java.util.Arrays; +import java.util.List; import java.util.UUID; +import java.util.concurrent.CountDownLatch; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.Test; @@ -69,33 +72,8 @@ void reconcile() { RagFileIndexReconciler reconciler = createTestInstance(requestTracker); String documentId = UUID.randomUUID().toString(); - long dataSourceId = - ragDataSourceRepository.createRagDataSource( - new RagDataSource( - null, - "test_datasource", - "test_embedding_model", - "summarizationModel", - 1024, - 20, - null, - null, - "test-id", - "test-id", - Types.ConnectionType.API, - null, - null)); - RagDocument document = - RagDocument.builder() - .documentId(documentId) - .dataSourceId(dataSourceId) - .s3Path("path_in_s3") - .extension("pdf") - .filename("myfile.pdf") - .timeCreated(Instant.now()) - .timeUpdated(Instant.now()) - .createdById("test-id") - .build(); + var dataSourceId = newDataSource(); + var document = createTestDoc(documentId, dataSourceId); Long id = ragFileRepository.insertDocumentMetadata(document); assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp()) .isNull(); @@ -120,39 +98,101 @@ void reconcile() { "rag-files", "path_in_s3", dataSourceId, new IndexConfiguration(1024, 20)))); } + @Test + void reconcile_stateChanges() { + Tracker> requestTracker = new Tracker<>(); + var waiter = new CountDownLatch(1); + RagFileIndexReconciler reconciler = + createTestInstance( + requestTracker, + List.of( + () -> { + try { + waiter.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + })); + + String documentId = UUID.randomUUID().toString(); + var dataSourceId = newDataSource(); + var document = createTestDoc(documentId, dataSourceId); + Long id = ragFileRepository.insertDocumentMetadata(document); + assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp()) + .isNull(); + + reconciler.submit(document.withId(id)); + // verify the doc is in a in-progress state + await() + .untilAsserted( + () -> { + RagDocument updatedDocument = ragFileRepository.findDocumentByDocumentId(documentId); + assertThat(updatedDocument.vectorUploadTimestamp()).isNull(); + assertThat(updatedDocument.indexingStatus()) + .isEqualTo(Types.RagDocumentStatus.IN_PROGRESS); + }); + + waiter.countDown(); + await().until(reconciler::isEmpty); + await() + .untilAsserted( + () -> { + assertThat(reconciler.isEmpty()).isTrue(); + RagDocument updatedDocument = ragFileRepository.findDocumentByDocumentId(documentId); + assertThat(updatedDocument.vectorUploadTimestamp()).isNotNull(); + assertThat(updatedDocument.indexingStatus()) + .isEqualTo(Types.RagDocumentStatus.SUCCESS); + }); + assertThat(requestTracker.getValues()) + .hasSize(1) + .contains( + new RagBackendClient.TrackedRequest<>( + new TrackedIndexRequest( + "rag-files", + document.s3Path(), + dataSourceId, + new IndexConfiguration(1024, 20)))); + } + + private static RagDocument createTestDoc(String documentId, long dataSourceId) { + return RagDocument.builder() + .documentId(documentId) + .dataSourceId(dataSourceId) + .s3Path("path_in_s3") + .extension("pdf") + .filename("myfile.pdf") + .timeCreated(Instant.now()) + .timeUpdated(Instant.now()) + .createdById("test-id") + .build(); + } + + private long newDataSource() { + return ragDataSourceRepository.createRagDataSource( + new RagDataSource( + null, + "test_datasource", + "test_embedding_model", + "summarizationModel", + 1024, + 20, + null, + null, + "test-id", + "test-id", + Types.ConnectionType.API, + null, + null)); + } + @Test void reconcile_notFound() { var requestTracker = new Tracker>(); RagFileIndexReconciler reconciler = createTestInstance(requestTracker, new NotFound("datasource not found in the rag backend")); String documentId = UUID.randomUUID().toString(); - long dataSourceId = - ragDataSourceRepository.createRagDataSource( - new RagDataSource( - null, - "test_datasource", - "test_embedding_model", - "summarizationModel", - 1024, - 20, - null, - null, - "test-id", - "test-id", - Types.ConnectionType.API, - null, - null)); - RagDocument document = - RagDocument.builder() - .documentId(documentId) - .dataSourceId(dataSourceId) - .s3Path("path_in_s3") - .extension("pdf") - .filename("myfile.pdf") - .timeCreated(Instant.now()) - .timeUpdated(Instant.now()) - .createdById("test-id") - .build(); + var dataSourceId = newDataSource(); + var document = createTestDoc(documentId, dataSourceId); Long id = ragFileRepository.insertDocumentMetadata(document); assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp()) .isNull(); @@ -182,33 +222,8 @@ void reconcile_exception() { RagFileIndexReconciler reconciler = createTestInstance(requestTracker, new RuntimeException("document indexing failed")); String documentId = UUID.randomUUID().toString(); - long dataSourceId = - ragDataSourceRepository.createRagDataSource( - new RagDataSource( - null, - "test_datasource", - "test_embedding_model", - "summarizationModel", - 1024, - 20, - null, - null, - "test-id", - "test-id", - Types.ConnectionType.API, - null, - null)); - RagDocument document = - RagDocument.builder() - .documentId(documentId) - .dataSourceId(dataSourceId) - .s3Path("path_in_s3") - .extension("pdf") - .filename("myfile.pdf") - .timeCreated(Instant.now()) - .timeUpdated(Instant.now()) - .createdById("test-id") - .build(); + var dataSourceId = newDataSource(); + var document = createTestDoc(documentId, dataSourceId); Long id = ragFileRepository.insertDocumentMetadata(document); assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp()) .isNull(); @@ -233,6 +248,20 @@ void reconcile_exception() { private RagFileIndexReconciler createTestInstance( Tracker> tracker, RuntimeException... exceptions) { + return createTestInstance( + tracker, + Arrays.stream(exceptions) + .map( + e -> + (Runnable) + () -> { + throw e; + }) + .toList()); + } + + private RagFileIndexReconciler createTestInstance( + Tracker> tracker, List runnables) { Jdbi jdbi = new JdbiConfiguration().jdbi(); var reconcilerConfig = ReconcilerConfig.builder().isTestReconciler(true).workerCount(1).build(); @@ -240,7 +269,7 @@ private RagFileIndexReconciler createTestInstance( new RagFileIndexReconciler( "rag-files", jdbi, - RagBackendClient.createNull(tracker, exceptions), + RagBackendClient.createNull(tracker, runnables), RagDataSourceRepository.createNull(), reconcilerConfig, RagFileRepository.createNull(), diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileSummaryReconcilerTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileSummaryReconcilerTest.java index 653a591e..c9f35bab 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileSummaryReconcilerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileSummaryReconcilerTest.java @@ -52,8 +52,10 @@ import com.cloudera.cai.util.reconcilers.ReconcilerConfig; import io.opentelemetry.api.OpenTelemetry; import java.time.Instant; +import java.util.Arrays; import java.util.List; import java.util.UUID; +import java.util.concurrent.CountDownLatch; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.Test; @@ -72,42 +74,17 @@ void reconcile() { RagFileSummaryReconciler reconciler = createTestInstance(requestTracker); String documentId = UUID.randomUUID().toString(); - long dataSourceId = - ragDataSourceRepository.createRagDataSource( - new RagDataSource( - null, - "test_datasource", - "test_embedding_model", - "summarizationModel", - 1024, - 20, - null, - null, - "test-id", - "test-id", - Types.ConnectionType.API, - null, - null)); - RagDocument document = - RagDocument.builder() - .documentId(documentId) - .dataSourceId(dataSourceId) - .s3Path("path_in_s3") - .extension("pdf") - .filename("myfile.pdf") - .timeCreated(Instant.now()) - .timeUpdated(Instant.now()) - .createdById("test-id") - .build(); + var dataSourceId = createDataSource("summarizationModel"); + var document = createTestDoc(documentId, dataSourceId, "path_in_s3"); Long id = ragFileRepository.insertDocumentMetadata(document); assertThat(ragFileRepository.findDocumentByDocumentId(documentId).summaryCreationTimestamp()) .isNull(); reconciler.submit(document.withId(id)); // add a copy that has already been summarized to make sure we don't try to - // re-summarize with - // long-running summarizations + // re-summarize with long-running summarizations reconciler.submit(document.withId(id).withSummaryCreationTimestamp(Instant.now())); + await().until(reconciler::isEmpty); await() .untilAsserted( @@ -115,15 +92,100 @@ void reconcile() { assertThat(reconciler.isEmpty()).isTrue(); RagDocument updatedDocument = ragFileRepository.findDocumentByDocumentId(documentId); assertThat(updatedDocument.summaryCreationTimestamp()).isNotNull(); + assertThat(updatedDocument.summaryStatus()) + .isEqualTo(Types.RagDocumentStatus.SUCCESS); assertThat(requestTracker.getValues()) .hasSize(1) .contains( new RagBackendClient.TrackedRequest<>( new RagBackendClient.SummaryRequest( - "rag-files", "path_in_s3", "myfile.pdf"))); + "rag-files", document.s3Path(), document.filename()))); + }); + } + + private static RagDocument createTestDoc( + String documentId, long dataSourceId, String path_in_s3) { + return RagDocument.builder() + .documentId(documentId) + .dataSourceId(dataSourceId) + .s3Path(path_in_s3) + .extension("pdf") + .filename("myfile.pdf") + .timeCreated(Instant.now()) + .timeUpdated(Instant.now()) + .createdById("test-id") + .build(); + } + + @Test + void reconcile_stateChanges() { + Tracker> requestTracker = new Tracker<>(); + var waiter = new CountDownLatch(1); + RagFileSummaryReconciler reconciler = + createTestInstance( + requestTracker, + List.of( + () -> { + try { + waiter.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + })); + + String documentId = UUID.randomUUID().toString(); + var dataSourceId = createDataSource("summarizationModel"); + var document = createTestDoc(documentId, dataSourceId, "path_in_s3"); + Long id = ragFileRepository.insertDocumentMetadata(document); + assertThat(ragFileRepository.findDocumentByDocumentId(documentId).summaryCreationTimestamp()) + .isNull(); + + reconciler.submit(document.withId(id)); + + await() + .untilAsserted( + () -> { + RagDocument updatedDocument = ragFileRepository.findDocumentByDocumentId(documentId); + assertThat(updatedDocument.summaryCreationTimestamp()).isNull(); + assertThat(updatedDocument.summaryStatus()) + .isEqualTo(Types.RagDocumentStatus.IN_PROGRESS); + }); + + waiter.countDown(); + await().until(reconciler::isEmpty); + await() + .untilAsserted( + () -> { + assertThat(reconciler.isEmpty()).isTrue(); + RagDocument updatedDocument = ragFileRepository.findDocumentByDocumentId(documentId); + assertThat(updatedDocument.summaryCreationTimestamp()).isNotNull(); + assertThat(requestTracker.getValues()) + .hasSize(1) + .contains( + new RagBackendClient.TrackedRequest<>( + new RagBackendClient.SummaryRequest( + "rag-files", document.s3Path(), document.filename()))); }); } + private long createDataSource(String summarizationModel) { + return ragDataSourceRepository.createRagDataSource( + new RagDataSource( + null, + "test_datasource", + "test_embedding_model", + summarizationModel, + 1024, + 20, + null, + null, + "test-id", + "test-id", + Types.ConnectionType.API, + null, + null)); + } + @Test void reconcile_notFound() { Tracker> requestTracker = new Tracker<>(); @@ -131,33 +193,8 @@ void reconcile_notFound() { createTestInstance(requestTracker, new NotFound("not found")); String documentId = UUID.randomUUID().toString(); - long dataSourceId = - ragDataSourceRepository.createRagDataSource( - new RagDataSource( - null, - "test_datasource", - "test_embedding_model", - "summarizationModel", - 1024, - 20, - null, - null, - "test-id", - "test-id", - Types.ConnectionType.API, - null, - null)); - RagDocument document = - RagDocument.builder() - .documentId(documentId) - .dataSourceId(dataSourceId) - .s3Path("path_in_s3") - .extension("pdf") - .filename("myfile.pdf") - .timeCreated(Instant.now()) - .timeUpdated(Instant.now()) - .createdById("test-id") - .build(); + var dataSourceId = createDataSource("summarizationModel"); + var document = createTestDoc(documentId, dataSourceId, "path_in_s3"); Long id = ragFileRepository.insertDocumentMetadata(document); assertThat(ragFileRepository.findDocumentByDocumentId(documentId).summaryCreationTimestamp()) .isNull(); @@ -188,33 +225,8 @@ void reconcile_exception() { createTestInstance(requestTracker, new RuntimeException("document summarization failed")); String documentId = UUID.randomUUID().toString(); - long dataSourceId = - ragDataSourceRepository.createRagDataSource( - new RagDataSource( - null, - "test_datasource", - "test_embedding_model", - "summarizationModel", - 1024, - 20, - null, - null, - "test-id", - "test-id", - Types.ConnectionType.API, - null, - null)); - RagDocument document = - RagDocument.builder() - .documentId(documentId) - .dataSourceId(dataSourceId) - .s3Path("path_in_s3") - .extension("pdf") - .filename("myfile.pdf") - .timeCreated(Instant.now()) - .timeUpdated(Instant.now()) - .createdById("test-id") - .build(); + var dataSourceId = createDataSource("summarizationModel"); + var document = createTestDoc(documentId, dataSourceId, "path_in_s3"); Long id = ragFileRepository.insertDocumentMetadata(document); assertThat(ragFileRepository.findDocumentByDocumentId(documentId).summaryCreationTimestamp()) .isNull(); @@ -243,34 +255,9 @@ void reconcile_noSummarizationModel() { Tracker> requestTracker = new Tracker<>(); RagFileSummaryReconciler reconciler = createTestInstance(requestTracker); - long dataSourceId = - ragDataSourceRepository.createRagDataSource( - new RagDataSource( - null, - "test_datasource", - "test_embedding_model", - null, - 1024, - 20, - null, - null, - "test-id", - "test-id", - Types.ConnectionType.API, - null, - null)); + var dataSourceId = createDataSource(null); String documentId = UUID.randomUUID().toString(); - RagDocument document = - RagDocument.builder() - .documentId(documentId) - .dataSourceId(dataSourceId) - .s3Path("path_in_s3_no_summarization_model") - .extension("pdf") - .filename("myfile.pdf") - .timeCreated(Instant.now()) - .timeUpdated(Instant.now()) - .createdById("test-id") - .build(); + var document = createTestDoc(documentId, dataSourceId, "path_in_s3_no_summarization_model"); ragFileRepository.insertDocumentMetadata(document); assertThat(ragFileRepository.findDocumentByDocumentId(documentId).summaryCreationTimestamp()) .isNull(); @@ -294,13 +281,27 @@ void reconcile_noSummarizationModel() { private RagFileSummaryReconciler createTestInstance( Tracker> tracker, RuntimeException... exceptions) { + return createTestInstance( + tracker, + Arrays.stream(exceptions) + .map( + e -> + (Runnable) + () -> { + throw e; + }) + .toList()); + } + + private RagFileSummaryReconciler createTestInstance( + Tracker> tracker, List runnables) { Jdbi jdbi = new JdbiConfiguration().jdbi(); var reconcilerConfig = ReconcilerConfig.builder().isTestReconciler(true).workerCount(1).build(); RagFileSummaryReconciler reconciler = new RagFileSummaryReconciler( "rag-files", jdbi, - RagBackendClient.createNull(tracker, exceptions), + RagBackendClient.createNull(tracker, runnables), ragFileRepository, reconcilerConfig, OpenTelemetry.noop());