diff --git a/iter.go b/iter.go index a4e2a15..2d5c0ef 100644 --- a/iter.go +++ b/iter.go @@ -4,16 +4,8 @@ package trietree import ( "iter" - "unicode/utf8" ) -// Prediction is identifier of a key. -type Prediction struct { - Start int // Start is start index of key in query. - End int // End is end index of key in query. - ID int // ID is for edge node identifier. -} - // Predict returns an iterator which enumerates Prediction: key suggestions // that match the query in the tree. func (dt *DTree) Predict(query string) iter.Seq[Prediction] { @@ -26,77 +18,6 @@ func (st *STree) Predict(query string) iter.Seq[Prediction] { return predict[int](st, query) } -type predictableTree[T comparable] interface { - root() T - nextNode(T, rune) T - nodeId(T) int - nodeLevel(T) int - nodeFail(T) T -} - -// methods DTree satisfies predictableTree[*DNode] -func (dt *DTree) root() *DNode { return &dt.Root } -func (dt *DTree) nodeId(n *DNode) int { return n.EdgeID } -func (dt *DTree) nodeLevel(n *DNode) int { return n.Level } -func (dt *DTree) nodeFail(n *DNode) *DNode { return n.Failure } - -// methods STree satisfies predictableTree[int] -func (st *STree) root() int { return 0 } -func (st *STree) nodeId(n int) int { return st.Nodes[n].EdgeID } -func (st *STree) nodeLevel(n int) int { return st.Levels[st.nodeId(n)-1] } -func (st *STree) nodeFail(n int) int { return st.Nodes[n].Fail } - -type traverser[T comparable] struct { - tree predictableTree[T] - query string - pivot T - index int -} - -func newTraverser[T comparable](tree predictableTree[T], query string) traverser[T] { - return traverser[T]{ - tree: tree, - query: query, - pivot: tree.root(), - index: 0, - } -} - -// next consumes a rune from query, and determine next node to travese tree. -// this returns next node, and end index of last parsed rune in query. -func (tr *traverser[T]) next() (node T, end int) { - var zero T - if tr.query == "" { - return zero, 0 - } - r, sz := utf8.DecodeRuneInString(tr.query) - if sz == 0 { - return zero, 0 - } - tr.query = tr.query[sz:] - tr.index += sz - tr.pivot = tr.tree.nextNode(tr.pivot, r) - return tr.pivot, tr.index -} - -func (tr *traverser[T]) close() { - tr.query = "" -} - -// trailingIndex returns the index of the n'th character from the end of string s. -func trailingIndex(s string, n int) int { - x := len(s) - for n > 0 && x > 0 { - _, sz := utf8.DecodeLastRuneInString(s[:x]) - if sz == 0 { - break - } - x -= sz - n-- - } - return x -} - func predict[T comparable](tree predictableTree[T], query string) iter.Seq[Prediction] { var zero T tr := newTraverser[T](tree, query) diff --git a/iter_test.go b/iter_test.go index 110d3b1..7bffea8 100644 --- a/iter_test.go +++ b/iter_test.go @@ -10,13 +10,6 @@ import ( "github.com/koron-go/trietree" ) -type prediction struct { - Start int - End int - ID int - Key string -} - type predictor interface { Predict(string) iter.Seq[trietree.Prediction] } diff --git a/predict.go b/predict.go new file mode 100644 index 0000000..971571c --- /dev/null +++ b/predict.go @@ -0,0 +1,124 @@ +package trietree + +import ( + "unicode/utf8" +) + +// Prediction is identifier of a key. +type Prediction struct { + Start int // Start is start index of key in query. + End int // End is end index of key in query. + ID int // ID is for edge node identifier. +} + +// PredictIter returns an iterator function which enumerates Prediction: key +// suggestions that match the query in the tree. +func (dt *DTree) PredictIter(query string) func() *Prediction { + return predictIter[*DNode](dt, query) +} + +// PredictIter returns an iterator function which enumerates Prediction: key +// suggestions that match the query in the tree. +func (st *STree) PredictIter(query string) func() *Prediction { + return predictIter[int](st, query) +} + +type predictableTree[T comparable] interface { + root() T + nextNode(T, rune) T + nodeId(T) int + nodeLevel(T) int + nodeFail(T) T +} + +// methods DTree satisfies predictableTree[*DNode] +func (dt *DTree) root() *DNode { return &dt.Root } +func (dt *DTree) nodeId(n *DNode) int { return n.EdgeID } +func (dt *DTree) nodeLevel(n *DNode) int { return n.Level } +func (dt *DTree) nodeFail(n *DNode) *DNode { return n.Failure } + +// methods STree satisfies predictableTree[int] +func (st *STree) root() int { return 0 } +func (st *STree) nodeId(n int) int { return st.Nodes[n].EdgeID } +func (st *STree) nodeLevel(n int) int { return st.Levels[st.nodeId(n)-1] } +func (st *STree) nodeFail(n int) int { return st.Nodes[n].Fail } + +type traverser[T comparable] struct { + tree predictableTree[T] + query string + pivot T + index int +} + +func newTraverser[T comparable](tree predictableTree[T], query string) traverser[T] { + return traverser[T]{ + tree: tree, + query: query, + pivot: tree.root(), + index: 0, + } +} + +// next consumes a rune from query, and determine next node to travese tree. +// this returns next node, and end index of last parsed rune in query. +func (tr *traverser[T]) next() (node T, end int) { + var zero T + if tr.query == "" { + return zero, 0 + } + r, sz := utf8.DecodeRuneInString(tr.query) + if sz == 0 { + return zero, 0 + } + tr.query = tr.query[sz:] + tr.index += sz + tr.pivot = tr.tree.nextNode(tr.pivot, r) + return tr.pivot, tr.index +} + +func (tr *traverser[T]) close() { + tr.query = "" +} + +// trailingIndex returns the index of the n'th character from the end of string s. +func trailingIndex(s string, n int) int { + x := len(s) + for n > 0 && x > 0 { + _, sz := utf8.DecodeLastRuneInString(s[:x]) + if sz == 0 { + break + } + x -= sz + n-- + } + return x +} + +func predictIter[T comparable](tree predictableTree[T], query string) func() *Prediction { + var ( + zero T + tr = newTraverser[T](tree, query) + node T + end int + ) + return func() *Prediction { + var p *Prediction + for p == nil { + if node == zero { + node, end = tr.next() + if node == zero { + tr.close() + return nil + } + } + for node != zero && p == nil { + if id := tree.nodeId(node); id > 0 { + st := trailingIndex(query[:end], tree.nodeLevel(node)) + p = &Prediction{Start: st, End: end, ID: id} + } + node = tree.nodeFail(node) + } + } + return p + } +} diff --git a/predict_test.go b/predict_test.go new file mode 100644 index 0000000..b29ecda --- /dev/null +++ b/predict_test.go @@ -0,0 +1,122 @@ +package trietree_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/koron-go/trietree" +) + +type prediction struct { + Start int + End int + ID int + Key string +} + +type predictIterator interface { + PredictIter(string) func() *trietree.Prediction +} + +func testPredictIter(t *testing.T, ptor predictIterator, q string, want []prediction) { + t.Helper() + got := make([]prediction, 0, 10) + iter := ptor.PredictIter(q) + for { + p := iter() + if p == nil { + break + } + got = append(got, prediction{ + Start: p.Start, + End: p.End, + ID: p.ID, + Key: q[p.Start:p.End], + }) + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("unexpected predictions: -want +got\n%s", d) + } +} + +type predictIteratorBuilder func(t *testing.T, keys ...string) predictIterator + +func testPredictIterSingle(t *testing.T, build predictIteratorBuilder) { + ptor := build(t, "1", "2", "3", "4", "5") + testPredictIter(t, ptor, "1", []prediction{ + {Start: 0, End: 1, ID: 1, Key: "1"}, + }) + testPredictIter(t, ptor, "2", []prediction{ + {Start: 0, End: 1, ID: 2, Key: "2"}, + }) + testPredictIter(t, ptor, "3", []prediction{ + {Start: 0, End: 1, ID: 3, Key: "3"}, + }) + testPredictIter(t, ptor, "4", []prediction{ + {Start: 0, End: 1, ID: 4, Key: "4"}, + }) + testPredictIter(t, ptor, "5", []prediction{ + {Start: 0, End: 1, ID: 5, Key: "5"}, + }) + testPredictIter(t, ptor, "6", []prediction{}) +} + +func testPredictIterMultiple(t *testing.T, build predictIteratorBuilder) { + ptor := build(t, "1", "2", "3", "4", "5") + testPredictIter(t, ptor, "1234567890", []prediction{ + {Start: 0, End: 1, ID: 1, Key: "1"}, + {Start: 1, End: 2, ID: 2, Key: "2"}, + {Start: 2, End: 3, ID: 3, Key: "3"}, + {Start: 3, End: 4, ID: 4, Key: "4"}, + {Start: 4, End: 5, ID: 5, Key: "5"}, + }) +} + +func testPredictIterBasic(t *testing.T, build predictIteratorBuilder) { + ptor := build(t, "ab", "bc", "bab", "d", "abcde") + testPredictIter(t, ptor, "ab", []prediction{ + {Start: 0, End: 2, ID: 1, Key: "ab"}, + }) + testPredictIter(t, ptor, "bc", []prediction{ + {Start: 0, End: 2, ID: 2, Key: "bc"}, + }) + testPredictIter(t, ptor, "bab", []prediction{ + {Start: 0, End: 3, ID: 3, Key: "bab"}, + {Start: 1, End: 3, ID: 1, Key: "ab"}, + }) + testPredictIter(t, ptor, "d", []prediction{ + {Start: 0, End: 1, ID: 4, Key: "d"}, + }) + testPredictIter(t, ptor, "abcde", []prediction{ + {Start: 0, End: 2, ID: 1, Key: "ab"}, + {Start: 1, End: 3, ID: 2, Key: "bc"}, + {Start: 3, End: 4, ID: 4, Key: "d"}, + {Start: 0, End: 5, ID: 5, Key: "abcde"}, + }) +} + +func testPredictIterAll(t *testing.T, builder predictIteratorBuilder) { + t.Run("single", func(t *testing.T) { + testPredictIterSingle(t, builder) + }) + t.Run("multiple", func(t *testing.T) { + testPredictIterMultiple(t, builder) + }) + t.Run("basic", func(t *testing.T) { + testPredictIterBasic(t, builder) + }) +} + +func TestPredictIter(t *testing.T) { + t.Run("dynamic", func(t *testing.T) { + testPredictIterAll(t, func(t *testing.T, keys ...string) predictIterator { + return testDTreePut(t, &trietree.DTree{}, keys...) + }) + }) + t.Run("static", func(t *testing.T) { + testPredictIterAll(t, func(t *testing.T, keys ...string) predictIterator { + dt := testDTreePut(t, &trietree.DTree{}, keys...) + return trietree.Freeze(dt) + }) + }) +}