Skip to content

Commit

Permalink
change name Predict from PredictSeq
Browse files Browse the repository at this point in the history
  • Loading branch information
koron committed Oct 9, 2024
1 parent 9c1b7ac commit 28e9b9e
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 151 deletions.
36 changes: 36 additions & 0 deletions predict.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package trietree

import (
"iter"
"unicode/utf8"
)

Expand Down Expand Up @@ -134,3 +135,38 @@ func predictIter[T comparable](tree predictableTree[T], query string) func() *Pr
return p
}
}

// Predict returns an iterator which enumerates Prediction: key suggestions
// that match the query in the tree.
func (dt *DTree) Predict(query string) iter.Seq[Prediction] {
return predict[*DNode](dt, query)
}

// Predict returns an iterator which enumerates Prediction: key suggestions
// that match the query in the tree.
func (st *STree) Predict(query string) iter.Seq[Prediction] {
return predict[int](st, query)
}

func predict[T comparable](tree predictableTree[T], query string) iter.Seq[Prediction] {
var zero T
tr := newTraverser[T](tree, query)
return func(yield func(Prediction) bool) {
for {
node, end, valid := tr.next()
if !valid {
return
}
for node != zero {
if id := tree.nodeId(node); id > 0 {
st := trailingIndex(query[:end], tree.nodeLevel(node))
if !yield(Prediction{Start: st, End: end, ID: id}) {
tr.close()
return
}
}
node = tree.nodeFail(node)
}
}
}
}
103 changes: 103 additions & 0 deletions predict_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package trietree_test

import (
"iter"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -137,3 +138,105 @@ func TestSTree_PredictMultiple(t *testing.T) {
{Start: 2, End: 3, ID: 4, Key: "d"},
})
}

type predictor interface {
Predict(string) iter.Seq[trietree.Prediction]
}

func testPredict(t *testing.T, ptor predictor, q string, want []prediction) {
t.Helper()
got := make([]prediction, 0, 10)
for p := range ptor.Predict(q) {
got = append(got, prediction{
Start: p.Start,
End: p.End,
ID: p.ID,
Key: q[p.Start:p.End],
})
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("unexpected predictions: -want +got\n%s", d)
}
}

type predictorBuilder func(t *testing.T, keys ...string) predictor

func testPredictSingle(t *testing.T, build predictorBuilder) {
ptor := build(t, "1", "2", "3", "4", "5")
testPredict(t, ptor, "1", []prediction{
{Start: 0, End: 1, ID: 1, Key: "1"},
})
testPredict(t, ptor, "2", []prediction{
{Start: 0, End: 1, ID: 2, Key: "2"},
})
testPredict(t, ptor, "3", []prediction{
{Start: 0, End: 1, ID: 3, Key: "3"},
})
testPredict(t, ptor, "4", []prediction{
{Start: 0, End: 1, ID: 4, Key: "4"},
})
testPredict(t, ptor, "5", []prediction{
{Start: 0, End: 1, ID: 5, Key: "5"},
})
testPredict(t, ptor, "6", []prediction{})
}

func testPredictMultiple(t *testing.T, build predictorBuilder) {
ptor := build(t, "1", "2", "3", "4", "5")
testPredict(t, ptor, "1234567890", []prediction{
{Start: 0, End: 1, ID: 1, Key: "1"},
{Start: 1, End: 2, ID: 2, Key: "2"},
{Start: 2, End: 3, ID: 3, Key: "3"},
{Start: 3, End: 4, ID: 4, Key: "4"},
{Start: 4, End: 5, ID: 5, Key: "5"},
})
}

func testPredictBasic(t *testing.T, build predictorBuilder) {
ptor := build(t, "ab", "bc", "bab", "d", "abcde")
testPredict(t, ptor, "ab", []prediction{
{Start: 0, End: 2, ID: 1, Key: "ab"},
})
testPredict(t, ptor, "bc", []prediction{
{Start: 0, End: 2, ID: 2, Key: "bc"},
})
testPredict(t, ptor, "bab", []prediction{
{Start: 0, End: 3, ID: 3, Key: "bab"},
{Start: 1, End: 3, ID: 1, Key: "ab"},
})
testPredict(t, ptor, "d", []prediction{
{Start: 0, End: 1, ID: 4, Key: "d"},
})
testPredict(t, ptor, "abcde", []prediction{
{Start: 0, End: 2, ID: 1, Key: "ab"},
{Start: 1, End: 3, ID: 2, Key: "bc"},
{Start: 3, End: 4, ID: 4, Key: "d"},
{Start: 0, End: 5, ID: 5, Key: "abcde"},
})
}

func testPredictAll(t *testing.T, builder predictorBuilder) {
t.Run("single", func(t *testing.T) {
testPredictSingle(t, builder)
})
t.Run("multiple", func(t *testing.T) {
testPredictMultiple(t, builder)
})
t.Run("basic", func(t *testing.T) {
testPredictBasic(t, builder)
})
}

func TestPredictSeq(t *testing.T) {
t.Run("dynamic", func(t *testing.T) {
testPredictAll(t, func(t *testing.T, keys ...string) predictor {
return testDTreePut(t, &trietree.DTree{}, keys...)
})
})
t.Run("static", func(t *testing.T) {
testPredictAll(t, func(t *testing.T, keys ...string) predictor {
dt := testDTreePut(t, &trietree.DTree{}, keys...)
return trietree.Freeze(dt)
})
})
}
40 changes: 0 additions & 40 deletions seq.go

This file was deleted.

111 changes: 0 additions & 111 deletions seq_test.go

This file was deleted.

0 comments on commit 28e9b9e

Please sign in to comment.