From c2a1b8eb836d73a31bfc4ccc0d36b4cb99c03b3e Mon Sep 17 00:00:00 2001 From: Scott Blum Date: Tue, 18 Jun 2019 14:09:56 -0400 Subject: [PATCH] gcache with context --- arc.go | 15 ++++++++------- arc_test.go | 16 ++++++++-------- cache.go | 19 ++++++++++--------- cache_test.go | 27 ++++++++++++++------------- examples/autoloading_cache.go | 9 +++++---- examples/custom_expiration.go | 5 +++-- examples/example.go | 3 ++- helpers_test.go | 18 +++++++++++------- lfu.go | 13 +++++++------ lfu_test.go | 10 +++++----- lru.go | 13 +++++++------ lru_test.go | 10 +++++----- simple.go | 15 ++++++++------- simple_test.go | 12 ++++++------ singleflight.go | 25 ++++++++++++++++--------- singleflight_test.go | 6 +++--- stats_test.go | 35 ++++++++++++++++++----------------- 17 files changed, 136 insertions(+), 115 deletions(-) diff --git a/arc.go b/arc.go index e2015e9..5dec0db 100644 --- a/arc.go +++ b/arc.go @@ -2,6 +2,7 @@ package gcache import ( "container/list" + "context" "time" ) @@ -163,10 +164,10 @@ func (c *ARC) set(key, value interface{}) (interface{}, error) { } // Get a value from cache pool using key if it exists. If not exists and it has LoaderFunc, it will generate the value using you have specified LoaderFunc method returns value. -func (c *ARC) Get(key interface{}) (interface{}, error) { +func (c *ARC) Get(ctx context.Context, key interface{}) (interface{}, error) { v, err := c.get(key, false) if err == KeyNotFoundError { - return c.getWithLoader(key, true) + return c.getWithLoader(ctx, key, true) } return v, err } @@ -174,10 +175,10 @@ func (c *ARC) Get(key interface{}) (interface{}, error) { // GetIFPresent gets a value from cache pool using key if it exists. // If it dose not exists key, returns KeyNotFoundError. // And send a request which refresh value for specified key if cache object has LoaderFunc. -func (c *ARC) GetIFPresent(key interface{}) (interface{}, error) { +func (c *ARC) GetIFPresent(ctx context.Context, key interface{}) (interface{}, error) { v, err := c.get(key, false) if err == KeyNotFoundError { - return c.getWithLoader(key, false) + return c.getWithLoader(ctx, key, false) } return v, err } @@ -237,11 +238,11 @@ func (c *ARC) getValue(key interface{}, onLoad bool) (interface{}, error) { return nil, KeyNotFoundError } -func (c *ARC) getWithLoader(key interface{}, isWait bool) (interface{}, error) { - if c.loaderExpireFunc == nil { +func (c *ARC) getWithLoader(ctx context.Context, key interface{}, isWait bool) (interface{}, error) { + if c.loaderExpireFunc == nil || ctx == nil { return nil, KeyNotFoundError } - value, _, err := c.load(key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { + value, _, err := c.load(ctx, key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { if e != nil { return nil, e } diff --git a/arc_test.go b/arc_test.go index 824fcc6..f13e9d8 100644 --- a/arc_test.go +++ b/arc_test.go @@ -21,16 +21,16 @@ func TestLoadingARCGet(t *testing.T) { func TestARCLength(t *testing.T) { gc := buildTestLoadingCacheWithExpiration(t, TYPE_ARC, 2, time.Millisecond) - gc.Get("test1") - gc.Get("test2") - gc.Get("test3") + gc.Get(ctx, "test1") + gc.Get(ctx, "test2") + gc.Get(ctx, "test3") length := gc.Len(true) expectedLength := 2 if length != expectedLength { t.Errorf("Expected length is %v, not %v", expectedLength, length) } time.Sleep(time.Millisecond) - gc.Get("test4") + gc.Get(ctx, "test4") length = gc.Len(true) expectedLength = 1 if length != expectedLength { @@ -44,7 +44,7 @@ func TestARCEvictItem(t *testing.T) { gc := buildTestLoadingCache(t, TYPE_ARC, cacheSize, loader) for i := 0; i < numbers; i++ { - _, err := gc.Get(fmt.Sprintf("Key-%d", i)) + _, err := gc.Get(ctx, fmt.Sprintf("Key-%d", i)) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -63,7 +63,7 @@ func TestARCPurgeCache(t *testing.T) { Build() for i := 0; i < cacheSize; i++ { - _, err := gc.Get(fmt.Sprintf("Key-%d", i)) + _, err := gc.Get(ctx, fmt.Sprintf("Key-%d", i)) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -85,8 +85,8 @@ func TestARCHas(t *testing.T) { for i := 0; i < 10; i++ { t.Run(fmt.Sprint(i), func(t *testing.T) { - gc.Get("test1") - gc.Get("test2") + gc.Get(ctx, "test1") + gc.Get(ctx, "test2") if gc.Has("test0") { t.Fatal("should not have test0") diff --git a/cache.go b/cache.go index e13e6f1..0eff2e1 100644 --- a/cache.go +++ b/cache.go @@ -1,6 +1,7 @@ package gcache import ( + "context" "errors" "fmt" "sync" @@ -19,8 +20,8 @@ var KeyNotFoundError = errors.New("Key not found.") type Cache interface { Set(key, value interface{}) error SetWithExpire(key, value interface{}, expiration time.Duration) error - Get(key interface{}) (interface{}, error) - GetIFPresent(key interface{}) (interface{}, error) + Get(ctx context.Context, key interface{}) (interface{}, error) + GetIFPresent(ctx context.Context, key interface{}) (interface{}, error) GetALL(checkExpired bool) map[interface{}]interface{} get(key interface{}, onLoad bool) (interface{}, error) Remove(key interface{}) bool @@ -48,8 +49,8 @@ type baseCache struct { } type ( - LoaderFunc func(interface{}) (interface{}, error) - LoaderExpireFunc func(interface{}) (interface{}, *time.Duration, error) + LoaderFunc func(context.Context, interface{}) (interface{}, error) + LoaderExpireFunc func(context.Context, interface{}) (interface{}, *time.Duration, error) EvictedFunc func(interface{}, interface{}) PurgeVisitorFunc func(interface{}, interface{}) AddedFunc func(interface{}, interface{}) @@ -86,8 +87,8 @@ func (cb *CacheBuilder) Clock(clock Clock) *CacheBuilder { // Set a loader function. // loaderFunc: create a new value with this function if cached value is expired. func (cb *CacheBuilder) LoaderFunc(loaderFunc LoaderFunc) *CacheBuilder { - cb.loaderExpireFunc = func(k interface{}) (interface{}, *time.Duration, error) { - v, err := loaderFunc(k) + cb.loaderExpireFunc = func(ctx context.Context, k interface{}) (interface{}, *time.Duration, error) { + v, err := loaderFunc(ctx, k) return v, nil, err } return cb @@ -189,14 +190,14 @@ func buildCache(c *baseCache, cb *CacheBuilder) { } // load a new value using by specified key. -func (c *baseCache) load(key interface{}, cb func(interface{}, *time.Duration, error) (interface{}, error), isWait bool) (interface{}, bool, error) { - v, called, err := c.loadGroup.Do(key, func() (v interface{}, e error) { +func (c *baseCache) load(ctx context.Context, key interface{}, cb func(interface{}, *time.Duration, error) (interface{}, error), isWait bool) (interface{}, bool, error) { + v, called, err := c.loadGroup.Do(ctx, key, func() (v interface{}, e error) { defer func() { if r := recover(); r != nil { e = fmt.Errorf("Loader panics: %v", r) } }() - return cb(c.loaderExpireFunc(key)) + return cb(c.loaderExpireFunc(ctx, key)) }, isWait) if err != nil { return nil, called, err diff --git a/cache_test.go b/cache_test.go index 21a9ce0..1356154 100644 --- a/cache_test.go +++ b/cache_test.go @@ -2,6 +2,7 @@ package gcache import ( "bytes" + "context" "encoding/gob" "sync" "sync/atomic" @@ -21,7 +22,7 @@ func TestLoaderFunc(t *testing.T) { var testCounter int64 counter := 1000 cache := builder. - LoaderFunc(func(key interface{}) (interface{}, error) { + LoaderFunc(func(_ context.Context, key interface{}) (interface{}, error) { time.Sleep(10 * time.Millisecond) return atomic.AddInt64(&testCounter, 1), nil }). @@ -34,7 +35,7 @@ func TestLoaderFunc(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - _, err := cache.Get(0) + _, err := cache.Get(ctx, 0) if err != nil { t.Error(err) } @@ -60,7 +61,7 @@ func TestLoaderExpireFuncWithoutExpire(t *testing.T) { var testCounter int64 counter := 1000 cache := builder. - LoaderExpireFunc(func(key interface{}) (interface{}, *time.Duration, error) { + LoaderExpireFunc(func(_ context.Context, key interface{}) (interface{}, *time.Duration, error) { return atomic.AddInt64(&testCounter, 1), nil, nil }). EvictedFunc(func(key, value interface{}) { @@ -72,7 +73,7 @@ func TestLoaderExpireFuncWithoutExpire(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - _, err := cache.Get(0) + _, err := cache.Get(ctx, 0) if err != nil { t.Error(err) } @@ -100,7 +101,7 @@ func TestLoaderExpireFuncWithExpire(t *testing.T) { counter := 1000 expire := 200 * time.Millisecond cache := builder. - LoaderExpireFunc(func(key interface{}) (interface{}, *time.Duration, error) { + LoaderExpireFunc(func(_ context.Context, key interface{}) (interface{}, *time.Duration, error) { return atomic.AddInt64(&testCounter, 1), &expire, nil }). Build() @@ -110,7 +111,7 @@ func TestLoaderExpireFuncWithExpire(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - _, err := cache.Get(0) + _, err := cache.Get(ctx, 0) if err != nil { t.Error(err) } @@ -121,7 +122,7 @@ func TestLoaderExpireFuncWithExpire(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - _, err := cache.Get(0) + _, err := cache.Get(ctx, 0) if err != nil { t.Error(err) } @@ -164,7 +165,7 @@ func TestLoaderPurgeVisitorFunc(t *testing.T) { var purgeCounter, evictCounter, loaderCounter int64 counter := 1000 cache := test.cacheBuilder. - LoaderFunc(func(key interface{}) (interface{}, error) { + LoaderFunc(func(_ context.Context, key interface{}) (interface{}, error) { return atomic.AddInt64(&loaderCounter, 1), nil }). EvictedFunc(func(key, value interface{}) { @@ -181,7 +182,7 @@ func TestLoaderPurgeVisitorFunc(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - _, err := cache.Get(i) + _, err := cache.Get(ctx, i) if err != nil { t.Error(err) } @@ -220,7 +221,7 @@ func TestDeserializeFunc(t *testing.T) { key2, value2 := "key2", "value2" cc := New(32). EvictType(cs.tp). - LoaderFunc(func(k interface{}) (interface{}, error) { + LoaderFunc(func(_ context.Context, k interface{}) (interface{}, error) { return value1, nil }). DeserializeFunc(func(k, v interface{}) (interface{}, error) { @@ -239,14 +240,14 @@ func TestDeserializeFunc(t *testing.T) { return buf.Bytes(), err }). Build() - v, err := cc.Get(key1) + v, err := cc.Get(ctx, key1) if err != nil { t.Fatal(err) } if v != value1 { t.Errorf("%v != %v", v, value1) } - v, err = cc.Get(key1) + v, err = cc.Get(ctx, key1) if err != nil { t.Fatal(err) } @@ -256,7 +257,7 @@ func TestDeserializeFunc(t *testing.T) { if err := cc.Set(key2, value2); err != nil { t.Error(err) } - v, err = cc.Get(key2) + v, err = cc.Get(ctx, key2) if err != nil { t.Error(err) } diff --git a/examples/autoloading_cache.go b/examples/autoloading_cache.go index 5200b46..774bbe4 100644 --- a/examples/autoloading_cache.go +++ b/examples/autoloading_cache.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/bluele/gcache" ) @@ -8,12 +9,12 @@ import ( func main() { gc := gcache.New(10). LFU(). - LoaderFunc(func(key interface{}) (interface{}, error) { - return fmt.Sprintf("%v-value", key), nil - }). + LoaderFunc(func(_ context.Context, key interface{}) (interface{}, error) { + return fmt.Sprintf("%v-value", key), nil + }). Build() - v, err := gc.Get("key") + v, err := gc.Get(context.Background(), "key") if err != nil { panic(err) } diff --git a/examples/custom_expiration.go b/examples/custom_expiration.go index 54f12a6..1383e7a 100644 --- a/examples/custom_expiration.go +++ b/examples/custom_expiration.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/bluele/gcache" "time" @@ -13,7 +14,7 @@ func main() { gc.SetWithExpire("key", "ok", time.Second*3) - v, err := gc.Get("key") + v, err := gc.Get(context.Background(), "key") if err != nil { panic(err) } @@ -22,7 +23,7 @@ func main() { fmt.Println("waiting 3s for value to expire:") time.Sleep(time.Second * 3) - v, err = gc.Get("key") + v, err = gc.Get(context.Background(), "key") if err != nil { panic(err) } diff --git a/examples/example.go b/examples/example.go index 97c1b32..55d045a 100644 --- a/examples/example.go +++ b/examples/example.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/bluele/gcache" ) @@ -11,7 +12,7 @@ func main() { Build() gc.Set("key", "ok") - v, err := gc.Get("key") + v, err := gc.Get(context.Background(), "key") if err != nil { panic(err) } diff --git a/helpers_test.go b/helpers_test.go index e254acc..f6eaf9d 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -1,19 +1,23 @@ package gcache import ( + "context" "fmt" "testing" "time" ) -func loader(key interface{}) (interface{}, error) { +func loader(_ context.Context, key interface{}) (interface{}, error) { return fmt.Sprintf("valueFor%s", key), nil } +// general background context for test calls +var ctx = context.Background() + func testSetCache(t *testing.T, gc Cache, numbers int) { for i := 0; i < numbers; i++ { key := fmt.Sprintf("Key-%d", i) - value, err := loader(key) + value, err := loader(ctx, key) if err != nil { t.Error(err) return @@ -25,11 +29,11 @@ func testSetCache(t *testing.T, gc Cache, numbers int) { func testGetCache(t *testing.T, gc Cache, numbers int) { for i := 0; i < numbers; i++ { key := fmt.Sprintf("Key-%d", i) - v, err := gc.Get(key) + v, err := gc.Get(ctx, key) if err != nil { t.Errorf("Unexpected error: %v", err) } - expectedV, _ := loader(key) + expectedV, _ := loader(ctx, key) if v != expectedV { t.Errorf("Expected value is %v, not %v", expectedV, v) } @@ -41,19 +45,19 @@ func testGetIFPresent(t *testing.T, evT string) { New(8). EvictType(evT). LoaderFunc( - func(key interface{}) (interface{}, error) { + func(_ context.Context, key interface{}) (interface{}, error) { return "value", nil }). Build() - v, err := cache.GetIFPresent("key") + v, err := cache.GetIFPresent(ctx, "key") if err != KeyNotFoundError { t.Errorf("err should not be %v", err) } time.Sleep(2 * time.Millisecond) - v, err = cache.GetIFPresent("key") + v, err = cache.GetIFPresent(ctx, "key") if err != nil { t.Errorf("err should not be %v", err) } diff --git a/lfu.go b/lfu.go index f781a1f..4d2d7fc 100644 --- a/lfu.go +++ b/lfu.go @@ -2,6 +2,7 @@ package gcache import ( "container/list" + "context" "time" ) @@ -99,10 +100,10 @@ func (c *LFUCache) set(key, value interface{}) (interface{}, error) { // Get a value from cache pool using key if it exists. // If it dose not exists key and has LoaderFunc, // generate a value using `LoaderFunc` method returns value. -func (c *LFUCache) Get(key interface{}) (interface{}, error) { +func (c *LFUCache) Get(ctx context.Context, key interface{}) (interface{}, error) { v, err := c.get(key, false) if err == KeyNotFoundError { - return c.getWithLoader(key, true) + return c.getWithLoader(ctx, key, true) } return v, err } @@ -110,10 +111,10 @@ func (c *LFUCache) Get(key interface{}) (interface{}, error) { // GetIFPresent gets a value from cache pool using key if it exists. // If it dose not exists key, returns KeyNotFoundError. // And send a request which refresh value for specified key if cache object has LoaderFunc. -func (c *LFUCache) GetIFPresent(key interface{}) (interface{}, error) { +func (c *LFUCache) GetIFPresent(ctx context.Context, key interface{}) (interface{}, error) { v, err := c.get(key, false) if err == KeyNotFoundError { - return c.getWithLoader(key, false) + return c.getWithLoader(ctx, key, false) } return v, err } @@ -151,11 +152,11 @@ func (c *LFUCache) getValue(key interface{}, onLoad bool) (interface{}, error) { return nil, KeyNotFoundError } -func (c *LFUCache) getWithLoader(key interface{}, isWait bool) (interface{}, error) { +func (c *LFUCache) getWithLoader(ctx context.Context, key interface{}, isWait bool) (interface{}, error) { if c.loaderExpireFunc == nil { return nil, KeyNotFoundError } - value, _, err := c.load(key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { + value, _, err := c.load(ctx, key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { if e != nil { return nil, e } diff --git a/lfu_test.go b/lfu_test.go index f18b523..8055aec 100644 --- a/lfu_test.go +++ b/lfu_test.go @@ -25,8 +25,8 @@ func TestLoadingLFUGet(t *testing.T) { func TestLFULength(t *testing.T) { gc := buildTestLoadingCache(t, TYPE_LFU, 1000, loader) - gc.Get("test1") - gc.Get("test2") + gc.Get(ctx, "test1") + gc.Get(ctx, "test2") length := gc.Len(true) expectedLength := 2 if length != expectedLength { @@ -40,7 +40,7 @@ func TestLFUEvictItem(t *testing.T) { gc := buildTestLoadingCache(t, TYPE_LFU, cacheSize, loader) for i := 0; i < numbers; i++ { - _, err := gc.Get(fmt.Sprintf("Key-%d", i)) + _, err := gc.Get(ctx, fmt.Sprintf("Key-%d", i)) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -56,8 +56,8 @@ func TestLFUHas(t *testing.T) { for i := 0; i < 10; i++ { t.Run(fmt.Sprint(i), func(t *testing.T) { - gc.Get("test1") - gc.Get("test2") + gc.Get(ctx, "test1") + gc.Get(ctx, "test2") if gc.Has("test0") { t.Fatal("should not have test0") diff --git a/lru.go b/lru.go index a85d660..3228c25 100644 --- a/lru.go +++ b/lru.go @@ -2,6 +2,7 @@ package gcache import ( "container/list" + "context" "time" ) @@ -91,10 +92,10 @@ func (c *LRUCache) SetWithExpire(key, value interface{}, expiration time.Duratio // Get a value from cache pool using key if it exists. // If it dose not exists key and has LoaderFunc, // generate a value using `LoaderFunc` method returns value. -func (c *LRUCache) Get(key interface{}) (interface{}, error) { +func (c *LRUCache) Get(ctx context.Context, key interface{}) (interface{}, error) { v, err := c.get(key, false) if err == KeyNotFoundError { - return c.getWithLoader(key, true) + return c.getWithLoader(ctx, key, true) } return v, err } @@ -102,10 +103,10 @@ func (c *LRUCache) Get(key interface{}) (interface{}, error) { // GetIFPresent gets a value from cache pool using key if it exists. // If it dose not exists key, returns KeyNotFoundError. // And send a request which refresh value for specified key if cache object has LoaderFunc. -func (c *LRUCache) GetIFPresent(key interface{}) (interface{}, error) { +func (c *LRUCache) GetIFPresent(ctx context.Context, key interface{}) (interface{}, error) { v, err := c.get(key, false) if err == KeyNotFoundError { - return c.getWithLoader(key, false) + return c.getWithLoader(ctx, key, false) } return v, err } @@ -144,11 +145,11 @@ func (c *LRUCache) getValue(key interface{}, onLoad bool) (interface{}, error) { return nil, KeyNotFoundError } -func (c *LRUCache) getWithLoader(key interface{}, isWait bool) (interface{}, error) { +func (c *LRUCache) getWithLoader(ctx context.Context, key interface{}, isWait bool) (interface{}, error) { if c.loaderExpireFunc == nil { return nil, KeyNotFoundError } - value, _, err := c.load(key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { + value, _, err := c.load(ctx, key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { if e != nil { return nil, e } diff --git a/lru_test.go b/lru_test.go index fa9140f..881b891 100644 --- a/lru_test.go +++ b/lru_test.go @@ -21,8 +21,8 @@ func TestLoadingLRUGet(t *testing.T) { func TestLRULength(t *testing.T) { gc := buildTestLoadingCache(t, TYPE_LRU, 1000, loader) - gc.Get("test1") - gc.Get("test2") + gc.Get(ctx, "test1") + gc.Get(ctx, "test2") length := gc.Len(true) expectedLength := 2 if length != expectedLength { @@ -36,7 +36,7 @@ func TestLRUEvictItem(t *testing.T) { gc := buildTestLoadingCache(t, TYPE_LRU, cacheSize, loader) for i := 0; i < numbers; i++ { - _, err := gc.Get(fmt.Sprintf("Key-%d", i)) + _, err := gc.Get(ctx, fmt.Sprintf("Key-%d", i)) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -52,8 +52,8 @@ func TestLRUHas(t *testing.T) { for i := 0; i < 10; i++ { t.Run(fmt.Sprint(i), func(t *testing.T) { - gc.Get("test1") - gc.Get("test2") + gc.Get(ctx, "test1") + gc.Get(ctx, "test2") if gc.Has("test0") { t.Fatal("should not have test0") diff --git a/simple.go b/simple.go index 7310af1..5be66b6 100644 --- a/simple.go +++ b/simple.go @@ -1,6 +1,7 @@ package gcache import ( + "context" "time" ) @@ -89,10 +90,10 @@ func (c *SimpleCache) set(key, value interface{}) (interface{}, error) { // Get a value from cache pool using key if it exists. // If it dose not exists key and has LoaderFunc, // generate a value using `LoaderFunc` method returns value. -func (c *SimpleCache) Get(key interface{}) (interface{}, error) { +func (c *SimpleCache) Get(ctx context.Context, key interface{}) (interface{}, error) { v, err := c.get(key, false) if err == KeyNotFoundError { - return c.getWithLoader(key, true) + return c.getWithLoader(ctx, key, true) } return v, err } @@ -100,10 +101,10 @@ func (c *SimpleCache) Get(key interface{}) (interface{}, error) { // GetIFPresent gets a value from cache pool using key if it exists. // If it dose not exists key, returns KeyNotFoundError. // And send a request which refresh value for specified key if cache object has LoaderFunc. -func (c *SimpleCache) GetIFPresent(key interface{}) (interface{}, error) { +func (c *SimpleCache) GetIFPresent(ctx context.Context, key interface{}) (interface{}, error) { v, err := c.get(key, false) if err == KeyNotFoundError { - return c.getWithLoader(key, false) + return c.getWithLoader(ctx, key, false) } return v, nil } @@ -140,11 +141,11 @@ func (c *SimpleCache) getValue(key interface{}, onLoad bool) (interface{}, error return nil, KeyNotFoundError } -func (c *SimpleCache) getWithLoader(key interface{}, isWait bool) (interface{}, error) { - if c.loaderExpireFunc == nil { +func (c *SimpleCache) getWithLoader(ctx context.Context, key interface{}, isWait bool) (interface{}, error) { + if c.loaderExpireFunc == nil || ctx == nil { return nil, KeyNotFoundError } - value, _, err := c.load(key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { + value, _, err := c.load(ctx, key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { if e != nil { return nil, e } diff --git a/simple_test.go b/simple_test.go index 5660d47..87340c7 100644 --- a/simple_test.go +++ b/simple_test.go @@ -21,8 +21,8 @@ func TestLoadingSimpleGet(t *testing.T) { func TestSimpleLength(t *testing.T) { gc := buildTestLoadingCache(t, TYPE_SIMPLE, 1000, loader) - gc.Get("test1") - gc.Get("test2") + gc.Get(ctx, "test1") + gc.Get(ctx, "test2") length := gc.Len(true) expectedLength := 2 if length != expectedLength { @@ -36,7 +36,7 @@ func TestSimpleEvictItem(t *testing.T) { gc := buildTestLoadingCache(t, TYPE_SIMPLE, cacheSize, loader) for i := 0; i < numbers; i++ { - _, err := gc.Get(fmt.Sprintf("Key-%d", i)) + _, err := gc.Get(ctx, fmt.Sprintf("Key-%d", i)) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -54,7 +54,7 @@ func TestSimpleUnboundedNoEviction(t *testing.T) { t.Errorf("Excepted cache size is %v not %v", current_size, size_tracker) } - _, err := gcu.Get(fmt.Sprintf("Key-%d", i)) + _, err := gcu.Get(ctx, fmt.Sprintf("Key-%d", i)) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -72,8 +72,8 @@ func TestSimpleHas(t *testing.T) { for i := 0; i < 10; i++ { t.Run(fmt.Sprint(i), func(t *testing.T) { - gc.Get("test1") - gc.Get("test2") + gc.Get(ctx, "test1") + gc.Get(ctx, "test2") if gc.Has("test0") { t.Fatal("should not have test0") diff --git a/singleflight.go b/singleflight.go index 2c6285e..1c9c4b2 100644 --- a/singleflight.go +++ b/singleflight.go @@ -19,13 +19,16 @@ limitations under the License. // This module provides a duplicate function call suppression // mechanism. -import "sync" +import ( + "context" + "sync" +) // call is an in-flight or completed Do call type call struct { - wg sync.WaitGroup - val interface{} - err error + ready chan struct{} + val interface{} + err error } // Group represents a class of work and forms a namespace in which @@ -40,7 +43,7 @@ type Group struct { // sure that only one execution is in-flight for a given key at a // time. If a duplicate comes in, the duplicate caller waits for the // original to complete and receives the same results. -func (g *Group) Do(key interface{}, fn func() (interface{}, error), isWait bool) (interface{}, bool, error) { +func (g *Group) Do(ctx context.Context, key interface{}, fn func() (interface{}, error), isWait bool) (interface{}, bool, error) { g.mu.Lock() v, err := g.cache.get(key, true) if err == nil { @@ -55,11 +58,15 @@ func (g *Group) Do(key interface{}, fn func() (interface{}, error), isWait bool) if !isWait { return nil, false, KeyNotFoundError } - c.wg.Wait() - return c.val, false, c.err + select { + case <-ctx.Done(): + return nil, false, ctx.Err() + case <-c.ready: + return c.val, false, c.err + } } c := new(call) - c.wg.Add(1) + c.ready = make(chan struct{}) g.m[key] = c g.mu.Unlock() if !isWait { @@ -72,7 +79,7 @@ func (g *Group) Do(key interface{}, fn func() (interface{}, error), isWait bool) func (g *Group) call(c *call, key interface{}, fn func() (interface{}, error)) (interface{}, error) { c.val, c.err = fn() - c.wg.Done() + close(c.ready) g.mu.Lock() delete(g.m, key) diff --git a/singleflight_test.go b/singleflight_test.go index 81b9490..5510426 100644 --- a/singleflight_test.go +++ b/singleflight_test.go @@ -28,7 +28,7 @@ import ( func TestDo(t *testing.T) { var g Group g.cache = New(32).Build() - v, _, err := g.Do("key", func() (interface{}, error) { + v, _, err := g.Do(ctx, "key", func() (interface{}, error) { return "bar", nil }, true) if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { @@ -43,7 +43,7 @@ func TestDoErr(t *testing.T) { var g Group g.cache = New(32).Build() someErr := errors.New("Some error") - v, _, err := g.Do("key", func() (interface{}, error) { + v, _, err := g.Do(ctx, "key", func() (interface{}, error) { return nil, someErr }, true) if err != someErr { @@ -69,7 +69,7 @@ func TestDoDupSuppress(t *testing.T) { for i := 0; i < n; i++ { wg.Add(1) go func() { - v, _, err := g.Do("key", fn, true) + v, _, err := g.Do(ctx, "key", fn, true) if err != nil { t.Errorf("Do error: %v", err) } diff --git a/stats_test.go b/stats_test.go index 1a6543a..f55f7ab 100644 --- a/stats_test.go +++ b/stats_test.go @@ -1,6 +1,7 @@ package gcache import ( + "context" "testing" ) @@ -30,7 +31,7 @@ func TestStats(t *testing.T) { } } -func getter(key interface{}) (interface{}, error) { +func getter(_ context.Context, key interface{}) (interface{}, error) { return key, nil } @@ -43,8 +44,8 @@ func TestCacheStats(t *testing.T) { builder: func() Cache { cc := New(32).Simple().Build() cc.Set(0, 0) - cc.Get(0) - cc.Get(1) + cc.Get(ctx, 0) + cc.Get(ctx, 1) return cc }, rate: 0.5, @@ -53,8 +54,8 @@ func TestCacheStats(t *testing.T) { builder: func() Cache { cc := New(32).LRU().Build() cc.Set(0, 0) - cc.Get(0) - cc.Get(1) + cc.Get(ctx, 0) + cc.Get(ctx, 1) return cc }, rate: 0.5, @@ -63,8 +64,8 @@ func TestCacheStats(t *testing.T) { builder: func() Cache { cc := New(32).LFU().Build() cc.Set(0, 0) - cc.Get(0) - cc.Get(1) + cc.Get(ctx, 0) + cc.Get(ctx, 1) return cc }, rate: 0.5, @@ -73,8 +74,8 @@ func TestCacheStats(t *testing.T) { builder: func() Cache { cc := New(32).ARC().Build() cc.Set(0, 0) - cc.Get(0) - cc.Get(1) + cc.Get(ctx, 0) + cc.Get(ctx, 1) return cc }, rate: 0.5, @@ -86,8 +87,8 @@ func TestCacheStats(t *testing.T) { LoaderFunc(getter). Build() cc.Set(0, 0) - cc.Get(0) - cc.Get(1) + cc.Get(ctx, 0) + cc.Get(ctx, 1) return cc }, rate: 0.5, @@ -99,8 +100,8 @@ func TestCacheStats(t *testing.T) { LoaderFunc(getter). Build() cc.Set(0, 0) - cc.Get(0) - cc.Get(1) + cc.Get(ctx, 0) + cc.Get(ctx, 1) return cc }, rate: 0.5, @@ -112,8 +113,8 @@ func TestCacheStats(t *testing.T) { LoaderFunc(getter). Build() cc.Set(0, 0) - cc.Get(0) - cc.Get(1) + cc.Get(ctx, 0) + cc.Get(ctx, 1) return cc }, rate: 0.5, @@ -125,8 +126,8 @@ func TestCacheStats(t *testing.T) { LoaderFunc(getter). Build() cc.Set(0, 0) - cc.Get(0) - cc.Get(1) + cc.Get(ctx, 0) + cc.Get(ctx, 1) return cc }, rate: 0.5,