Skip to content

Commit

Permalink
implement normal function iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
koron committed Feb 25, 2024
1 parent 29260bd commit 1b6d6ca
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 86 deletions.
79 changes: 0 additions & 79 deletions iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,8 @@ package trietree

import (
"iter"
"unicode/utf8"
)

// Prediction is identifier of a key.
type Prediction struct {
Start int // Start is start index of key in query.
End int // End is end index of key in query.
ID int // ID is for edge node identifier.
}

// Predict returns an iterator which enumerates Prediction: key suggestions
// that match the query in the tree.
func (dt *DTree) Predict(query string) iter.Seq[Prediction] {
Expand All @@ -26,77 +18,6 @@ func (st *STree) Predict(query string) iter.Seq[Prediction] {
return predict[int](st, query)
}

type predictableTree[T comparable] interface {
root() T
nextNode(T, rune) T
nodeId(T) int
nodeLevel(T) int
nodeFail(T) T
}

// methods DTree satisfies predictableTree[*DNode]
func (dt *DTree) root() *DNode { return &dt.Root }
func (dt *DTree) nodeId(n *DNode) int { return n.EdgeID }
func (dt *DTree) nodeLevel(n *DNode) int { return n.Level }
func (dt *DTree) nodeFail(n *DNode) *DNode { return n.Failure }

// methods STree satisfies predictableTree[int]
func (st *STree) root() int { return 0 }
func (st *STree) nodeId(n int) int { return st.Nodes[n].EdgeID }
func (st *STree) nodeLevel(n int) int { return st.Levels[st.nodeId(n)-1] }
func (st *STree) nodeFail(n int) int { return st.Nodes[n].Fail }

type traverser[T comparable] struct {
tree predictableTree[T]
query string
pivot T
index int
}

func newTraverser[T comparable](tree predictableTree[T], query string) traverser[T] {
return traverser[T]{
tree: tree,
query: query,
pivot: tree.root(),
index: 0,
}
}

// next consumes a rune from query, and determine next node to travese tree.
// this returns next node, and end index of last parsed rune in query.
func (tr *traverser[T]) next() (node T, end int) {
var zero T
if tr.query == "" {
return zero, 0
}
r, sz := utf8.DecodeRuneInString(tr.query)
if sz == 0 {
return zero, 0
}
tr.query = tr.query[sz:]
tr.index += sz
tr.pivot = tr.tree.nextNode(tr.pivot, r)
return tr.pivot, tr.index
}

func (tr *traverser[T]) close() {
tr.query = ""
}

// trailingIndex returns the index of the n'th character from the end of string s.
func trailingIndex(s string, n int) int {
x := len(s)
for n > 0 && x > 0 {
_, sz := utf8.DecodeLastRuneInString(s[:x])
if sz == 0 {
break
}
x -= sz
n--
}
return x
}

func predict[T comparable](tree predictableTree[T], query string) iter.Seq[Prediction] {
var zero T
tr := newTraverser[T](tree, query)
Expand Down
7 changes: 0 additions & 7 deletions iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@ import (
"github.com/koron-go/trietree"
)

type prediction struct {
Start int
End int
ID int
Key string
}

type predictor interface {
Predict(string) iter.Seq[trietree.Prediction]
}
Expand Down
124 changes: 124 additions & 0 deletions predict.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package trietree

import (
"unicode/utf8"
)

// Prediction is identifier of a key.
type Prediction struct {
Start int // Start is start index of key in query.
End int // End is end index of key in query.
ID int // ID is for edge node identifier.
}

// PredictIter returns an iterator function which enumerates Prediction: key
// suggestions that match the query in the tree.
func (dt *DTree) PredictIter(query string) func() *Prediction {
return predictIter[*DNode](dt, query)
}

// PredictIter returns an iterator function which enumerates Prediction: key
// suggestions that match the query in the tree.
func (st *STree) PredictIter(query string) func() *Prediction {
return predictIter[int](st, query)
}

type predictableTree[T comparable] interface {
root() T
nextNode(T, rune) T
nodeId(T) int
nodeLevel(T) int
nodeFail(T) T
}

// methods DTree satisfies predictableTree[*DNode]
func (dt *DTree) root() *DNode { return &dt.Root }
func (dt *DTree) nodeId(n *DNode) int { return n.EdgeID }
func (dt *DTree) nodeLevel(n *DNode) int { return n.Level }
func (dt *DTree) nodeFail(n *DNode) *DNode { return n.Failure }

// methods STree satisfies predictableTree[int]
func (st *STree) root() int { return 0 }
func (st *STree) nodeId(n int) int { return st.Nodes[n].EdgeID }
func (st *STree) nodeLevel(n int) int { return st.Levels[st.nodeId(n)-1] }
func (st *STree) nodeFail(n int) int { return st.Nodes[n].Fail }

type traverser[T comparable] struct {
tree predictableTree[T]
query string
pivot T
index int
}

func newTraverser[T comparable](tree predictableTree[T], query string) traverser[T] {
return traverser[T]{
tree: tree,
query: query,
pivot: tree.root(),
index: 0,
}
}

// next consumes a rune from query, and determine next node to travese tree.
// this returns next node, and end index of last parsed rune in query.
func (tr *traverser[T]) next() (node T, end int) {
var zero T
if tr.query == "" {
return zero, 0
}
r, sz := utf8.DecodeRuneInString(tr.query)
if sz == 0 {
return zero, 0
}
tr.query = tr.query[sz:]
tr.index += sz
tr.pivot = tr.tree.nextNode(tr.pivot, r)
return tr.pivot, tr.index
}

func (tr *traverser[T]) close() {
tr.query = ""
}

// trailingIndex returns the index of the n'th character from the end of string s.
func trailingIndex(s string, n int) int {
x := len(s)
for n > 0 && x > 0 {
_, sz := utf8.DecodeLastRuneInString(s[:x])
if sz == 0 {
break
}
x -= sz
n--
}
return x
}

func predictIter[T comparable](tree predictableTree[T], query string) func() *Prediction {
var (
zero T
tr = newTraverser[T](tree, query)
node T
end int
)
return func() *Prediction {
var p *Prediction
for p == nil {
if node == zero {
node, end = tr.next()
if node == zero {
tr.close()
return nil
}
}
for node != zero && p == nil {
if id := tree.nodeId(node); id > 0 {
st := trailingIndex(query[:end], tree.nodeLevel(node))
p = &Prediction{Start: st, End: end, ID: id}
}
node = tree.nodeFail(node)
}
}
return p
}
}
122 changes: 122 additions & 0 deletions predict_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package trietree_test

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/koron-go/trietree"
)

type prediction struct {
Start int
End int
ID int
Key string
}

type predictIterator interface {
PredictIter(string) func() *trietree.Prediction
}

func testPredictIter(t *testing.T, ptor predictIterator, q string, want []prediction) {
t.Helper()
got := make([]prediction, 0, 10)
iter := ptor.PredictIter(q)
for {
p := iter()
if p == nil {
break
}
got = append(got, prediction{
Start: p.Start,
End: p.End,
ID: p.ID,
Key: q[p.Start:p.End],
})
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("unexpected predictions: -want +got\n%s", d)
}
}

type predictIteratorBuilder func(t *testing.T, keys ...string) predictIterator

func testPredictIterSingle(t *testing.T, build predictIteratorBuilder) {
ptor := build(t, "1", "2", "3", "4", "5")
testPredictIter(t, ptor, "1", []prediction{
{Start: 0, End: 1, ID: 1, Key: "1"},
})
testPredictIter(t, ptor, "2", []prediction{
{Start: 0, End: 1, ID: 2, Key: "2"},
})
testPredictIter(t, ptor, "3", []prediction{
{Start: 0, End: 1, ID: 3, Key: "3"},
})
testPredictIter(t, ptor, "4", []prediction{
{Start: 0, End: 1, ID: 4, Key: "4"},
})
testPredictIter(t, ptor, "5", []prediction{
{Start: 0, End: 1, ID: 5, Key: "5"},
})
testPredictIter(t, ptor, "6", []prediction{})
}

func testPredictIterMultiple(t *testing.T, build predictIteratorBuilder) {
ptor := build(t, "1", "2", "3", "4", "5")
testPredictIter(t, ptor, "1234567890", []prediction{
{Start: 0, End: 1, ID: 1, Key: "1"},
{Start: 1, End: 2, ID: 2, Key: "2"},
{Start: 2, End: 3, ID: 3, Key: "3"},
{Start: 3, End: 4, ID: 4, Key: "4"},
{Start: 4, End: 5, ID: 5, Key: "5"},
})
}

func testPredictIterBasic(t *testing.T, build predictIteratorBuilder) {
ptor := build(t, "ab", "bc", "bab", "d", "abcde")
testPredictIter(t, ptor, "ab", []prediction{
{Start: 0, End: 2, ID: 1, Key: "ab"},
})
testPredictIter(t, ptor, "bc", []prediction{
{Start: 0, End: 2, ID: 2, Key: "bc"},
})
testPredictIter(t, ptor, "bab", []prediction{
{Start: 0, End: 3, ID: 3, Key: "bab"},
{Start: 1, End: 3, ID: 1, Key: "ab"},
})
testPredictIter(t, ptor, "d", []prediction{
{Start: 0, End: 1, ID: 4, Key: "d"},
})
testPredictIter(t, ptor, "abcde", []prediction{
{Start: 0, End: 2, ID: 1, Key: "ab"},
{Start: 1, End: 3, ID: 2, Key: "bc"},
{Start: 3, End: 4, ID: 4, Key: "d"},
{Start: 0, End: 5, ID: 5, Key: "abcde"},
})
}

func testPredictIterAll(t *testing.T, builder predictIteratorBuilder) {
t.Run("single", func(t *testing.T) {
testPredictIterSingle(t, builder)
})
t.Run("multiple", func(t *testing.T) {
testPredictIterMultiple(t, builder)
})
t.Run("basic", func(t *testing.T) {
testPredictIterBasic(t, builder)
})
}

func TestPredictIter(t *testing.T) {
t.Run("dynamic", func(t *testing.T) {
testPredictIterAll(t, func(t *testing.T, keys ...string) predictIterator {
return testDTreePut(t, &trietree.DTree{}, keys...)
})
})
t.Run("static", func(t *testing.T) {
testPredictIterAll(t, func(t *testing.T, keys ...string) predictIterator {
dt := testDTreePut(t, &trietree.DTree{}, keys...)
return trietree.Freeze(dt)
})
})
}

0 comments on commit 1b6d6ca

Please sign in to comment.