Skip to content

Commit

Permalink
backend: move chunking to llm package
Browse files Browse the repository at this point in the history
  • Loading branch information
NickSavage committed Nov 21, 2024
1 parent fb1f88a commit caf0d10
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 58 deletions.
29 changes: 1 addition & 28 deletions go-backend/handlers/cards.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ func (s *Handler) ChunkCard(card models.Card) error {
db := s.DB

tx, err := db.Begin()
chunks := s.GenerateChunks(card.Body)
chunks := llms.GenerateChunks(card.Body)
query := `DELETE FROM card_chunks WHERE card_pk = $1 AND user_id = $2`
_, err = tx.Exec(query, card.ID, card.UserID)
if err != nil {
Expand All @@ -986,33 +986,6 @@ func (s *Handler) ChunkCard(card models.Card) error {
return nil

}

func (s *Handler) GenerateChunks(input string) []string {
results := []string{}

// Only trim leading/trailing spaces
input = strings.TrimSpace(input)

// Split by periods but add them back
sentences := strings.Split(input+".", ".")
for i, sentence := range sentences {
// Skip the last empty element caused by our added period
if i == len(sentences)-1 && sentence == "" {
break
}

// Only trim leading spaces, preserve newlines and trailing spaces
sentence = strings.TrimLeft(sentence, " ")
if sentence == "" {
continue
}

results = append(results, sentence+".")
}

return results
}

func (s *Handler) GetCardChunks(userID, cardPK int) ([]models.CardChunk, error) {
query := `SELECT
id, card_pk, user_id, chunk_text
Expand Down
30 changes: 0 additions & 30 deletions go-backend/handlers/cards_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -771,33 +771,3 @@ func TestCheckCardLinkedOrRelated(t *testing.T) {
t.Errorf("expected card to not be linked, returned true")
}
}

func TestChunkCardBody(t *testing.T) {
s := setup()
defer tests.Teardown()

input := `Lorem ipsum odor amet, consectetuer adipiscing elit. Luctus egestas lobortis cursus mollis facilisi. Scelerisque vel litora rhoncus porttitor eros. Lacus orci morbi a varius lobortis rutrum interdum per. Nostra commodo phasellus etiam morbi metus porttitor. Mauris a fermentum habitasse sollicitudin semper porta. Fermentum phasellus hendrerit purus, etiam erat litora.
Lorem cubilia cubilia dis iaculis, odio vivamus interdum adipiscing dolor.`

results := s.GenerateChunks(input)
for _, result := range results {
log.Printf(result)
}

if len(results) != 8 {
t.Errorf("wrong number of chunks returned, got %v want %v", len(results), 8)
}
string := "Fermentum phasellus hendrerit purus, etiam erat litora."
if len(results) > 6 && results[6] != string {
t.Errorf("wrong chunk return, %v, got %v want %v", results[6] == string, results[6], string)
t.Errorf("one: %v", results[6])
t.Errorf("two: %v", string)

}
last := "\n\nLorem cubilia cubilia dis iaculis, odio vivamus interdum adipiscing dolor."
if len(results) > 7 && results[7] != last {
t.Errorf("wrong chunk return, got %v want %v", results[7], last)

}
}
31 changes: 31 additions & 0 deletions go-backend/llms/chunking.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package llms

import (
"strings"
)

func GenerateChunks(input string) []string {
results := []string{}

// Only trim leading/trailing spaces
input = strings.TrimSpace(input)

// Split by periods but add them back
sentences := strings.Split(input+".", ".")
for i, sentence := range sentences {
// Skip the last empty element caused by our added period
if i == len(sentences)-1 && sentence == "" {
break
}

// Only trim leading spaces, preserve newlines and trailing spaces
sentence = strings.TrimLeft(sentence, " ")
if sentence == "" {
continue
}

results = append(results, sentence+".")
}

return results
}
33 changes: 33 additions & 0 deletions go-backend/llms/chunking_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package llms

import (
"log"
"testing"
)

func TestChunkCardBody(t *testing.T) {
input := `Lorem ipsum odor amet, consectetuer adipiscing elit. Luctus egestas lobortis cursus mollis facilisi. Scelerisque vel litora rhoncus porttitor eros. Lacus orci morbi a varius lobortis rutrum interdum per. Nostra commodo phasellus etiam morbi metus porttitor. Mauris a fermentum habitasse sollicitudin semper porta. Fermentum phasellus hendrerit purus, etiam erat litora.
Lorem cubilia cubilia dis iaculis, odio vivamus interdum adipiscing dolor.`

results := GenerateChunks(input)
for _, result := range results {
log.Printf(result)
}

if len(results) != 8 {
t.Errorf("wrong number of chunks returned, got %v want %v", len(results), 8)
}
string := "Fermentum phasellus hendrerit purus, etiam erat litora."
if len(results) > 6 && results[6] != string {
t.Errorf("wrong chunk return, %v, got %v want %v", results[6] == string, results[6], string)
t.Errorf("one: %v", results[6])
t.Errorf("two: %v", string)

}
last := "\n\nLorem cubilia cubilia dis iaculis, odio vivamus interdum adipiscing dolor."
if len(results) > 7 && results[7] != last {
t.Errorf("wrong chunk return, got %v want %v", results[7], last)

}
}

0 comments on commit caf0d10

Please sign in to comment.