Skip to content

Commit

Permalink
solr resource loader + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandrobenedetti committed Nov 19, 2024
1 parent ec90258 commit 4cbafdc
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Objects;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.llm.store.EmbeddingModelException;
import org.apache.solr.llm.store.rest.ManagedEmbeddingModelStore;

Expand All @@ -45,7 +46,11 @@ public class SolrEmbeddingModel implements Accountable {
private final Integer hashCode;

public static SolrEmbeddingModel getInstance(
String className, String name, Map<String, Object> params) throws EmbeddingModelException {
SolrResourceLoader solrResourceLoader,
String className,
String name,
Map<String, Object> params)
throws EmbeddingModelException {
try {
/*
* The idea here is to build a {@link dev.langchain4j.model.embedding.EmbeddingModel} using inversion
Expand All @@ -54,7 +59,7 @@ public static SolrEmbeddingModel getInstance(
* has its own builder that uses setters with the same name of the parameter in input.
* */
EmbeddingModel textToVector;
Class<?> modelClass = Class.forName(className);
Class<?> modelClass = solrResourceLoader.findClass(className, EmbeddingModel.class);
var builder = modelClass.getMethod("builder").invoke(null);
if (params != null) {
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ private static List<Object> modelsAsManagedResources(List<SolrEmbeddingModel> mo
}

@SuppressWarnings("unchecked")
public static SolrEmbeddingModel fromEmbeddingModelMap(Map<String, Object> embeddingModel) {
public static SolrEmbeddingModel fromEmbeddingModelMap(
SolrResourceLoader solrResourceLoader, Map<String, Object> embeddingModel) {
return SolrEmbeddingModel.getInstance(
solrResourceLoader,
(String) embeddingModel.get(CLASS_KEY), // modelClassName
(String) embeddingModel.get(NAME_KEY), // modelName
(Map<String, Object>) embeddingModel.get(PARAMS_KEY));
Expand Down Expand Up @@ -136,7 +138,7 @@ public void loadStoredModels() {

private void addModelFromMap(Map<String, Object> modelMap) {
try {
addModel(fromEmbeddingModelMap(modelMap));
addModel(fromEmbeddingModelMap(solrResourceLoader, modelMap));
} catch (final EmbeddingModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
Expand Down
77 changes: 2 additions & 75 deletions solr/modules/llm/src/test/org/apache/solr/llm/TestLlmBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,8 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.file.PathUtils;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.util.Utils;
import org.apache.solr.core.SolrCore;
import org.apache.solr.llm.embedding.SolrEmbeddingModel;
import org.apache.solr.llm.store.EmbeddingModelException;
import org.apache.solr.llm.store.rest.ManagedEmbeddingModelStore;
import org.apache.solr.util.RestTestBase;
import org.slf4j.Logger;
Expand All @@ -55,7 +50,7 @@ public class TestLlmBase extends RestTestBase {
protected static String vectorField2 = "vector2";
protected static String vectorFieldByteEncoding = "vector_byte_encoding";

public static void setupTest(
protected static void setupTest(
String solrconfig, String schema, boolean buildIndex, boolean persistModelStore)
throws Exception {
initFolders(persistModelStore);
Expand All @@ -64,12 +59,6 @@ public static void setupTest(
if (buildIndex) prepareIndex();
}

public static ManagedEmbeddingModelStore getManagedModelStore() {
try (SolrCore core = solrClientTestRule.getCoreContainer().getCore(DEFAULT_TEST_CORENAME)) {
return ManagedEmbeddingModelStore.getManagedModelStore(core);
}
}

protected static void initFolders(boolean isPersistent) throws Exception {
tmpSolrHome = createTempDir();
tmpConfDir = tmpSolrHome.resolve(CONF_DIR);
Expand Down Expand Up @@ -105,10 +94,6 @@ protected static void afterTest() throws Exception {
System.clearProperty("managed.schema.mutable");
}

public static void makeRestTestHarnessNull() {
restTestHarness = null;
}

/** produces a model encoded in json * */
public static String getModelInJson(String name, String className, String params) {
final StringBuilder sb = new StringBuilder();
Expand All @@ -129,41 +114,14 @@ protected static void loadModel(String name, String className, String params) th
assertJPut(ManagedEmbeddingModelStore.REST_END_POINT, model, "/responseHeader/status==0");
}

public static void loadModels(String fileName) throws Exception {
public static void loadModel(String fileName) throws Exception {
final URL url = TestLlmBase.class.getResource("/modelExamples/" + fileName);
final String multipleModels = Files.readString(Path.of(url.toURI()), StandardCharsets.UTF_8);

assertJPut(
ManagedEmbeddingModelStore.REST_END_POINT, multipleModels, "/responseHeader/status==0");
}

public static SolrEmbeddingModel createModelFromFiles(
String modelFileName, String featureFileName) throws EmbeddingModelException, Exception {
return createModelFromFiles(modelFileName, featureFileName);
}

public static SolrEmbeddingModel createModelFromFiles(String modelFileName) throws Exception {
URL url = TestLlmBase.class.getResource("/modelExamples/" + modelFileName);
final String modelJson = Files.readString(Path.of(url.toURI()), StandardCharsets.UTF_8);
final ManagedEmbeddingModelStore ms = getManagedModelStore();

final SolrEmbeddingModel model =
ManagedEmbeddingModelStore.fromEmbeddingModelMap(mapFromJson(modelJson));
ms.addModel(model);
return model;
}

@SuppressWarnings("unchecked")
private static Map<String, Object> mapFromJson(String json) throws EmbeddingModelException {
Object parsedJson = null;
try {
parsedJson = Utils.fromJSONString(json);
} catch (final Exception ioExc) {
throw new EmbeddingModelException("ObjectBuilder failed parsing json", ioExc);
}
return (Map<String, Object>) parsedJson;
}

protected static void prepareIndex() throws Exception {
List<SolrInputDocument> docsToIndex = prepareDocs();
for (SolrInputDocument doc : docsToIndex) {
Expand Down Expand Up @@ -230,35 +188,4 @@ private static List<SolrInputDocument> prepareDocs() {

return docs;
}

protected static void indexWithEmbeddingGeneration() throws Exception {
List<SolrInputDocument> docsToIndex = prepareTextualDocs();
for (SolrInputDocument doc : docsToIndex) {
assertU(adoc(doc));
}

assertU(commit());
}

private static List<SolrInputDocument> prepareTextualDocs() {
int docsCount = 5;
List<SolrInputDocument> docs = new ArrayList<>(docsCount);
for (int i = 1; i < docsCount + 1; i++) {
SolrInputDocument doc = new SolrInputDocument();
doc.addField(IDField, i);
docs.add(doc);
}

docs.get(0)
.addField(
stringField, "Vegeta is the prince of all saiyans"); // cosine distance vector1= 1.0
docs.get(1)
.addField(
stringField, "Goku is a saiyan raised on earth"); // cosine distance vector1= 0.998
docs.get(2).addField(stringField, "Gohan is a saiyaman, son of Goku");
docs.get(3).addField(stringField, "Goten is a saiyaman, second son son of Goku");
docs.get(4).addField(stringField, "Trunks is a saiyaman, second son son of Vegeta");

return docs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class TextToVectorQParserTest extends TestLlmBase {
@BeforeClass
public static void init() throws Exception {
setupTest("solrconfig-llm.xml", "schema.xml", true, false);
loadModels("dummy-model.json");
loadModel("dummy-model.json");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public void testRestManagerEndpoints() throws Exception {

@Test
public void loadModel_cohere_shouldLoadModelConfig() throws Exception {
loadModels("cohere-model.json");
loadModel("cohere-model.json");

final String modelName = "cohere-1";
assertJQ(ManagedEmbeddingModelStore.REST_END_POINT, "/models/[0]/name=='" + modelName + "'");
Expand All @@ -146,7 +146,7 @@ public void loadModel_cohere_shouldLoadModelConfig() throws Exception {

@Test
public void loadModel_openAi_shouldLoadModelConfig() throws Exception {
loadModels("openai-model.json");
loadModel("openai-model.json");

final String modelName = "openai-1";
assertJQ(ManagedEmbeddingModelStore.REST_END_POINT, "/models/[0]/name=='" + modelName + "'");
Expand All @@ -168,7 +168,7 @@ public void loadModel_openAi_shouldLoadModelConfig() throws Exception {

@Test
public void loadModel_mistralAi_shouldLoadModelConfig() throws Exception {
loadModels("mistralai-model.json");
loadModel("mistralai-model.json");

final String modelName = "mistralai-1";
assertJQ(ManagedEmbeddingModelStore.REST_END_POINT, "/models/[0]/name=='" + modelName + "'");
Expand All @@ -189,7 +189,7 @@ public void loadModel_mistralAi_shouldLoadModelConfig() throws Exception {

@Test
public void loadModel_huggingface_shouldLoadModelConfig() throws Exception {
loadModels("huggingface-model.json");
loadModel("huggingface-model.json");

final String modelName = "huggingface-1";
assertJQ(ManagedEmbeddingModelStore.REST_END_POINT, "/models/[0]/name=='" + modelName + "'");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void testModelStorePersistence() throws Exception {
assertJQ(ManagedEmbeddingModelStore.REST_END_POINT, "/models/==[]");

// load models and features from files
loadModels("cohere-model.json");
loadModel("cohere-model.json");

final String modelName = "cohere-1";
assertJQ(ManagedEmbeddingModelStore.REST_END_POINT, "/models/[0]/name=='" + modelName + "'");
Expand Down

0 comments on commit 4cbafdc

Please sign in to comment.