Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
koron committed May 18, 2024
1 parent 3d28336 commit 747a019
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 4 deletions.
23 changes: 19 additions & 4 deletions trie2/trie2.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func (dt *DTrie[T]) Put(k string, v T) {
dt.values = append(dt.values, v)
return
}
// update an existed value.
dt.values[id-1] = v
}

Expand All @@ -44,36 +45,50 @@ func (dt *DTrie[T]) Freeze(copyValues bool) *STrie[T] {
tree := trietree.Freeze(&dt.tree)
var values []T
if copyValues {
values = make([]T, len(values))
values = make([]T, len(dt.values))
copy(values, dt.values)
} else {
values = dt.values
}
return &STrie[T]{tree: *tree, values: values}
}

func (st *STrie[T]) Marshal(w io.Writer) error {
func (st *STrie[T]) Marshal(w io.Writer, marshalValues func(io.Writer, []T) error) error {
if len(st.values) != len(st.tree.Levels) {
return fmt.Errorf("number of values and levels unmatched: value=%d levels=%d", len(st.values), len(st.tree.Levels))
}
if err := st.tree.Write(w); err != nil {
return err
}
if marshalValues != nil {
if err := marshalValues(w, st.values); err != nil {
return fmt.Errorf("failed to marshal values: %w", err)
}
return nil
}
if err := gob.NewEncoder(w).Encode(st.values); err != nil {
return err
}
return nil
}

func Unmarshal[T any](r io.Reader) (*STrie[T], error) {
func Unmarshal[T any](r io.Reader, unmarshalValues func(io.Reader, int) ([]T, error)) (*STrie[T], error) {
tree, err := trietree.Read(r)
if err != nil {
return nil, err
}
if len(tree.Levels) == 0 {
return &STrie[T]{tree: *tree}, nil
}
// read v from r then append it to values.
// read values from r with unmarshalValues.
if unmarshalValues != nil {
values, err := unmarshalValues(r, len(tree.Levels))
if err != nil {
return nil, err
}
return &STrie[T]{tree: *tree, values: values}, nil
}
// read values from r without unmarshalValues.
values := make([]T, 0, len(tree.Levels))
if err := gob.NewDecoder(r).Decode(&values); err != nil {
return nil, err
Expand Down
169 changes: 169 additions & 0 deletions trie2/trie2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package trie2

import (
"bytes"
"encoding/json"
"io"
"testing"

"github.com/google/go-cmp/cmp"
)

type Data struct {
N int
S string
}

func TestMarshal0(t *testing.T) {
dt := DTrie[Data]{}
bb := &bytes.Buffer{}
if err := dt.Freeze(false).Marshal(bb, nil); err != nil {
t.Fatalf("failed to marshal: %s", err)
}
st, err := Unmarshal[Data](bb, nil)
if err != nil {
t.Fatalf("failed to unmarshal: %s", err)
}
if d := cmp.Diff(dt.values, st.values); d != "" {
t.Errorf("failed unmarshal values: -want +got\n%s", d)
}
}

func TestMarshalValue(t *testing.T) {
dt := DTrie[Data]{}
dt.Put("a", Data{111, "aaa"})
dt.Put("ab", Data{222, "bbb"})
dt.Put("abc", Data{333, "ccc"})
dt.Put("d", Data{444, "ddd"})
dt.Put("de", Data{555, "eee"})
bb := &bytes.Buffer{}
if err := dt.Freeze(false).Marshal(bb, nil); err != nil {
t.Fatalf("failed to marshal: %s", err)
}
st, err := Unmarshal[Data](bb, nil)
if err != nil {
t.Fatalf("failed to unmarshal: %s", err)
}
if d := cmp.Diff(dt.values, st.values); d != "" {
t.Errorf("failed unmarshal values: -want +got\n%s", d)
}
}

func TestMarshalPointer(t *testing.T) {
dt := DTrie[*Data]{}
dt.Put("a", &Data{111, "aaa"})
dt.Put("ab", &Data{222, "bbb"})
dt.Put("abc", &Data{333, "ccc"})
dt.Put("d", &Data{444, "ddd"})
dt.Put("de", &Data{555, "eee"})
bb := &bytes.Buffer{}
if err := dt.Freeze(false).Marshal(bb, nil); err != nil {
t.Fatalf("failed to marshal: %s", err)
}
st, err := Unmarshal[*Data](bb, nil)
if err != nil {
t.Fatalf("failed to unmarshal: %s", err)
}
if d := cmp.Diff(dt.values, st.values); d != "" {
t.Errorf("failed unmarshal values: -want +got\n%s", d)
}
}

type DataJSON struct {
N int `json:"n"`
S string `json:"s"`
}

func TestMarshalCustom(t *testing.T) {
dt := DTrie[DataJSON]{}
dt.Put("a", DataJSON{111, "aaa"})
dt.Put("ab", DataJSON{222, "bbb"})
dt.Put("abc", DataJSON{333, "ccc"})
dt.Put("d", DataJSON{444, "ddd"})
dt.Put("de", DataJSON{555, "eee"})
bb := &bytes.Buffer{}
err := dt.Freeze(true).Marshal(bb, func(w io.Writer, values []DataJSON) error {
return json.NewEncoder(w).Encode(values)
})
if err != nil {
t.Fatalf("failed to marshal: %s", err)
}
st, err := Unmarshal[DataJSON](bb, func(r io.Reader, n int) ([]DataJSON, error) {
values := make([]DataJSON, 0, n)
if err := json.NewDecoder(r).Decode(&values); err != nil {
return nil, err
}
return values, nil
})
if err != nil {
t.Fatalf("failed to unmarshal: %s", err)
}
if d := cmp.Diff(dt.values, st.values); d != "" {
t.Errorf("failed unmarshal values: -want +got\n%s", d)
}
}

func TestLongestPrefixDTree(t *testing.T) {
dt := DTrie[Data]{}
dt.Put("a", Data{111, "aaa"})
dt.Put("ab", Data{222, "bbb"})
dt.Put("abc", Data{333, "ccc"})
dt.Put("d", Data{444, "ddd"})
dt.Put("de", Data{555, "eee"})
for i, c := range []struct {
query string
wantV Data
wantP string
wantF bool
}{
{"az", Data{111, "aaa"}, "a", true},
{"za", Data{}, "", false},
{"abcde", Data{333, "ccc"}, "abc", true},
{"ababc", Data{222, "bbb"}, "ab", true},
} {
gotV, gotP, gotF := dt.LongestPrefix(c.query)
if gotF != c.wantF {
t.Errorf("existence unmatch #%d: want=%t got=%t", i, c.wantF, gotF)
continue
}
if gotP != c.wantP {
t.Errorf("prefix unmatch #%d: want=%s got=%s", i, c.wantP, gotP)
}
if d := cmp.Diff(c.wantV, gotV); d != "" {
t.Errorf("values unmatch #%d: -want +got\n%s", i, d)
}
}
}

func TestLongestPrefixSTree(t *testing.T) {
dt := DTrie[Data]{}
dt.Put("a", Data{111, "aaa"})
dt.Put("ab", Data{222, "bbb"})
dt.Put("abc", Data{333, "ccc"})
dt.Put("d", Data{444, "ddd"})
dt.Put("de", Data{555, "eee"})
st := dt.Freeze(false)
for i, c := range []struct {
query string
wantV Data
wantP string
wantF bool
}{
{"az", Data{111, "aaa"}, "a", true},
{"za", Data{}, "", false},
{"abcde", Data{333, "ccc"}, "abc", true},
{"ababc", Data{222, "bbb"}, "ab", true},
} {
gotV, gotP, gotF := st.LongestPrefix(c.query)
if gotF != c.wantF {
t.Errorf("existence unmatch #%d: want=%t got=%t", i, c.wantF, gotF)
continue
}
if gotP != c.wantP {
t.Errorf("prefix unmatch #%d: want=%s got=%s", i, c.wantP, gotP)
}
if d := cmp.Diff(c.wantV, gotV); d != "" {
t.Errorf("values unmatch #%d: -want +got\n%s", i, d)
}
}
}

0 comments on commit 747a019

Please sign in to comment.