Skip to content

Commit

Permalink
add tests for the intermediate state change on document reconciliation
Browse files Browse the repository at this point in the history
  • Loading branch information
jkwatson committed Dec 20, 2024
1 parent c37b75f commit e63df42
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.opentelemetry.instrumentation.annotations.WithSpan;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

Expand Down Expand Up @@ -152,11 +154,10 @@ public static RagBackendClient createNull() {
return createNull(new Tracker<>());
}

public static RagBackendClient createNull(
Tracker<TrackedRequest<?>> tracker, RuntimeException... t) {
public static RagBackendClient createNull(Tracker<TrackedRequest<?>> tracker, List<Runnable> r) {
return new RagBackendClient(SimpleHttpClient.createNull()) {
private final RuntimeException[] exceptions = t;
private int exceptionIndex = 0;
private final List<Runnable> runnables = r;
private int runnableIndex = 0;

@Override
public void indexFile(
Expand All @@ -170,8 +171,8 @@ public void indexFile(
}

private void checkForException() {
if (exceptionIndex < exceptions.length) {
throw exceptions[exceptionIndex++];
if (runnableIndex < runnables.size()) {
runnables.get(runnableIndex++).run();
}
}

Expand Down Expand Up @@ -209,6 +210,20 @@ public void deleteDocument(long dataSourceId, String documentId) {
};
}

public static RagBackendClient createNull(
Tracker<TrackedRequest<?>> tracker, RuntimeException... t) {
return RagBackendClient.createNull(
tracker,
Arrays.stream(t)
.map(
e ->
(Runnable)
() -> {
throw e;
})
.toList());
}

public record TrackedIndexRequest(
String bucketName, String s3Path, long dataSourceId, IndexConfiguration configuration) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
import com.cloudera.cai.util.reconcilers.ReconcilerConfig;
import io.opentelemetry.api.OpenTelemetry;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import org.jdbi.v3.core.Jdbi;
import org.junit.jupiter.api.Test;

Expand All @@ -69,33 +72,8 @@ void reconcile() {
RagFileIndexReconciler reconciler = createTestInstance(requestTracker);

String documentId = UUID.randomUUID().toString();
long dataSourceId =
ragDataSourceRepository.createRagDataSource(
new RagDataSource(
null,
"test_datasource",
"test_embedding_model",
"summarizationModel",
1024,
20,
null,
null,
"test-id",
"test-id",
Types.ConnectionType.API,
null,
null));
RagDocument document =
RagDocument.builder()
.documentId(documentId)
.dataSourceId(dataSourceId)
.s3Path("path_in_s3")
.extension("pdf")
.filename("myfile.pdf")
.timeCreated(Instant.now())
.timeUpdated(Instant.now())
.createdById("test-id")
.build();
var dataSourceId = newDataSource();
var document = createTestDoc(documentId, dataSourceId);
Long id = ragFileRepository.insertDocumentMetadata(document);
assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp())
.isNull();
Expand All @@ -120,39 +98,101 @@ void reconcile() {
"rag-files", "path_in_s3", dataSourceId, new IndexConfiguration(1024, 20))));
}

@Test
void reconcile_stateChanges() {
Tracker<RagBackendClient.TrackedRequest<?>> requestTracker = new Tracker<>();
var waiter = new CountDownLatch(1);
RagFileIndexReconciler reconciler =
createTestInstance(
requestTracker,
List.of(
() -> {
try {
waiter.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}));

String documentId = UUID.randomUUID().toString();
var dataSourceId = newDataSource();
var document = createTestDoc(documentId, dataSourceId);
Long id = ragFileRepository.insertDocumentMetadata(document);
assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp())
.isNull();

reconciler.submit(document.withId(id));
// verify the doc is in a in-progress state
await()
.untilAsserted(
() -> {
RagDocument updatedDocument = ragFileRepository.findDocumentByDocumentId(documentId);
assertThat(updatedDocument.vectorUploadTimestamp()).isNull();
assertThat(updatedDocument.indexingStatus())
.isEqualTo(Types.RagDocumentStatus.IN_PROGRESS);
});

waiter.countDown();
await().until(reconciler::isEmpty);
await()
.untilAsserted(
() -> {
assertThat(reconciler.isEmpty()).isTrue();
RagDocument updatedDocument = ragFileRepository.findDocumentByDocumentId(documentId);
assertThat(updatedDocument.vectorUploadTimestamp()).isNotNull();
assertThat(updatedDocument.indexingStatus())
.isEqualTo(Types.RagDocumentStatus.SUCCESS);
});
assertThat(requestTracker.getValues())
.hasSize(1)
.contains(
new RagBackendClient.TrackedRequest<>(
new TrackedIndexRequest(
"rag-files",
document.s3Path(),
dataSourceId,
new IndexConfiguration(1024, 20))));
}

private static RagDocument createTestDoc(String documentId, long dataSourceId) {
return RagDocument.builder()
.documentId(documentId)
.dataSourceId(dataSourceId)
.s3Path("path_in_s3")
.extension("pdf")
.filename("myfile.pdf")
.timeCreated(Instant.now())
.timeUpdated(Instant.now())
.createdById("test-id")
.build();
}

private long newDataSource() {
return ragDataSourceRepository.createRagDataSource(
new RagDataSource(
null,
"test_datasource",
"test_embedding_model",
"summarizationModel",
1024,
20,
null,
null,
"test-id",
"test-id",
Types.ConnectionType.API,
null,
null));
}

@Test
void reconcile_notFound() {
var requestTracker = new Tracker<RagBackendClient.TrackedRequest<?>>();
RagFileIndexReconciler reconciler =
createTestInstance(requestTracker, new NotFound("datasource not found in the rag backend"));
String documentId = UUID.randomUUID().toString();
long dataSourceId =
ragDataSourceRepository.createRagDataSource(
new RagDataSource(
null,
"test_datasource",
"test_embedding_model",
"summarizationModel",
1024,
20,
null,
null,
"test-id",
"test-id",
Types.ConnectionType.API,
null,
null));
RagDocument document =
RagDocument.builder()
.documentId(documentId)
.dataSourceId(dataSourceId)
.s3Path("path_in_s3")
.extension("pdf")
.filename("myfile.pdf")
.timeCreated(Instant.now())
.timeUpdated(Instant.now())
.createdById("test-id")
.build();
var dataSourceId = newDataSource();
var document = createTestDoc(documentId, dataSourceId);
Long id = ragFileRepository.insertDocumentMetadata(document);
assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp())
.isNull();
Expand Down Expand Up @@ -182,33 +222,8 @@ void reconcile_exception() {
RagFileIndexReconciler reconciler =
createTestInstance(requestTracker, new RuntimeException("document indexing failed"));
String documentId = UUID.randomUUID().toString();
long dataSourceId =
ragDataSourceRepository.createRagDataSource(
new RagDataSource(
null,
"test_datasource",
"test_embedding_model",
"summarizationModel",
1024,
20,
null,
null,
"test-id",
"test-id",
Types.ConnectionType.API,
null,
null));
RagDocument document =
RagDocument.builder()
.documentId(documentId)
.dataSourceId(dataSourceId)
.s3Path("path_in_s3")
.extension("pdf")
.filename("myfile.pdf")
.timeCreated(Instant.now())
.timeUpdated(Instant.now())
.createdById("test-id")
.build();
var dataSourceId = newDataSource();
var document = createTestDoc(documentId, dataSourceId);
Long id = ragFileRepository.insertDocumentMetadata(document);
assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp())
.isNull();
Expand All @@ -233,14 +248,28 @@ void reconcile_exception() {

private RagFileIndexReconciler createTestInstance(
Tracker<RagBackendClient.TrackedRequest<?>> tracker, RuntimeException... exceptions) {
return createTestInstance(
tracker,
Arrays.stream(exceptions)
.map(
e ->
(Runnable)
() -> {
throw e;
})
.toList());
}

private RagFileIndexReconciler createTestInstance(
Tracker<RagBackendClient.TrackedRequest<?>> tracker, List<Runnable> runnables) {
Jdbi jdbi = new JdbiConfiguration().jdbi();
var reconcilerConfig = ReconcilerConfig.builder().isTestReconciler(true).workerCount(1).build();

RagFileIndexReconciler reconciler =
new RagFileIndexReconciler(
"rag-files",
jdbi,
RagBackendClient.createNull(tracker, exceptions),
RagBackendClient.createNull(tracker, runnables),
RagDataSourceRepository.createNull(),
reconcilerConfig,
RagFileRepository.createNull(),
Expand Down
Loading

0 comments on commit e63df42

Please sign in to comment.