diff --git a/predict.go b/predict.go index 9118f34..3d47ab2 100644 --- a/predict.go +++ b/predict.go @@ -1,6 +1,7 @@ package trietree import ( + "iter" "unicode/utf8" ) @@ -134,3 +135,38 @@ func predictIter[T comparable](tree predictableTree[T], query string) func() *Pr return p } } + +// Predict returns an iterator which enumerates Prediction: key suggestions +// that match the query in the tree. +func (dt *DTree) Predict(query string) iter.Seq[Prediction] { + return predict[*DNode](dt, query) +} + +// Predict returns an iterator which enumerates Prediction: key suggestions +// that match the query in the tree. +func (st *STree) Predict(query string) iter.Seq[Prediction] { + return predict[int](st, query) +} + +func predict[T comparable](tree predictableTree[T], query string) iter.Seq[Prediction] { + var zero T + tr := newTraverser[T](tree, query) + return func(yield func(Prediction) bool) { + for { + node, end, valid := tr.next() + if !valid { + return + } + for node != zero { + if id := tree.nodeId(node); id > 0 { + st := trailingIndex(query[:end], tree.nodeLevel(node)) + if !yield(Prediction{Start: st, End: end, ID: id}) { + tr.close() + return + } + } + node = tree.nodeFail(node) + } + } + } +} diff --git a/predict_test.go b/predict_test.go index 0744fef..c70093c 100644 --- a/predict_test.go +++ b/predict_test.go @@ -1,6 +1,7 @@ package trietree_test import ( + "iter" "testing" "github.com/google/go-cmp/cmp" @@ -137,3 +138,105 @@ func TestSTree_PredictMultiple(t *testing.T) { {Start: 2, End: 3, ID: 4, Key: "d"}, }) } + +type predictor interface { + Predict(string) iter.Seq[trietree.Prediction] +} + +func testPredict(t *testing.T, ptor predictor, q string, want []prediction) { + t.Helper() + got := make([]prediction, 0, 10) + for p := range ptor.Predict(q) { + got = append(got, prediction{ + Start: p.Start, + End: p.End, + ID: p.ID, + Key: q[p.Start:p.End], + }) + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("unexpected predictions: -want +got\n%s", d) + } +} + +type predictorBuilder func(t *testing.T, keys ...string) predictor + +func testPredictSingle(t *testing.T, build predictorBuilder) { + ptor := build(t, "1", "2", "3", "4", "5") + testPredict(t, ptor, "1", []prediction{ + {Start: 0, End: 1, ID: 1, Key: "1"}, + }) + testPredict(t, ptor, "2", []prediction{ + {Start: 0, End: 1, ID: 2, Key: "2"}, + }) + testPredict(t, ptor, "3", []prediction{ + {Start: 0, End: 1, ID: 3, Key: "3"}, + }) + testPredict(t, ptor, "4", []prediction{ + {Start: 0, End: 1, ID: 4, Key: "4"}, + }) + testPredict(t, ptor, "5", []prediction{ + {Start: 0, End: 1, ID: 5, Key: "5"}, + }) + testPredict(t, ptor, "6", []prediction{}) +} + +func testPredictMultiple(t *testing.T, build predictorBuilder) { + ptor := build(t, "1", "2", "3", "4", "5") + testPredict(t, ptor, "1234567890", []prediction{ + {Start: 0, End: 1, ID: 1, Key: "1"}, + {Start: 1, End: 2, ID: 2, Key: "2"}, + {Start: 2, End: 3, ID: 3, Key: "3"}, + {Start: 3, End: 4, ID: 4, Key: "4"}, + {Start: 4, End: 5, ID: 5, Key: "5"}, + }) +} + +func testPredictBasic(t *testing.T, build predictorBuilder) { + ptor := build(t, "ab", "bc", "bab", "d", "abcde") + testPredict(t, ptor, "ab", []prediction{ + {Start: 0, End: 2, ID: 1, Key: "ab"}, + }) + testPredict(t, ptor, "bc", []prediction{ + {Start: 0, End: 2, ID: 2, Key: "bc"}, + }) + testPredict(t, ptor, "bab", []prediction{ + {Start: 0, End: 3, ID: 3, Key: "bab"}, + {Start: 1, End: 3, ID: 1, Key: "ab"}, + }) + testPredict(t, ptor, "d", []prediction{ + {Start: 0, End: 1, ID: 4, Key: "d"}, + }) + testPredict(t, ptor, "abcde", []prediction{ + {Start: 0, End: 2, ID: 1, Key: "ab"}, + {Start: 1, End: 3, ID: 2, Key: "bc"}, + {Start: 3, End: 4, ID: 4, Key: "d"}, + {Start: 0, End: 5, ID: 5, Key: "abcde"}, + }) +} + +func testPredictAll(t *testing.T, builder predictorBuilder) { + t.Run("single", func(t *testing.T) { + testPredictSingle(t, builder) + }) + t.Run("multiple", func(t *testing.T) { + testPredictMultiple(t, builder) + }) + t.Run("basic", func(t *testing.T) { + testPredictBasic(t, builder) + }) +} + +func TestPredictSeq(t *testing.T) { + t.Run("dynamic", func(t *testing.T) { + testPredictAll(t, func(t *testing.T, keys ...string) predictor { + return testDTreePut(t, &trietree.DTree{}, keys...) + }) + }) + t.Run("static", func(t *testing.T) { + testPredictAll(t, func(t *testing.T, keys ...string) predictor { + dt := testDTreePut(t, &trietree.DTree{}, keys...) + return trietree.Freeze(dt) + }) + }) +} diff --git a/seq.go b/seq.go deleted file mode 100644 index bd01769..0000000 --- a/seq.go +++ /dev/null @@ -1,40 +0,0 @@ -package trietree - -import ( - "iter" -) - -// Predict returns an iterator which enumerates Prediction: key suggestions -// that match the query in the tree. -func (dt *DTree) PredictSeq(query string) iter.Seq[Prediction] { - return predict[*DNode](dt, query) -} - -// Predict returns an iterator which enumerates Prediction: key suggestions -// that match the query in the tree. -func (st *STree) PredictSeq(query string) iter.Seq[Prediction] { - return predict[int](st, query) -} - -func predict[T comparable](tree predictableTree[T], query string) iter.Seq[Prediction] { - var zero T - tr := newTraverser[T](tree, query) - return func(yield func(Prediction) bool) { - for { - node, end, valid := tr.next() - if !valid { - return - } - for node != zero { - if id := tree.nodeId(node); id > 0 { - st := trailingIndex(query[:end], tree.nodeLevel(node)) - if !yield(Prediction{Start: st, End: end, ID: id}) { - tr.close() - return - } - } - node = tree.nodeFail(node) - } - } - } -} diff --git a/seq_test.go b/seq_test.go deleted file mode 100644 index 362f16b..0000000 --- a/seq_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package trietree_test - -import ( - "iter" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/koron-go/trietree" -) - -type predictor interface { - PredictSeq(string) iter.Seq[trietree.Prediction] -} - -func testPredict(t *testing.T, ptor predictor, q string, want []prediction) { - t.Helper() - got := make([]prediction, 0, 10) - for p := range ptor.PredictSeq(q) { - got = append(got, prediction{ - Start: p.Start, - End: p.End, - ID: p.ID, - Key: q[p.Start:p.End], - }) - } - if d := cmp.Diff(want, got); d != "" { - t.Errorf("unexpected predictions: -want +got\n%s", d) - } -} - -type predictorBuilder func(t *testing.T, keys ...string) predictor - -func testPredictSingle(t *testing.T, build predictorBuilder) { - ptor := build(t, "1", "2", "3", "4", "5") - testPredict(t, ptor, "1", []prediction{ - {Start: 0, End: 1, ID: 1, Key: "1"}, - }) - testPredict(t, ptor, "2", []prediction{ - {Start: 0, End: 1, ID: 2, Key: "2"}, - }) - testPredict(t, ptor, "3", []prediction{ - {Start: 0, End: 1, ID: 3, Key: "3"}, - }) - testPredict(t, ptor, "4", []prediction{ - {Start: 0, End: 1, ID: 4, Key: "4"}, - }) - testPredict(t, ptor, "5", []prediction{ - {Start: 0, End: 1, ID: 5, Key: "5"}, - }) - testPredict(t, ptor, "6", []prediction{}) -} - -func testPredictMultiple(t *testing.T, build predictorBuilder) { - ptor := build(t, "1", "2", "3", "4", "5") - testPredict(t, ptor, "1234567890", []prediction{ - {Start: 0, End: 1, ID: 1, Key: "1"}, - {Start: 1, End: 2, ID: 2, Key: "2"}, - {Start: 2, End: 3, ID: 3, Key: "3"}, - {Start: 3, End: 4, ID: 4, Key: "4"}, - {Start: 4, End: 5, ID: 5, Key: "5"}, - }) -} - -func testPredictBasic(t *testing.T, build predictorBuilder) { - ptor := build(t, "ab", "bc", "bab", "d", "abcde") - testPredict(t, ptor, "ab", []prediction{ - {Start: 0, End: 2, ID: 1, Key: "ab"}, - }) - testPredict(t, ptor, "bc", []prediction{ - {Start: 0, End: 2, ID: 2, Key: "bc"}, - }) - testPredict(t, ptor, "bab", []prediction{ - {Start: 0, End: 3, ID: 3, Key: "bab"}, - {Start: 1, End: 3, ID: 1, Key: "ab"}, - }) - testPredict(t, ptor, "d", []prediction{ - {Start: 0, End: 1, ID: 4, Key: "d"}, - }) - testPredict(t, ptor, "abcde", []prediction{ - {Start: 0, End: 2, ID: 1, Key: "ab"}, - {Start: 1, End: 3, ID: 2, Key: "bc"}, - {Start: 3, End: 4, ID: 4, Key: "d"}, - {Start: 0, End: 5, ID: 5, Key: "abcde"}, - }) -} - -func testPredictAll(t *testing.T, builder predictorBuilder) { - t.Run("single", func(t *testing.T) { - testPredictSingle(t, builder) - }) - t.Run("multiple", func(t *testing.T) { - testPredictMultiple(t, builder) - }) - t.Run("basic", func(t *testing.T) { - testPredictBasic(t, builder) - }) -} - -func TestPredictSeq(t *testing.T) { - t.Run("dynamic", func(t *testing.T) { - testPredictAll(t, func(t *testing.T, keys ...string) predictor { - return testDTreePut(t, &trietree.DTree{}, keys...) - }) - }) - t.Run("static", func(t *testing.T) { - testPredictAll(t, func(t *testing.T, keys ...string) predictor { - dt := testDTreePut(t, &trietree.DTree{}, keys...) - return trietree.Freeze(dt) - }) - }) -}