Skip to content

Commit

Permalink
Add a more complex caching mechanism which loads assets concurrently,…
Browse files Browse the repository at this point in the history
… to handle slow asset servers. Abstracted enough to be able to handle caching to disk without having to change users of the Cache.
  • Loading branch information
xxxserxxx committed Oct 14, 2024
1 parent 2e762f2 commit 97524e1
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 23 deletions.
122 changes: 122 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package main

import (
"time"

"github.com/spezifisch/stmps/logger"
)

// Cache fetches assets and holds a copy, returning them on request.
// A Cache is composed of four mechanisms:
//
// 1. a zero object
// 2. a function for fetching assets
// 3. a function for invalidating assets
// 4. a call-back function for when an asset is fetched
//
// When an asset is requested, Cache returns the asset if it is cached.
// Otherwise, it returns the zero object, and queues up a fetch for the object
// in the background. When the fetch is complete, the callback function is
// called, allowing the caller to get the real asset. An invalidation function
// allows Cache to manage the cache size by removing cached invalid objects.
//
// Caches are indexed by strings, because. They don't have to be, but
// stmps doesn't need them to be anything different.
type Cache[T any] struct {
zero T
cache map[string]T
pipeline chan string
quit func()
}

// NewCache sets up a new cache, given
//
// - a zero value, returned immediately on cache misses
// - a fetcher, which can be a long-running function that loads assets.
// fetcher should take a key ID and return an asset, or an error.
// - a call-back, which will be called when a requested asset is available. It
// will be called with the asset ID, and the loaded asset.
// - an invalidation function, returning true if a cached object stored under a
// key can be removed from the cache. It will be called with an asset ID to
// check.
// - an invalidation frequency; the invalidation function will be called for
// every cached object this frequently.
// - a logger, used for reporting errors returned by the fetching function
//
// The invalidation should be reasonably efficient.
func NewCache[T any](
zeroValue T,
fetcher func(string) (T, error),
fetchedItem func(string, T),
isInvalid func(string) bool,
invalidateFrequency time.Duration,
logger *logger.Logger,
) Cache[T] {

cache := make(map[string]T)
getPipe := make(chan string, 1000)

go func() {
for i := range getPipe {
asset, err := fetcher(i)
if err != nil {
logger.Printf("error fetching asset %s: %s", i, err)
continue
}
cache[i] = asset
fetchedItem(i, asset)
}
}()

timer := time.NewTicker(invalidateFrequency)
done := make(chan bool)
go func() {
for {
select {
case <-timer.C:
for k := range cache {
if isInvalid(k) {
delete(cache, k)
}
}
case <-done:
return
}
}
}()

return Cache[T]{
zero: zeroValue,
cache: cache,
pipeline: getPipe,
quit: func() {
close(getPipe)
done <- true
},
}
}

// Get returns a cached asset, or the zero asset on a cache miss.
// On a cache miss, the requested asset is queued for fetching.
func (c *Cache[T]) Get(key string) T {
if v, ok := c.cache[key]; ok {
return v
}
c.pipeline <- key
return c.zero
}

// Close releases resources used by the cache, clearing the cache
// and shutting down goroutines. It should be called when the
// Cache is no longer used, and before program exit.
//
// Note: since the current iteration of Cache is a memory cache, it isn't
// strictly necessary to call this on program exit; however, as the caching
// mechanism may change and use other system resources, it's good practice to
// call this on exit.
func (c Cache[T]) Close() {
for k := range c.cache {
delete(c.cache, k)
}
c.quit()
}
208 changes: 208 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package main

import (
"testing"
"time"

"github.com/spezifisch/stmps/logger"
)

func TestNewCache(t *testing.T) {
logger := logger.Logger{}

t.Run("basic string cache creation", func(t *testing.T) {
zero := "empty"
c := NewCache(
zero,
func(k string) (string, error) { return zero, nil },
func(k, v string) {},
func(k string) bool { return false },
time.Second,
&logger,
)
defer c.Close()
if c.zero != zero {
t.Errorf("expected %q, got %q", zero, c.zero)
}
if c.cache == nil || len(c.cache) != 0 {
t.Errorf("expected non-nil, empty map; got %#v", c.cache)
}
if c.pipeline == nil {
t.Errorf("expected non-nil chan; got %#v", c.pipeline)
}
})

t.Run("different data type cache creation", func(t *testing.T) {
zero := -1
c := NewCache(
zero,
func(k string) (int, error) { return zero, nil },
func(k string, v int) {},
func(k string) bool { return false },
time.Second,
&logger,
)
defer c.Close()
if c.zero != zero {
t.Errorf("expected %d, got %d", zero, c.zero)
}
if c.cache == nil || len(c.cache) != 0 {
t.Errorf("expected non-nil, empty map; got %#v", c.cache)
}
if c.pipeline == nil {
t.Errorf("expected non-nil chan; got %#v", c.pipeline)
}
})
}

