diff --git a/pkg/tokenizers/sentencepiece/internal/sentencepiece/sentencepiece.go b/pkg/tokenizers/sentencepiece/internal/sentencepiece/sentencepiece.go index fd681ca..36bab23 100644 --- a/pkg/tokenizers/sentencepiece/internal/sentencepiece/sentencepiece.go +++ b/pkg/tokenizers/sentencepiece/internal/sentencepiece/sentencepiece.go @@ -29,23 +29,23 @@ type trieNode struct { score float32 index int32 end bool - children map[rune]trieNode + children map[rune]*trieNode } -func newTrieNode(text string, level int) trieNode { - return trieNode{ +func newTrieNode(text string, level int) *trieNode { + return &trieNode{ text: text, level: level, score: 0.0, index: 0, end: false, - children: make(map[rune]trieNode), + children: make(map[rune]*trieNode), } } // Sentencepiece holds the model type Sentencepiece struct { - root trieNode + root *trieNode lowercase bool unknown int32 controlWords map[string]int32 @@ -110,7 +110,7 @@ func (s *Sentencepiece) TokenizeToIDs(text string) []int32 { func (s *Sentencepiece) insert(word string, score float32, index int32) { _, size := utf8.DecodeLastRuneInString(word) charCount := len(word) - node := &s.root + node := s.root for i, r := range word { text := node.text cnode, ok := node.children[r] @@ -124,24 +124,22 @@ func (s *Sentencepiece) insert(word string, score float32, index int32) { cnode.index = index } node.children[r] = cnode - node = &cnode + node = cnode } } -func (s *Sentencepiece) commonPrefixSearch(runes []rune) []trieNode { - output := make([]trieNode, 0, len(runes)) - node := &s.root +func (s *Sentencepiece) commonPrefixSearch(runes []rune, cb func(*trieNode)) { + node := s.root for _, r := range runes { cnode, ok := node.children[r] if !ok { break } if cnode.end { - output = append(output, cnode) + cb(cnode) } - node = &cnode + node = cnode } - return output } func (s *Sentencepiece) decodeBackwards(slices []slice) []slice { @@ -166,15 +164,14 @@ func (s *Sentencepiece) decodeForwardToken(runes []rune) []slice { slices := s.initSlices(len(runes) + 1) scores[0] = 0.0 for i := range runes { - matches := s.commonPrefixSearch(runes[i:]) - for _, node := range matches { + s.commonPrefixSearch(runes[i:], func(node *trieNode) { localScore := scores[i] + node.score charEnd := i + node.level if localScore > scores[charEnd] { slices[charEnd] = slice{score: localScore, index: node.index, start: i, end: charEnd} scores[charEnd] = localScore } - } + }) if scores[i+1] <= minScore { slices[i+1] = slice{score: minScore, index: s.unknown, start: i, end: i + 1} scores[i+1] = 0.0