Skip to content

Commit

Permalink
Implement authorization webhook cache (#216)
Browse files Browse the repository at this point in the history
Co-authored-by: Hackerwins <[email protected]>
  • Loading branch information
dc7303 and hackerwins authored Jul 22, 2021
1 parent 89b549c commit b4503a4
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 17 deletions.
12 changes: 12 additions & 0 deletions internal/cli/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,18 @@ func init() {
yorkie.DefaultAuthorizationWebhookWaitIntervalMillis,
"Maximum wait interval for authorization webhook.",
)
cmd.Flags().Uint64Var(
&conf.Backend.AuthorizationWebhookCacheAuthorizedTTLSec,
"authorization-webhook-cache-authorized-ttl-sec",
yorkie.DefaultAuthorizationWebhookCacheAuthorizedTTLSec,
"TTL value to set when caching authorized webhook response.",
)
cmd.Flags().Uint64Var(
&conf.Backend.AuthorizationWebhookCacheUnauthorizedTTLSec,
"authorization-webhook-cache-unauthorized-ttl-sec",
yorkie.DefaultAuthorizationWebhookCacheUnauthorizedTTLSec,
"TTL value to set when caching unauthorized webhook response.",
)

rootCmd.AddCommand(cmd)
}
113 changes: 113 additions & 0 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2021 The Yorkie Authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* reference from the Kubernetes repository:
* https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apimachinery/pkg/util/cache/lruexpirecache.go
*/

package cache

import (
"container/list"
"errors"
"sync"
"time"
)

var (
// ErrInvalidMaxSize is returned when the given max size is not positive.
ErrInvalidMaxSize = errors.New("max size must be > 0")
)

// LRUExpireCache is a cache that ensures the mostly recently accessed keys are returned with
// a ttl beyond which keys are forcibly expired.
type LRUExpireCache struct {
lock sync.Mutex

maxSize int
evictionList list.List
entries map[string]*list.Element
}

// NewLRUExpireCache creates an expiring cache with the given size
func NewLRUExpireCache(maxSize int) (*LRUExpireCache, error) {
if maxSize <= 0 {
return nil, ErrInvalidMaxSize
}

return &LRUExpireCache{
maxSize: maxSize,
entries: map[string]*list.Element{},
}, nil
}

type cacheEntry struct {
key string
value interface{}
expireTime time.Time
}

// Add adds the value to the cache at key with the specified maximum duration.
func (c *LRUExpireCache) Add(
key string,
value interface{},
ttl time.Duration,
) {
c.lock.Lock()
defer c.lock.Unlock()

oldElement, ok := c.entries[key]
if ok {
c.evictionList.MoveToFront(oldElement)
oldElement.Value.(*cacheEntry).value = value
oldElement.Value.(*cacheEntry).expireTime = time.Now()
return
}

if c.evictionList.Len() >= c.maxSize {
toEvict := c.evictionList.Back()
c.evictionList.Remove(toEvict)
delete(c.entries, toEvict.Value.(*cacheEntry).key)
}

element := c.evictionList.PushFront(&cacheEntry{
key: key,
value: value,
expireTime: time.Now().Add(ttl),
})
c.entries[key] = element
}

// Get returns the value at the specified key from the cache if it exists and is not
// expired, or returns false.
func (c *LRUExpireCache) Get(key string) (interface{}, bool) {
c.lock.Lock()
defer c.lock.Unlock()

element, ok := c.entries[key]
if !ok {
return nil, false
}

if time.Now().After(element.Value.(*cacheEntry).expireTime) {
c.evictionList.Remove(element)
delete(c.entries, key)
return nil, false
}

c.evictionList.MoveToFront(element)

return element.Value.(*cacheEntry).value, true
}
55 changes: 55 additions & 0 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package cache_test

import (
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/yorkie-team/yorkie/pkg/cache"
)

