Skip to content

Commit

Permalink
add domain verification on register
Browse files Browse the repository at this point in the history
  • Loading branch information
rasoro committed Jan 17, 2025
1 parent 39a4742 commit d75ff5a
Show file tree
Hide file tree
Showing 14 changed files with 338 additions and 9 deletions.
14 changes: 13 additions & 1 deletion api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/go-redis/redis/v8"
"github.com/ilhasoft/wwcs/config"
"github.com/ilhasoft/wwcs/pkg/db"
"github.com/ilhasoft/wwcs/pkg/flows"
"github.com/ilhasoft/wwcs/pkg/history"
"github.com/ilhasoft/wwcs/pkg/metric"
"github.com/ilhasoft/wwcs/pkg/queue"
Expand Down Expand Up @@ -87,7 +88,18 @@ func main() {

clientM := websocket.NewClientManager(rdb, int(queueConfig.ClientTTL))

app := websocket.NewApp(websocket.NewPool(), rdb, mdb, metrics, histories, clientM, queueConn)
flowsClient := flows.NewClient(config.Get().FlowsURL)

app := websocket.NewApp(
websocket.NewPool(),
rdb,
mdb,
metrics,
histories,
clientM,
queueConn,
flowsClient,
)
app.StartConnectionsHeartbeat()
websocket.SetupRoutes(app)

Expand Down
4 changes: 4 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ type Configuration struct {
RedisQueue RedisQueue
SentryDSN string `env:"WWC_APP_SENTRY_DSN"`
DB DB

RestrictDomains bool `default:"false" env:"WWC_RESTRICT_DOMAINS"`

FlowsURL string `default:"flows.weni.ai" env:"WWC_FLOWS_URL"`
}

type S3 struct {
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM golang:1.17.11-alpine3.16 AS base
FROM golang:1.23-alpine3.20 AS base

ARG APP_UID=1000
ARG APP_GID=1000
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/ilhasoft/wwcs

go 1.17
go 1.23

require (
github.com/adjust/rmq/v4 v4.0.1
Expand Down
11 changes: 11 additions & 0 deletions local_test_webchat.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<script>
(function (d, s, u) {
let h = d.getElementsByTagName(s)[0], k = d.createElement(s);
k.onload = function () {
let l = d.createElement(s); l.src = u; l.async = true;
h.parentNode.insertBefore(l, k.nextSibling);
};
k.async = true; k.src = 'https://storage.googleapis.com/push-webchat/wwc-latest.js';
h.parentNode.insertBefore(k, h);
})(document, 'script', './script.js');
</script>
42 changes: 42 additions & 0 deletions pkg/flows/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package flows

import (
"encoding/json"
"fmt"
"net/http"
)

type IClient interface {
GetChannelAllowedDomains(string) ([]string, error)
}

type Client struct {
BaseURL string `json:"base_url"`
}

func NewClient(baseURL string) *Client {
return &Client{
BaseURL: baseURL,
}
}

func (c *Client) GetChannelAllowedDomains(channelUUID string) ([]string, error) {
url := fmt.Sprintf("%s/api/v2/internals/channel_allowed_domains?channel=%s", c.BaseURL, channelUUID)
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get channel allowed domains, status code: %d", resp.StatusCode)
}

var domains []string
err = json.NewDecoder(resp.Body).Decode(&domains)
if err != nil {
return nil, err
}

return domains, nil
}
52 changes: 52 additions & 0 deletions pkg/flows/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package flows

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetChannelAllowedDomains(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("[\"domain1.com\", \"domain2.com\"]"))
}))
defer server.Close()

client := Client{BaseURL: server.URL}

domains, err := client.GetChannelAllowedDomains("09bf3dee-973e-43d3-8b94-441406c4a565")

assert.NoError(t, err)
assert.Equal(t, 2, len(domains))
}

func TestGetChannelAllowedDomainsStatus404(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()

client := Client{BaseURL: server.URL}

_, err := client.GetChannelAllowedDomains("09bf3dee-973e-43d3-8b94-441406c4a565")

assert.Equal(t, err.Error(), "failed to get channel allowed domains, status code: 404")
}

func TestGetChannelAllowedDomainsStatusWithNoDomain(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("[]"))
}))
defer server.Close()

client := Client{BaseURL: server.URL}

domains, err := client.GetChannelAllowedDomains("09bf3dee-973e-43d3-8b94-441406c4a565")

assert.NoError(t, err)
assert.Equal(t, 0, len(domains))
}
49 changes: 49 additions & 0 deletions pkg/memcache/memcache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package memcache

import (
"sync"
"time"
)

type Cache[K comparable, V any] struct {

Check failure on line 8 in pkg/memcache/memcache.go

View workflow job for this annotation

GitHub Actions / Test

expected ']', found comparable
items map[K]item[V]
mu sync.Mutex
}

type item[V any] struct {

Check failure on line 13 in pkg/memcache/memcache.go

View workflow job for this annotation

GitHub Actions / Test

expected ']', found any
value V
expiry time.Time
deleted bool
}

func New[K comparable, V any]() *Cache[K, V] {
return &Cache[K, V]{

Check failure on line 20 in pkg/memcache/memcache.go

View workflow job for this annotation

GitHub Actions / Test

expected declaration, found 'return'
items: make(map[K]item[V]),
}
}

