diff --git a/memory/priority/rbtree_cache_node.go b/memory/priority/rbtree_cache_node.go index e181f95..9d695f4 100644 --- a/memory/priority/rbtree_cache_node.go +++ b/memory/priority/rbtree_cache_node.go @@ -32,34 +32,34 @@ type rbTreeCacheNode struct { isDeleted bool //是否被删除 } -func newKVRBTreeCacheNode(key string, value any, expiration time.Duration) *rbTreeCacheNode { - node := &rbTreeCacheNode{ +// newRBTreeCacheNode 创建红黑树节点,注意如果是容器类型节点要value传递初始化一个零值 +func newRBTreeCacheNode(key string, value any) *rbTreeCacheNode { + return &rbTreeCacheNode{ key: key, value: value, } +} + +func newKVRBTreeCacheNode(key string, value any, expiration time.Duration) *rbTreeCacheNode { + node := newRBTreeCacheNode(key, value) node.setExpiration(expiration) return node } func newListRBTreeCacheNode(key string) *rbTreeCacheNode { - return &rbTreeCacheNode{ - key: key, - value: list.NewLinkedList[any](), - } + return newRBTreeCacheNode(key, list.NewLinkedList[any]()) } func newSetRBTreeCacheNode(key string, initSize int) *rbTreeCacheNode { - return &rbTreeCacheNode{ - key: key, - value: set.NewMapSet[any](initSize), - } + return newRBTreeCacheNode(key, set.NewMapSet[any](initSize)) } func newIntRBTreeCacheNode(key string) *rbTreeCacheNode { - return &rbTreeCacheNode{ - key: key, - value: int64(0), - } + return newRBTreeCacheNode(key, int64(0)) +} + +func newFloatRBTreeCacheNode(key string) *rbTreeCacheNode { + return newRBTreeCacheNode(key, float64(0)) } // setExpiration 设置有效期 diff --git a/memory/priority/rbtree_priority_cache.go b/memory/priority/rbtree_priority_cache.go index db78dfd..0c6d9f2 100644 --- a/memory/priority/rbtree_priority_cache.go +++ b/memory/priority/rbtree_priority_cache.go @@ -98,15 +98,8 @@ func (r *RBTreePriorityCache) Set(_ context.Context, key string, val any, expira r.globalLock.Lock() defer r.globalLock.Unlock() - node, cacheErr := r.cacheData.Find(key) - if cacheErr != nil { - if r.isFull() { - r.deleteNodeByPriority() - } - node = newKVRBTreeCacheNode(key, val, expiration) - r.addNode(node) - return nil - } + node := r.findOrCreateNode(key, func() any { return val }) + node.replace(val, expiration) return nil } @@ -157,6 +150,8 @@ func (r *RBTreePriorityCache) Get(ctx context.Context, key string) (val ecache.V return } + r.globalLock.Lock() + defer r.globalLock.Unlock() now := time.Now() if !node.beforeDeadline(now) { r.doubleCheckWhenExpire(node, now) @@ -169,11 +164,8 @@ func (r *RBTreePriorityCache) Get(ctx context.Context, key string) (val ecache.V return } -// doubleCheckWhenExpire 缓存过期时的二次校验,防止被抢先删除了 +// doubleCheckWhenExpire 缓存过期时的二次校验,防止被抢先删除了【调用该方法必须先获得锁】 func (r *RBTreePriorityCache) doubleCheckWhenExpire(node *rbTreeCacheNode, now time.Time) { - r.globalLock.Lock() - defer r.globalLock.Unlock() - checkNode, checkCacheErr := r.cacheData.Find(node.key) if checkCacheErr != nil { return //被抢先删除了 @@ -212,15 +204,9 @@ func (r *RBTreePriorityCache) LPush(ctx context.Context, key string, val ...any) r.globalLock.Lock() defer r.globalLock.Unlock() - node, cacheErr := r.cacheData.Find(key) - if cacheErr != nil { - if r.isFull() { - r.deleteNodeByPriority() - } - node = newListRBTreeCacheNode(key) - r.addNode(node) - } - + node := r.findOrCreateNode(key, func() any { + return list.NewLinkedList[any]() + }) nodeVal, ok := node.value.(*list.LinkedList[any]) if !ok { return 0, errOnlyListCanLPUSH @@ -268,15 +254,9 @@ func (r *RBTreePriorityCache) SAdd(ctx context.Context, key string, members ...a r.globalLock.Lock() defer r.globalLock.Unlock() - node, cacheErr := r.cacheData.Find(key) - if cacheErr != nil { - if r.isFull() { - r.deleteNodeByPriority() - } - node = newSetRBTreeCacheNode(key, r.collectionCap) - r.addNode(node) - } - + node := r.findOrCreateNode(key, func() any { + return set.NewMapSet[any](r.collectionCap) + }) nodeVal, ok := node.value.(*set.MapSet[any]) if !ok { return 0, errOnlySetCanSAdd @@ -327,14 +307,7 @@ func (r *RBTreePriorityCache) IncrBy(ctx context.Context, key string, value int6 r.globalLock.Lock() defer r.globalLock.Unlock() - node, cacheErr := r.cacheData.Find(key) - if cacheErr != nil { - if r.isFull() { - r.deleteNodeByPriority() - } - node = newIntRBTreeCacheNode(key) - r.addNode(node) - } + node := r.findOrCreateNode(key, func() any { return int64(0) }) nodeVal, ok := node.value.(int64) if !ok { @@ -347,18 +320,64 @@ func (r *RBTreePriorityCache) IncrBy(ctx context.Context, key string, value int6 return newVal, nil } -func (r *RBTreePriorityCache) DecrBy(ctx context.Context, key string, value int64) (int64, error) { +func (r *RBTreePriorityCache) IncrByFloat(ctx context.Context, key string, value float64) (float64, error) { r.globalLock.Lock() defer r.globalLock.Unlock() - node, cacheErr := r.cacheData.Find(key) - if cacheErr != nil { - if r.isFull() { - r.deleteNodeByPriority() + node := r.findOrCreateNode(key, func() any { return float64(0) }) + nodeVal, ok := node.value.(float64) + if !ok { + //如果是int类型可以尝试转换 + intNodeVal, ok := node.value.(int64) + if !ok { + return 0, errOnlyNumCanIncrBy } - node = newIntRBTreeCacheNode(key) - r.addNode(node) + nodeVal = float64(intNodeVal) + } + + newVal := nodeVal + value + node.value = newVal + + return newVal, nil +} + +func (r *RBTreePriorityCache) Delete(ctx context.Context, keys ...string) (int64, error) { + delCount := int64(0) + now := time.Now() + for _, key := range keys { + r.globalLock.RLock() + _, cacheErr := r.cacheData.Find(key) + r.globalLock.RUnlock() + if cacheErr != nil { + continue + } + + r.globalLock.Lock() + node, cacheErr := r.cacheData.Find(key) + if cacheErr != nil { + r.globalLock.Unlock() + continue + } + + // 过期删除不添加计数 + if !node.beforeDeadline(now) { + r.deleteNode(node) + r.globalLock.Unlock() + continue + } + + r.deleteNode(node) + r.globalLock.Unlock() + delCount++ } + return delCount, nil +} + +func (r *RBTreePriorityCache) DecrBy(ctx context.Context, key string, value int64) (int64, error) { + r.globalLock.Lock() + defer r.globalLock.Unlock() + + node := r.findOrCreateNode(key, func() any { return int64(0) }) nodeVal, ok := node.value.(int64) if !ok { @@ -403,6 +422,19 @@ func (r *RBTreePriorityCache) isFull() bool { return r.cacheNum >= r.cacheLimit } +// findOrCreateNode 查找节点,不存在时使用默认值创建节点【调用该方法必须先获得锁】 +func (r *RBTreePriorityCache) findOrCreateNode(key string, initFunc func() any) *rbTreeCacheNode { + node, cacheErr := r.cacheData.Find(key) + if cacheErr != nil { + if r.isFull() { + r.deleteNodeByPriority() + } + node = newRBTreeCacheNode(key, initFunc()) + r.addNode(node) + } + return node +} + // deleteNodeByPriority 根据优先级淘汰缓存结点【调用该方法必须先获得锁】 func (r *RBTreePriorityCache) deleteNodeByPriority() { for { @@ -434,7 +466,9 @@ func (r *RBTreePriorityCache) autoClean() { now := time.Now() for _, value := range values { if !value.beforeDeadline(now) { + r.globalLock.Lock() r.doubleCheckWhenExpire(value, now) + r.globalLock.Unlock() } } } diff --git a/memory/priority/rbtree_priority_cache_test.go b/memory/priority/rbtree_priority_cache_test.go index 0b79179..c33b7dd 100644 --- a/memory/priority/rbtree_priority_cache_test.go +++ b/memory/priority/rbtree_priority_cache_test.go @@ -21,6 +21,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/ecodeclub/ecache/internal/errs" "github.com/ecodeclub/ekit/list" "github.com/ecodeclub/ekit/set" @@ -626,7 +628,9 @@ func TestRBTreePriorityCache_doubleCheckInGet(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { startCache := tc.startCache() + startCache.globalLock.Lock() startCache.doubleCheckWhenExpire(tc.node, time.Now()) + startCache.globalLock.Unlock() assert.Equal(t, true, compareTwoRBTreeClient(startCache, tc.wantCache())) }) } @@ -1407,7 +1411,7 @@ func TestRBTreePriorityCache_IncrBy(t *testing.T) { wantRet: 1, }, { - name: "wrong type", + name: "wrong string type ", startCache: func() *RBTreePriorityCache { cache, _ := NewRBTreePriorityCache() cache.globalLock.Lock() @@ -1426,6 +1430,21 @@ func TestRBTreePriorityCache_IncrBy(t *testing.T) { }, wantErr: errOnlyNumCanIncrBy, }, + { + name: "wrong float type", + startCache: func() *RBTreePriorityCache { + cache, _ := NewRBTreePriorityCache() + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + node := newFloatRBTreeCacheNode("key1") + node.value = float64(3.14) + cache.addNode(node) + return cache + }, + key: "key1", + value: 1, + wantErr: errOnlyNumCanIncrBy, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -1588,3 +1607,240 @@ func TestRBTreePriorityCache_autoClean(t *testing.T) { value6Str, _ = value6.String() assert.Equal(t, "value6", value6Str) } + +func TestRBTreePriorityCache_Delete(t *testing.T) { + + testCases := []struct { + name string + cache *RBTreePriorityCache + keys []string + wantRes int64 + wantErr error + }{ + { + name: "cache 1 , delete 1", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + + cache.globalLock.Lock() + cache.addNode(newKVRBTreeCacheNode("key1", testStructForPriority{priority: -1}, 0)) + cache.globalLock.Unlock() + return cache + }(), + keys: []string{"key1"}, + wantRes: 1, + }, + { + name: "cache 1 , delete 0", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + + cache.globalLock.Lock() + cache.addNode(newKVRBTreeCacheNode("key1", testStructForPriority{priority: -1}, 0)) + cache.globalLock.Unlock() + return cache + }(), + keys: []string{"key2"}, + wantRes: 0, + }, + { + name: "cache 1, expired 1, delete 0", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + + cache.globalLock.Lock() + cache.addNode(newKVRBTreeCacheNode("key1", testStructForPriority{priority: -1}, time.Second)) + cache.globalLock.Unlock() + + time.Sleep(3 * time.Second) + return cache + }(), + keys: []string{"key1"}, + wantRes: 0, + }, + { + name: "cache 4, delete 3", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + + cache.globalLock.Lock() + cache.addNode(newKVRBTreeCacheNode("key1", testStructForPriority{priority: -1}, time.Second)) + cache.addNode(newKVRBTreeCacheNode("key2", testStructForPriority{priority: -1}, time.Second)) + cache.addNode(newKVRBTreeCacheNode("key3", testStructForPriority{priority: -1}, time.Second)) + cache.addNode(newKVRBTreeCacheNode("key4", testStructForPriority{priority: -1}, time.Second)) + cache.globalLock.Unlock() + return cache + }(), + keys: []string{"key1", "key2", "key3"}, + wantRes: 3, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cache := tc.cache + res, err := cache.Delete(context.Background(), tc.keys...) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantRes, res) + + //确认已经被删除 + for _, key := range tc.keys { + _, err = tc.cache.cacheData.Find(key) + assert.NotNil(t, err) + } + }) + } +} + +func TestRBTreePriorityCache_IncrByFloat(t *testing.T) { + + testCases := []struct { + name string + cache *RBTreePriorityCache + key string + value float64 + wantCache *RBTreePriorityCache + wantRes float64 + wantErr error + }{ + { + name: "hit cache key1, increase key1", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + node := newFloatRBTreeCacheNode("key1") + node.value = float64(3.14) + cache.addNode(node) + return cache + }(), + key: "key1", + value: 0.1, + wantRes: 3.24, + wantCache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + node := newFloatRBTreeCacheNode("key1") + node.value = 3.24 + cache.addNode(node) + return cache + }(), + }, + { + name: "miss cache key1, create key1", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + return cache + }(), + key: "key1", + value: 0.1, + wantRes: 0.1, + wantCache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + node := newFloatRBTreeCacheNode("key1") + node.value = 0.1 + cache.addNode(node) + return cache + }(), + }, + { + name: "hit cache key1, wrong type", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(8)) + require.NoError(t, err) + + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + cache.addNode(newKVRBTreeCacheNode("key1", "value1", 0)) + return cache + }(), + key: "key1", + value: 0.1, + wantErr: errOnlyNumCanIncrBy, + }, + { + name: "cache is full, increase key1, evited old key", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(1)) + require.NoError(t, err) + + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + node := newFloatRBTreeCacheNode("key0") + node.value = 3.14 + cache.addNode(node) + return cache + }(), + key: "key1", + value: 0.1, + wantRes: 0.1, + wantCache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(1)) + require.NoError(t, err) + + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + node := newFloatRBTreeCacheNode("key1") + node.value = 0.1 + cache.addNode(node) + return cache + }(), + }, + { + name: "hit cache key1, convert int64 type to float64 type and increase", + cache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(1)) + require.NoError(t, err) + + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + node := newIntRBTreeCacheNode("key1") + node.value = int64(3) + cache.addNode(node) + return cache + }(), + key: "key1", + value: 0.1, + wantRes: 3.1, + wantCache: func() *RBTreePriorityCache { + cache, err := newRBTreePriorityCache(WithCacheLimit(1)) + require.NoError(t, err) + + cache.globalLock.Lock() + defer cache.globalLock.Unlock() + node := newFloatRBTreeCacheNode("key1") + node.value = 3.1 + cache.addNode(node) + return cache + }(), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + startCache := tc.cache + value, err := startCache.IncrByFloat(context.Background(), tc.key, tc.value) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantRes, value) + assert.Equal(t, true, compareTwoRBTreeClient(startCache, tc.wantCache)) + }) + } +}