-
Notifications
You must be signed in to change notification settings - Fork 0
/
easytokenizer.go
118 lines (107 loc) · 3.79 KB
/
easytokenizer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package tokenizer
/*
#cgo CXXFLAGS: -std=c++11
#cgo LDFLAGS: -L${SRCDIR} -ltokenizer -lm
#include "easytokenizer_wrapper.h"
*/
import "C"
import (
"unsafe"
)
type EasyTokenizer struct {
easyTokenizer C.EasyTokenizer
vocabPath string
doLowerCase bool
codePointLevel bool
}
// NewTokenizer init EasyTokenizer
func NewTokenizer(vocabPath string, doLowerCase bool) *EasyTokenizer {
var tokenizer EasyTokenizer
tokenizer.vocabPath = vocabPath
tokenizer.doLowerCase = doLowerCase
// just true
tokenizer.codePointLevel = true
// initTokenizer
tokenizer.easyTokenizer = C.initTokenizer(C.CString(vocabPath), C.bool(tokenizer.doLowerCase), C.bool(tokenizer.codePointLevel))
return &tokenizer
}
func (t *EasyTokenizer) Close() {
C.free(unsafe.Pointer(t.easyTokenizer))
t.easyTokenizer = nil
}
func (t *EasyTokenizer) Encode(text string, maxSeqLength int) []int32 {
// text
cText := C.CString(text)
defer C.free(unsafe.Pointer(cText))
// make allocation of []int32 slice
outputData := make([]int32, maxSeqLength)
// call function
C.encode(t.easyTokenizer, cText, C.bool(true), C.bool(true), C.int(maxSeqLength), (*C.int)(unsafe.Pointer(&outputData[0])))
return outputData
}
func (t *EasyTokenizer) EncodeWithIds(text string, maxSeqLength int) ([]int32, []int32, []int32, []int32) {
// text
cText := C.CString(text)
defer C.free(unsafe.Pointer(cText))
// inputIds, tokenTypeIds, attentionMask, offsets
var inputIds, tokenTypeIds, attentionMask, offsets *C.int
// slice number of inputIds, tokenTypeIds, attentionMask, offsets
var numInputIds, numTokenTypeIds, numAttentionMask, numOffsets C.int
// call function
C.encodeWithIds(t.easyTokenizer, cText,
&inputIds, &numInputIds,
&tokenTypeIds, &numTokenTypeIds,
&attentionMask, &numAttentionMask,
&offsets, &numOffsets,
C.bool(true), C.bool(true), C.bool(true), C.int(maxSeqLength))
// to Golang Slice
sliceInputIds := (*[1 << 30]int32)(unsafe.Pointer(inputIds))[:numInputIds:numInputIds]
sliceTokenTypeIds := (*[1 << 30]int32)(unsafe.Pointer(tokenTypeIds))[:numTokenTypeIds:numTokenTypeIds]
sliceAttentionMask := (*[1 << 30]int32)(unsafe.Pointer(attentionMask))[:numAttentionMask:numAttentionMask]
sliceOffsets := (*[1 << 30]int32)(unsafe.Pointer(offsets))[:numOffsets:numOffsets]
// release
defer C.free(unsafe.Pointer(inputIds))
defer C.free(unsafe.Pointer(tokenTypeIds))
defer C.free(unsafe.Pointer(attentionMask))
defer C.free(unsafe.Pointer(offsets))
goInputIds := make([]int32, numInputIds)
goTokenTypeIds := make([]int32, numTokenTypeIds)
goAttentionMask := make([]int32, numAttentionMask)
for i := 0; i < int(numInputIds); i++ {
goInputIds[i] = sliceInputIds[i]
goTokenTypeIds[i] = sliceTokenTypeIds[i]
goAttentionMask[i] = sliceAttentionMask[i]
}
goOffsets := make([]int32, numOffsets)
for i := 0; i < int(numOffsets); i++ {
goOffsets[i] = sliceOffsets[i]
}
return goInputIds, goTokenTypeIds, goAttentionMask, goOffsets
}
func (t *EasyTokenizer) WordPieceTokenize(text string) ([]string, []int32) {
// text
cText := C.CString(text)
defer C.free(unsafe.Pointer(cText))
// init variable
var tokens **C.char
var offsets *C.int
var numTokens, numOffsets C.int
// call function
C.wordPieceTokenize(t.easyTokenizer, cText, &tokens, &numTokens, &offsets, &numOffsets)
defer C.free(unsafe.Pointer(tokens))
defer C.free(unsafe.Pointer(offsets))
// to Golang Slice
sliceTokens := (*[1 << 30]*C.char)(unsafe.Pointer(tokens))[:numTokens:numTokens]
sliceOffsets := (*[1 << 30]int32)(unsafe.Pointer(offsets))[:numOffsets:numOffsets]
// parse string token
goTokens := make([]string, numTokens)
for i := 0; i < int(numTokens); i++ {
goTokens[i] = C.GoString(sliceTokens[i])
}
// parse string offsets
goOffsets := make([]int32, numOffsets)
for i := 0; i < int(numOffsets); i++ {
goOffsets[i] = sliceOffsets[i]
}
return goTokens, goOffsets
}