Skip to content

Commit

Permalink
Merge pull request #94 from netlify/add-ban-list
Browse files Browse the repository at this point in the history
add banlist and tests
  • Loading branch information
rybit authored Sep 25, 2019
2 parents 4d128d6 + e8272e1 commit bd38f31
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 0 deletions.
119 changes: 119 additions & 0 deletions http/banlist/banlist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package banlist

import (
"encoding/json"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"

"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

type Config struct {
Domains []string
URLs []string
}

type Banlist struct {
domainHolder atomic.Value
urlHolder atomic.Value
mtx sync.Mutex
ch chan os.Signal
log logrus.FieldLogger
path string
}

func New(log logrus.FieldLogger, filepath string) *Banlist {
bl := newBanlist(log, filepath)
bl.listen()
bl.runUpdate()
return bl
}

func newBanlist(log logrus.FieldLogger, path string) *Banlist {
bl := &Banlist{log: log, path: path}
bl.domainHolder.Store(make(map[string]struct{}))
bl.urlHolder.Store(make(map[string]struct{}))
return bl
}

func (b *Banlist) listen() {
b.ch = make(chan os.Signal, 1)
signal.Notify(b.ch, syscall.SIGHUP)
go func() {
for range b.ch {
b.runUpdate()
}
b.log.Info("No longer listening for SIGHUP")
}()
}

func (b *Banlist) runUpdate() {
if err := b.update(); err != nil {
b.log.WithError(err).Warn("error updating banlist")
} else {
b.log.Info("banlist updated")
}
}

func (b *Banlist) update() error {
b.mtx.Lock()
defer b.mtx.Unlock()

f, err := os.Open(b.path)
if err != nil {
return errors.Wrap(err, "error opening banlist config")
}
defer f.Close()

c := new(Config)
if err := json.NewDecoder(f).Decode(c); err != nil {
return errors.Wrap(err, "error decoding banlist config")
}

domains := make(map[string]struct{})
urls := make(map[string]struct{})
for _, el := range c.Domains {
domains[strings.ToLower(el)] = struct{}{}
}
for _, el := range c.URLs {
urls[strings.ToLower(el)] = struct{}{}
}

b.domainHolder.Store(domains)
b.urlHolder.Store(urls)
return nil
}

// CheckRequest will check if the domain is blocked or the path is blocked
func (b *Banlist) CheckRequest(r *http.Request) bool {
domain := strings.SplitN(r.Host, ":", 2)[0]
if _, ok := b.domains()[strings.ToLower(domain)]; ok {
return true
}

url := domain + r.URL.Path
if _, ok := b.urls()[strings.ToLower(url)]; ok {
return true
}

return false
}

func (b *Banlist) Close() {
signal.Stop(b.ch)
close(b.ch)
}

func (b *Banlist) domains() map[string]struct{} {
return b.domainHolder.Load().(map[string]struct{})
}

func (b *Banlist) urls() map[string]struct{} {
return b.urlHolder.Load().(map[string]struct{})
}
101 changes: 101 additions & 0 deletions http/banlist/banlist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package banlist

import (
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBanlistMissingFile(t *testing.T) {
bl := newBanlist(tl(t), "not a path")
require.Error(t, bl.update())
}

func TestBanlistInvalidFileContents(t *testing.T) {
path, err := ioutil.TempFile("", "")
require.NoError(t, err)
defer os.Remove(path.Name())

_, err = path.WriteString("this isn't valid json")
require.NoError(t, err)

bl := newBanlist(tl(t), path.Name())
require.Error(t, bl.update())
}

func TestBanlistNoPaths(t *testing.T) {
bl := testList(t, &Config{
Domains: []string{"something.com"},
})

assert.Empty(t, bl.urls())
domains := bl.domains()
assert.Len(t, domains, 1)
_, ok := domains["something.com"]
assert.True(t, ok)
}

func TestBanlistNoDomains(t *testing.T) {
bl := testList(t, &Config{
URLs: []string{"something.com/path/to/thing"},
})

urls := bl.urls()
assert.Len(t, urls, 1)
_, ok := urls["something.com/path/to/thing"]
assert.True(t, ok)

assert.Empty(t, bl.domains())
}
func TestBanlistBanning(t *testing.T) {
bl := testList(t, &Config{
URLs: []string{"villians.com/the/joker"},
Domains: []string{"sick.com"},
})

tests := []struct {
url string
isBanned bool
name string
}{
{"http://heros.com", false, "completely unbanned"},
{"http://sick.com:12345", true, "banned domain with port"},
{"http://sick.com", true, "banned domain without port"},
{"http://siCK.com", true, "banned domain mixed case"},
{"http://villians.com:12354/the/joker", true, "banned path with port"},
{"http://villians.com/the/joker", true, "banned path without port"},
{"http://villians.com/the/Joker", true, "banned path mixed case"},
{"http://villians.com/the/joker?query=param", true, "banned path with query params"},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, test.url, nil)
assert.Equal(t, test.isBanned, bl.CheckRequest(req))
})
}
}

func tl(t *testing.T) logrus.FieldLogger {
l := logrus.New()
l.SetLevel(logrus.DebugLevel)
return l.WithField("test", t.Name())
}

func testList(t *testing.T, config *Config) *Banlist {
path, err := ioutil.TempFile("", "")
require.NoError(t, err)
defer os.Remove(path.Name())

require.NoError(t, json.NewEncoder(path).Encode(config))

bl := newBanlist(tl(t), path.Name())
require.NoError(t, bl.update())
return bl
}

0 comments on commit bd38f31

Please sign in to comment.