From b4929aaf22f811a395cdd826e0d585459814760f Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 12 Jul 2024 18:43:01 +0800 Subject: [PATCH] feat: new strmap --- container/strmap/strmap.go | 111 ++++++++++++++++++++++++++++++++ container/strmap/strmap_test.go | 98 ++++++++++++++++++++++++++++ container/strmap/utils.go | 101 +++++++++++++++++++++++++++++ container/strmap/utils_test.go | 32 +++++++++ 4 files changed, 342 insertions(+) create mode 100644 container/strmap/strmap.go create mode 100644 container/strmap/strmap_test.go create mode 100644 container/strmap/utils.go create mode 100644 container/strmap/utils_test.go diff --git a/container/strmap/strmap.go b/container/strmap/strmap.go new file mode 100644 index 0000000..c0b948d --- /dev/null +++ b/container/strmap/strmap.go @@ -0,0 +1,111 @@ +package strmap + +import ( + "fmt" + "sort" + "strings" +) + +// StrMap represents GC friendly string map implementation. +// it's readonly after it's created +type StrMap struct { + data []byte + items []mapItem + + hashtable []int +} + +type mapItem struct { + off int + sz int + slot uint + v uintptr +} + +// New creates StrMap from map[string]uintptr +// uintptr can be any value and it will be returned by Get. +func New(m map[string]uintptr) *StrMap { + sz := 0 + for k, _ := range m { + sz += len(k) + } + b := make([]byte, 0, sz) + items := make([]mapItem, 0, len(m)) + for k, v := range m { + items = append(items, mapItem{off: len(b), sz: len(k), slot: uint(hashstr(k)), v: v}) + b = append(b, k...) + } + ret := &StrMap{data: b, items: items} + ret.makeHashtable() + return ret +} + +// Len returns the size of map +func (m *StrMap) Len() int { + return len(m.items) +} + +func (m *StrMap) makeHashtable() { + slots := calcHashtableSlots(len(m.items)) + m.hashtable = make([]int, slots) + + for i := range m.items { + m.items[i].slot = m.items[i].slot % slots + } + + // make sure items with the same slot stored together + // good for cpu cache + sort.Slice(m.items, func(i, j int) bool { + return m.items[i].slot < m.items[j].slot + }) + + for i := 0; i < len(m.hashtable); i++ { + m.hashtable[i] = -1 + } + for i := range m.items { + e := &m.items[i] + if m.hashtable[e.slot] < 0 { + // we only need to store the 1st item if hash conflict + // since they're already stored together + // will check the next item when Get + m.hashtable[e.slot] = i + } + } +} + +// Get ... +func (m *StrMap) Get(s string) (uintptr, bool) { + slot := uint(hashstr(s)) % uint(len(m.hashtable)) + i := m.hashtable[slot] + if i < 0 { + return 0, false + } + e := &m.items[i] + for { + if string(m.data[e.off:e.off+e.sz]) == s { // double check + return e.v, true + } + i++ + if i >= len(m.items) { + break + } + e = &m.items[i] + if e.slot != slot { // items sorted by slot + break + } + } + return 0, false +} + +func (m *StrMap) String() string { + b := &strings.Builder{} + b.WriteByte('[') + for i, e := range m.items { + if i != 0 { + b.WriteString(", ") + } + fmt.Fprintf(b, "{off:%d, slot:%x, str:%q}", e.off, e.slot, string(m.data[e.off:e.off+e.sz])) + } + b.WriteByte(']') + return b.String() +} diff --git a/container/strmap/strmap_test.go b/container/strmap/strmap_test.go new file mode 100644 index 0000000..c80246a --- /dev/null +++ b/container/strmap/strmap_test.go @@ -0,0 +1,98 @@ +package strmap + +import ( + "crypto/rand" + "runtime" + "testing" + + "github.com/cloudwego/gopkg/internal/unsafe" + "github.com/stretchr/testify/require" +) + +func randString(m int) string { + b := make([]byte, m) + rand.Read(b) + return string(b) +} + +func randStrings(m, n int) []string { + b := make([]byte, m*n) + rand.Read(b) + ret := make([]string, 0, n) + for i := 0; i < n; i++ { + s := b[m*i:] + s = s[:m] + ret = append(ret, unsafe.ByteSliceToString(s)) + } + return ret +} + +func newStdStrMap(ss []string) map[string]uintptr { + v := uintptr(1) + m := make(map[string]uintptr) + for _, s := range ss { + _, ok := m[s] + if !ok { + m[s] = v + v++ + } + } + return m +} + +func TestStrMap(t *testing.T) { + ss := randStrings(20, 100000) + m := newStdStrMap(ss) + strset := New(m) + require.Equal(t, len(m), strset.Len()) + for i, s := range ss { + v0 := m[s] + v1, _ := strset.Get(s) + require.Equal(t, v0, v1, i) + } +} + +func Benchmark_StrMap(b *testing.B) { + ss := randStrings(20, 100000) + m := newStdStrMap(ss) + strset := New(m) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = strset.Get(ss[i%len(ss)]) + } +} + +func Benchmark_StdMap(b *testing.B) { + ss := randStrings(20, 100000) + m := newStdStrMap(ss) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = m[ss[i%len(ss)]] + } +} + +func Benchmark_StrMap_GC(b *testing.B) { + ss := randStrings(50, 1000000) + m := newStdStrMap(ss) + strset := New(m) + ss = nil + m = nil + runtime.GC() + b.ResetTimer() + for i := 0; i < b.N; i++ { + runtime.GC() + } + runtime.KeepAlive(strset) +} + +func Benchmark_StdMap_GC(b *testing.B) { + ss := randStrings(50, 1000000) + m := newStdStrMap(ss) + ss = nil + runtime.GC() + b.ResetTimer() + for i := 0; i < b.N; i++ { + runtime.GC() + } + runtime.KeepAlive(m) +} diff --git a/container/strmap/utils.go b/container/strmap/utils.go new file mode 100644 index 0000000..b901f8a --- /dev/null +++ b/container/strmap/utils.go @@ -0,0 +1,101 @@ +package strmap + +import ( + "time" + "unsafe" +) + +const ( + fnvHashOffset64 = uint64(14695981039346656037) // fnv hash offset64 + fnvHashPrime64 = uint64(1099511628211) +) + +var hashseed = fnvHashOffset64 + +func init() { + hashseed = hashstr(time.Now().String()) +} + +func strDataPtr(s string) unsafe.Pointer { + // XXX: for str or slice, the Data ptr is always the 1st field + return *(*unsafe.Pointer)(unsafe.Pointer(&s)) +} + +func hashstr(s string) uint64 { + // a modified version of fnv hash, + // it computes 8 bytes per round, + // and doesn't generate the same result for diff cpu arch + + h := hashseed + p := strDataPtr(s) + + // 8 byte per round + i := 0 + for n := len(s) >> 3; i < n; i++ { + h *= fnvHashPrime64 + h ^= *(*uint64)(unsafe.Add(p, i<<3)) // p[i*8] + } + + // left 0-7 bytes + i = i << 3 + for ; i < len(s); i++ { + h *= fnvHashPrime64 + h ^= uint64(s[i]) + } + return h +} + +var bits2primes = []uint{ + 0: 17, // 1 + 1: 17, // 2 + 2: 17, // 4 + 3: 17, // 8 + 4: 17, // at least 17 for <= 16 + 5: 31, // 32 + 6: 61, // 64 + 7: 127, // 128 + 8: 251, // 256 + 9: 509, // 512 + 10: 1021, // 1024 + 11: 2039, // 2048 + 12: 4093, // 4096 + 13: 8191, // 8192 + 14: 16381, // 16384 + 15: 32749, // 32768 + 16: 65521, // 65536 + 17: 131071, // 131072 + 18: 262139, // 262144 + 19: 524287, // 524288 + 20: 1048573, // 1048576 + 21: 2097143, // 2097152 + 22: 4194301, // 4194304 + 23: 8388593, // 8388608 + 24: 16777213, // 16777216 + 25: 33554393, // 33554432 + 26: 67108859, // 67108864 + 27: 134217689, // 134217728 + 28: 268435399, // 268435456 + 29: 536870909, // 536870912 + 30: 1073741789, // 1073741824 +} + +func calcHashtableSlots(n int) uint { + // load factor + n = int(float32(n) / 0.75) + + // count bits to decide which prime number to use + bits := 0 + for v := uint64(n); v > 0; v = v >> 1 { + bits++ + } + + // add one more bit, + // so if n=1500, than returns 2039 instead of 1021 + bits++ + + if bits > len(bits2primes) { + // ???? are you sure we need to hold so many items? ~ 1B items for 30 bits + return uint(n) + } + return bits2primes[bits] // a prime bigger than n +} diff --git a/container/strmap/utils_test.go b/container/strmap/utils_test.go new file mode 100644 index 0000000..4ce402c --- /dev/null +++ b/container/strmap/utils_test.go @@ -0,0 +1,32 @@ +package strmap + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHashStr(t *testing.T) { + require.Equal(t, hashstr("1234"), hashstr("1234")) + require.NotEqual(t, hashstr("12345"), hashstr("12346")) + require.Equal(t, hashstr("12345678"), hashstr("12345678")) + require.NotEqual(t, hashstr("123456789"), hashstr("123456788")) +} + +func BenchmarkHashStr(b *testing.B) { + strSizes := []int{8, 16, 32, 64, 128, 512} + ss := make([]string, len(strSizes)) + for i := range ss { + ss[i] = randString(strSizes[i]) + } + b.ResetTimer() + for _, s := range ss { + b.Run(fmt.Sprintf("size-%d", len(s)), func(b *testing.B) { + b.SetBytes(int64(len(s))) + for i := 0; i < b.N; i++ { + _ = hashstr(s) + } + }) + } +}