diff --git a/.gitignore b/.gitignore index 52c3bddd..1d5e17cb 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ breakpoints.txt # coverage output coverage.txt + +# go workspace +go.work diff --git a/cache/lru/lru.go b/cache/lru/lru.go index a58e98f4..c98bc20a 100644 --- a/cache/lru/lru.go +++ b/cache/lru/lru.go @@ -7,9 +7,6 @@ import ( "github.com/lightninglabs/neutrino/cache" ) -// elementMap is an alias for a map from a generic interface to a list.Element. -type elementMap[K comparable, V any] map[K]V - // entry represents a (key,value) pair entry in the Cache. The Cache's list // stores entries which let us get the cache key when an entry is evicted. type entry[K comparable, V cache.Value] struct { @@ -33,7 +30,7 @@ type Cache[K comparable, V cache.Value] struct { // cache is a generic cache which allows us to find an elements position // in the ll list from a given key. - cache elementMap[K, *Element[entry[K, V]]] + cache syncMap[K, *Element[entry[K, V]]] // mtx is used to make sure the Cache is thread-safe. mtx sync.RWMutex @@ -45,7 +42,7 @@ func NewCache[K comparable, V cache.Value](capacity uint64) *Cache[K, V] { return &Cache[K, V]{ capacity: capacity, ll: NewList[entry[K, V]](), - cache: make(map[K]*Element[entry[K, V]]), + cache: syncMap[K, *Element[entry[K, V]]]{}, } } @@ -84,7 +81,7 @@ func (c *Cache[K, V]) evict(needed uint64) (bool, error) { // Remove the element from the cache. c.ll.Remove(elr) - delete(c.cache, ce.key) + c.cache.Delete(ce.key) evicted = true } } @@ -108,17 +105,22 @@ func (c *Cache[K, V]) Put(key K, value V) (bool, error) { "cache with capacity %v", vs, c.capacity) } + // Load the element. + el, ok := c.cache.Load(key) + + // Update the internal list inside a lock. c.mtx.Lock() - defer c.mtx.Unlock() // If the element already exists, remove it and decrease cache's size. - el, ok := c.cache[key] if ok { es, err := el.Value.value.Size() if err != nil { + c.mtx.Unlock() + return false, fmt.Errorf("couldn't determine size of "+ "existing cache value %v", err) } + c.ll.Remove(el) c.size -= es } @@ -132,26 +134,31 @@ func (c *Cache[K, V]) Put(key K, value V) (bool, error) { // We have made enough space in the cache, so just insert it. el = c.ll.PushFront(entry[K, V]{key, value}) - c.cache[key] = el c.size += vs + // Release the lock. + c.mtx.Unlock() + + // Update the cache. + c.cache.Store(key, el) + return evicted, nil } // Get will return value for a given key, making the element the most recently // accessed item in the process. Will return nil if the key isn't found. func (c *Cache[K, V]) Get(key K) (V, error) { - c.mtx.Lock() - defer c.mtx.Unlock() - var defaultVal V - el, ok := c.cache[key] + el, ok := c.cache.Load(key) if !ok { // Element not found in the cache. return defaultVal, cache.ErrElementNotFound } + c.mtx.Lock() + defer c.mtx.Unlock() + // When the cache needs to evict a element to make space for another // one, it starts eviction from the back, so by moving this element to // the front, it's eviction is delayed because it's recently accessed. @@ -166,3 +173,45 @@ func (c *Cache[K, V]) Len() int { return c.ll.Len() } + +// Delete removes an item from the cache. +func (c *Cache[K, V]) Delete(key K) { + c.LoadAndDelete(key) +} + +// LoadAndDelete queries an item and deletes it from the cache using the +// specified key. +func (c *Cache[K, V]) LoadAndDelete(key K) (V, bool) { + var defaultVal V + + // Noop if the element doesn't exist. + el, ok := c.cache.LoadAndDelete(key) + if !ok { + return defaultVal, false + } + + c.mtx.Lock() + defer c.mtx.Unlock() + + // Get its size. + vs, err := el.Value.value.Size() + if err != nil { + return defaultVal, false + } + + // Remove the element from the list and update the cache's size. + c.ll.Remove(el) + c.size -= vs + + return el.Value.value, true +} + +// Range iterates the cache. +func (c *Cache[K, V]) Range(visitor func(K, V) bool) { + // valueVisitor is a closure to help unwrap the value from the cache. + valueVisitor := func(key K, value *Element[entry[K, V]]) bool { + return visitor(key, value.Value.value) + } + + c.cache.Range(valueVisitor) +} diff --git a/cache/lru/lru_test.go b/cache/lru/lru_test.go index b594c26a..3be389fa 100644 --- a/cache/lru/lru_test.go +++ b/cache/lru/lru_test.go @@ -97,7 +97,7 @@ func TestElementSizeCapacityEvictsEverything(t *testing.T) { // Insert element with size=capacity of cache, should evict everything. c.Put(4, &sizeable{value: 4, size: 3}) require.Equal(t, c.Len(), 1) - require.Equal(t, len(c.cache), 1) + require.Equal(t, c.cache.Len(), 1) four := getSizeableValue(c.Get(4)) require.Equal(t, four, 4) @@ -110,7 +110,7 @@ func TestElementSizeCapacityEvictsEverything(t *testing.T) { // Insert element with size=capacity of cache. c.Put(4, &sizeable{value: 4, size: 6}) require.Equal(t, c.Len(), 1) - require.Equal(t, len(c.cache), 1) + require.Equal(t, c.cache.Len(), 1) four = getSizeableValue(c.Get(4)) require.Equal(t, four, 4) } @@ -296,3 +296,94 @@ func TestConcurrencyBigCache(t *testing.T) { wg.Wait() } + +// TestLoadAndDelete checks the `LoadAndDelete` method. +func TestLoadAndDelete(t *testing.T) { + t.Parallel() + + c := NewCache[int, *sizeable](3) + + // Create a test item. + item1 := &sizeable{value: 1, size: 1} + + // Put the item. + _, err := c.Put(0, item1) + require.NoError(t, err) + + // Load the item and check it's returned as expected. + loadedItem, loaded := c.LoadAndDelete(0) + require.True(t, loaded) + require.Equal(t, item1, loadedItem) + + // Now check that the item has been deleted. + _, err = c.Get(0) + require.ErrorIs(t, err, cache.ErrElementNotFound) + + // Load the item again should give us a nil value and false. + loadedItem, loaded = c.LoadAndDelete(0) + require.False(t, loaded) + require.Nil(t, loadedItem) + + // The length should be 0. + require.Zero(t, c.Len()) + require.Zero(t, c.size) +} + +// TestRangeIteration checks that the `Range` method works as expected. +func TestRangeIteration(t *testing.T) { + t.Parallel() + + c := NewCache[int, *sizeable](100) + + // Create test items. + const numItems = 10 + for i := 0; i < numItems; i++ { + _, err := c.Put(i, &sizeable{value: i, size: 1}) + require.NoError(t, err) + } + + // Create a dummy visitor that just counts the number of items visited. + visited := 0 + testVisitor := func(key int, value *sizeable) bool { + visited++ + return true + } + + // Call the method. + c.Range(testVisitor) + + // Check the number of items visited. + require.Equal(t, numItems, visited) +} + +// TestRangeAbort checks that the `Range` will abort when the visitor returns +// false. +func TestRangeAbort(t *testing.T) { + t.Parallel() + + c := NewCache[int, *sizeable](100) + + // Create test items. + const numItems = 10 + for i := 0; i < numItems; i++ { + _, err := c.Put(i, &sizeable{value: i, size: 1}) + require.NoError(t, err) + } + + // Create a visitor that counts the number of items visited and returns + // false when visited 5 times. + visited := 0 + testVisitor := func(key int, value *sizeable) bool { + visited++ + if visited >= numItems/2 { + return false + } + return true + } + + // Call the method. + c.Range(testVisitor) + + // Check the number of items visited. + require.Equal(t, numItems/2, visited) +} diff --git a/cache/lru/sync_map.go b/cache/lru/sync_map.go new file mode 100644 index 00000000..448435b4 --- /dev/null +++ b/cache/lru/sync_map.go @@ -0,0 +1,66 @@ +package lru + +import "sync" + +// syncMap wraps a sync.Map with type parameters such that it's easier to +// access the items stored in the map since no type assertion is needed. It +// also requires explicit type definition when declaring and initiating the +// variables, which helps us understanding what's stored in a given map. +// +// NOTE: this is unexported to avoid confusion with `lnd`'s `SyncMap`. +type syncMap[K comparable, V any] struct { + sync.Map +} + +// Store puts an item in the map. +func (m *syncMap[K, V]) Store(key K, value V) { + m.Map.Store(key, value) +} + +// Load queries an item from the map using the specified key. If the item +// cannot be found, an empty value and false will be returned. If the stored +// item fails the type assertion, a nil value and false will be returned. +func (m *syncMap[K, V]) Load(key K) (V, bool) { + result, ok := m.Map.Load(key) + if !ok { + return *new(V), false // nolint: gocritic + } + + item, ok := result.(V) + return item, ok +} + +// Delete removes an item from the map specified by the key. +func (m *syncMap[K, V]) Delete(key K) { + m.Map.Delete(key) +} + +// LoadAndDelete queries an item and deletes it from the map using the +// specified key. +func (m *syncMap[K, V]) LoadAndDelete(key K) (V, bool) { + result, loaded := m.Map.LoadAndDelete(key) + if !loaded { + return *new(V), loaded // nolint: gocritic + } + + item, ok := result.(V) + return item, ok +} + +// Range iterates the map. +func (m *syncMap[K, V]) Range(visitor func(K, V) bool) { + m.Map.Range(func(k any, v any) bool { + return visitor(k.(K), v.(V)) + }) +} + +// Len returns the number of items in the map. +func (m *syncMap[K, V]) Len() int { + var count int + m.Range(func(K, V) bool { + count++ + return true + }) + + return count +}