Skip to content

Commit

Permalink
tokenizers/sentencepiece: Improve performance by removing allocations
Browse files Browse the repository at this point in the history
```
goos: linux
goarch: amd64
pkg: github.com/nlpodyssey/cybertron/pkg/tokenizers/sentencepiece/internal/sentencepiece
cpu: Intel(R) Xeon(R) CPU E5-2697 v4 @ 2.30GHz
                                      │   old.txt    │               new.txt               │
                                      │    sec/op    │   sec/op     vs base                │
SentencePiece/compose_email_to_joh-36   2235.6µ ± 1%   147.1µ ± 1%  -93.42% (p=0.000 n=10)

                                      │    old.txt     │               new.txt                │
                                      │      B/op      │     B/op      vs base                │
SentencePiece/compose_email_to_joh-36   3289.19Ki ± 0%   27.02Ki ± 0%  -99.18% (p=0.000 n=10)

                                      │   old.txt    │              new.txt               │
                                      │  allocs/op   │ allocs/op   vs base                │
SentencePiece/compose_email_to_joh-36   1830.00 ± 0%   92.00 ± 0%  -94.97% (p=0.000 n=10)
```
  • Loading branch information
damz committed May 30, 2024
1 parent d0c62f8 commit 88b93f4
Showing 1 changed file with 13 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ type trieNode struct {
score float32
index int32
end bool
children map[rune]trieNode
children map[rune]*trieNode
}

func newTrieNode(text string, level int) trieNode {
return trieNode{
func newTrieNode(text string, level int) *trieNode {
return &trieNode{
text: text,
level: level,
score: 0.0,
index: 0,
end: false,
children: make(map[rune]trieNode),
children: make(map[rune]*trieNode),
}
}

// Sentencepiece holds the model
type Sentencepiece struct {
root trieNode
root *trieNode
lowercase bool
unknown int32
controlWords map[string]int32
Expand Down Expand Up @@ -110,7 +110,7 @@ func (s *Sentencepiece) TokenizeToIDs(text string) []int32 {
func (s *Sentencepiece) insert(word string, score float32, index int32) {
_, size := utf8.DecodeLastRuneInString(word)
charCount := len(word)
node := &s.root
node := s.root
for i, r := range word {
text := node.text
cnode, ok := node.children[r]
Expand All @@ -124,24 +124,22 @@ func (s *Sentencepiece) insert(word string, score float32, index int32) {
cnode.index = index
}
node.children[r] = cnode
node = &cnode
node = cnode
}
}

func (s *Sentencepiece) commonPrefixSearch(runes []rune) []trieNode {
output := make([]trieNode, 0, len(runes))
node := &s.root
func (s *Sentencepiece) commonPrefixSearch(runes []rune, cb func(*trieNode)) {
node := s.root
for _, r := range runes {
cnode, ok := node.children[r]
if !ok {
break
}
if cnode.end {
output = append(output, cnode)
cb(cnode)
}
node = &cnode
node = cnode
}
return output
}

func (s *Sentencepiece) decodeBackwards(slices []slice) []slice {
Expand All @@ -166,15 +164,14 @@ func (s *Sentencepiece) decodeForwardToken(runes []rune) []slice {
slices := s.initSlices(len(runes) + 1)
scores[0] = 0.0
for i := range runes {
matches := s.commonPrefixSearch(runes[i:])
for _, node := range matches {
s.commonPrefixSearch(runes[i:], func(node *trieNode) {
localScore := scores[i] + node.score
charEnd := i + node.level
if localScore > scores[charEnd] {
slices[charEnd] = slice{score: localScore, index: node.index, start: i, end: charEnd}
scores[charEnd] = localScore
}
}
})
if scores[i+1] <= minScore {
slices[i+1] = slice{score: minScore, index: s.unknown, start: i, end: i + 1}
scores[i+1] = 0.0
Expand Down

0 comments on commit 88b93f4

Please sign in to comment.