Skip to content

Commit

Permalink
refactoring, use generics aggressively
Browse files Browse the repository at this point in the history
  • Loading branch information
koron committed Feb 25, 2024
1 parent fc43b06 commit 29260bd
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 104 deletions.
166 changes: 96 additions & 70 deletions iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,89 +7,115 @@ import (
"unicode/utf8"
)

// Prediction is identifier of a key.
type Prediction struct {
Index int
ID int
Depth int
Label rune
Start int // Start is start index of key in query.
End int // End is end index of key in query.
ID int // ID is for edge node identifier.
}

func (dt *DTree) Predict(s string) iter.Seq[Prediction] {
var (
query = s
idx = 0
pivot = &dt.Root // won't be nil
nextNode = func() (*DNode, int, rune) {
x := idx
r, sz := utf8.DecodeRuneInString(query)
query = query[sz:]
idx += sz
pivot = dt.nextNode(pivot, r)
return pivot, x, r
}
)
// Predict returns an iterator which enumerates Prediction: key suggestions
// that match the query in the tree.
func (dt *DTree) Predict(query string) iter.Seq[Prediction] {
return predict[*DNode](dt, query)
}

return func(yield func(Prediction) bool) {
var (
currNode *DNode = nil
currIdx int
currRune rune
)
for {
if currNode == nil {
if query == "" {
return
}
currNode, currIdx, currRune = nextNode()
//log.Printf("update: cx=%d cr=%c cn=%v", currIdx, currRune, currNode)
}
if currNode.EdgeID > 0 {
//log.Printf("yield: cx=%d cr=%c cn=%v", currIdx, currRune, currNode)
if !yield(Prediction{Index: currIdx, ID: currNode.EdgeID, Depth: currNode.Level, Label: currRune}) {
query = ""
return
}
}
currNode = currNode.Failure
}
// Predict returns an iterator which enumerates Prediction: key suggestions
// that match the query in the tree.
func (st *STree) Predict(query string) iter.Seq[Prediction] {
return predict[int](st, query)
}

type predictableTree[T comparable] interface {
root() T
nextNode(T, rune) T
nodeId(T) int
nodeLevel(T) int
nodeFail(T) T
}

// methods DTree satisfies predictableTree[*DNode]
func (dt *DTree) root() *DNode { return &dt.Root }
func (dt *DTree) nodeId(n *DNode) int { return n.EdgeID }
func (dt *DTree) nodeLevel(n *DNode) int { return n.Level }
func (dt *DTree) nodeFail(n *DNode) *DNode { return n.Failure }

// methods STree satisfies predictableTree[int]
func (st *STree) root() int { return 0 }
func (st *STree) nodeId(n int) int { return st.Nodes[n].EdgeID }
func (st *STree) nodeLevel(n int) int { return st.Levels[st.nodeId(n)-1] }
func (st *STree) nodeFail(n int) int { return st.Nodes[n].Fail }

type traverser[T comparable] struct {
tree predictableTree[T]
query string
pivot T
index int
}

func newTraverser[T comparable](tree predictableTree[T], query string) traverser[T] {
return traverser[T]{
tree: tree,
query: query,
pivot: tree.root(),
index: 0,
}
}

// next consumes a rune from query, and determine next node to travese tree.
// this returns next node, and end index of last parsed rune in query.
func (tr *traverser[T]) next() (node T, end int) {
var zero T
if tr.query == "" {
return zero, 0
}
r, sz := utf8.DecodeRuneInString(tr.query)
if sz == 0 {
return zero, 0
}
tr.query = tr.query[sz:]
tr.index += sz
tr.pivot = tr.tree.nextNode(tr.pivot, r)
return tr.pivot, tr.index
}

func (st *STree) Predict(s string) iter.Seq[Prediction] {
var (
query = s
idx = 0
pivot = 0
nextNode = func() (nodeID int, nodeIdx int, nodeRune rune) {
nodeIdx = idx
nodeRune, sz := utf8.DecodeRuneInString(query)
query = query[sz:]
idx += sz
pivot = st.nextNode(pivot, nodeRune)
return pivot, nodeIdx, nodeRune
func (tr *traverser[T]) close() {
tr.query = ""
}

// trailingIndex returns the index of the n'th character from the end of string s.
func trailingIndex(s string, n int) int {
x := len(s)
for n > 0 && x > 0 {
_, sz := utf8.DecodeLastRuneInString(s[:x])
if sz == 0 {
break
}
)
x -= sz
n--
}
return x
}

func predict[T comparable](tree predictableTree[T], query string) iter.Seq[Prediction] {
var zero T
tr := newTraverser[T](tree, query)
return func(yield func(Prediction) bool) {
var (
currNode = 0
currIdx int
currRune rune
)
for {
if currNode == 0 {
if query == "" {
return
}
currNode, currIdx, currRune = nextNode()
node, end := tr.next()
if node == zero {
return
}
n := st.Nodes[currNode]
if id := n.EdgeID; id > 0 {
if !yield(Prediction{Index: currIdx, ID: id, Depth: st.Levels[id-1], Label: currRune}) {
query = ""
return
for node != zero {
if id := tree.nodeId(node); id > 0 {
st := trailingIndex(query[:end], tree.nodeLevel(node))
if !yield(Prediction{Start: st, End: end, ID: id}) {
tr.close()
return
}
}
node = tree.nodeFail(node)
}
currNode = n.Fail
}
}
}
80 changes: 46 additions & 34 deletions iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,27 @@ import (
"github.com/koron-go/trietree"
)

type prediction struct {
Start int
End int
ID int
Key string
}

type predictor interface {
Predict(string) iter.Seq[trietree.Prediction]
}

func testPredict(t *testing.T, ptor predictor, q string, want []trietree.Prediction) {
func testPredict(t *testing.T, ptor predictor, q string, want []prediction) {
t.Helper()
got := make([]trietree.Prediction, 0, 10)
got := make([]prediction, 0, 10)
for p := range ptor.Predict(q) {
got = append(got, p)
got = append(got, prediction{
Start: p.Start,
End: p.End,
ID: p.ID,
Key: q[p.Start:p.End],
})
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("unexpected predictions: -want +got\n%s", d)
Expand All @@ -29,55 +41,55 @@ type predictorBuilder func(t *testing.T, keys ...string) predictor

func testPredictSingle(t *testing.T, build predictorBuilder) {
ptor := build(t, "1", "2", "3", "4", "5")
testPredict(t, ptor, "1", []trietree.Prediction{
{Index: 0, ID: 1, Depth: 1, Label: '1'},
testPredict(t, ptor, "1", []prediction{
{Start: 0, End: 1, ID: 1, Key: "1"},
})
testPredict(t, ptor, "2", []trietree.Prediction{
{Index: 0, ID: 2, Depth: 1, Label: '2'},
testPredict(t, ptor, "2", []prediction{
{Start: 0, End: 1, ID: 2, Key: "2"},
})
testPredict(t, ptor, "3", []trietree.Prediction{
{Index: 0, ID: 3, Depth: 1, Label: '3'},
testPredict(t, ptor, "3", []prediction{
{Start: 0, End: 1, ID: 3, Key: "3"},
})
testPredict(t, ptor, "4", []trietree.Prediction{
{Index: 0, ID: 4, Depth: 1, Label: '4'},
testPredict(t, ptor, "4", []prediction{
{Start: 0, End: 1, ID: 4, Key: "4"},
})
testPredict(t, ptor, "5", []trietree.Prediction{
{Index: 0, ID: 5, Depth: 1, Label: '5'},
testPredict(t, ptor, "5", []prediction{
{Start: 0, End: 1, ID: 5, Key: "5"},
})
testPredict(t, ptor, "6", []trietree.Prediction{})
testPredict(t, ptor, "6", []prediction{})
}

func testPredictMultiple(t *testing.T, build predictorBuilder) {
ptor := build(t, "1", "2", "3", "4", "5")
testPredict(t, ptor, "1234567890", []trietree.Prediction{
{Index: 0, ID: 1, Depth: 1, Label: '1'},
{Index: 1, ID: 2, Depth: 1, Label: '2'},
{Index: 2, ID: 3, Depth: 1, Label: '3'},
{Index: 3, ID: 4, Depth: 1, Label: '4'},
{Index: 4, ID: 5, Depth: 1, Label: '5'},
testPredict(t, ptor, "1234567890", []prediction{
{Start: 0, End: 1, ID: 1, Key: "1"},
{Start: 1, End: 2, ID: 2, Key: "2"},
{Start: 2, End: 3, ID: 3, Key: "3"},
{Start: 3, End: 4, ID: 4, Key: "4"},
{Start: 4, End: 5, ID: 5, Key: "5"},
})
}

func testPredictBasic(t *testing.T, build predictorBuilder) {
ptor := build(t, "ab", "bc", "bab", "d", "abcde")
testPredict(t, ptor, "ab", []trietree.Prediction{
{Index: 1, ID: 1, Depth: 2, Label: 'b'},
testPredict(t, ptor, "ab", []prediction{
{Start: 0, End: 2, ID: 1, Key: "ab"},
})
testPredict(t, ptor, "bc", []trietree.Prediction{
{Index: 1, ID: 2, Depth: 2, Label: 'c'},
testPredict(t, ptor, "bc", []prediction{
{Start: 0, End: 2, ID: 2, Key: "bc"},
})
testPredict(t, ptor, "bab", []trietree.Prediction{
{Index: 2, ID: 3, Depth: 3, Label: 'b'},
{Index: 2, ID: 1, Depth: 2, Label: 'b'},
testPredict(t, ptor, "bab", []prediction{
{Start: 0, End: 3, ID: 3, Key: "bab"},
{Start: 1, End: 3, ID: 1, Key: "ab"},
})
testPredict(t, ptor, "d", []trietree.Prediction{
{Index: 0, ID: 4, Depth: 1, Label: 'd'},
testPredict(t, ptor, "d", []prediction{
{Start: 0, End: 1, ID: 4, Key: "d"},
})
testPredict(t, ptor, "abcde", []trietree.Prediction{
{Index: 1, ID: 1, Depth: 2, Label: 'b'},
{Index: 2, ID: 2, Depth: 2, Label: 'c'},
{Index: 3, ID: 4, Depth: 1, Label: 'd'},
{Index: 4, ID: 5, Depth: 5, Label: 'e'},
testPredict(t, ptor, "abcde", []prediction{
{Start: 0, End: 2, ID: 1, Key: "ab"},
{Start: 1, End: 3, ID: 2, Key: "bc"},
{Start: 3, End: 4, ID: 4, Key: "d"},
{Start: 0, End: 5, ID: 5, Key: "abcde"},
})
}

Expand Down

0 comments on commit 29260bd

Please sign in to comment.