func (c *Cache[K, V]) Set(key K, value V, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.items[key] = item[V]{
value: value,
expiry: time.Now().Add(ttl),
}
}

func (c *Cache[K, V]) Get(key K) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()
item, found := c.items[key]
if !found || time.Now().After(item.expiry) || item.deleted {
delete(c.items, key)
return item.value, false
}
return item.value, true
}

func (c *Cache[K, V]) Remove(key K) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.items, key)
}
26 changes: 26 additions & 0 deletions pkg/memcache/memcache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package memcache

import (
"testing"
"time"
)

func TestCache(t *testing.T) {
cacheDomains := New[string, []string]()

Check failure on line 9 in pkg/memcache/memcache_test.go

View workflow job for this annotation

GitHub Actions / Test

expected ']' or ':', found ','
chanUUID := "5f610454-98c1-4a54-9499-e2d2b9b68334"
cacheDomains.Set(chanUUID, []string{"127.0.0.1", "localhost"}, time.Duration(time.Second*2))

chan1domains, ok := cacheDomains.Get(chanUUID)
if !ok {
t.Error("Expected channel UUID to be found")
}
if len(chan1domains) != 2 {
t.Error("Expected 2 domains in cache, got", len(chan1domains))
}
time.Sleep(3 * time.Second)

_, ok = cacheDomains.Get(chanUUID)
if ok {
t.Error("Expected channel UUID not to be found")
}
}
5 changes: 4 additions & 1 deletion pkg/websocket/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/go-redis/redis/v8"
"github.com/ilhasoft/wwcs/pkg/flows"
"github.com/ilhasoft/wwcs/pkg/history"
"github.com/ilhasoft/wwcs/pkg/metric"
"github.com/ilhasoft/wwcs/pkg/queue"
Expand All @@ -21,10 +22,11 @@ type App struct {
Histories history.Service
ClientManager ClientManager
QueueConnectionManager queue.Connection
FlowsClient flows.IClient
}

// Create new App instance.
func NewApp(pool *ClientPool, rdb *redis.Client, mdb *mongo.Database, metrics *metric.Service, histories history.Service, clientM ClientManager, qconnM queue.Connection) *App {
func NewApp(pool *ClientPool, rdb *redis.Client, mdb *mongo.Database, metrics *metric.Service, histories history.Service, clientM ClientManager, qconnM queue.Connection, fc flows.IClient) *App {
return &App{
ClientPool: pool,
RDB: rdb,
Expand All @@ -33,6 +35,7 @@ func NewApp(pool *ClientPool, rdb *redis.Client, mdb *mongo.Database, metrics *m
Histories: histories,
ClientManager: clientM,
QueueConnectionManager: qconnM,
FlowsClient: fc,
}
}

Expand Down
48 changes: 48 additions & 0 deletions pkg/websocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/gorilla/websocket"
"github.com/ilhasoft/wwcs/config"
"github.com/ilhasoft/wwcs/pkg/history"
"github.com/ilhasoft/wwcs/pkg/memcache"
"github.com/ilhasoft/wwcs/pkg/metric"
"github.com/ilhasoft/wwcs/pkg/queue"
"github.com/pkg/errors"
Expand All @@ -31,6 +32,8 @@ var (
ErrorNeedRegistration = errors.New("unable to redirect: id and url is blank")
)

var cacheChannelDomains = memcache.New[string, []string]()

Check failure on line 35 in pkg/websocket/client.go

View workflow job for this annotation

GitHub Actions / Test

expected ']' or ':', found ','

// Client side data
type Client struct {
ID string
Expand Down Expand Up @@ -174,8 +177,53 @@ func CloseClientSession(payload OutgoingPayload, app *App) error {
return nil
}

func CheckAllowedDomain(app *App, channelUUID string, originDomain string) bool {
var allowedDomains []string = nil
var err error
cachedDomains, notexpired := cacheChannelDomains.Get(channelUUID)
if !notexpired {
allowedDomains = cachedDomains
} else {
allowedDomains, err = app.FlowsClient.GetChannelAllowedDomains(channelUUID)
if err != nil {
log.Error("Error on get allowed domains", err)
return false
}
cacheChannelDomains.Set(channelUUID, allowedDomains, time.Minute*5)
}
if len(allowedDomains) > 0 {
for _, domain := range allowedDomains {
if originDomain == domain {
return true
}
}
return false
}
return true
}

func OriginToDomain(origin string) (string, error) {
u, err := url.Parse(origin)
if err != nil {
fmt.Println("Error on parse URL to get domain:", err)
return "", err
}
domain := strings.Split(u.Host, ":")[0]
return domain, nil
}

// Register register an user
func (c *Client) Register(payload OutgoingPayload, triggerTo postJSON, app *App) error {
if config.Get().RestrictDomains {
domain, err := OriginToDomain(c.Origin)
if err != nil {
return err
}
allowed := CheckAllowedDomain(app, payload.ChannelUUID(), domain)
if !allowed {
return errors.New("domain not allowed")
}
}
start := time.Now()
err := validateOutgoingPayloadRegister(payload)
if err != nil {
Expand Down
Loading

0 comments on commit d75ff5a

Please sign in to comment.