diff --git a/go/README.md b/go/README.md index f5f2d25d7f9..0cfe009a428 100644 --- a/go/README.md +++ b/go/README.md @@ -10,6 +10,10 @@ - postgres=# `create role chroma with login password 'chroma';` - postgres=# `alter role chroma with superuser;` - postgres=# `create database chroma;` +- Set postgres ENV Vars + Several tests (such as record_log_service_test.go) require the following environment variables to be set: + - `export POSTGRES_HOST=localhost` + - `export POSTGRES_PORT=5432` - Atlas schema migration - [~/chroma/go]: `atlas migrate diff --env dev` - [~/chroma/go]: `atlas --env dev migrate apply --url "postgres://chroma:chroma@localhost:5432/chroma?sslmode=disable"` diff --git a/go/pkg/logservice/grpc/record_log_service.go b/go/pkg/logservice/grpc/record_log_service.go index b1899a0b360..40d38a3e57b 100644 --- a/go/pkg/logservice/grpc/record_log_service.go +++ b/go/pkg/logservice/grpc/record_log_service.go @@ -28,6 +28,8 @@ func (s *Server) PushLogs(ctx context.Context, req *logservicepb.PushLogsRequest } var recordsContent [][]byte for _, record := range req.Records { + // We remove the collection id for space reasons, as its double stored in the wrapping database RecordLog object. + // PullLogs will rehydrate the collection id from the database. record.CollectionId = "" data, err := proto.Marshal(record) if err != nil { @@ -73,6 +75,8 @@ func (s *Server) PullLogs(ctx context.Context, req *logservicepb.PullLogsRequest } return nil, grpcError } + // Here we rehydrate the collection id from the database since in PushLogs we removed it for space reasons. + record.CollectionId = *recordLogs[index].CollectionID recordLog := &logservicepb.RecordLog{ LogId: recordLogs[index].ID, Record: record, diff --git a/go/pkg/logservice/grpc/record_log_service_test.go b/go/pkg/logservice/grpc/record_log_service_test.go index ed18e1f23a7..3e453e351ed 100644 --- a/go/pkg/logservice/grpc/record_log_service_test.go +++ b/go/pkg/logservice/grpc/record_log_service_test.go @@ -4,6 +4,9 @@ import ( "bytes" "context" "encoding/binary" + "testing" + "time" + "github.com/chroma-core/chroma/go/pkg/logservice/testutils" "github.com/chroma-core/chroma/go/pkg/metastore/db/dbcore" "github.com/chroma-core/chroma/go/pkg/metastore/db/dbmodel" @@ -16,8 +19,6 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "gorm.io/gorm" - "testing" - "time" ) type RecordLogServiceTestSuite struct { @@ -132,6 +133,11 @@ func (suite *RecordLogServiceTestSuite) TestServer_PushLogs() { func (suite *RecordLogServiceTestSuite) TestServer_PullLogs() { // push some records recordsToSubmit := GetTestEmbeddingRecords(suite.collectionId.String()) + // deep clone the records since PushLogs will mutate the records and we need a source of truth + recordsToSubmit_sot := make([]*coordinatorpb.SubmitEmbeddingRecord, len(recordsToSubmit)) + for i := range recordsToSubmit { + recordsToSubmit_sot[i] = proto.Clone(recordsToSubmit[i]).(*coordinatorpb.SubmitEmbeddingRecord) + } pushRequest := logservicepb.PushLogsRequest{ CollectionId: suite.collectionId.String(), Records: recordsToSubmit, @@ -150,13 +156,13 @@ func (suite *RecordLogServiceTestSuite) TestServer_PullLogs() { suite.Len(pullResponse.Records, 3) for index := range pullResponse.Records { suite.Equal(int64(index+1), pullResponse.Records[index].LogId) - suite.Equal(pullResponse.Records[index].Record.Id, recordsToSubmit[index].Id) - suite.Equal(pullResponse.Records[index].Record.Operation, recordsToSubmit[index].Operation) - suite.Equal(pullResponse.Records[index].Record.CollectionId, recordsToSubmit[index].CollectionId) - suite.Equal(pullResponse.Records[index].Record.Metadata, recordsToSubmit[index].Metadata) - suite.Equal(pullResponse.Records[index].Record.Vector.Dimension, recordsToSubmit[index].Vector.Dimension) - suite.Equal(pullResponse.Records[index].Record.Vector.Encoding, recordsToSubmit[index].Vector.Encoding) - suite.Equal(pullResponse.Records[index].Record.Vector.Vector, recordsToSubmit[index].Vector.Vector) + suite.Equal(recordsToSubmit_sot[index].Id, pullResponse.Records[index].Record.Id) + suite.Equal(recordsToSubmit_sot[index].Operation, pullResponse.Records[index].Record.Operation) + suite.Equal(recordsToSubmit_sot[index].CollectionId, pullResponse.Records[index].Record.CollectionId) + suite.Equal(recordsToSubmit_sot[index].Metadata, pullResponse.Records[index].Record.Metadata) + suite.Equal(recordsToSubmit_sot[index].Vector.Dimension, pullResponse.Records[index].Record.Vector.Dimension) + suite.Equal(recordsToSubmit_sot[index].Vector.Encoding, pullResponse.Records[index].Record.Vector.Encoding) + suite.Equal(recordsToSubmit_sot[index].Vector.Vector, pullResponse.Records[index].Record.Vector.Vector) } }