func TestGet(t *testing.T) {
logger := logger.Logger{}
zero := "zero"
items := map[string]string{"a": "1", "b": "2", "c": "3"}
c := NewCache(
zero,
func(k string) (string, error) {
return items[k], nil
},
func(k, v string) {},
func(k string) bool { return false },
time.Second,
&logger,
)
defer c.Close()
t.Run("empty cache get returns zero", func(t *testing.T) {
got := c.Get("a")
if got != zero {
t.Errorf("expected %q, got %q", zero, got)
}
})
// Give the fetcher a chance to populate the cache
time.Sleep(time.Millisecond)
t.Run("non-empty cache get returns value", func(t *testing.T) {
got := c.Get("a")
expected := "1"
if got != expected {
t.Errorf("expected %q, got %q", expected, got)
}
})
}

func TestCallback(t *testing.T) {
logger := logger.Logger{}
zero := "zero"
var gotK, gotV string
expectedK := "a"
expectedV := "1"
c := NewCache(
zero,
func(k string) (string, error) {
return expectedV, nil
},
func(k, v string) {
gotK = k
gotV = v
},
func(k string) bool { return false },
time.Second,
&logger,
)
defer c.Close()
t.Run("callback gets called back", func(t *testing.T) {
c.Get(expectedK)
// Give the callback goroutine a chance to do its thing
time.Sleep(time.Millisecond)
if gotK != expectedK {
t.Errorf("expected key %q, got %q", expectedV, gotV)
}
if gotV != expectedV {
t.Errorf("expected value %q, got %q", expectedV, gotV)
}
})
}

func TestClose(t *testing.T) {
logger := logger.Logger{}
t.Run("pipeline is closed", func(t *testing.T) {
c0 := NewCache(
"",
func(k string) (string, error) { return "A", nil },
func(k, v string) {},
func(k string) bool { return false },
time.Second,
&logger,
)
// Put something in the cache
c0.Get("")
// Give the cache time to populate the cache
time.Sleep(time.Millisecond)
// Make sure the cache isn't empty
if len(c0.cache) == 0 {
t.Fatalf("expected the cache to be non-empty, but it was. Probably a threading issue with the test, and we need a longer timeout.")
}
defer func() {
if r := recover(); r == nil {
t.Error("expected panic on pipeline use; got none")
}
}()
c0.Close()
if len(c0.cache) > 0 {
t.Errorf("expected empty cache; was %d", len(c0.cache))
}
c0.Get("")
})

t.Run("callback gets called back", func(t *testing.T) {
c0 := NewCache(
"",
func(k string) (string, error) { return "", nil },
func(k, v string) {},
func(k string) bool { return false },
time.Second,
&logger,
)
defer func() {
if r := recover(); r == nil {
t.Error("expected panic on pipeline use; got none")
}
}()
c0.Close()
c0.Get("")
})
}

func TestInvalidate(t *testing.T) {
logger := logger.Logger{}
zero := "zero"
var gotV string
expected := "1"
c := NewCache(
zero,
func(k string) (string, error) {
return expected, nil
},
func(k, v string) {
gotV = v
},
func(k string) bool {
return true
},
500*time.Millisecond,
&logger,
)
defer c.Close()
t.Run("basic invalidation", func(t *testing.T) {
if c.Get("a") != zero {
t.Errorf("expected %q, got %q", zero, gotV)
}
// Give the callback goroutine a chance to do its thing
time.Sleep(time.Millisecond)
if c.Get("a") != expected {
t.Errorf("expected %q, got %q", expected, gotV)
}
// Give the invalidation time to be called
time.Sleep(600 * time.Millisecond)
if c.Get("a") != zero {
t.Errorf("expected %q, got %q", zero, gotV)
}
})
}
Loading

0 comments on commit 97524e1

Please sign in to comment.