func TestCache(t *testing.T) {
t.Run("create lru expire cache test", func(t *testing.T) {
lruCache, err := cache.NewLRUExpireCache(1)
assert.NoError(t, err)
assert.NotNil(t, lruCache)

lruCache, err = cache.NewLRUExpireCache(0)
assert.ErrorIs(t, err, cache.ErrInvalidMaxSize)
assert.Nil(t, lruCache)
})

t.Run("add test", func(t *testing.T) {
lruCache, err := cache.NewLRUExpireCache(1)
assert.NoError(t, err)

lruCache.Add("request1", "response1", time.Second)
response1, ok := lruCache.Get("request1")
assert.True(t, ok)
assert.NotNil(t, response1)

lruCache.Add("request2", "response2", time.Second)
response2, ok := lruCache.Get("request2")
assert.True(t, ok)
assert.NotNil(t, response2)

// max size of the current cache is 1
response1, ok = lruCache.Get("request1")
assert.False(t, ok)
assert.Nil(t, response1)
})

t.Run("get expired cache test", func(t *testing.T) {
lruCache, err := cache.NewLRUExpireCache(1)
assert.NoError(t, err)

ttl := time.Millisecond
lruCache.Add("request", "response", ttl)

time.Sleep(ttl)
response, ok := lruCache.Get("request")
assert.False(t, ok)
assert.Nil(t, response)
})
}
110 changes: 107 additions & 3 deletions test/integration/auth_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/rs/xid"
"github.com/stretchr/testify/assert"
Expand All @@ -31,6 +32,7 @@ import (

"github.com/yorkie-team/yorkie/client"
"github.com/yorkie-team/yorkie/pkg/document"
"github.com/yorkie-team/yorkie/pkg/document/proxy"
"github.com/yorkie-team/yorkie/pkg/types"
"github.com/yorkie-team/yorkie/test/helper"
"github.com/yorkie-team/yorkie/yorkie"
Expand Down Expand Up @@ -168,9 +170,7 @@ func TestAuthWebhook(t *testing.T) {
})

t.Run("authorization webhook that fails after retries test", func(t *testing.T) {
var recoveryCnt uint64
recoveryCnt = 4
server := newUnavailableAuthServer(t, recoveryCnt)
server := newUnavailableAuthServer(t, 4)

conf := helper.TestConfig(server.URL)
conf.Backend.AuthorizationWebhookMaxRetries = 2
Expand All @@ -188,4 +188,108 @@ func TestAuthWebhook(t *testing.T) {
err = cli.Activate(ctx)
assert.Equal(t, codes.Unauthenticated, status.Convert(err).Code())
})

t.Run("authorized request cache test", func(t *testing.T) {
reqCnt := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, err := types.NewAuthWebhookRequest(r.Body)
assert.NoError(t, err)

var res types.AuthWebhookResponse
res.Allowed = true

_, err = res.Write(w)
assert.NoError(t, err)

if req.Method == types.PushPull {
reqCnt++
}
}))

var authorizedTTLSec uint64 = 1
conf := helper.TestConfig(server.URL)
conf.Backend.AuthorizationWebhookCacheAuthorizedTTLSec = authorizedTTLSec

agent, err := yorkie.New(conf)
assert.NoError(t, err)
assert.NoError(t, agent.Start())
defer func() { assert.NoError(t, agent.Shutdown(true)) }()

ctx := context.Background()
cli, err := client.Dial(agent.RPCAddr(), client.Option{Token: "token"})
assert.NoError(t, err)
defer func() { assert.NoError(t, cli.Close()) }()

err = cli.Activate(ctx)
assert.NoError(t, err)

doc := document.New(helper.Collection, t.Name())
err = cli.Attach(ctx, doc)
assert.NoError(t, err)

// 01. multiple requests to update the document.
for i := 0; i < 3; i++ {
assert.NoError(t, doc.Update(func(root *proxy.ObjectProxy) error {
root.SetNewObject("k1")
return nil
}))
assert.NoError(t, cli.Sync(ctx))
}

// 02. multiple requests to update the document after eviction by ttl.
time.Sleep(time.Duration(authorizedTTLSec) * time.Second)
for i := 0; i < 3; i++ {
assert.NoError(t, doc.Update(func(root *proxy.ObjectProxy) error {
root.SetNewObject("k1")
return nil
}))
assert.NoError(t, cli.Sync(ctx))
}

