diff --git a/trie2/trie2.go b/trie2/trie2.go index ec36b33..a8c9230 100644 --- a/trie2/trie2.go +++ b/trie2/trie2.go @@ -23,6 +23,7 @@ func (dt *DTrie[T]) Put(k string, v T) { dt.values = append(dt.values, v) return } + // update an existed value. dt.values[id-1] = v } @@ -44,7 +45,7 @@ func (dt *DTrie[T]) Freeze(copyValues bool) *STrie[T] { tree := trietree.Freeze(&dt.tree) var values []T if copyValues { - values = make([]T, len(values)) + values = make([]T, len(dt.values)) copy(values, dt.values) } else { values = dt.values @@ -52,20 +53,26 @@ func (dt *DTrie[T]) Freeze(copyValues bool) *STrie[T] { return &STrie[T]{tree: *tree, values: values} } -func (st *STrie[T]) Marshal(w io.Writer) error { +func (st *STrie[T]) Marshal(w io.Writer, marshalValues func(io.Writer, []T) error) error { if len(st.values) != len(st.tree.Levels) { return fmt.Errorf("number of values and levels unmatched: value=%d levels=%d", len(st.values), len(st.tree.Levels)) } if err := st.tree.Write(w); err != nil { return err } + if marshalValues != nil { + if err := marshalValues(w, st.values); err != nil { + return fmt.Errorf("failed to marshal values: %w", err) + } + return nil + } if err := gob.NewEncoder(w).Encode(st.values); err != nil { return err } return nil } -func Unmarshal[T any](r io.Reader) (*STrie[T], error) { +func Unmarshal[T any](r io.Reader, unmarshalValues func(io.Reader, int) ([]T, error)) (*STrie[T], error) { tree, err := trietree.Read(r) if err != nil { return nil, err @@ -73,7 +80,15 @@ func Unmarshal[T any](r io.Reader) (*STrie[T], error) { if len(tree.Levels) == 0 { return &STrie[T]{tree: *tree}, nil } - // read v from r then append it to values. + // read values from r with unmarshalValues. + if unmarshalValues != nil { + values, err := unmarshalValues(r, len(tree.Levels)) + if err != nil { + return nil, err + } + return &STrie[T]{tree: *tree, values: values}, nil + } + // read values from r without unmarshalValues. values := make([]T, 0, len(tree.Levels)) if err := gob.NewDecoder(r).Decode(&values); err != nil { return nil, err diff --git a/trie2/trie2_test.go b/trie2/trie2_test.go new file mode 100644 index 0000000..bb88115 --- /dev/null +++ b/trie2/trie2_test.go @@ -0,0 +1,169 @@ +package trie2 + +import ( + "bytes" + "encoding/json" + "io" + "testing" + + "github.com/google/go-cmp/cmp" +) + +type Data struct { + N int + S string +} + +func TestMarshal0(t *testing.T) { + dt := DTrie[Data]{} + bb := &bytes.Buffer{} + if err := dt.Freeze(false).Marshal(bb, nil); err != nil { + t.Fatalf("failed to marshal: %s", err) + } + st, err := Unmarshal[Data](bb, nil) + if err != nil { + t.Fatalf("failed to unmarshal: %s", err) + } + if d := cmp.Diff(dt.values, st.values); d != "" { + t.Errorf("failed unmarshal values: -want +got\n%s", d) + } +} + +func TestMarshalValue(t *testing.T) { + dt := DTrie[Data]{} + dt.Put("a", Data{111, "aaa"}) + dt.Put("ab", Data{222, "bbb"}) + dt.Put("abc", Data{333, "ccc"}) + dt.Put("d", Data{444, "ddd"}) + dt.Put("de", Data{555, "eee"}) + bb := &bytes.Buffer{} + if err := dt.Freeze(false).Marshal(bb, nil); err != nil { + t.Fatalf("failed to marshal: %s", err) + } + st, err := Unmarshal[Data](bb, nil) + if err != nil { + t.Fatalf("failed to unmarshal: %s", err) + } + if d := cmp.Diff(dt.values, st.values); d != "" { + t.Errorf("failed unmarshal values: -want +got\n%s", d) + } +} + +func TestMarshalPointer(t *testing.T) { + dt := DTrie[*Data]{} + dt.Put("a", &Data{111, "aaa"}) + dt.Put("ab", &Data{222, "bbb"}) + dt.Put("abc", &Data{333, "ccc"}) + dt.Put("d", &Data{444, "ddd"}) + dt.Put("de", &Data{555, "eee"}) + bb := &bytes.Buffer{} + if err := dt.Freeze(false).Marshal(bb, nil); err != nil { + t.Fatalf("failed to marshal: %s", err) + } + st, err := Unmarshal[*Data](bb, nil) + if err != nil { + t.Fatalf("failed to unmarshal: %s", err) + } + if d := cmp.Diff(dt.values, st.values); d != "" { + t.Errorf("failed unmarshal values: -want +got\n%s", d) + } +} + +type DataJSON struct { + N int `json:"n"` + S string `json:"s"` +} + +func TestMarshalCustom(t *testing.T) { + dt := DTrie[DataJSON]{} + dt.Put("a", DataJSON{111, "aaa"}) + dt.Put("ab", DataJSON{222, "bbb"}) + dt.Put("abc", DataJSON{333, "ccc"}) + dt.Put("d", DataJSON{444, "ddd"}) + dt.Put("de", DataJSON{555, "eee"}) + bb := &bytes.Buffer{} + err := dt.Freeze(true).Marshal(bb, func(w io.Writer, values []DataJSON) error { + return json.NewEncoder(w).Encode(values) + }) + if err != nil { + t.Fatalf("failed to marshal: %s", err) + } + st, err := Unmarshal[DataJSON](bb, func(r io.Reader, n int) ([]DataJSON, error) { + values := make([]DataJSON, 0, n) + if err := json.NewDecoder(r).Decode(&values); err != nil { + return nil, err + } + return values, nil + }) + if err != nil { + t.Fatalf("failed to unmarshal: %s", err) + } + if d := cmp.Diff(dt.values, st.values); d != "" { + t.Errorf("failed unmarshal values: -want +got\n%s", d) + } +} + +func TestLongestPrefixDTree(t *testing.T) { + dt := DTrie[Data]{} + dt.Put("a", Data{111, "aaa"}) + dt.Put("ab", Data{222, "bbb"}) + dt.Put("abc", Data{333, "ccc"}) + dt.Put("d", Data{444, "ddd"}) + dt.Put("de", Data{555, "eee"}) + for i, c := range []struct { + query string + wantV Data + wantP string + wantF bool + }{ + {"az", Data{111, "aaa"}, "a", true}, + {"za", Data{}, "", false}, + {"abcde", Data{333, "ccc"}, "abc", true}, + {"ababc", Data{222, "bbb"}, "ab", true}, + } { + gotV, gotP, gotF := dt.LongestPrefix(c.query) + if gotF != c.wantF { + t.Errorf("existence unmatch #%d: want=%t got=%t", i, c.wantF, gotF) + continue + } + if gotP != c.wantP { + t.Errorf("prefix unmatch #%d: want=%s got=%s", i, c.wantP, gotP) + } + if d := cmp.Diff(c.wantV, gotV); d != "" { + t.Errorf("values unmatch #%d: -want +got\n%s", i, d) + } + } +} + +func TestLongestPrefixSTree(t *testing.T) { + dt := DTrie[Data]{} + dt.Put("a", Data{111, "aaa"}) + dt.Put("ab", Data{222, "bbb"}) + dt.Put("abc", Data{333, "ccc"}) + dt.Put("d", Data{444, "ddd"}) + dt.Put("de", Data{555, "eee"}) + st := dt.Freeze(false) + for i, c := range []struct { + query string + wantV Data + wantP string + wantF bool + }{ + {"az", Data{111, "aaa"}, "a", true}, + {"za", Data{}, "", false}, + {"abcde", Data{333, "ccc"}, "abc", true}, + {"ababc", Data{222, "bbb"}, "ab", true}, + } { + gotV, gotP, gotF := st.LongestPrefix(c.query) + if gotF != c.wantF { + t.Errorf("existence unmatch #%d: want=%t got=%t", i, c.wantF, gotF) + continue + } + if gotP != c.wantP { + t.Errorf("prefix unmatch #%d: want=%s got=%s", i, c.wantP, gotP) + } + if d := cmp.Diff(c.wantV, gotV); d != "" { + t.Errorf("values unmatch #%d: -want +got\n%s", i, d) + } + } +}