From bdfabb74ab2d10013bfe262a85d1a719a706be9d Mon Sep 17 00:00:00 2001 From: John Watson Date: Tue, 10 Dec 2024 08:17:58 -0800 Subject: [PATCH] CAII endpoint discovery (#60) * "wip on endpoint listing" * "wip on list_endpoints typing" * "refactoring to endpoint object" * "wip filtering" * "endpoints queried!" * "refactoring" * "wip on cleaning up types" * "type cleanup complete" * "moving files" * "use a dummy embedding model for deletes" * fix some bits from merge, get evals working again with CAII, tests passing * formatting * clean up ruff stuff * use the chat llm for evals * fix mypy for reformatting * "wip on java reconciler" * "reconciler don't do no model; start python work" * "python - updating for summarization model" * "comment out batch embeddings to get it working again" * add handling for no summarization in the files table * finish up ui and python for summarization * make sure to update the time-updated fields on data sources and chat sessions * use no-op models when we don't need real ones for summary functionality * Update release version to dev-testing * use the summarization llm when summarizing summaries --------- Co-authored-by: Elijah Williams Co-authored-by: actions-user --- .env.example | 2 - .project-metadata.yaml | 8 -- .../main/java/com/cloudera/cai/rag/Types.java | 1 + .../datasources/RagDataSourceRepository.java | 9 +- .../rag/files/RagFileSummaryReconciler.java | 8 +- .../cai/rag/sessions/SessionRepository.java | 5 +- .../com/cloudera/cai/util/IdGenerator.java | 4 +- .../h2/15_add_summarization_model.down.sql | 45 ++++++ .../h2/15_add_summarization_model.up.sql | 45 ++++++ .../main/resources/migrations/migrations.txt | 4 +- .../15_add_summarization_model.down.sql | 45 ++++++ .../15_add_summarization_model.up.sql | 45 ++++++ .../java/com/cloudera/cai/rag/TestData.java | 1 + .../RagDataSourceControllerTest.java | 1 + .../RagDataSourceRepositoryTest.java | 16 ++- .../cai/rag/files/RagFileControllerTest.java | 1 + .../rag/files/RagFileIndexReconcilerTest.java | 2 + .../cai/rag/files/RagFileServiceTest.java | 8 +- .../files/RagFileSummaryReconcilerTest.java | 64 ++++++++- .../rag/sessions/SessionControllerTest.java | 6 +- .../cloudera/cai/util/IdGeneratorTest.java | 4 +- llm-service/app/ai/indexing/index.py | 4 +- llm-service/app/ai/indexing/readers/pdf.py | 18 ++- llm-service/app/ai/vector_stores/qdrant.py | 46 ++++-- .../app/ai/vector_stores/vector_store.py | 4 +- .../app/routers/index/data_source/__init__.py | 10 +- .../app/routers/index/models/__init__.py | 7 +- .../services/{ => caii}/CaiiEmbeddingModel.py | 74 +++++----- .../app/services/{ => caii}/CaiiModel.py | 0 llm-service/app/services/caii/__init__.py | 37 +++++ llm-service/app/services/{ => caii}/caii.py | 136 +++++++++--------- llm-service/app/services/caii/types.py | 122 ++++++++++++++++ llm-service/app/services/caii/utils.py | 47 ++++++ llm-service/app/services/chat.py | 8 +- .../app/services/data_sources_metadata_api.py | 6 +- llm-service/app/services/doc_summaries.py | 45 +++--- llm-service/app/services/evaluators.py | 6 +- llm-service/app/services/models.py | 98 +++++++------ llm-service/app/services/noop_models.py | 111 ++++++++++++++ llm-service/app/tests/ai/indexing/test_csv.py | 2 +- .../ai/indexing/test_pdf_page_tracker.py | 10 +- llm-service/app/tests/conftest.py | 82 ++--------- .../tests/routers/index/test_data_source.py | 2 +- llm-service/scripts/get_job_run_status.py | 5 +- llm-service/scripts/run_refresh_job.py | 6 +- scripts/01_install_base.py | 18 ++- scripts/refresh_project.py | 5 +- scripts/release_version.txt | 2 +- scripts/startup_app.py | 2 +- ui/src/api/dataSourceApi.ts | 1 + .../DataSourcesManagement/DataSourcesForm.tsx | 35 ++++- .../IndexSettingsTab/IndexSettings.tsx | 1 + .../ManageTab/UploadedFilesTable.tsx | 59 ++++++-- 53 files changed, 995 insertions(+), 338 deletions(-) create mode 100644 backend/src/main/resources/migrations/h2/15_add_summarization_model.down.sql create mode 100644 backend/src/main/resources/migrations/h2/15_add_summarization_model.up.sql create mode 100644 backend/src/main/resources/migrations/postgres/15_add_summarization_model.down.sql create mode 100644 backend/src/main/resources/migrations/postgres/15_add_summarization_model.up.sql rename llm-service/app/services/{ => caii}/CaiiEmbeddingModel.py (65%) rename llm-service/app/services/{ => caii}/CaiiModel.py (100%) create mode 100644 llm-service/app/services/caii/__init__.py rename llm-service/app/services/{ => caii}/caii.py (59%) create mode 100644 llm-service/app/services/caii/types.py create mode 100644 llm-service/app/services/caii/utils.py create mode 100644 llm-service/app/services/noop_models.py diff --git a/.env.example b/.env.example index 52c355ed..f3f3e540 100644 --- a/.env.example +++ b/.env.example @@ -12,8 +12,6 @@ DB_URL=jdbc:h2:../databases/rag # If using CAII, fill these in: CAII_DOMAIN= -CAII_INFERENCE_ENDPOINT_NAME= -CAII_EMBEDDING_ENDPOINT_NAME= # set this to true if you have uv installed on your system, other wise don't include this USE_SYSTEM_UV=true diff --git a/.project-metadata.yaml b/.project-metadata.yaml index cc1f0e36..7dffc1c7 100644 --- a/.project-metadata.yaml +++ b/.project-metadata.yaml @@ -34,14 +34,6 @@ environment_variables: default: "" description: "The domain of the CAII service. Setting this will enable CAII as the sole source for both inference and embedding models." required: false - CAII_INFERENCE_ENDPOINT_NAME: - default: "" - description: "The name of the inference endpoint for the CAII service. Required if CAII_DOMAIN is set." - required: false - CAII_EMBEDDING_ENDPOINT_NAME: - default: "" - description: "The name of the embedding endpoint for the CAII service. Required if CAII_DOMAIN is set." - required: false DB_URL: default: "jdbc:h2:file:~/databases/rag" description: "Internal DB URL. Do not change." diff --git a/backend/src/main/java/com/cloudera/cai/rag/Types.java b/backend/src/main/java/com/cloudera/cai/rag/Types.java index f85ae7d9..2ae38993 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/Types.java +++ b/backend/src/main/java/com/cloudera/cai/rag/Types.java @@ -81,6 +81,7 @@ public record RagDataSource( Long id, String name, String embeddingModel, + String summarizationModel, Integer chunkSize, Integer chunkOverlapPercent, Instant timeCreated, diff --git a/backend/src/main/java/com/cloudera/cai/rag/datasources/RagDataSourceRepository.java b/backend/src/main/java/com/cloudera/cai/rag/datasources/RagDataSourceRepository.java index ac28bbc0..edace1c7 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/datasources/RagDataSourceRepository.java +++ b/backend/src/main/java/com/cloudera/cai/rag/datasources/RagDataSourceRepository.java @@ -41,6 +41,7 @@ import com.cloudera.cai.rag.Types.RagDataSource; import com.cloudera.cai.rag.configuration.JdbiConfiguration; import com.cloudera.cai.util.exceptions.NotFound; +import java.time.Instant; import java.util.List; import lombok.extern.slf4j.Slf4j; import org.jdbi.v3.core.Jdbi; @@ -62,8 +63,8 @@ public Long createRagDataSource(RagDataSource input) { handle -> { var sql = """ - INSERT INTO rag_data_source (name, chunk_size, chunk_overlap_percent, created_by_id, updated_by_id, connection_type, embedding_model) - VALUES (:name, :chunkSize, :chunkOverlapPercent, :createdById, :updatedById, :connectionType, :embeddingModel) + INSERT INTO rag_data_source (name, chunk_size, chunk_overlap_percent, created_by_id, updated_by_id, connection_type, embedding_model, summarization_model) + VALUES (:name, :chunkSize, :chunkOverlapPercent, :createdById, :updatedById, :connectionType, :embeddingModel, :summarizationModel) """; try (var update = handle.createUpdate(sql)) { update.bindMethods(input); @@ -78,7 +79,7 @@ public void updateRagDataSource(RagDataSource input) { var sql = """ UPDATE rag_data_source - SET name = :name, connection_type = :connectionType, updated_by_id = :updatedById + SET name = :name, connection_type = :connectionType, updated_by_id = :updatedById, summarization_model = :summarizationModel, time_updated = :now WHERE id = :id AND deleted IS NULL """; try (var update = handle.createUpdate(sql)) { @@ -87,6 +88,8 @@ public void updateRagDataSource(RagDataSource input) { .bind("updatedById", input.updatedById()) .bind("connectionType", input.connectionType()) .bind("id", input.id()) + .bind("summarizationModel", input.summarizationModel()) + .bind("now", Instant.now()) .execute(); } }); diff --git a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileSummaryReconciler.java b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileSummaryReconciler.java index f44395a7..e8e58c4e 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileSummaryReconciler.java +++ b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileSummaryReconciler.java @@ -85,9 +85,11 @@ public void resync() { log.debug("checking for RAG documents to be summarized"); String sql = """ - SELECT * from rag_data_source_document - WHERE summary_creation_timestamp IS NULL - AND time_created > :yesterday + SELECT rdsd.* from rag_data_source_document rdsd + JOIN rag_data_source rds ON rdsd.data_source_id = rds.id + WHERE rdsd.summary_creation_timestamp IS NULL + AND (rdsd.time_created > :yesterday OR rds.time_updated > :yesterday) + AND rds.summarization_model IS NOT NULL """; jdbi.useHandle( handle -> { diff --git a/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java b/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java index 3b806772..ff3d76c7 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java +++ b/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java @@ -164,15 +164,16 @@ public void delete(Long id) { } public void update(Types.Session input) { + var updatedInput = input.withTimeUpdated(Instant.now()); jdbi.useHandle( handle -> { var sql = """ UPDATE CHAT_SESSION - SET name = :name, updated_by_id = :updatedById, inference_model = :inferenceModel, response_chunks = :responseChunks + SET name = :name, updated_by_id = :updatedById, inference_model = :inferenceModel, response_chunks = :responseChunks, time_updated = :timeUpdated WHERE id = :id """; - handle.createUpdate(sql).bindMethods(input).execute(); + handle.createUpdate(sql).bindMethods(updatedInput).execute(); }); } } diff --git a/backend/src/main/java/com/cloudera/cai/util/IdGenerator.java b/backend/src/main/java/com/cloudera/cai/util/IdGenerator.java index 5192431b..8f62f300 100644 --- a/backend/src/main/java/com/cloudera/cai/util/IdGenerator.java +++ b/backend/src/main/java/com/cloudera/cai/util/IdGenerator.java @@ -38,6 +38,7 @@ package com.cloudera.cai.util; +import java.util.Random; import java.util.UUID; import org.springframework.stereotype.Component; @@ -53,6 +54,7 @@ public static IdGenerator createNull(String... dummyIds) { } private static class NullIdGenerator extends IdGenerator { + private final Random random = new Random(); private final String[] dummyIds; @@ -62,7 +64,7 @@ private NullIdGenerator(String[] dummyIds) { @Override public String generateId() { - return dummyIds.length == 0 ? "StubbedId" : dummyIds[0]; + return dummyIds.length == 0 ? "StubbedId-" + random.nextInt() : dummyIds[0]; } } } diff --git a/backend/src/main/resources/migrations/h2/15_add_summarization_model.down.sql b/backend/src/main/resources/migrations/h2/15_add_summarization_model.down.sql new file mode 100644 index 00000000..a03d2e74 --- /dev/null +++ b/backend/src/main/resources/migrations/h2/15_add_summarization_model.down.sql @@ -0,0 +1,45 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2024 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +SET MODE MYSQL; + +BEGIN; + +ALTER TABLE rag_data_source DROP COLUMN summarization_model; + +COMMIT; \ No newline at end of file diff --git a/backend/src/main/resources/migrations/h2/15_add_summarization_model.up.sql b/backend/src/main/resources/migrations/h2/15_add_summarization_model.up.sql new file mode 100644 index 00000000..6ada2d05 --- /dev/null +++ b/backend/src/main/resources/migrations/h2/15_add_summarization_model.up.sql @@ -0,0 +1,45 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2024 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +SET MODE MYSQL; + +BEGIN; + +ALTER TABLE rag_data_source ADD COLUMN summarization_model varchar(255); + +COMMIT; \ No newline at end of file diff --git a/backend/src/main/resources/migrations/migrations.txt b/backend/src/main/resources/migrations/migrations.txt index caa1a622..57a295eb 100644 --- a/backend/src/main/resources/migrations/migrations.txt +++ b/backend/src/main/resources/migrations/migrations.txt @@ -26,4 +26,6 @@ 13_add_chat_configuration.down.sql 13_add_chat_configuration.up.sql 14_add_embedding_model.down.sql -14_add_embedding_model.up.sql \ No newline at end of file +14_add_embedding_model.up.sql +15_add_summarization_model.down.sql +15_add_summarization_model.up.sql \ No newline at end of file diff --git a/backend/src/main/resources/migrations/postgres/15_add_summarization_model.down.sql b/backend/src/main/resources/migrations/postgres/15_add_summarization_model.down.sql new file mode 100644 index 00000000..8cda71e4 --- /dev/null +++ b/backend/src/main/resources/migrations/postgres/15_add_summarization_model.down.sql @@ -0,0 +1,45 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2024 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +SET MODE MYSQL; + +BEGIN; + +ALTER TABLE rag_data_source DROP COLUMN embedding_model; + +COMMIT; \ No newline at end of file diff --git a/backend/src/main/resources/migrations/postgres/15_add_summarization_model.up.sql b/backend/src/main/resources/migrations/postgres/15_add_summarization_model.up.sql new file mode 100644 index 00000000..f1fada4d --- /dev/null +++ b/backend/src/main/resources/migrations/postgres/15_add_summarization_model.up.sql @@ -0,0 +1,45 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2024 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +SET MODE MYSQL; + +BEGIN; + +ALTER TABLE rag_data_source ADD COLUMN embedding_model varchar(255); + +COMMIT; \ No newline at end of file diff --git a/backend/src/test/java/com/cloudera/cai/rag/TestData.java b/backend/src/test/java/com/cloudera/cai/rag/TestData.java index d22a5b05..88a37c6f 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/TestData.java +++ b/backend/src/test/java/com/cloudera/cai/rag/TestData.java @@ -58,6 +58,7 @@ public static Types.RagDataSource createTestDataSourceInstance( null, name, "test_embedding_model", + "summarizationModel", chunkSize, chunkOverlapPercent, null, diff --git a/backend/src/test/java/com/cloudera/cai/rag/datasources/RagDataSourceControllerTest.java b/backend/src/test/java/com/cloudera/cai/rag/datasources/RagDataSourceControllerTest.java index 32780a21..6638673c 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/datasources/RagDataSourceControllerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/datasources/RagDataSourceControllerTest.java @@ -92,6 +92,7 @@ void updateName() { newDataSource.id(), "updated-name", "test_embedding_model", + "summarizationModel", newDataSource.chunkSize(), newDataSource.chunkOverlapPercent(), newDataSource.timeCreated(), diff --git a/backend/src/test/java/com/cloudera/cai/rag/datasources/RagDataSourceRepositoryTest.java b/backend/src/test/java/com/cloudera/cai/rag/datasources/RagDataSourceRepositoryTest.java index c1053a86..fa9c6695 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/datasources/RagDataSourceRepositoryTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/datasources/RagDataSourceRepositoryTest.java @@ -42,10 +42,12 @@ import static com.cloudera.cai.rag.Types.ConnectionType.MANUAL; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; import com.cloudera.cai.rag.TestData; import com.cloudera.cai.rag.configuration.JdbiConfiguration; import com.cloudera.cai.util.exceptions.NotFound; +import java.time.Duration; import java.time.Instant; import org.junit.jupiter.api.Test; @@ -75,8 +77,11 @@ void update() { TestData.createTestDataSourceInstance("test-name", 512, 10, MANUAL) .withCreatedById("abc") .withUpdatedById("abc")); - assertThat(repository.getRagDataSourceById(id).name()).isEqualTo("test-name"); - assertThat(repository.getRagDataSourceById(id).updatedById()).isEqualTo("abc"); + var insertedDataSource = repository.getRagDataSourceById(id); + assertThat(insertedDataSource.name()).isEqualTo("test-name"); + assertThat(insertedDataSource.updatedById()).isEqualTo("abc"); + var timeInserted = insertedDataSource.timeUpdated(); + assertThat(timeInserted).isNotNull(); var expectedRagDataSource = TestData.createTestDataSourceInstance("new-name", 512, 10, API) @@ -84,12 +89,15 @@ void update() { .withUpdatedById("def") .withId(id) .withDocumentCount(0); - + // wait a moment so the updated time will always be later than insert time + await().atLeast(Duration.ofMillis(1)); repository.updateRagDataSource(expectedRagDataSource); - assertThat(repository.getRagDataSourceById(id)) + var updatedDataSource = repository.getRagDataSourceById(id); + assertThat(updatedDataSource) .usingRecursiveComparison() .ignoringFieldsOfTypes(Instant.class) .isEqualTo(expectedRagDataSource); + assertThat(updatedDataSource.timeUpdated()).isAfter(timeInserted); } @Test diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java index d2b7c0ae..da453916 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java @@ -116,6 +116,7 @@ void getRagDocuments() { null, "test_datasource", "test_embedding_model", + "summarizationModel", 1024, 20, null, diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileIndexReconcilerTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileIndexReconcilerTest.java index 2d73236c..dd976d0d 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileIndexReconcilerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileIndexReconcilerTest.java @@ -75,6 +75,7 @@ void reconcile() { null, "test_datasource", "test_embedding_model", + "summarizationModel", 1024, 20, null, @@ -131,6 +132,7 @@ void reconcile_notFound() { null, "test_datasource", "test_embedding_model", + "summarizationModel", 1024, 20, null, diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java index 6e97dae9..cf306bc8 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java @@ -118,7 +118,7 @@ void saveRagFile_trailingPeriod() { byte[] bytes = "23243223423".getBytes(); MockMultipartFile mockMultipartFile = new MockMultipartFile(name, originalFilename, "text/plain", bytes); - String documentId = "TestID"; + String documentId = UUID.randomUUID().toString(); RagFileService ragFileService = createRagFileService(documentId, new Tracker<>()); Types.RagDocumentMetadata result = ragFileService.saveRagFile(mockMultipartFile, newDataSourceId(), "test-id"); @@ -134,7 +134,7 @@ void saveRagFile_removeDirectories() { byte[] bytes = "23243223423".getBytes(); MockMultipartFile mockMultipartFile = new MockMultipartFile(name, originalFilename, "text/plain", bytes); - String documentId = "TestID"; + String documentId = UUID.randomUUID().toString(); var dataSourceId = newDataSourceId(); String expectedS3Path = "prefix/" + dataSourceId + "/" + documentId; var requestTracker = new Tracker(); @@ -153,7 +153,7 @@ void saveRagFile_noFilename() { String name = "file"; byte[] bytes = "23243223423".getBytes(); MockMultipartFile mockMultipartFile = new MockMultipartFile(name, null, "text/plain", bytes); - String documentId = "TestID"; + String documentId = UUID.randomUUID().toString(); RagFileService ragFileService = createRagFileService(documentId, new Tracker<>()); assertThatThrownBy( () -> ragFileService.saveRagFile(mockMultipartFile, newDataSourceId(), "test-id")) @@ -166,7 +166,7 @@ void saveRagFile_noDataSource() { byte[] bytes = "23243223423".getBytes(); MockMultipartFile mockMultipartFile = new MockMultipartFile(name, "filename", "text/plain", bytes); - String documentId = "TestID"; + String documentId = UUID.randomUUID().toString(); RagFileService ragFileService = createRagFileService(documentId, new Tracker<>()); assertThatThrownBy(() -> ragFileService.saveRagFile(mockMultipartFile, -1L, "test-id")) .isInstanceOf(NotFound.class); diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileSummaryReconcilerTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileSummaryReconcilerTest.java index e0d1fad2..de3569a7 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileSummaryReconcilerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileSummaryReconcilerTest.java @@ -52,6 +52,7 @@ import com.cloudera.cai.util.reconcilers.ReconcilerConfig; import io.opentelemetry.api.OpenTelemetry; import java.time.Instant; +import java.util.List; import java.util.UUID; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.Test; @@ -61,6 +62,9 @@ class RagFileSummaryReconcilerTest { private final RagDataSourceRepository ragDataSourceRepository = RagDataSourceRepository.createNull(); + // todo: test for the time limit on how long we will retry document summarization (and also that + // updated the data source will re-trigger tries) + @Test void reconcile() { Tracker> requestTracker = new Tracker<>(); @@ -73,6 +77,7 @@ void reconcile() { null, "test_datasource", "test_embedding_model", + "summarizationModel", 1024, 20, null, @@ -94,7 +99,7 @@ void reconcile() { .createdById("test-id") .build(); Long id = ragFileRepository.saveDocumentMetadata(document); - assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp()) + assertThat(ragFileRepository.findDocumentByDocumentId(documentId).summaryCreationTimestamp()) .isNull(); reconciler.submit(document.withId(id)); @@ -129,6 +134,7 @@ void reconcile_notFound() { null, "test_datasource", "test_embedding_model", + "summarizationModel", 1024, 20, null, @@ -150,7 +156,7 @@ void reconcile_notFound() { .createdById("test-id") .build(); Long id = ragFileRepository.saveDocumentMetadata(document); - assertThat(ragFileRepository.findDocumentByDocumentId(documentId).vectorUploadTimestamp()) + assertThat(ragFileRepository.findDocumentByDocumentId(documentId).summaryCreationTimestamp()) .isNull(); reconciler.submit(document.withId(id)); @@ -169,6 +175,60 @@ void reconcile_notFound() { }); } + @Test + void reconcile_noSummarizationModel() { + Tracker> requestTracker = new Tracker<>(); + RagFileSummaryReconciler reconciler = createTestInstance(requestTracker); + + long dataSourceId = + ragDataSourceRepository.createRagDataSource( + new RagDataSource( + null, + "test_datasource", + "test_embedding_model", + null, + 1024, + 20, + null, + null, + "test-id", + "test-id", + Types.ConnectionType.API, + null, + null)); + String documentId = UUID.randomUUID().toString(); + RagDocument document = + RagDocument.builder() + .documentId(documentId) + .dataSourceId(dataSourceId) + .s3Path("path_in_s3_no_summarization_model") + .extension("pdf") + .filename("myfile.pdf") + .timeCreated(Instant.now()) + .timeUpdated(Instant.now()) + .createdById("test-id") + .build(); + ragFileRepository.saveDocumentMetadata(document); + assertThat(ragFileRepository.findDocumentByDocumentId(documentId).summaryCreationTimestamp()) + .isNull(); + + reconciler.resync(); + await().until(reconciler::isEmpty); + + RagDocument updatedDocument = ragFileRepository.findDocumentByDocumentId(documentId); + assertThat(updatedDocument.summaryCreationTimestamp()).isNull(); + List> values = requestTracker.getValues(); + var relevantSummarizationRequests = + values.stream() + .filter( + r -> { + var summaryRequest = (RagBackendClient.SummaryRequest) r.detail(); + return summaryRequest.s3DocumentKey().equals("path_in_s3_no_summarization_model"); + }) + .count(); + assertThat(relevantSummarizationRequests).isEqualTo(0); + } + private RagFileSummaryReconciler createTestInstance( Tracker> tracker, RuntimeException... exceptions) { Jdbi jdbi = new JdbiConfiguration().jdbi(); diff --git a/backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java b/backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java index 799a6114..22ff0dca 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java @@ -80,7 +80,7 @@ void update() throws JsonProcessingException { new MockCookie("_basusertoken", UserTokenCookieDecoderTest.encodeCookie("test-user"))); var sessionName = "test"; Types.Session input = TestData.createTestSessionInstance(sessionName); - Types.Session result = sessionController.create(input, request); + Types.Session insertedSession = sessionController.create(input, request); var updatedResponseChunks = 1; var updatedInferenceModel = "new-model-name"; @@ -93,7 +93,7 @@ void update() throws JsonProcessingException { var updatedSession = sessionController.update( - result + insertedSession .withInferenceModel(updatedInferenceModel) .withResponseChunks(updatedResponseChunks) .withName(updatedName), @@ -105,7 +105,7 @@ void update() throws JsonProcessingException { assertThat(updatedSession.responseChunks()).isEqualTo(updatedResponseChunks); assertThat(updatedSession.dataSourceIds()).containsExactlyInAnyOrder(1L, 2L, 3L); assertThat(updatedSession.timeCreated()).isNotNull(); - assertThat(updatedSession.timeUpdated()).isNotNull(); + assertThat(updatedSession.timeUpdated()).isAfter(insertedSession.timeUpdated()); assertThat(updatedSession.createdById()).isEqualTo("test-user"); assertThat(updatedSession.updatedById()).isEqualTo("update-test-user"); assertThat(updatedSession.lastInteractionTime()).isNull(); diff --git a/backend/src/test/java/com/cloudera/cai/util/IdGeneratorTest.java b/backend/src/test/java/com/cloudera/cai/util/IdGeneratorTest.java index 32dd556f..d8457487 100644 --- a/backend/src/test/java/com/cloudera/cai/util/IdGeneratorTest.java +++ b/backend/src/test/java/com/cloudera/cai/util/IdGeneratorTest.java @@ -58,7 +58,7 @@ void nullIdWithOneValue() { @Test void nullIdWithNoValues() { IdGenerator idGenerator = IdGenerator.createNull(); - assertThat(idGenerator.generateId()).isEqualTo("StubbedId"); - assertThat(idGenerator.generateId()).isEqualTo("StubbedId"); + assertThat(idGenerator.generateId()).startsWith("StubbedId"); + assertThat(idGenerator.generateId()).startsWith("StubbedId"); } } diff --git a/llm-service/app/ai/indexing/index.py b/llm-service/app/ai/indexing/index.py index bb991eb3..21be0100 100644 --- a/llm-service/app/ai/indexing/index.py +++ b/llm-service/app/ai/indexing/index.py @@ -91,7 +91,9 @@ def __init__( self.chunks_vector_store = chunks_vector_store def index_file(self, file_path: Path, document_id: str) -> None: - logger.debug(f"Indexing file: {file_path} with embedding model: {self.embedding_model.model_name}") + logger.debug( + f"Indexing file: {file_path} with embedding model: {self.embedding_model.model_name}" + ) file_extension = os.path.splitext(file_path)[1] reader_cls = READERS.get(file_extension) diff --git a/llm-service/app/ai/indexing/readers/pdf.py b/llm-service/app/ai/indexing/readers/pdf.py index 6850134f..9760fbd8 100644 --- a/llm-service/app/ai/indexing/readers/pdf.py +++ b/llm-service/app/ai/indexing/readers/pdf.py @@ -68,7 +68,8 @@ def assert_correctness(self) -> None: document_length = len(self.document_text) if self.page_start_index[-1] != document_length + 1: raise Exception( - f"Start of page after last {self.page_start_index[-1]} does not match document text length {document_length + 1}") + f"Start of page after last {self.page_start_index[-1]} does not match document text length {document_length + 1}" + ) def _find_page_number(self, start_index: int) -> str: last_good_page_number = "" @@ -110,14 +111,25 @@ def load_chunks(self, file_path: Path) -> list[TextNode]: return chunks def process_with_docling(self, file_path: Path) -> list[TextNode] | None: - docling_enabled = os.getenv("USE_ENHANCED_PDF_PROCESSING", "false").lower() == "true" + docling_enabled = ( + os.getenv("USE_ENHANCED_PDF_PROCESSING", "false").lower() == "true" + ) if not docling_enabled: return None directory = file_path.parent logger.debug(f"{directory=}") with open("docling-output.txt", "a") as f: process: CompletedProcess[bytes] = subprocess.run( - ["docling", "-v", "--abort-on-error", f"--output={directory}", str(file_path)], stdout=f, stderr=f) + [ + "docling", + "-v", + "--abort-on-error", + f"--output={directory}", + str(file_path), + ], + stdout=f, + stderr=f, + ) logger.debug(f"docling return code = {process.returncode}") # todo: figure out page numbers & look into the docling llama-index integration markdown_file_path = file_path.with_suffix(".md") diff --git a/llm-service/app/ai/vector_stores/qdrant.py b/llm-service/app/ai/vector_stores/qdrant.py index 977b5b2c..2efdee1e 100644 --- a/llm-service/app/ai/vector_stores/qdrant.py +++ b/llm-service/app/ai/vector_stores/qdrant.py @@ -38,9 +38,9 @@ import logging import os from typing import Optional, Any -import umap import qdrant_client +import umap from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.indices import VectorStoreIndex from llama_index.core.vector_stores.types import BasePydanticVectorStore @@ -64,24 +64,35 @@ def new_qdrant_client() -> qdrant_client.QdrantClient: class QdrantVectorStore(VectorStore): @staticmethod def for_chunks( - data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None + data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None ) -> "QdrantVectorStore": - return QdrantVectorStore(table_name=f"index_{data_source_id}", data_source_id=data_source_id, client=client) + return QdrantVectorStore( + table_name=f"index_{data_source_id}", + data_source_id=data_source_id, + client=client, + ) @staticmethod def for_summaries( - data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None + data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None ) -> "QdrantVectorStore": return QdrantVectorStore( - table_name=f"summary_index_{data_source_id}", data_source_id=data_source_id, client=client + table_name=f"summary_index_{data_source_id}", + data_source_id=data_source_id, + client=client, ) def __init__( - self, table_name: str, data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None + self, + table_name: str, + data_source_id: int, + client: Optional[qdrant_client.QdrantClient] = None, ): self.client = client or new_qdrant_client() self.table_name = table_name - self.data_source_metadata = data_sources_metadata_api.get_metadata(data_source_id) + self.data_source_metadata = data_sources_metadata_api.get_metadata( + data_source_id + ) def get_embedding_model(self) -> BaseEmbedding: return models.get_embedding_model(self.data_source_metadata.embedding_model) @@ -103,7 +114,7 @@ def delete_document(self, document_id: str) -> None: if self.exists(): index = VectorStoreIndex.from_vector_store( vector_store=self.llama_vector_store(), - embed_model=models.get_embedding_model(), + embed_model=models.get_noop_embedding_model(), ) index.delete_ref_doc(document_id) @@ -114,7 +125,9 @@ def llama_vector_store(self) -> BasePydanticVectorStore: vector_store = LlamaIndexQdrantVectorStore(self.table_name, self.client) return vector_store - def visualize(self, user_query: Optional[str] = None) -> list[tuple[tuple[float, float], str]]: + def visualize( + self, user_query: Optional[str] = None + ) -> list[tuple[tuple[float, float], str]]: records: list[Record] if not self.exists(): return [] @@ -125,7 +138,13 @@ def visualize(self, user_query: Optional[str] = None) -> list[tuple[tuple[float, if user_query: embedding_model = self.get_embedding_model() user_query_vector = embedding_model.get_query_embedding(user_query) - records.append(Record(vector=user_query_vector, id="abc123", payload={"file_name": "USER_QUERY"})) + records.append( + Record( + vector=user_query_vector, + id="abc123", + payload={"file_name": "USER_QUERY"}, + ) + ) record: Record filenames = [] @@ -139,8 +158,11 @@ def visualize(self, user_query: Optional[str] = None) -> list[tuple[tuple[float, try: reduced_embeddings = reducer.fit_transform(embeddings) # todo: figure out how to satisfy mypy on this line - return [(tuple(coordinate), filenames[i]) for i, coordinate in enumerate(reduced_embeddings.tolist())] # type: ignore + return [ + (tuple(coordinate), filenames[i]) # type: ignore + for i, coordinate in enumerate(reduced_embeddings.tolist()) + ] except Exception as e: # Log the error logger.error(f"Error during UMAP transformation: {e}") - return [] \ No newline at end of file + return [] diff --git a/llm-service/app/ai/vector_stores/vector_store.py b/llm-service/app/ai/vector_stores/vector_store.py index f9c54ab9..dfec8acf 100644 --- a/llm-service/app/ai/vector_stores/vector_store.py +++ b/llm-service/app/ai/vector_stores/vector_store.py @@ -69,7 +69,9 @@ def exists(self) -> bool: """Does the vector store exist?""" @abstractmethod - def visualize(self, user_query: Optional[str] = None) -> list[tuple[tuple[float,float], str]]: + def visualize( + self, user_query: Optional[str] = None + ) -> list[tuple[tuple[float, float], str]]: """get a 2-d visualization of the vectors in the store""" @abstractmethod diff --git a/llm-service/app/routers/index/data_source/__init__.py b/llm-service/app/routers/index/data_source/__init__.py index 9815e891..4f847ccb 100644 --- a/llm-service/app/routers/index/data_source/__init__.py +++ b/llm-service/app/routers/index/data_source/__init__.py @@ -99,23 +99,21 @@ def chunk_contents(self, chunk_id: str) -> ChunkContentsResponse: metadata=node.metadata, ) - @router.get("/visualize") @exceptions.propagates - def visualize(self) -> list[tuple[tuple[float,float], str]]: + def visualize(self) -> list[tuple[tuple[float, float], str]]: return self.chunks_vector_store.visualize() - class VisualizationRequest(BaseModel): user_query: str - @router.post("/visualize") @exceptions.propagates - def visualize_with_query(self, request: VisualizationRequest) -> list[tuple[tuple[float,float], str]]: + def visualize_with_query( + self, request: VisualizationRequest + ) -> list[tuple[tuple[float, float], str]]: return self.chunks_vector_store.visualize(request.user_query) - @router.delete( "/", summary="Deletes the data source from the index.", response_model=None ) diff --git a/llm-service/app/routers/index/models/__init__.py b/llm-service/app/routers/index/models/__init__.py index 7235c9c3..9a3487d5 100644 --- a/llm-service/app/routers/index/models/__init__.py +++ b/llm-service/app/routers/index/models/__init__.py @@ -35,11 +35,12 @@ # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # -from typing import Any, Dict, List, Literal +from typing import List, Literal from fastapi import APIRouter from .... import exceptions +from ....services.caii.types import ModelResponse from ....services.models import ( ModelSource, get_available_embedding_models, @@ -54,13 +55,13 @@ @router.get("/llm", summary="Get LLM Inference models.") @exceptions.propagates -def get_llm_models() -> List[Dict[str, Any]]: +def get_llm_models() -> List[ModelResponse]: return get_available_llm_models() @router.get("/embeddings", summary="Get LLM Embedding models.") @exceptions.propagates -def get_llm_embedding_models() -> List[Dict[str, Any]]: +def get_llm_embedding_models() -> List[ModelResponse]: return get_available_embedding_models() diff --git a/llm-service/app/services/CaiiEmbeddingModel.py b/llm-service/app/services/caii/CaiiEmbeddingModel.py similarity index 65% rename from llm-service/app/services/CaiiEmbeddingModel.py rename to llm-service/app/services/caii/CaiiEmbeddingModel.py index 259cb757..7ce04b8a 100644 --- a/llm-service/app/services/CaiiEmbeddingModel.py +++ b/llm-service/app/services/caii/CaiiEmbeddingModel.py @@ -38,16 +38,21 @@ import http.client as http_client import json import os -from typing import Any, Dict, List +from typing import Any from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding from pydantic import Field +from .types import Endpoint +from .utils import build_auth_headers + class CaiiEmbeddingModel(BaseEmbedding): - endpoint: Any = Field(Any, description="The endpoint to use for embeddings") + endpoint: Endpoint = Field( + Endpoint, description="The endpoint to use for embeddings" + ) - def __init__(self, endpoint: Dict[str, Any]): + def __init__(self, endpoint: Endpoint): super().__init__() self.endpoint = endpoint @@ -61,12 +66,7 @@ def _get_query_embedding(self, query: str) -> Embedding: return self._get_embedding(query, "query") def _get_embedding(self, query: str, input_type: str) -> Embedding: - model = self.endpoint["endpointmetadata"]["model_name"] - domain = os.environ["CAII_DOMAIN"] - - connection = http_client.HTTPSConnection(domain, 443) - headers = self.build_auth_headers() - headers["Content-Type"] = "application/json" + model = self.endpoint.endpointmetadata.model_name body = json.dumps( { "input": query, @@ -75,47 +75,47 @@ def _get_embedding(self, query: str, input_type: str) -> Embedding: "model": model, } ) - connection.request("POST", self.endpoint["url"], body=body, headers=headers) - res = connection.getresponse() - data = res.read() - json_response = data.decode("utf-8") - structured_response = json.loads(json_response) + structured_response = self.make_embedding_request(body) embedding = structured_response["data"][0]["embedding"] assert isinstance(embedding, list) assert all(isinstance(x, float) for x in embedding) return embedding - def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: - model = self.endpoint["endpointmetadata"]["model_name"] + def make_embedding_request(self, body: str) -> Any: domain = os.environ["CAII_DOMAIN"] connection = http_client.HTTPSConnection(domain, 443) - headers = self.build_auth_headers() + headers = build_auth_headers() headers["Content-Type"] = "application/json" - body = json.dumps( - { - "input": texts, - "input_type": "passage", - "truncate": "END", - "model": model, - } - ) - connection.request("POST", self.endpoint["url"], body=body, headers=headers) + connection.request("POST", self.endpoint.url, body=body, headers=headers) res = connection.getresponse() data = res.read() json_response = data.decode("utf-8") structured_response = json.loads(json_response) - embeddings = structured_response["data"][0]["embedding"] - assert isinstance(embeddings, list) - assert all(isinstance(x, list) for x in embeddings) - assert all(all(isinstance(y, float) for y in x) for x in embeddings) + return structured_response - return embeddings +## TODO: get this working. At the moment, the shape of the data in the response isn't what the code is expecting - def build_auth_headers(self) -> Dict[str, str]: - with open("/tmp/jwt", "r") as file: - jwt_contents = json.load(file) - access_token = jwt_contents["access_token"] - headers = {"Authorization": f"Bearer {access_token}"} - return headers + # def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: + # if len(texts) == 1: + # return [self._get_text_embedding(texts[0])] + # + # print(f"Getting embeddings for {len(texts)} texts") + # model = self.endpoint.endpointmetadata.model_name + # body = json.dumps( + # { + # "input": texts, + # "input_type": "passage", + # "truncate": "END", + # "model": model, + # } + # ) + # structured_response = self.make_embedding_request(body) + # embeddings = structured_response["data"][0]["embedding"] + # print(f"Got embeddings for {len(embeddings)} texts") + # assert isinstance(embeddings, list) + # assert all(isinstance(x, list) for x in embeddings) + # assert all(all(isinstance(y, float) for y in x) for x in embeddings) + # + # return embeddings diff --git a/llm-service/app/services/CaiiModel.py b/llm-service/app/services/caii/CaiiModel.py similarity index 100% rename from llm-service/app/services/CaiiModel.py rename to llm-service/app/services/caii/CaiiModel.py diff --git a/llm-service/app/services/caii/__init__.py b/llm-service/app/services/caii/__init__.py new file mode 100644 index 00000000..e2b4ac6c --- /dev/null +++ b/llm-service/app/services/caii/__init__.py @@ -0,0 +1,37 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2024 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# diff --git a/llm-service/app/services/caii.py b/llm-service/app/services/caii/caii.py similarity index 59% rename from llm-service/app/services/caii.py rename to llm-service/app/services/caii/caii.py index 01d11b67..2c6e7627 100644 --- a/llm-service/app/services/caii.py +++ b/llm-service/app/services/caii/caii.py @@ -35,9 +35,10 @@ # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # + import json import os -from typing import Any, Callable, Dict, List, Sequence +from typing import Callable, List, Sequence import requests from fastapi import HTTPException @@ -47,39 +48,54 @@ from .CaiiEmbeddingModel import CaiiEmbeddingModel from .CaiiModel import CaiiModel, CaiiModelMistral +from .types import Endpoint, ListEndpointEntry, ModelResponse +from .utils import build_auth_headers +DEFAULT_NAMESPACE = "serving-default" -def describe_endpoint(domain: str, endpoint_name: str) -> Any: - with open("/tmp/jwt", "r") as file: - jwt_contents = json.load(file) - access_token = jwt_contents["access_token"] - headers = {"Authorization": f"Bearer {access_token}"} +def describe_endpoint(endpoint_name: str) -> Endpoint: + domain = os.environ["CAII_DOMAIN"] + headers = build_auth_headers() describe_url = f"https://{domain}/api/v1alpha1/describeEndpoint" - desc_json = {"name": endpoint_name, "namespace": "serving-default"} + desc_json = {"name": endpoint_name, "namespace": DEFAULT_NAMESPACE} desc = requests.post(describe_url, headers=headers, json=desc_json) if desc.status_code == 404: raise HTTPException( status_code=404, detail=f"Endpoint '{endpoint_name}' not found" ) - return json.loads(desc.content) + json_content = json.loads(desc.content) + return Endpoint(**json_content) + + +def list_endpoints() -> list[ListEndpointEntry]: + domain = os.environ["CAII_DOMAIN"] + try: + headers = build_auth_headers() + describe_url = f"https://{domain}/api/v1alpha1/listEndpoints" + desc_json = {"namespace": DEFAULT_NAMESPACE} + + desc = requests.post(describe_url, headers=headers, json=desc_json) + endpoints = json.loads(desc.content)["endpoints"] + return [ListEndpointEntry(**endpoint) for endpoint in endpoints] + except requests.exceptions.ConnectionError: + raise HTTPException( + status_code=421, + detail=f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.", + ) def get_llm( - domain: str, endpoint_name: str, messages_to_prompt: Callable[[Sequence[ChatMessage]], str], completion_to_prompt: Callable[[str], str], ) -> LLM: - endpoint = describe_endpoint(domain=domain, endpoint_name=endpoint_name) - api_base = endpoint["url"].removesuffix("/chat/completions") - with open("/tmp/jwt", "r") as file: - jwt_contents = json.load(file) - access_token = jwt_contents["access_token"] - headers = {"Authorization": f"Bearer {access_token}"} - - model = endpoint["endpointmetadata"]["model_name"] + endpoint = describe_endpoint(endpoint_name=endpoint_name) + api_base = endpoint.url.removesuffix("/chat/completions") + headers = build_auth_headers() + + model = endpoint.endpointmetadata.model_name if "mistral" in endpoint_name.lower(): llm = CaiiModelMistral( model=model, @@ -103,62 +119,50 @@ def get_llm( return llm -def get_embedding_model(domain: str, model_name: str) -> BaseEmbedding: +def get_embedding_model(model_name: str) -> BaseEmbedding: endpoint_name = model_name - endpoint = describe_endpoint(domain=domain, endpoint_name=endpoint_name) + endpoint = describe_endpoint(endpoint_name=endpoint_name) return CaiiEmbeddingModel(endpoint=endpoint) -### metadata methods below here +# task types from the MLServing proto definition +# TASK_UNKNOWN = 0; +# INFERENCE = 1; +# TEXT_GENERATION = 2; +# EMBED = 3; +# TEXT_TO_TEXT_GENERATION = 4; +# CLASSIFICATION = 5; +# FILL_MASK = 6; +# RANK = 7; -def get_caii_llm_models() -> List[Dict[str, Any]]: - domain = os.environ["CAII_DOMAIN"] - endpoint_name = os.environ["CAII_INFERENCE_ENDPOINT_NAME"] - try: - models = describe_endpoint(domain=domain, endpoint_name=endpoint_name) - except requests.exceptions.ConnectionError as e: - print(e) - raise HTTPException( - status_code=421, - detail=f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.", - ) - except HTTPException as e: - if e.status_code == 404: - return [{"model_id": endpoint_name}] - else: - raise e - return build_model_response(models) +def get_caii_llm_models() -> List[ModelResponse]: + return get_models_with_task("TEXT_GENERATION") -def get_caii_embedding_models() -> List[Dict[str, Any]]: - # notes: - # NameResolutionError is we can't contact the CAII_DOMAIN +def get_caii_embedding_models() -> List[ModelResponse]: + return get_models_with_task("EMBED") - domain = os.environ["CAII_DOMAIN"] - endpoint_name = os.environ["CAII_EMBEDDING_ENDPOINT_NAME"] - try: - models = describe_endpoint(domain=domain, endpoint_name=endpoint_name) - except requests.exceptions.ConnectionError as e: - print(e) - raise HTTPException( - status_code=421, - detail=f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.", + +def get_models_with_task(task_type: str) -> List[ModelResponse]: + endpoints = list_endpoints() + endpoint_details = list( + map(lambda endpoint: describe_endpoint(endpoint.name), endpoints) + ) + llm_endpoints = list( + filter( + lambda endpoint: endpoint.task and endpoint.task == task_type, + endpoint_details, ) - except HTTPException as e: - if e.status_code == 404: - return [{"model_id": endpoint_name}] - else: - raise e - return build_model_response(models) - - -def build_model_response(models: Dict[str, Any]) -> List[Dict[str, Any]]: - return [ - { - "model_id": models["name"], - "name": models["name"], - "available": models["replica_count"] > 0, - "replica_count": models["replica_count"], - } - ] + ) + models = list(map(build_model_response, llm_endpoints)) + return models + + +def build_model_response(endpoint: Endpoint) -> ModelResponse: + return ModelResponse( + model_id=endpoint.name, + name=endpoint.name, + available=endpoint.replica_count > 0, + replica_count=endpoint.replica_count, + ) diff --git a/llm-service/app/services/caii/types.py b/llm-service/app/services/caii/types.py new file mode 100644 index 00000000..4bc46c30 --- /dev/null +++ b/llm-service/app/services/caii/types.py @@ -0,0 +1,122 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2024 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# +from dataclasses import dataclass +from typing import List, Dict, Any, Optional + +from pydantic import BaseModel, ConfigDict + + +# class EndpointCondition(BaseModel): +# status: str +# severity: str +# last_transition_time: str +# reason: str +# message: str + + +# class ReplicaMetadata(BaseModel): +# modelVersion: str +# replicaCount: int +# replicaNames: List[str] + + +# class RegistrySource(BaseModel): +# model_config = ConfigDict(protected_namespaces=()) +# model_id: Optional[str] +# version: Optional[int] + +# class EndpointStatus(BaseModel): +# failed_copies: int +# total_copies: int +# active_model_state: str +# target_model_state: str +# transition_status: str +# + + +class EndpointMetadata(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + # current_model: Optional[RegistrySource] + # previous_model: Optional[RegistrySource] + model_name: str + + +class Endpoint(BaseModel): + namespace: str + name: str + url: str + # conditions: List[EndpointCondition] + # status: EndpointStatus + observed_generation: int + replica_count: int + # replica_metadata: List[ReplicaMetadata] + created_by: str + description: str + created_at: str + resources: Dict[str, str] + # source: Dict[str, RegistrySource] + autoscaling: Dict[str, Any] + endpointmetadata: EndpointMetadata + traffic: Dict[str, str] + api_standard: str + has_chat_template: bool + metricFormat: str + task: str + instance_type: str + + +@dataclass +class ListEndpointEntry: + namespace: str + name: str + url: str + state: str + created_by: str + replica_count: int + replica_metadata: List[Any] + api_standard: str + has_chat_template: bool + metricFormat: str + + +@dataclass +class ModelResponse: + model_id: str + name: str + available: Optional[bool] = None + replica_count: Optional[int] = None diff --git a/llm-service/app/services/caii/utils.py b/llm-service/app/services/caii/utils.py new file mode 100644 index 00000000..b47de247 --- /dev/null +++ b/llm-service/app/services/caii/utils.py @@ -0,0 +1,47 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2024 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# +import json +from typing import Dict + + +def build_auth_headers() -> Dict[str, str]: + with open("/tmp/jwt", "r") as file: + jwt_contents = json.load(file) + access_token = jwt_contents["access_token"] + headers = {"Authorization": f"Bearer {access_token}"} + return headers diff --git a/llm-service/app/services/chat.py b/llm-service/app/services/chat.py index 2a299545..205c15a9 100644 --- a/llm-service/app/services/chat.py +++ b/llm-service/app/services/chat.py @@ -44,8 +44,6 @@ from llama_index.core.base.llms.types import MessageRole from llama_index.core.chat_engine.types import AgentChatResponse -from ..ai.vector_stores.qdrant import QdrantVectorStore -from ..rag_types import RagPredictConfiguration from . import evaluators, qdrant from .chat_store import ( Evaluation, @@ -54,6 +52,8 @@ RagStudioChatMessage, chat_store, ) +from ..ai.vector_stores.qdrant import QdrantVectorStore +from ..rag_types import RagPredictConfiguration def v2_chat( @@ -81,7 +81,9 @@ def v2_chat( configuration, retrieve_chat_history(session_id), ) - relevance, faithfulness = evaluators.evaluate_response(query, response) + relevance, faithfulness = evaluators.evaluate_response( + query, response, configuration.model_name + ) response_source_nodes = format_source_nodes(response) new_chat_message = RagStudioChatMessage( id=response_id, diff --git a/llm-service/app/services/data_sources_metadata_api.py b/llm-service/app/services/data_sources_metadata_api.py index 8d50f1f4..1d9c8987 100644 --- a/llm-service/app/services/data_sources_metadata_api.py +++ b/llm-service/app/services/data_sources_metadata_api.py @@ -42,6 +42,7 @@ import requests + @dataclass class RagDataSource: id: int @@ -54,6 +55,7 @@ class RagDataSource: created_by_id: str updated_by_id: str connection_type: str + summarization_model: Optional[str] = None document_count: Optional[int] = None total_doc_size: Optional[int] = None @@ -61,6 +63,7 @@ class RagDataSource: BACKEND_BASE_URL = os.getenv("API_URL", "http://localhost:8080") url_template = BACKEND_BASE_URL + "/api/v1/rag/dataSources/{}" + def get_metadata(data_source_id: int) -> RagDataSource: response = requests.get(url_template.format(data_source_id)) response.raise_for_status() @@ -69,6 +72,7 @@ def get_metadata(data_source_id: int) -> RagDataSource: id=data["id"], name=data["name"], embedding_model=data["embeddingModel"], + summarization_model=data.get("summarizationModel"), chunk_size=data["chunkSize"], chunk_overlap_percent=data["chunkOverlapPercent"], time_created=datetime.fromtimestamp(data["timeCreated"]), @@ -77,5 +81,5 @@ def get_metadata(data_source_id: int) -> RagDataSource: updated_by_id=data["updatedById"], connection_type=data["connectionType"], document_count=data.get("documentCount"), - total_doc_size=data.get("totalDocSize") + total_doc_size=data.get("totalDocSize"), ) diff --git a/llm-service/app/services/doc_summaries.py b/llm-service/app/services/doc_summaries.py index a56790cb..0bbc6e71 100644 --- a/llm-service/app/services/doc_summaries.py +++ b/llm-service/app/services/doc_summaries.py @@ -50,11 +50,12 @@ from llama_index.core.node_parser import SentenceSplitter from llama_index.core.readers import SimpleDirectoryReader -from ..ai.vector_stores.qdrant import QdrantVectorStore -from ..config import settings +from . import data_sources_metadata_api from . import models from .s3 import download from .utils import get_last_segment +from ..ai.vector_stores.qdrant import QdrantVectorStore +from ..config import settings SUMMARY_PROMPT = 'Summarize the document into a single sentence. If an adequate summary is not possible, please return "No summary available.".' @@ -73,7 +74,7 @@ def read_summary(data_source_id: int, document_id: str) -> str: raise HTTPException(status_code=404, detail="Knowledge base not found.") storage_context = make_storage_context(data_source_id) - doc_summary_index = load_document_summary_index(storage_context) + doc_summary_index = load_document_summary_index(storage_context, data_source_id) if document_id not in doc_summary_index.index_struct.doc_id_to_summary_id: return "No summary found for this document." @@ -82,9 +83,9 @@ def read_summary(data_source_id: int, document_id: str) -> str: def generate_summary( - data_source_id: int, - s3_bucket_name: str, - s3_document_key: str, + data_source_id: int, + s3_bucket_name: str, + s3_document_key: str, ) -> str: """Generate, persist, and return a summary for `s3_document_key`.""" with tempfile.TemporaryDirectory() as tmpdirname: @@ -101,7 +102,7 @@ def generate_summary( initialize_summary_index_storage(data_source_id) storage_context = make_storage_context(data_source_id) - doc_summary_index = load_document_summary_index(storage_context) + doc_summary_index = load_document_summary_index(storage_context, data_source_id, read_only_mode=False) for document in documents: doc_summary_index.insert(document) @@ -114,14 +115,19 @@ def generate_summary( ## todo: move to somewhere better; these are defaults to use when none are explicitly provided -def _set_settings_globals() -> None: - Settings.llm = models.get_llm() - Settings.embed_model = models.get_embedding_model() +def _set_settings_globals(data_source_id: int, read_only_mode: bool = True) -> None: + metadata = data_sources_metadata_api.get_metadata(data_source_id) + if read_only_mode: + Settings.llm = models.get_noop_llm_model() + Settings.embed_model = models.get_noop_embedding_model() + else: + Settings.llm = models.get_llm(metadata.summarization_model) + Settings.embed_model = models.get_embedding_model(metadata.embedding_model) Settings.text_splitter = SentenceSplitter(chunk_size=1024) def initialize_summary_index_storage(data_source_id: int) -> None: - _set_settings_globals() + _set_settings_globals(data_source_id) doc_summary_index = DocumentSummaryIndex.from_documents( [], summary_query=SUMMARY_PROMPT, @@ -129,10 +135,9 @@ def initialize_summary_index_storage(data_source_id: int) -> None: doc_summary_index.storage_context.persist(persist_dir=index_dir(data_source_id)) -def load_document_summary_index( - storage_context: StorageContext, -) -> DocumentSummaryIndex: - _set_settings_globals() +def load_document_summary_index(storage_context: StorageContext, data_source_id: int, + read_only_mode: bool = True) -> DocumentSummaryIndex: + _set_settings_globals(data_source_id, read_only_mode) doc_summary_index: DocumentSummaryIndex = cast( DocumentSummaryIndex, load_index_from_storage(storage_context, summary_query=SUMMARY_PROMPT), @@ -142,17 +147,21 @@ def load_document_summary_index( def summarize_data_source(data_source_id: int) -> str: """Return a summary of all documents in the data source.""" + metadata = data_sources_metadata_api.get_metadata(data_source_id) + if not metadata.summarization_model: + return "Summarization disabled. Please specify a summarization model in the knowledge base to enable." + index = index_dir(data_source_id) if not os.path.exists(index): return "" storage_context = make_storage_context(data_source_id) - doc_summary_index = load_document_summary_index(storage_context) + doc_summary_index = load_document_summary_index(storage_context, data_source_id) doc_ids = doc_summary_index.index_struct.doc_id_to_summary_id.keys() summaries = map(doc_summary_index.get_document_summary, doc_ids) prompt = 'I have summarized a list of documents that may or may not be related to each other. Please provide an overview of the document corpus as an executive summary. Do not start with "Here is...". The summary should be concise and not be frivolous' - response = Settings.llm.complete(prompt + "\n".join(summaries)) + response = models.get_llm(metadata.summarization_model).complete(prompt + "\n".join(summaries)) return response.text @@ -183,7 +192,7 @@ def delete_document(data_source_id: int, doc_id: str) -> None: if not os.path.exists(index): return storage_context = make_storage_context(data_source_id) - doc_summary_index = load_document_summary_index(storage_context) + doc_summary_index = load_document_summary_index(storage_context, data_source_id) if doc_id not in doc_summary_index.index_struct.doc_id_to_summary_id: return doc_summary_index.delete(doc_id) diff --git a/llm-service/app/services/evaluators.py b/llm-service/app/services/evaluators.py index a4282378..f91e820b 100644 --- a/llm-service/app/services/evaluators.py +++ b/llm-service/app/services/evaluators.py @@ -44,10 +44,10 @@ def evaluate_response( - query: str, - chat_response: AgentChatResponse, + query: str, chat_response: AgentChatResponse, model_name: str ) -> tuple[float, float]: - evaluator_llm = models.get_llm() + # todo: pass in the correct llm model and use it, rather than requiring querying for it like this. + evaluator_llm = models.get_llm(model_name) relevancy_evaluator = RelevancyEvaluator(llm=evaluator_llm) relevance = relevancy_evaluator.evaluate_response( diff --git a/llm-service/app/services/models.py b/llm-service/app/services/models.py index 5a925307..520ed008 100644 --- a/llm-service/app/services/models.py +++ b/llm-service/app/services/models.py @@ -37,7 +37,7 @@ # import os from enum import Enum -from typing import Any, Dict, List, Literal +from typing import List, Literal, Optional from fastapi import HTTPException from llama_index.core.base.embeddings.base import BaseEmbedding @@ -46,34 +46,43 @@ from llama_index.embeddings.bedrock import BedrockEmbedding from llama_index.llms.bedrock_converse import BedrockConverse -from .caii import get_caii_embedding_models, get_caii_llm_models -from .caii import get_embedding_model as caii_embedding -from .caii import get_llm as caii_llm +from .caii.caii import get_caii_embedding_models, get_caii_llm_models +from .caii.caii import get_embedding_model as caii_embedding +from .caii.caii import get_llm as caii_llm +from .caii.types import ModelResponse from .llama_utils import completion_to_prompt, messages_to_prompt +from .noop_models import DummyEmbeddingModel, DummyLlm DEFAULT_BEDROCK_LLM_MODEL = "meta.llama3-1-8b-instruct-v1:0" -def get_embedding_model(model_name: str = "cohere.embed-english-v3") -> BaseEmbedding: - if is_caii_enabled(): - return caii_embedding( - domain=os.environ["CAII_DOMAIN"], - model_name=os.environ["CAII_EMBEDDING_ENDPOINT_NAME"]) +def get_noop_embedding_model() -> BaseEmbedding: + return DummyEmbeddingModel() + + +def get_noop_llm_model() -> LLM: + return DummyLlm() + + +def get_embedding_model(model_name: str) -> BaseEmbedding: if model_name is None: - model_name = "cohere.embed-english-v3" + model_name = get_available_embedding_models()[0].model_id + + if is_caii_enabled(): + return caii_embedding(model_name=model_name) + return BedrockEmbedding(model_name=model_name) -def get_llm(model_name: str = DEFAULT_BEDROCK_LLM_MODEL) -> LLM: +def get_llm(model_name: Optional[str]) -> LLM: + if not model_name: + model_name = get_available_llm_models()[0].model_id if is_caii_enabled(): return caii_llm( - domain=os.environ["CAII_DOMAIN"], - endpoint_name=os.environ["CAII_INFERENCE_ENDPOINT_NAME"], + endpoint_name=model_name, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, ) - if not model_name: - model_name = DEFAULT_BEDROCK_LLM_MODEL return BedrockConverse( model=model_name, @@ -82,13 +91,13 @@ def get_llm(model_name: str = DEFAULT_BEDROCK_LLM_MODEL) -> LLM: ) -def get_available_embedding_models() -> List[Dict[str, Any]]: +def get_available_embedding_models() -> List[ModelResponse]: if is_caii_enabled(): return get_caii_embedding_models() return _get_bedrock_embedding_models() -def get_available_llm_models() -> List[Dict[str, Any]]: +def get_available_llm_models() -> list[ModelResponse]: if is_caii_enabled(): return get_caii_llm_models() return _get_bedrock_llm_models() @@ -99,33 +108,28 @@ def is_caii_enabled() -> bool: return len(domain) > 0 -def _get_bedrock_llm_models() -> List[Dict[str, Any]]: +def _get_bedrock_llm_models() -> List[ModelResponse]: return [ - { - "model_id": DEFAULT_BEDROCK_LLM_MODEL, - "name": "Llama3.1 8B Instruct v1", - }, - { - "model_id": "meta.llama3-1-70b-instruct-v1:0", - "name": "Llama3.1 70B Instruct v1", - }, - { - "model_id": "cohere.command-r-plus-v1:0", - "name": "Cohere Command R Plus v1", - } + ModelResponse( + model_id=DEFAULT_BEDROCK_LLM_MODEL, name="Llama3.1 8B Instruct v1" + ), + ModelResponse( + model_id="meta.llama3-1-70b-instruct-v1:0", name="Llama3.1 70B Instruct v1" + ), + ModelResponse( + model_id="cohere.command-r-plus-v1:0", name="Cohere Command R Plus v1" + ), ] -def _get_bedrock_embedding_models() -> List[Dict[str, Any]]: +def _get_bedrock_embedding_models() -> List[ModelResponse]: return [ - { - "model_id": "cohere.embed-english-v3", - "name": "Cohere Embed English v3", - }, - { - "model_id": "cohere.embed-multilingual-v3", - "name": "Cohere Embed Multilingual v3", - }, + ModelResponse( + model_id="cohere.embed-english-v3", name="Cohere Embed English v3" + ), + ModelResponse( + model_id="cohere.embed-multilingual-v3", name="Cohere Embed Multilingual v3" + ), ] @@ -143,10 +147,16 @@ def get_model_source() -> ModelSource: def test_llm_model(model_name: str) -> Literal["ok"]: models = get_available_llm_models() for model in models: - if model["model_id"] == model_name: - if not is_caii_enabled() or model["available"]: + if model.model_id == model_name: + if not is_caii_enabled() or model.available: get_llm(model_name).chat( - messages=[ChatMessage(role=MessageRole.USER, content="Are you available to answer questions?")]) + messages=[ + ChatMessage( + role=MessageRole.USER, + content="Are you available to answer questions?", + ) + ] + ) return "ok" else: raise HTTPException(status_code=503, detail="Model not ready") @@ -157,8 +167,8 @@ def test_llm_model(model_name: str) -> Literal["ok"]: def test_embedding_model(model_name: str) -> str: models = get_available_embedding_models() for model in models: - if model["model_id"] == model_name: - if not is_caii_enabled() or model["available"]: + if model.model_id == model_name: + if not is_caii_enabled() or model.available: get_embedding_model(model_name).get_text_embedding("test") return "ok" else: diff --git a/llm-service/app/services/noop_models.py b/llm-service/app/services/noop_models.py new file mode 100644 index 00000000..c8c2f4d7 --- /dev/null +++ b/llm-service/app/services/noop_models.py @@ -0,0 +1,111 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2024 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# +from typing import Sequence, Any + +from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding +from llama_index.core.base.llms.types import LLMMetadata, ChatMessage, ChatResponse, CompletionResponse, \ + ChatResponseGen, CompletionResponseGen, ChatResponseAsyncGen, CompletionResponseAsyncGen +from llama_index.core.llms import LLM +from pydantic import Field + + +class DummyEmbeddingModel(BaseEmbedding): + def _get_query_embedding(self, query: str) -> Embedding: + return [] + + async def _aget_query_embedding(self, query: str) -> Embedding: + return [] + + def _get_text_embedding(self, text: str) -> Embedding: + return [] + + +class DummyLlm(LLM): + completion_response: str = Field("this is a completion response") + chat_response: str = Field("this is a chat response") + + def __init__( + self, + completion_response: str = "this is a completion response", + chat_response: str = "hello", + ): + super().__init__() + self.completion_response = completion_response + self.chat_response = chat_response + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata() + + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + return ChatResponse(message=ChatMessage.from_str(self.chat_response)) + + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + return CompletionResponse(text=self.completion_response) + + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + raise NotImplementedError("Not implemented") + + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + raise NotImplementedError("Not implemented") + + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + raise NotImplementedError("Not implemented") + + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + raise NotImplementedError("Not implemented") + + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + raise NotImplementedError("Not implemented") + + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + raise NotImplementedError("Not implemented") diff --git a/llm-service/app/tests/ai/indexing/test_csv.py b/llm-service/app/tests/ai/indexing/test_csv.py index 18bb757e..d55f259a 100644 --- a/llm-service/app/tests/ai/indexing/test_csv.py +++ b/llm-service/app/tests/ai/indexing/test_csv.py @@ -27,7 +27,7 @@ def test_csv_indexing() -> None: chunk_size=100, chunk_overlap=0, ), - embedding_model=models.get_embedding_model(), + embedding_model=models.get_embedding_model("dummy_model"), chunks_vector_store=vector_store, ) indexer.index_file(Path(temp_file.name), document_id) diff --git a/llm-service/app/tests/ai/indexing/test_pdf_page_tracker.py b/llm-service/app/tests/ai/indexing/test_pdf_page_tracker.py index ebe29e04..df7dc528 100644 --- a/llm-service/app/tests/ai/indexing/test_pdf_page_tracker.py +++ b/llm-service/app/tests/ai/indexing/test_pdf_page_tracker.py @@ -4,6 +4,7 @@ from app.ai.indexing.readers.pdf import PageTracker + class TestPageTracker: @staticmethod def test_initializes_correctly() -> None: @@ -44,7 +45,12 @@ def test_populates_chunk_page_numbers() -> None: pages[0].metadata["page_label"] = "1" pages[1].metadata["page_label"] = "2" page_counter = PageTracker(pages) - chunks = [TextNode(start_char_idx=0), TextNode(start_char_idx=4), TextNode(start_char_idx=7), TextNode(start_char_idx=10)] + chunks = [ + TextNode(start_char_idx=0), + TextNode(start_char_idx=4), + TextNode(start_char_idx=7), + TextNode(start_char_idx=10), + ] page_counter.populate_chunk_page_numbers(chunks) assert chunks[0].metadata["page_number"] == "1" assert chunks[1].metadata["page_number"] == "1" @@ -59,4 +65,4 @@ def test_populates_chunk_page_numbers_chunk_spans_2_pages() -> None: page_counter = PageTracker(pages) chunks = [TextNode(start_char_idx=0)] page_counter.populate_chunk_page_numbers(chunks) - assert chunks[0].metadata["page_number"] == "1" \ No newline at end of file + assert chunks[0].metadata["page_number"] == "1" diff --git a/llm-service/app/tests/conftest.py b/llm-service/app/tests/conftest.py index c7e1d85a..e0057ac7 100644 --- a/llm-service/app/tests/conftest.py +++ b/llm-service/app/tests/conftest.py @@ -42,7 +42,7 @@ from collections.abc import Iterator from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Sequence +from typing import Any, Dict import boto3 import lipsum @@ -51,24 +51,14 @@ from boto3.resources.base import ServiceResource from fastapi.testclient import TestClient from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding -from llama_index.core.base.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, -) from llama_index.core.llms import LLM from moto import mock_aws -from pydantic import Field from app.ai.vector_stores.qdrant import QdrantVectorStore from app.main import app from app.services import models, data_sources_metadata_api from app.services.data_sources_metadata_api import RagDataSource +from app.services.noop_models import DummyLlm from app.services.utils import get_last_segment @@ -131,62 +121,6 @@ def index_document_request_body( } -class DummyLlm(LLM): - completion_response: str = Field("this is a completion response") - chat_response: str = Field("this is a chat response") - - def __init__( - self, - completion_response: str = "this is a completion response", - chat_response: str = "hello", - ): - super().__init__() - self.completion_response = completion_response - self.chat_response = chat_response - - @property - def metadata(self) -> LLMMetadata: - return LLMMetadata() - - def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - return ChatResponse(message=ChatMessage.from_str(self.chat_response)) - - def complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponse: - return CompletionResponse(text=self.completion_response) - - def stream_chat( - self, messages: Sequence[ChatMessage], **kwargs: Any - ) -> ChatResponseGen: - raise NotImplementedError("Not implemented") - - def stream_complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponseGen: - raise NotImplementedError("Not implemented") - - async def achat( - self, messages: Sequence[ChatMessage], **kwargs: Any - ) -> ChatResponse: - raise NotImplementedError("Not implemented") - - async def acomplete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponse: - raise NotImplementedError("Not implemented") - - async def astream_chat( - self, messages: Sequence[ChatMessage], **kwargs: Any - ) -> ChatResponseAsyncGen: - raise NotImplementedError("Not implemented") - - async def astream_complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponseAsyncGen: - raise NotImplementedError("Not implemented") - - class DummyEmbeddingModel(BaseEmbedding): def _get_query_embedding(self, query: str) -> Embedding: return [0.1] * 1024 @@ -221,6 +155,7 @@ def summary_vector_store( lambda ds_id: original(ds_id, qdrant_client), ) + @pytest.fixture(autouse=True) def datasource_metadata(monkeypatch: pytest.MonkeyPatch) -> None: def get_datasource_metadata(data_source_id: int) -> RagDataSource: @@ -228,6 +163,7 @@ def get_datasource_metadata(data_source_id: int) -> RagDataSource: id=data_source_id, name="test", embedding_model="test", + summarization_model="test", chunk_size=512, chunk_overlap_percent=10, time_created=datetime.now(), @@ -239,7 +175,9 @@ def get_datasource_metadata(data_source_id: int) -> RagDataSource: total_doc_size=1, ) - monkeypatch.setattr(data_sources_metadata_api, "get_metadata", get_datasource_metadata) + monkeypatch.setattr( + data_sources_metadata_api, "get_metadata", get_datasource_metadata + ) @pytest.fixture(autouse=True) @@ -252,6 +190,7 @@ def get_embedding_model(model_name: str = "dummy_value") -> BaseEmbedding: # Requires that the app usages import the file and not the function directly as python creates a copy when importing the function monkeypatch.setattr(models, "get_embedding_model", get_embedding_model) + monkeypatch.setattr(models, "get_noop_embedding_model", get_embedding_model) return model @@ -259,8 +198,11 @@ def get_embedding_model(model_name: str = "dummy_value") -> BaseEmbedding: def llm(monkeypatch: pytest.MonkeyPatch) -> LLM: model = DummyLlm() + def get_llm(model_name: str = "dummy_value") -> LLM: + return model + # Requires that the app usages import the file and not the function directly as python creates a copy when importing the function - monkeypatch.setattr(models, "get_llm", lambda : model) + monkeypatch.setattr(models, "get_llm", get_llm) return model diff --git a/llm-service/app/tests/routers/index/test_data_source.py b/llm-service/app/tests/routers/index/test_data_source.py index 0970aa3f..6aad0293 100644 --- a/llm-service/app/tests/routers/index/test_data_source.py +++ b/llm-service/app/tests/routers/index/test_data_source.py @@ -50,7 +50,7 @@ def get_vector_store_index(data_source_id: int) -> VectorStoreIndex: vector_store = QdrantVectorStore.for_chunks(data_source_id).llama_vector_store() index = VectorStoreIndex.from_vector_store( - vector_store, embed_model=models.get_embedding_model() + vector_store, embed_model=models.get_embedding_model("dummy_model") ) return index diff --git a/llm-service/scripts/get_job_run_status.py b/llm-service/scripts/get_job_run_status.py index 81edb8c2..d3c3fe7a 100644 --- a/llm-service/scripts/get_job_run_status.py +++ b/llm-service/scripts/get_job_run_status.py @@ -41,11 +41,10 @@ client = cmlapi.default_client() -project_id =os.environ['CDSW_PROJECT_ID'] +project_id = os.environ["CDSW_PROJECT_ID"] # ## todo: investigate if we can filter using wildcards or regex on the job name -jobs = client.list_jobs(project_id, search_filter="{\"name\": \"Update/build RAG Studio\"}") +jobs = client.list_jobs(project_id, search_filter='{"name": "Update/build RAG Studio"}') job_id = jobs.jobs[0].id job_runs = client.list_job_runs(project_id, job_id, sort="-created_at").job_runs[0] print(job_runs.status) - diff --git a/llm-service/scripts/run_refresh_job.py b/llm-service/scripts/run_refresh_job.py index f7bce5ba..e5a011a5 100644 --- a/llm-service/scripts/run_refresh_job.py +++ b/llm-service/scripts/run_refresh_job.py @@ -40,10 +40,10 @@ import os client = cmlapi.default_client() -project_id =os.environ['CDSW_PROJECT_ID'] +project_id = os.environ["CDSW_PROJECT_ID"] ## todo: investigate if we can filter using wildcards or regex on the job name -jobs = client.list_jobs(project_id, search_filter="{\"name\": \"Update/build RAG Studio\"}") +jobs = client.list_jobs(project_id, search_filter='{"name": "Update/build RAG Studio"}') print(jobs) job_id = jobs.jobs[0].id print(job_id) -client.create_job_run({}, project_id, job_id) \ No newline at end of file +client.create_job_run({}, project_id, job_id) diff --git a/scripts/01_install_base.py b/scripts/01_install_base.py index 6d16f37a..276292fb 100644 --- a/scripts/01_install_base.py +++ b/scripts/01_install_base.py @@ -38,11 +38,21 @@ import subprocess -print(subprocess.run(["bash /home/cdsw/scripts/install_java.sh"], shell=True, check=True)) +print( + subprocess.run(["bash /home/cdsw/scripts/install_java.sh"], shell=True, check=True) +) print("Installing Java 21 is complete") -print(subprocess.run(["bash /home/cdsw/scripts/install_qdrant.sh"], shell=True, check=True)) +print( + subprocess.run( + ["bash /home/cdsw/scripts/install_qdrant.sh"], shell=True, check=True + ) +) print("Installing Qdrant is complete") -print(subprocess.run(["bash /home/cdsw/scripts/install_easyocr_model.sh"], shell=True, check=True)) -print("Downloading EASYOCR models complete") \ No newline at end of file +print( + subprocess.run( + ["bash /home/cdsw/scripts/install_easyocr_model.sh"], shell=True, check=True + ) +) +print("Downloading EASYOCR models complete") diff --git a/scripts/refresh_project.py b/scripts/refresh_project.py index 8107257c..548dfae1 100644 --- a/scripts/refresh_project.py +++ b/scripts/refresh_project.py @@ -46,10 +46,11 @@ print(subprocess.run(["bash", "/home/cdsw/scripts/refresh_project.sh"], check=True)) print( - "Project refresh complete. Restarting the RagStudio Application to pick up changes, if this isn't the initial deployment.") + "Project refresh complete. Restarting the RagStudio Application to pick up changes, if this isn't the initial deployment." +) client = cmlapi.default_client() -project_id = os.environ['CDSW_PROJECT_ID'] +project_id = os.environ["CDSW_PROJECT_ID"] apps = client.list_applications(project_id=project_id) if len(apps.applications) > 0: # todo: handle case where there are multiple apps diff --git a/scripts/release_version.txt b/scripts/release_version.txt index a4c512d1..32150ea1 100644 --- a/scripts/release_version.txt +++ b/scripts/release_version.txt @@ -1 +1 @@ -export RELEASE_TAG=1.4.0-beta +export RELEASE_TAG=dev-testing diff --git a/scripts/startup_app.py b/scripts/startup_app.py index ff9ff7e0..cc1b2d46 100644 --- a/scripts/startup_app.py +++ b/scripts/startup_app.py @@ -40,4 +40,4 @@ while True: print(subprocess.run(["bash /home/cdsw/scripts/startup_app.sh"], shell=True)) - print("Application Restarting") \ No newline at end of file + print("Application Restarting") diff --git a/ui/src/api/dataSourceApi.ts b/ui/src/api/dataSourceApi.ts index ce5c2c32..22fc2638 100644 --- a/ui/src/api/dataSourceApi.ts +++ b/ui/src/api/dataSourceApi.ts @@ -70,6 +70,7 @@ export interface DataSourceBaseType { chunkOverlapPercent: number; connectionType: ConnectionType; embeddingModel: string; + summarizationModel?: string; } export type DataSourceType = DataSourceBaseType & { diff --git a/ui/src/pages/DataSources/DataSourcesManagement/DataSourcesForm.tsx b/ui/src/pages/DataSources/DataSourcesManagement/DataSourcesForm.tsx index 60013b9a..d7fdf4cc 100644 --- a/ui/src/pages/DataSources/DataSourcesManagement/DataSourcesForm.tsx +++ b/ui/src/pages/DataSources/DataSourcesManagement/DataSourcesForm.tsx @@ -36,12 +36,22 @@ * DATA. ******************************************************************************/ -import { Collapse, Form, FormInstance, Input, InputNumber, Select } from "antd"; +import { + Collapse, + Form, + FormInstance, + Input, + InputNumber, + Select, + Tooltip, + Typography, +} from "antd"; import { ConnectionType, DataSourceBaseType } from "src/api/dataSourceApi"; import RequestConfigureOptions from "pages/DataSources/DataSourcesManagement/RequestConfigureOptions.tsx"; -import { useGetEmbeddingModels } from "src/api/modelsApi.ts"; +import { useGetEmbeddingModels, useGetLlmModels } from "src/api/modelsApi.ts"; import { useEffect } from "react"; import { transformModelOptions } from "src/utils/modelUtils.ts"; +import { InfoCircleOutlined } from "@ant-design/icons"; export const distanceMetricOptions = [ { @@ -112,6 +122,7 @@ export const dataSourceCreationInitialValues = { connectionType: ConnectionType.MANUAL, chunkOverlapPercent: 10, embeddingModel: "", + summarizationModel: "", }; export interface DataSourcesFormProps { @@ -131,6 +142,7 @@ const DataSourcesForm = ({ initialValues = dataSourceCreationInitialValues, }: DataSourcesFormProps) => { const embeddingsModels = useGetEmbeddingModels(); + const llmModels = useGetLlmModels(); useEffect(() => { if (initialValues.embeddingModel) { @@ -187,6 +199,7 @@ const DataSourcesForm = ({ + + Summarization model + + + + + } + initialValue={initialValues.summarizationModel} + > +