Skip to content

Commit

Permalink
Merge pull request #42 from damz/pr/sentencepiece-performance
Browse files Browse the repository at this point in the history
tokenizers/sentencepiece: Improve performance by removing allocations
  • Loading branch information
matteo-grella authored Jun 8, 2024
2 parents d0c62f8 + 88b93f4 commit a7ba5c1
Showing 1 changed file with 13 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ type trieNode struct {
score float32
index int32
end bool
children map[rune]trieNode
children map[rune]*trieNode
}

func newTrieNode(text string, level int) trieNode {
return trieNode{
func newTrieNode(text string, level int) *trieNode {
return &trieNode{
text: text,
level: level,
score: 0.0,
index: 0,
end: false,
children: make(map[rune]trieNode),
children: make(map[rune]*trieNode),
}
}

// Sentencepiece holds the model
type Sentencepiece struct {
root trieNode
root *trieNode
lowercase bool
unknown int32
controlWords map[string]int32
Expand Down Expand Up @@ -110,7 +110,7 @@ func (s *Sentencepiece) TokenizeToIDs(text string) []int32 {
func (s *Sentencepiece) insert(word string, score float32, index int32) {
_, size := utf8.DecodeLastRuneInString(word)
charCount := len(word)
node := &s.root
node := s.root
for i, r := range word {
text := node.text
cnode, ok := node.children[r]
Expand All @@ -124,24 +124,22 @@ func (s *Sentencepiece) insert(word string, score float32, index int32) {
cnode.index = index
}
node.children[r] = cnode
node = &cnode
node = cnode
}
}

func (s *Sentencepiece) commonPrefixSearch(runes []rune) []trieNode {
output := make([]trieNode, 0, len(runes))
node := &s.root
func (s *Sentencepiece) commonPrefixSearch(runes []rune, cb func(*trieNode)) {
node := s.root
for _, r := range runes {
cnode, ok := node.children[r]
if !ok {
break
}
if cnode.end {
output = append(output, cnode)
cb(cnode)
}
node = &cnode
node = cnode
}
return output
}

func (s *Sentencepiece) decodeBackwards(slices []slice) []slice {
Expand All @@ -166,15 +164,14 @@ func (s *Sentencepiece) decodeForwardToken(runes []rune) []slice {
slices := s.initSlices(len(runes) + 1)
scores[0] = 0.0
for i := range runes {
matches := s.commonPrefixSearch(runes[i:])
for _, node := range matches {
s.commonPrefixSearch(runes[i:], func(node *trieNode) {
localScore := scores[i] + node.score
charEnd := i + node.level
if localScore > scores[charEnd] {
slices[charEnd] = slice{score: localScore, index: node.index, start: i, end: charEnd}
scores[charEnd] = localScore
}
}
})
if scores[i+1] <= minScore {
slices[i+1] = slice{score: minScore, index: s.unknown, start: i, end: i + 1}
scores[i+1] = 0.0
Expand Down

0 comments on commit a7ba5c1

Please sign in to comment.