-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
246 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
package trietree | ||
|
||
import ( | ||
"unicode/utf8" | ||
) | ||
|
||
// Prediction is identifier of a key. | ||
type Prediction struct { | ||
Start int // Start is start index of key in query. | ||
End int // End is end index of key in query. | ||
ID int // ID is for edge node identifier. | ||
} | ||
|
||
// PredictIter returns an iterator function which enumerates Prediction: key | ||
// suggestions that match the query in the tree. | ||
func (dt *DTree) PredictIter(query string) func() *Prediction { | ||
return predictIter[*DNode](dt, query) | ||
} | ||
|
||
// PredictIter returns an iterator function which enumerates Prediction: key | ||
// suggestions that match the query in the tree. | ||
func (st *STree) PredictIter(query string) func() *Prediction { | ||
return predictIter[int](st, query) | ||
} | ||
|
||
type predictableTree[T comparable] interface { | ||
root() T | ||
nextNode(T, rune) T | ||
nodeId(T) int | ||
nodeLevel(T) int | ||
nodeFail(T) T | ||
} | ||
|
||
// methods DTree satisfies predictableTree[*DNode] | ||
func (dt *DTree) root() *DNode { return &dt.Root } | ||
func (dt *DTree) nodeId(n *DNode) int { return n.EdgeID } | ||
func (dt *DTree) nodeLevel(n *DNode) int { return n.Level } | ||
func (dt *DTree) nodeFail(n *DNode) *DNode { return n.Failure } | ||
|
||
// methods STree satisfies predictableTree[int] | ||
func (st *STree) root() int { return 0 } | ||
func (st *STree) nodeId(n int) int { return st.Nodes[n].EdgeID } | ||
func (st *STree) nodeLevel(n int) int { return st.Levels[st.nodeId(n)-1] } | ||
func (st *STree) nodeFail(n int) int { return st.Nodes[n].Fail } | ||
|
||
type traverser[T comparable] struct { | ||
tree predictableTree[T] | ||
query string | ||
pivot T | ||
index int | ||
} | ||
|
||
func newTraverser[T comparable](tree predictableTree[T], query string) traverser[T] { | ||
return traverser[T]{ | ||
tree: tree, | ||
query: query, | ||
pivot: tree.root(), | ||
index: 0, | ||
} | ||
} | ||
|
||
// next consumes a rune from query, and determine next node to travese tree. | ||
// this returns next node, and end index of last parsed rune in query. | ||
func (tr *traverser[T]) next() (node T, end int) { | ||
var zero T | ||
if tr.query == "" { | ||
return zero, 0 | ||
} | ||
r, sz := utf8.DecodeRuneInString(tr.query) | ||
if sz == 0 { | ||
return zero, 0 | ||
} | ||
tr.query = tr.query[sz:] | ||
tr.index += sz | ||
tr.pivot = tr.tree.nextNode(tr.pivot, r) | ||
return tr.pivot, tr.index | ||
} | ||
|
||
func (tr *traverser[T]) close() { | ||
tr.query = "" | ||
} | ||
|
||
// trailingIndex returns the index of the n'th character from the end of string s. | ||
func trailingIndex(s string, n int) int { | ||
x := len(s) | ||
for n > 0 && x > 0 { | ||
_, sz := utf8.DecodeLastRuneInString(s[:x]) | ||
if sz == 0 { | ||
break | ||
} | ||
x -= sz | ||
n-- | ||
} | ||
return x | ||
} | ||
|
||
func predictIter[T comparable](tree predictableTree[T], query string) func() *Prediction { | ||
var ( | ||
zero T | ||
tr = newTraverser[T](tree, query) | ||
node T | ||
end int | ||
) | ||
return func() *Prediction { | ||
var p *Prediction | ||
for p == nil { | ||
if node == zero { | ||
node, end = tr.next() | ||
if node == zero { | ||
tr.close() | ||
return nil | ||
} | ||
} | ||
for node != zero && p == nil { | ||
if id := tree.nodeId(node); id > 0 { | ||
st := trailingIndex(query[:end], tree.nodeLevel(node)) | ||
p = &Prediction{Start: st, End: end, ID: id} | ||
} | ||
node = tree.nodeFail(node) | ||
} | ||
} | ||
return p | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package trietree_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
"github.com/koron-go/trietree" | ||
) | ||
|
||
type prediction struct { | ||
Start int | ||
End int | ||
ID int | ||
Key string | ||
} | ||
|
||
type predictIterator interface { | ||
PredictIter(string) func() *trietree.Prediction | ||
} | ||
|
||
func testPredictIter(t *testing.T, ptor predictIterator, q string, want []prediction) { | ||
t.Helper() | ||
got := make([]prediction, 0, 10) | ||
iter := ptor.PredictIter(q) | ||
for { | ||
p := iter() | ||
if p == nil { | ||
break | ||
} | ||
got = append(got, prediction{ | ||
Start: p.Start, | ||
End: p.End, | ||
ID: p.ID, | ||
Key: q[p.Start:p.End], | ||
}) | ||
} | ||
if d := cmp.Diff(want, got); d != "" { | ||
t.Errorf("unexpected predictions: -want +got\n%s", d) | ||
} | ||
} | ||
|
||
type predictIteratorBuilder func(t *testing.T, keys ...string) predictIterator | ||
|
||
func testPredictIterSingle(t *testing.T, build predictIteratorBuilder) { | ||
ptor := build(t, "1", "2", "3", "4", "5") | ||
testPredictIter(t, ptor, "1", []prediction{ | ||
{Start: 0, End: 1, ID: 1, Key: "1"}, | ||
}) | ||
testPredictIter(t, ptor, "2", []prediction{ | ||
{Start: 0, End: 1, ID: 2, Key: "2"}, | ||
}) | ||
testPredictIter(t, ptor, "3", []prediction{ | ||
{Start: 0, End: 1, ID: 3, Key: "3"}, | ||
}) | ||
testPredictIter(t, ptor, "4", []prediction{ | ||
{Start: 0, End: 1, ID: 4, Key: "4"}, | ||
}) | ||
testPredictIter(t, ptor, "5", []prediction{ | ||
{Start: 0, End: 1, ID: 5, Key: "5"}, | ||
}) | ||
testPredictIter(t, ptor, "6", []prediction{}) | ||
} | ||
|
||
func testPredictIterMultiple(t *testing.T, build predictIteratorBuilder) { | ||
ptor := build(t, "1", "2", "3", "4", "5") | ||
testPredictIter(t, ptor, "1234567890", []prediction{ | ||
{Start: 0, End: 1, ID: 1, Key: "1"}, | ||
{Start: 1, End: 2, ID: 2, Key: "2"}, | ||
{Start: 2, End: 3, ID: 3, Key: "3"}, | ||
{Start: 3, End: 4, ID: 4, Key: "4"}, | ||
{Start: 4, End: 5, ID: 5, Key: "5"}, | ||
}) | ||
} | ||
|
||
func testPredictIterBasic(t *testing.T, build predictIteratorBuilder) { | ||
ptor := build(t, "ab", "bc", "bab", "d", "abcde") | ||
testPredictIter(t, ptor, "ab", []prediction{ | ||
{Start: 0, End: 2, ID: 1, Key: "ab"}, | ||
}) | ||
testPredictIter(t, ptor, "bc", []prediction{ | ||
{Start: 0, End: 2, ID: 2, Key: "bc"}, | ||
}) | ||
testPredictIter(t, ptor, "bab", []prediction{ | ||
{Start: 0, End: 3, ID: 3, Key: "bab"}, | ||
{Start: 1, End: 3, ID: 1, Key: "ab"}, | ||
}) | ||
testPredictIter(t, ptor, "d", []prediction{ | ||
{Start: 0, End: 1, ID: 4, Key: "d"}, | ||
}) | ||
testPredictIter(t, ptor, "abcde", []prediction{ | ||
{Start: 0, End: 2, ID: 1, Key: "ab"}, | ||
{Start: 1, End: 3, ID: 2, Key: "bc"}, | ||
{Start: 3, End: 4, ID: 4, Key: "d"}, | ||
{Start: 0, End: 5, ID: 5, Key: "abcde"}, | ||
}) | ||
} | ||
|
||
func testPredictIterAll(t *testing.T, builder predictIteratorBuilder) { | ||
t.Run("single", func(t *testing.T) { | ||
testPredictIterSingle(t, builder) | ||
}) | ||
t.Run("multiple", func(t *testing.T) { | ||
testPredictIterMultiple(t, builder) | ||
}) | ||
t.Run("basic", func(t *testing.T) { | ||
testPredictIterBasic(t, builder) | ||
}) | ||
} | ||
|
||
func TestPredictIter(t *testing.T) { | ||
t.Run("dynamic", func(t *testing.T) { | ||
testPredictIterAll(t, func(t *testing.T, keys ...string) predictIterator { | ||
return testDTreePut(t, &trietree.DTree{}, keys...) | ||
}) | ||
}) | ||
t.Run("static", func(t *testing.T) { | ||
testPredictIterAll(t, func(t *testing.T, keys ...string) predictIterator { | ||
dt := testDTreePut(t, &trietree.DTree{}, keys...) | ||
return trietree.Freeze(dt) | ||
}) | ||
}) | ||
} |