From 0b936a2bb82b108b6e995e8efac9f40bf2642b4d Mon Sep 17 00:00:00 2001 From: cs0lar Date: Sun, 16 Apr 2023 10:48:43 +0100 Subject: [PATCH] fixes index name to classname conversion --- autogpt/memory/weaviate.py | 11 ++++++++++- tests/integration/weaviate_memory_tests.py | 19 +++++++------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/autogpt/memory/weaviate.py b/autogpt/memory/weaviate.py index 6fcce0a0216d..35e7844a2a24 100644 --- a/autogpt/memory/weaviate.py +++ b/autogpt/memory/weaviate.py @@ -37,9 +37,18 @@ def __init__(self, cfg): else: self.client = Client(url, auth_client_secret=auth_credentials) - self.index = cfg.memory_index + self.index = WeaviateMemory.format_classname(cfg.memory_index) self._create_schema() + @staticmethod + def format_classname(index): + # weaviate uses capitalised index names + # The python client uses the following code to format + # index names before the corresponding class is created + if len(index) == 1: + return index.capitalize() + return index[0].capitalize() + index[1:] + def _create_schema(self): schema = default_schema(self.index) if not self.client.schema.contains(schema): diff --git a/tests/integration/weaviate_memory_tests.py b/tests/integration/weaviate_memory_tests.py index 503fe9d22ed3..4acea0ffda1e 100644 --- a/tests/integration/weaviate_memory_tests.py +++ b/tests/integration/weaviate_memory_tests.py @@ -12,17 +12,10 @@ from autogpt.memory.base import get_ada_embedding -@mock.patch.dict(os.environ, { - "WEAVIATE_HOST": "127.0.0.1", - "WEAVIATE_PROTOCOL": "http", - "WEAVIATE_PORT": "8080", - "WEAVIATE_USERNAME": "", - "WEAVIATE_PASSWORD": "", - "MEMORY_INDEX": "AutogptTests" -}) class TestWeaviateMemory(unittest.TestCase): cfg = None client = None + index = None @classmethod def setUpClass(cls): @@ -40,6 +33,8 @@ def setUpClass(cls): else: cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}") + cls.index = WeaviateMemory.format_classname(cls.cfg.memory_index) + """ In order to run these tests you will need a local instance of Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose @@ -51,7 +46,7 @@ def setUpClass(cls): """ def setUp(self): try: - self.client.schema.delete_class(self.cfg.memory_index) + self.client.schema.delete_class(self.index) except: pass @@ -60,8 +55,8 @@ def setUp(self): def test_add(self): doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones' self.memory.add(doc) - result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do() - actual = result['data']['Get'][self.cfg.memory_index] + result = self.client.query.get(self.index, ['raw_text']).do() + actual = result['data']['Get'][self.index] self.assertEqual(len(actual), 1) self.assertEqual(actual[0]['raw_text'], doc) @@ -73,7 +68,7 @@ def test_get(self): batch.add_data_object( uuid=get_valid_uuid(uuid4()), data_object={'raw_text': doc}, - class_name=self.cfg.memory_index, + class_name=self.index, vector=get_ada_embedding(doc) )