assert.Equal(t, 2, reqCnt)
})

t.Run("unauthorized request cache test", func(t *testing.T) {
reqCnt := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := types.NewAuthWebhookRequest(r.Body)
assert.NoError(t, err)

var res types.AuthWebhookResponse
res.Allowed = false

_, err = res.Write(w)
assert.NoError(t, err)

reqCnt++
}))

var unauthorizedTTLSec uint64 = 1
conf := helper.TestConfig(server.URL)
conf.Backend.AuthorizationWebhookCacheUnauthorizedTTLSec = unauthorizedTTLSec

agent, err := yorkie.New(conf)
assert.NoError(t, err)
assert.NoError(t, agent.Start())
defer func() { assert.NoError(t, agent.Shutdown(true)) }()

ctx := context.Background()
cli, err := client.Dial(agent.RPCAddr(), client.Option{Token: "token"})
assert.NoError(t, err)
defer func() { assert.NoError(t, cli.Close()) }()

// 01. multiple requests.
for i := 0; i < 3; i++ {
err = cli.Activate(ctx)
assert.Equal(t, codes.Unauthenticated, status.Convert(err).Code())
}

// 02. multiple requests after eviction by ttl.
time.Sleep(time.Duration(unauthorizedTTLSec) * time.Second)
for i := 0; i < 3; i++ {
err = cli.Activate(ctx)
assert.Equal(t, codes.Unauthenticated, status.Convert(err).Code())
}
assert.Equal(t, 2, reqCnt)
})
}
29 changes: 26 additions & 3 deletions yorkie/auth/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,17 @@ func VerifyAccess(ctx context.Context, be *backend.Backend, info *types.AccessIn
return err
}

return withExponentialBackoff(ctx, be.Config, func() (int, error) {
cacheKey := string(reqBody)
if entry, ok := be.AuthWebhookCache.Get(cacheKey); ok {
resp := entry.(*types.AuthWebhookResponse)
if !resp.Allowed {
return fmt.Errorf("%s: %w", resp.Reason, ErrNotAllowed)
}
return nil
}

var authResp *types.AuthWebhookResponse
if err := withExponentialBackoff(ctx, be.Config, func() (int, error) {
resp, err := http.Post(
be.Config.AuthorizationWebhookURL,
"application/json",
Expand All @@ -78,7 +88,7 @@ func VerifyAccess(ctx context.Context, be *backend.Backend, info *types.AccessIn
return resp.StatusCode, ErrUnexpectedStatusCode
}

authResp, err := types.NewAuthWebhookResponse(resp.Body)
authResp, err = types.NewAuthWebhookResponse(resp.Body)
if err != nil {
return resp.StatusCode, err
}
Expand All @@ -88,7 +98,19 @@ func VerifyAccess(ctx context.Context, be *backend.Backend, info *types.AccessIn
}

return resp.StatusCode, nil
})
}); err != nil {
if errors.Is(err, ErrNotAllowed) {
unauthorizedTTL := time.Duration(be.Config.AuthorizationWebhookCacheUnauthorizedTTLSec) * time.Second
be.AuthWebhookCache.Add(cacheKey, authResp, unauthorizedTTL)
}

return err
}

authorizedTTL := time.Duration(be.Config.AuthorizationWebhookCacheAuthorizedTTLSec) * time.Second
be.AuthWebhookCache.Add(cacheKey, authResp, authorizedTTL)

return nil
}

func withExponentialBackoff(ctx context.Context, cfg *backend.Config, webhookFn func() (int, error)) error {
Expand All @@ -100,6 +122,7 @@ func withExponentialBackoff(ctx context.Context, cfg *backend.Config, webhookFn
if err == ErrUnexpectedStatusCode {
return fmt.Errorf("unexpected status code from webhook: %d", statusCode)
}

return err
}

Expand Down
Loading

0 comments on commit b4503a4

Please sign in to comment.