From b4503a450070da701333ff4844a0b33e7eb3138d Mon Sep 17 00:00:00 2001 From: Dongcheol Choe <40932237+dc7303@users.noreply.github.com> Date: Thu, 22 Jul 2021 10:25:16 +0900 Subject: [PATCH] Implement authorization webhook cache (#216) Co-authored-by: Hackerwins --- internal/cli/agent.go | 12 +++ pkg/cache/cache.go | 113 ++++++++++++++++++++++++++ pkg/cache/cache_test.go | 55 +++++++++++++ test/integration/auth_webhook_test.go | 110 ++++++++++++++++++++++++- yorkie/auth/webhook.go | 29 ++++++- yorkie/backend/backend.go | 34 ++++++-- yorkie/config.go | 6 +- 7 files changed, 342 insertions(+), 17 deletions(-) create mode 100644 pkg/cache/cache.go create mode 100644 pkg/cache/cache_test.go diff --git a/internal/cli/agent.go b/internal/cli/agent.go index a60198c51..4e7eb4a20 100644 --- a/internal/cli/agent.go +++ b/internal/cli/agent.go @@ -231,6 +231,18 @@ func init() { yorkie.DefaultAuthorizationWebhookWaitIntervalMillis, "Maximum wait interval for authorization webhook.", ) + cmd.Flags().Uint64Var( + &conf.Backend.AuthorizationWebhookCacheAuthorizedTTLSec, + "authorization-webhook-cache-authorized-ttl-sec", + yorkie.DefaultAuthorizationWebhookCacheAuthorizedTTLSec, + "TTL value to set when caching authorized webhook response.", + ) + cmd.Flags().Uint64Var( + &conf.Backend.AuthorizationWebhookCacheUnauthorizedTTLSec, + "authorization-webhook-cache-unauthorized-ttl-sec", + yorkie.DefaultAuthorizationWebhookCacheUnauthorizedTTLSec, + "TTL value to set when caching unauthorized webhook response.", + ) rootCmd.AddCommand(cmd) } diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 000000000..f2187ccd6 --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,113 @@ +/* + * Copyright 2021 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * reference from the Kubernetes repository: + * https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apimachinery/pkg/util/cache/lruexpirecache.go + */ + +package cache + +import ( + "container/list" + "errors" + "sync" + "time" +) + +var ( + // ErrInvalidMaxSize is returned when the given max size is not positive. + ErrInvalidMaxSize = errors.New("max size must be > 0") +) + +// LRUExpireCache is a cache that ensures the mostly recently accessed keys are returned with +// a ttl beyond which keys are forcibly expired. +type LRUExpireCache struct { + lock sync.Mutex + + maxSize int + evictionList list.List + entries map[string]*list.Element +} + +// NewLRUExpireCache creates an expiring cache with the given size +func NewLRUExpireCache(maxSize int) (*LRUExpireCache, error) { + if maxSize <= 0 { + return nil, ErrInvalidMaxSize + } + + return &LRUExpireCache{ + maxSize: maxSize, + entries: map[string]*list.Element{}, + }, nil +} + +type cacheEntry struct { + key string + value interface{} + expireTime time.Time +} + +// Add adds the value to the cache at key with the specified maximum duration. +func (c *LRUExpireCache) Add( + key string, + value interface{}, + ttl time.Duration, +) { + c.lock.Lock() + defer c.lock.Unlock() + + oldElement, ok := c.entries[key] + if ok { + c.evictionList.MoveToFront(oldElement) + oldElement.Value.(*cacheEntry).value = value + oldElement.Value.(*cacheEntry).expireTime = time.Now() + return + } + + if c.evictionList.Len() >= c.maxSize { + toEvict := c.evictionList.Back() + c.evictionList.Remove(toEvict) + delete(c.entries, toEvict.Value.(*cacheEntry).key) + } + + element := c.evictionList.PushFront(&cacheEntry{ + key: key, + value: value, + expireTime: time.Now().Add(ttl), + }) + c.entries[key] = element +} + +// Get returns the value at the specified key from the cache if it exists and is not +// expired, or returns false. +func (c *LRUExpireCache) Get(key string) (interface{}, bool) { + c.lock.Lock() + defer c.lock.Unlock() + + element, ok := c.entries[key] + if !ok { + return nil, false + } + + if time.Now().After(element.Value.(*cacheEntry).expireTime) { + c.evictionList.Remove(element) + delete(c.entries, key) + return nil, false + } + + c.evictionList.MoveToFront(element) + + return element.Value.(*cacheEntry).value, true +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 000000000..66e47d943 --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,55 @@ +package cache_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/yorkie-team/yorkie/pkg/cache" +) + +func TestCache(t *testing.T) { + t.Run("create lru expire cache test", func(t *testing.T) { + lruCache, err := cache.NewLRUExpireCache(1) + assert.NoError(t, err) + assert.NotNil(t, lruCache) + + lruCache, err = cache.NewLRUExpireCache(0) + assert.ErrorIs(t, err, cache.ErrInvalidMaxSize) + assert.Nil(t, lruCache) + }) + + t.Run("add test", func(t *testing.T) { + lruCache, err := cache.NewLRUExpireCache(1) + assert.NoError(t, err) + + lruCache.Add("request1", "response1", time.Second) + response1, ok := lruCache.Get("request1") + assert.True(t, ok) + assert.NotNil(t, response1) + + lruCache.Add("request2", "response2", time.Second) + response2, ok := lruCache.Get("request2") + assert.True(t, ok) + assert.NotNil(t, response2) + + // max size of the current cache is 1 + response1, ok = lruCache.Get("request1") + assert.False(t, ok) + assert.Nil(t, response1) + }) + + t.Run("get expired cache test", func(t *testing.T) { + lruCache, err := cache.NewLRUExpireCache(1) + assert.NoError(t, err) + + ttl := time.Millisecond + lruCache.Add("request", "response", ttl) + + time.Sleep(ttl) + response, ok := lruCache.Get("request") + assert.False(t, ok) + assert.Nil(t, response) + }) +} diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index e322d68f9..7629b0d7a 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/rs/xid" "github.com/stretchr/testify/assert" @@ -31,6 +32,7 @@ import ( "github.com/yorkie-team/yorkie/client" "github.com/yorkie-team/yorkie/pkg/document" + "github.com/yorkie-team/yorkie/pkg/document/proxy" "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/test/helper" "github.com/yorkie-team/yorkie/yorkie" @@ -168,9 +170,7 @@ func TestAuthWebhook(t *testing.T) { }) t.Run("authorization webhook that fails after retries test", func(t *testing.T) { - var recoveryCnt uint64 - recoveryCnt = 4 - server := newUnavailableAuthServer(t, recoveryCnt) + server := newUnavailableAuthServer(t, 4) conf := helper.TestConfig(server.URL) conf.Backend.AuthorizationWebhookMaxRetries = 2 @@ -188,4 +188,108 @@ func TestAuthWebhook(t *testing.T) { err = cli.Activate(ctx) assert.Equal(t, codes.Unauthenticated, status.Convert(err).Code()) }) + + t.Run("authorized request cache test", func(t *testing.T) { + reqCnt := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req, err := types.NewAuthWebhookRequest(r.Body) + assert.NoError(t, err) + + var res types.AuthWebhookResponse + res.Allowed = true + + _, err = res.Write(w) + assert.NoError(t, err) + + if req.Method == types.PushPull { + reqCnt++ + } + })) + + var authorizedTTLSec uint64 = 1 + conf := helper.TestConfig(server.URL) + conf.Backend.AuthorizationWebhookCacheAuthorizedTTLSec = authorizedTTLSec + + agent, err := yorkie.New(conf) + assert.NoError(t, err) + assert.NoError(t, agent.Start()) + defer func() { assert.NoError(t, agent.Shutdown(true)) }() + + ctx := context.Background() + cli, err := client.Dial(agent.RPCAddr(), client.Option{Token: "token"}) + assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + + err = cli.Activate(ctx) + assert.NoError(t, err) + + doc := document.New(helper.Collection, t.Name()) + err = cli.Attach(ctx, doc) + assert.NoError(t, err) + + // 01. multiple requests to update the document. + for i := 0; i < 3; i++ { + assert.NoError(t, doc.Update(func(root *proxy.ObjectProxy) error { + root.SetNewObject("k1") + return nil + })) + assert.NoError(t, cli.Sync(ctx)) + } + + // 02. multiple requests to update the document after eviction by ttl. + time.Sleep(time.Duration(authorizedTTLSec) * time.Second) + for i := 0; i < 3; i++ { + assert.NoError(t, doc.Update(func(root *proxy.ObjectProxy) error { + root.SetNewObject("k1") + return nil + })) + assert.NoError(t, cli.Sync(ctx)) + } + + assert.Equal(t, 2, reqCnt) + }) + + t.Run("unauthorized request cache test", func(t *testing.T) { + reqCnt := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := types.NewAuthWebhookRequest(r.Body) + assert.NoError(t, err) + + var res types.AuthWebhookResponse + res.Allowed = false + + _, err = res.Write(w) + assert.NoError(t, err) + + reqCnt++ + })) + + var unauthorizedTTLSec uint64 = 1 + conf := helper.TestConfig(server.URL) + conf.Backend.AuthorizationWebhookCacheUnauthorizedTTLSec = unauthorizedTTLSec + + agent, err := yorkie.New(conf) + assert.NoError(t, err) + assert.NoError(t, agent.Start()) + defer func() { assert.NoError(t, agent.Shutdown(true)) }() + + ctx := context.Background() + cli, err := client.Dial(agent.RPCAddr(), client.Option{Token: "token"}) + assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + + // 01. multiple requests. + for i := 0; i < 3; i++ { + err = cli.Activate(ctx) + assert.Equal(t, codes.Unauthenticated, status.Convert(err).Code()) + } + + // 02. multiple requests after eviction by ttl. + time.Sleep(time.Duration(unauthorizedTTLSec) * time.Second) + for i := 0; i < 3; i++ { + err = cli.Activate(ctx) + assert.Equal(t, codes.Unauthenticated, status.Convert(err).Code()) + } + assert.Equal(t, 2, reqCnt) + }) } diff --git a/yorkie/auth/webhook.go b/yorkie/auth/webhook.go index ff314b2c1..e3c58bf84 100644 --- a/yorkie/auth/webhook.go +++ b/yorkie/auth/webhook.go @@ -58,7 +58,17 @@ func VerifyAccess(ctx context.Context, be *backend.Backend, info *types.AccessIn return err } - return withExponentialBackoff(ctx, be.Config, func() (int, error) { + cacheKey := string(reqBody) + if entry, ok := be.AuthWebhookCache.Get(cacheKey); ok { + resp := entry.(*types.AuthWebhookResponse) + if !resp.Allowed { + return fmt.Errorf("%s: %w", resp.Reason, ErrNotAllowed) + } + return nil + } + + var authResp *types.AuthWebhookResponse + if err := withExponentialBackoff(ctx, be.Config, func() (int, error) { resp, err := http.Post( be.Config.AuthorizationWebhookURL, "application/json", @@ -78,7 +88,7 @@ func VerifyAccess(ctx context.Context, be *backend.Backend, info *types.AccessIn return resp.StatusCode, ErrUnexpectedStatusCode } - authResp, err := types.NewAuthWebhookResponse(resp.Body) + authResp, err = types.NewAuthWebhookResponse(resp.Body) if err != nil { return resp.StatusCode, err } @@ -88,7 +98,19 @@ func VerifyAccess(ctx context.Context, be *backend.Backend, info *types.AccessIn } return resp.StatusCode, nil - }) + }); err != nil { + if errors.Is(err, ErrNotAllowed) { + unauthorizedTTL := time.Duration(be.Config.AuthorizationWebhookCacheUnauthorizedTTLSec) * time.Second + be.AuthWebhookCache.Add(cacheKey, authResp, unauthorizedTTL) + } + + return err + } + + authorizedTTL := time.Duration(be.Config.AuthorizationWebhookCacheAuthorizedTTLSec) * time.Second + be.AuthWebhookCache.Add(cacheKey, authResp, authorizedTTL) + + return nil } func withExponentialBackoff(ctx context.Context, cfg *backend.Config, webhookFn func() (int, error)) error { @@ -100,6 +122,7 @@ func withExponentialBackoff(ctx context.Context, cfg *backend.Config, webhookFn if err == ErrUnexpectedStatusCode { return fmt.Errorf("unexpected status code from webhook: %d", statusCode) } + return err } diff --git a/yorkie/backend/backend.go b/yorkie/backend/backend.go index 9d88d7ae6..4165a9357 100644 --- a/yorkie/backend/backend.go +++ b/yorkie/backend/backend.go @@ -25,6 +25,7 @@ import ( "github.com/rs/xid" "github.com/yorkie-team/yorkie/internal/log" + "github.com/yorkie-team/yorkie/pkg/cache" "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/yorkie/backend/db" "github.com/yorkie-team/yorkie/yorkie/backend/db/mongo" @@ -34,6 +35,8 @@ import ( "github.com/yorkie-team/yorkie/yorkie/metrics" ) +const authWebhookCacheSize = 5000 + // Config is the configuration for creating a Backend instance. type Config struct { // SnapshotThreshold is the threshold that determines if changes should be @@ -56,6 +59,12 @@ type Config struct { // AuthorizationWebhookMaxWaitIntervalMillis is the max interval // that waits before retrying the authorization webhook. AuthorizationWebhookMaxWaitIntervalMillis uint64 `json:"AuthorizationWebhookMaxWaitIntervalMillis"` + + // AuthorizationWebhookCacheAuthorizedTTLSec is the TTL value to set when caching the authorized result. + AuthorizationWebhookCacheAuthorizedTTLSec uint64 `json:"AuthorizationWebhookCacheAuthorizedTTLSec"` + + // AuthorizationWebhookCacheAuthorizedTTLSec is the TTL value to set when caching the unauthorized result. + AuthorizationWebhookCacheUnauthorizedTTLSec uint64 `json:"AuthorizationWebhookCacheUnauthorizedTTLSec"` } // RequireAuth returns whether the given method require authorization. @@ -94,9 +103,10 @@ type Backend struct { Config *Config agentInfo *sync.AgentInfo - DB db.DB - Coordinator sync.Coordinator - Metrics metrics.Metrics + DB db.DB + Coordinator sync.Coordinator + Metrics metrics.Metrics + AuthWebhookCache *cache.LRUExpireCache // closing is closed by backend close. closing chan struct{} @@ -155,13 +165,19 @@ func New( agentInfo.RPCAddr, ) + lruCache, err := cache.NewLRUExpireCache(authWebhookCacheSize) + if err != nil { + return nil, err + } + return &Backend{ - Config: conf, - agentInfo: agentInfo, - DB: mongoClient, - Coordinator: coordinator, - Metrics: met, - closing: make(chan struct{}), + Config: conf, + agentInfo: agentInfo, + DB: mongoClient, + Coordinator: coordinator, + Metrics: met, + AuthWebhookCache: lruCache, + closing: make(chan struct{}), }, nil } diff --git a/yorkie/config.go b/yorkie/config.go index c554b6bf2..b9318ce00 100644 --- a/yorkie/config.go +++ b/yorkie/config.go @@ -43,8 +43,10 @@ const ( DefaultSnapshotThreshold = 500 DefaultSnapshotInterval = 100 - DefaultAuthorizationWebhookMaxRetries = 10 - DefaultAuthorizationWebhookWaitIntervalMillis = 3000 + DefaultAuthorizationWebhookMaxRetries = 10 + DefaultAuthorizationWebhookWaitIntervalMillis = 3000 + DefaultAuthorizationWebhookCacheAuthorizedTTLSec = 10 + DefaultAuthorizationWebhookCacheUnauthorizedTTLSec = 10 ) // Config is the configuration for creating a Yorkie instance.