-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #94 from netlify/add-ban-list
add banlist and tests
- Loading branch information
Showing
2 changed files
with
220 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package banlist | ||
|
||
import ( | ||
"encoding/json" | ||
"net/http" | ||
"os" | ||
"os/signal" | ||
"strings" | ||
"sync" | ||
"sync/atomic" | ||
"syscall" | ||
|
||
"github.com/pkg/errors" | ||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
type Config struct { | ||
Domains []string | ||
URLs []string | ||
} | ||
|
||
type Banlist struct { | ||
domainHolder atomic.Value | ||
urlHolder atomic.Value | ||
mtx sync.Mutex | ||
ch chan os.Signal | ||
log logrus.FieldLogger | ||
path string | ||
} | ||
|
||
func New(log logrus.FieldLogger, filepath string) *Banlist { | ||
bl := newBanlist(log, filepath) | ||
bl.listen() | ||
bl.runUpdate() | ||
return bl | ||
} | ||
|
||
func newBanlist(log logrus.FieldLogger, path string) *Banlist { | ||
bl := &Banlist{log: log, path: path} | ||
bl.domainHolder.Store(make(map[string]struct{})) | ||
bl.urlHolder.Store(make(map[string]struct{})) | ||
return bl | ||
} | ||
|
||
func (b *Banlist) listen() { | ||
b.ch = make(chan os.Signal, 1) | ||
signal.Notify(b.ch, syscall.SIGHUP) | ||
go func() { | ||
for range b.ch { | ||
b.runUpdate() | ||
} | ||
b.log.Info("No longer listening for SIGHUP") | ||
}() | ||
} | ||
|
||
func (b *Banlist) runUpdate() { | ||
if err := b.update(); err != nil { | ||
b.log.WithError(err).Warn("error updating banlist") | ||
} else { | ||
b.log.Info("banlist updated") | ||
} | ||
} | ||
|
||
func (b *Banlist) update() error { | ||
b.mtx.Lock() | ||
defer b.mtx.Unlock() | ||
|
||
f, err := os.Open(b.path) | ||
if err != nil { | ||
return errors.Wrap(err, "error opening banlist config") | ||
} | ||
defer f.Close() | ||
|
||
c := new(Config) | ||
if err := json.NewDecoder(f).Decode(c); err != nil { | ||
return errors.Wrap(err, "error decoding banlist config") | ||
} | ||
|
||
domains := make(map[string]struct{}) | ||
urls := make(map[string]struct{}) | ||
for _, el := range c.Domains { | ||
domains[strings.ToLower(el)] = struct{}{} | ||
} | ||
for _, el := range c.URLs { | ||
urls[strings.ToLower(el)] = struct{}{} | ||
} | ||
|
||
b.domainHolder.Store(domains) | ||
b.urlHolder.Store(urls) | ||
return nil | ||
} | ||
|
||
// CheckRequest will check if the domain is blocked or the path is blocked | ||
func (b *Banlist) CheckRequest(r *http.Request) bool { | ||
domain := strings.SplitN(r.Host, ":", 2)[0] | ||
if _, ok := b.domains()[strings.ToLower(domain)]; ok { | ||
return true | ||
} | ||
|
||
url := domain + r.URL.Path | ||
if _, ok := b.urls()[strings.ToLower(url)]; ok { | ||
return true | ||
} | ||
|
||
return false | ||
} | ||
|
||
func (b *Banlist) Close() { | ||
signal.Stop(b.ch) | ||
close(b.ch) | ||
} | ||
|
||
func (b *Banlist) domains() map[string]struct{} { | ||
return b.domainHolder.Load().(map[string]struct{}) | ||
} | ||
|
||
func (b *Banlist) urls() map[string]struct{} { | ||
return b.urlHolder.Load().(map[string]struct{}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
package banlist | ||
|
||
import ( | ||
"encoding/json" | ||
"io/ioutil" | ||
"net/http" | ||
"net/http/httptest" | ||
"os" | ||
"testing" | ||
|
||
"github.com/sirupsen/logrus" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestBanlistMissingFile(t *testing.T) { | ||
bl := newBanlist(tl(t), "not a path") | ||
require.Error(t, bl.update()) | ||
} | ||
|
||
func TestBanlistInvalidFileContents(t *testing.T) { | ||
path, err := ioutil.TempFile("", "") | ||
require.NoError(t, err) | ||
defer os.Remove(path.Name()) | ||
|
||
_, err = path.WriteString("this isn't valid json") | ||
require.NoError(t, err) | ||
|
||
bl := newBanlist(tl(t), path.Name()) | ||
require.Error(t, bl.update()) | ||
} | ||
|
||
func TestBanlistNoPaths(t *testing.T) { | ||
bl := testList(t, &Config{ | ||
Domains: []string{"something.com"}, | ||
}) | ||
|
||
assert.Empty(t, bl.urls()) | ||
domains := bl.domains() | ||
assert.Len(t, domains, 1) | ||
_, ok := domains["something.com"] | ||
assert.True(t, ok) | ||
} | ||
|
||
func TestBanlistNoDomains(t *testing.T) { | ||
bl := testList(t, &Config{ | ||
URLs: []string{"something.com/path/to/thing"}, | ||
}) | ||
|
||
urls := bl.urls() | ||
assert.Len(t, urls, 1) | ||
_, ok := urls["something.com/path/to/thing"] | ||
assert.True(t, ok) | ||
|
||
assert.Empty(t, bl.domains()) | ||
} | ||
func TestBanlistBanning(t *testing.T) { | ||
bl := testList(t, &Config{ | ||
URLs: []string{"villians.com/the/joker"}, | ||
Domains: []string{"sick.com"}, | ||
}) | ||
|
||
tests := []struct { | ||
url string | ||
isBanned bool | ||
name string | ||
}{ | ||
{"http://heros.com", false, "completely unbanned"}, | ||
{"http://sick.com:12345", true, "banned domain with port"}, | ||
{"http://sick.com", true, "banned domain without port"}, | ||
{"http://siCK.com", true, "banned domain mixed case"}, | ||
{"http://villians.com:12354/the/joker", true, "banned path with port"}, | ||
{"http://villians.com/the/joker", true, "banned path without port"}, | ||
{"http://villians.com/the/Joker", true, "banned path mixed case"}, | ||
{"http://villians.com/the/joker?query=param", true, "banned path with query params"}, | ||
} | ||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
req := httptest.NewRequest(http.MethodGet, test.url, nil) | ||
assert.Equal(t, test.isBanned, bl.CheckRequest(req)) | ||
}) | ||
} | ||
} | ||
|
||
func tl(t *testing.T) logrus.FieldLogger { | ||
l := logrus.New() | ||
l.SetLevel(logrus.DebugLevel) | ||
return l.WithField("test", t.Name()) | ||
} | ||
|
||
func testList(t *testing.T, config *Config) *Banlist { | ||
path, err := ioutil.TempFile("", "") | ||
require.NoError(t, err) | ||
defer os.Remove(path.Name()) | ||
|
||
require.NoError(t, json.NewEncoder(path).Encode(config)) | ||
|
||
bl := newBanlist(tl(t), path.Name()) | ||
require.NoError(t, bl.update()) | ||
return bl | ||
} |