Skip to content

Commit

Permalink
Merge pull request Significant-Gravitas#1836 from cs0lar/fix/weaviate…
Browse files Browse the repository at this point in the history
…_index_to_classname

fixes Weaviate index name to classname conversion
  • Loading branch information
richbeales authored Apr 17, 2023
2 parents 6222b2d + 0b936a2 commit e849e4f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
11 changes: 10 additions & 1 deletion autogpt/memory/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,18 @@ def __init__(self, cfg):
else:
self.client = Client(url, auth_client_secret=auth_credentials)

self.index = cfg.memory_index
self.index = WeaviateMemory.format_classname(cfg.memory_index)
self._create_schema()

@staticmethod
def format_classname(index):
# weaviate uses capitalised index names
# The python client uses the following code to format
# index names before the corresponding class is created
if len(index) == 1:
return index.capitalize()
return index[0].capitalize() + index[1:]

def _create_schema(self):
schema = default_schema(self.index)
if not self.client.schema.contains(schema):
Expand Down
19 changes: 7 additions & 12 deletions tests/integration/weaviate_memory_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,10 @@
from autogpt.memory.base import get_ada_embedding


@mock.patch.dict(os.environ, {
"WEAVIATE_HOST": "127.0.0.1",
"WEAVIATE_PROTOCOL": "http",
"WEAVIATE_PORT": "8080",
"WEAVIATE_USERNAME": "",
"WEAVIATE_PASSWORD": "",
"MEMORY_INDEX": "AutogptTests"
})
class TestWeaviateMemory(unittest.TestCase):
cfg = None
client = None
index = None

@classmethod
def setUpClass(cls):
Expand All @@ -40,6 +33,8 @@ def setUpClass(cls):
else:
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}")

cls.index = WeaviateMemory.format_classname(cls.cfg.memory_index)

"""
In order to run these tests you will need a local instance of
Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose
Expand All @@ -51,7 +46,7 @@ def setUpClass(cls):
"""
def setUp(self):
try:
self.client.schema.delete_class(self.cfg.memory_index)
self.client.schema.delete_class(self.index)
except:
pass

Expand All @@ -60,8 +55,8 @@ def setUp(self):
def test_add(self):
doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones'
self.memory.add(doc)
result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do()
actual = result['data']['Get'][self.cfg.memory_index]
result = self.client.query.get(self.index, ['raw_text']).do()
actual = result['data']['Get'][self.index]

self.assertEqual(len(actual), 1)
self.assertEqual(actual[0]['raw_text'], doc)
Expand All @@ -73,7 +68,7 @@ def test_get(self):
batch.add_data_object(
uuid=get_valid_uuid(uuid4()),
data_object={'raw_text': doc},
class_name=self.cfg.memory_index,
class_name=self.index,
vector=get_ada_embedding(doc)
)

Expand Down

0 comments on commit e849e4f

Please sign in to comment.