Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move executable to cmd directory #56

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
bin:
govendor sync
go build
go build cmd/*.go

test:
govendor sync
go test -v
go test -v ./...

test-cov-html:
go test -coverprofile=coverage.out
go tool cover -html=coverage.out

bench:
go test -bench=.
go test -bench=. ./...

bench-cpu:
go test -bench=. -benchtime=5s -cpuprofile=cpu.pprof
Expand Down
10 changes: 5 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package main
package audit

import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"sync/atomic"
"syscall"
"time"
"fmt"
)

// Endianness is an alias for what we assume is the current machine endianness
Expand Down Expand Up @@ -63,13 +63,13 @@ func NewNetlinkClient(recvSize int) (*NetlinkClient, error) {
// Set the buffer size if we were asked
if recvSize > 0 {
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, recvSize); err != nil {
el.Println("Failed to set receive buffer size")
Stderr.Println("Failed to set receive buffer size")
}
}

// Print the current receive buffer size
if v, err := syscall.GetsockoptInt(n.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF); err == nil {
l.Println("Socket receive buffer size:", v)
Std.Println("Socket receive buffer size:", v)
}

go func() {
Expand Down Expand Up @@ -151,6 +151,6 @@ func (n *NetlinkClient) KeepConnection() {

err := n.Send(packet, payload)
if err != nil {
el.Println("Error occurred while trying to keep the connection:", err)
Stderr.Println("Error occurred while trying to keep the connection:", err)
}
}
31 changes: 8 additions & 23 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package main
package audit

import (
"bytes"
"encoding/binary"
"github.com/stretchr/testify/assert"
"os"
"syscall"
"testing"

"github.com/slackhq/go-audit/internal/test"
"github.com/stretchr/testify/assert"
)

func TestNetlinkClient_KeepConnection(t *testing.T) {
Expand All @@ -29,8 +30,8 @@ func TestNetlinkClient_KeepConnection(t *testing.T) {
assert.EqualValues(t, msg.Data[:40], expectedData, "data was wrong")

// Make sure we get errors printed
lb, elb := hookLogger()
defer resetLogger()
lb, elb := test.HookLogger(Std, Stderr)
defer test.ResetLogger(Std, Stderr)
syscall.Close(n.fd)
n.KeepConnection()
assert.Equal(t, "", lb.String(), "Got some log lines we did not expect")
Expand Down Expand Up @@ -88,8 +89,8 @@ func TestNetlinkClient_SendReceive(t *testing.T) {
}

func TestNewNetlinkClient(t *testing.T) {
lb, elb := hookLogger()
defer resetLogger()
lb, elb := test.HookLogger(Std, Stderr)
defer test.ResetLogger(Std, Stderr)

n, err := NewNetlinkClient(1024)

Expand Down Expand Up @@ -143,19 +144,3 @@ func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload

return msg
}

// Resets global loggers
func resetLogger() {
l.SetOutput(os.Stdout)
el.SetOutput(os.Stderr)
}

// Hooks the global loggers writers so you can assert their contents
func hookLogger() (lb *bytes.Buffer, elb *bytes.Buffer) {
lb = &bytes.Buffer{}
l.SetOutput(lb)

elb = &bytes.Buffer{}
el.SetOutput(elb)
return
}
84 changes: 40 additions & 44 deletions audit.go → cmd/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"flag"
"fmt"
"log"
"log/syslog"
"os"
"os/exec"
Expand All @@ -15,12 +14,10 @@ import (
"strings"
"syscall"

"github.com/slackhq/go-audit"
"github.com/spf13/viper"
)

var l = log.New(os.Stdout, "", 0)
var el = log.New(os.Stderr, "", 0)

type executor func(string, ...string) error

func lExec(s string, a ...string) error {
Expand All @@ -46,8 +43,8 @@ func loadConfig(configFile string) (*viper.Viper, error) {
return nil, err
}

l.SetFlags(config.GetInt("log.flags"))
el.SetFlags(config.GetInt("log.flags"))
audit.Std.SetFlags(config.GetInt("log.flags"))
audit.Stderr.SetFlags(config.GetInt("log.flags"))

return config, nil
}
Expand All @@ -58,7 +55,7 @@ func setRules(config *viper.Viper, e executor) error {
return fmt.Errorf("Failed to flush existing audit rules. Error: %s", err)
}

l.Println("Flushed existing audit rules")
audit.Std.Println("Flushed existing audit rules")

// Add ours in
if rules := config.GetStringSlice("rules"); len(rules) != 0 {
Expand All @@ -72,7 +69,7 @@ func setRules(config *viper.Viper, e executor) error {
return fmt.Errorf("Failed to add rule #%d. Error: %s", i+1, err)
}

l.Printf("Added audit rule #%d\n", i+1)
audit.Std.Printf("Added audit rule #%d\n", i+1)
}
} else {
return errors.New("No audit rules found")
Expand All @@ -81,8 +78,8 @@ func setRules(config *viper.Viper, e executor) error {
return nil
}

func createOutput(config *viper.Viper) (*AuditWriter, error) {
var writer *AuditWriter
func createOutput(config *viper.Viper) (*audit.JSONAuditWriter, error) {
var writer *audit.JSONAuditWriter
var err error
i := 0

Expand Down Expand Up @@ -123,7 +120,7 @@ func createOutput(config *viper.Viper) (*AuditWriter, error) {
return writer, nil
}

func createSyslogOutput(config *viper.Viper) (*AuditWriter, error) {
func createSyslogOutput(config *viper.Viper) (*audit.JSONAuditWriter, error) {
attempts := config.GetInt("output.syslog.attempts")
if attempts < 1 {
return nil, fmt.Errorf("Output attempts for syslog must be at least 1, %v provided", attempts)
Expand All @@ -140,10 +137,10 @@ func createSyslogOutput(config *viper.Viper) (*AuditWriter, error) {
return nil, fmt.Errorf("Failed to open syslog writer. Error: %v", err)
}

return NewAuditWriter(syslogWriter, attempts), nil
return audit.NewAuditWriter(syslogWriter, attempts), nil
}

func createFileOutput(config *viper.Viper) (*AuditWriter, error) {
func createFileOutput(config *viper.Viper) (*audit.JSONAuditWriter, error) {
attempts := config.GetInt("output.file.attempts")
if attempts < 1 {
return nil, fmt.Errorf("Output attempts for file must be at least 1, %v provided", attempts)
Expand Down Expand Up @@ -193,10 +190,10 @@ func createFileOutput(config *viper.Viper) (*AuditWriter, error) {
return nil, fmt.Errorf("Could not chown output file. Error: %s", err)
}

return NewAuditWriter(f, attempts), nil
return audit.NewAuditWriter(f, attempts), nil
}

func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
func handleLogRotation(config *viper.Viper, writer *audit.JSONAuditWriter) {
// Re-open our log file. This is triggered by a USR1 signal and is meant to be used upon log rotation

sigc := make(chan os.Signal, 1)
Expand All @@ -205,38 +202,37 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
for range sigc {
newWriter, err := createFileOutput(config)
if err != nil {
el.Fatalln("Error re-opening log file. Exiting.")
audit.Stderr.Fatalln("Error re-opening log file. Exiting.")
}

oldFile := writer.w.(*os.File)
writer.w = newWriter.w
writer.e = newWriter.e
oldFile := writer.IOWriter().(*os.File)
writer.SetIOWriter(newWriter.IOWriter())

err = oldFile.Close()
if err != nil {
el.Printf("Error closing old log file: %+v\n", err)
audit.Stderr.Printf("Error closing old log file: %+v\n", err)
}
}
}

func createStdOutOutput(config *viper.Viper) (*AuditWriter, error) {
func createStdOutOutput(config *viper.Viper) (*audit.JSONAuditWriter, error) {
attempts := config.GetInt("output.stdout.attempts")
if attempts < 1 {
return nil, fmt.Errorf("Output attempts for stdout must be at least 1, %v provided", attempts)
}

// l logger is no longer stdout
l.SetOutput(os.Stderr)
audit.Std.SetOutput(os.Stderr)

return NewAuditWriter(os.Stdout, attempts), nil
return audit.NewAuditWriter(os.Stdout, attempts), nil
}

func createFilters(config *viper.Viper) ([]AuditFilter, error) {
func createFilters(config *viper.Viper) ([]audit.AuditFilter, error) {
var err error
var ok bool

fs := config.Get("filters")
filters := []AuditFilter{}
filters := []audit.AuditFilter{}

if fs == nil {
return filters, nil
Expand All @@ -253,7 +249,7 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
return filters, fmt.Errorf("Could not parse filter %d; '%+v'", i+1, f)
}

af := AuditFilter{}
af := audit.AuditFilter{}
for k, v := range f2 {
switch k {
case "message_type":
Expand All @@ -262,10 +258,10 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
if err != nil {
return filters, fmt.Errorf("`message_type` in filter %d could not be parsed; Value: `%+v`; Error: %s", i+1, v, err)
}
af.messageType = uint16(fv)
af.MessageType = uint16(fv)

} else if ev, ok := v.(int); ok {
af.messageType = uint16(ev)
af.MessageType = uint16(ev)

} else {
return filters, fmt.Errorf("`message_type` in filter %d could not be parsed; Value: `%+v`", i+1, v)
Expand All @@ -277,31 +273,31 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
return filters, fmt.Errorf("`regex` in filter %d could not be parsed; Value: `%+v`", i+1, v)
}

if af.regex, err = regexp.Compile(re); err != nil {
if af.Regex, err = regexp.Compile(re); err != nil {
return filters, fmt.Errorf("`regex` in filter %d could not be parsed; Value: `%+v`; Error: %s", i+1, v, err)
}

case "syscall":
if af.syscall, ok = v.(string); ok {
if af.Syscall, ok = v.(string); ok {
// All is good
} else if ev, ok := v.(int); ok {
af.syscall = strconv.Itoa(ev)
af.Syscall = strconv.Itoa(ev)
} else {
return filters, fmt.Errorf("`syscall` in filter %d could not be parsed; Value: `%+v`", i+1, v)
}
}
}

if af.regex == nil {
if af.Regex == nil {
return filters, fmt.Errorf("Filter %d is missing the `regex` entry", i+1)
}

if af.messageType == 0 {
if af.MessageType == 0 {
return filters, fmt.Errorf("Filter %d is missing the `message_type` entry", i+1)
}

filters = append(filters, af)
l.Printf("Ignoring syscall `%v` containing message type `%v` matching string `%s`\n", af.syscall, af.messageType, af.regex.String())
audit.Std.Printf("Ignoring syscall `%v` containing message type `%v` matching string `%s`\n", af.Syscall, af.MessageType, af.Regex.String())
}

return filters, nil
Expand All @@ -313,37 +309,37 @@ func main() {
flag.Parse()

if *configFile == "" {
el.Println("A config file must be provided")
audit.Stderr.Println("A config file must be provided")
flag.Usage()
os.Exit(1)
}

config, err := loadConfig(*configFile)
if err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

// output needs to be created before anything that write to stdout
writer, err := createOutput(config)
if err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

if err := setRules(config, lExec); err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

filters, err := createFilters(config)
if err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

nlClient, err := NewNetlinkClient(config.GetInt("socket_buffer.receive"))
nlClient, err := audit.NewNetlinkClient(config.GetInt("socket_buffer.receive"))
if err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

marshaller := NewAuditMarshaller(
marshaller := audit.NewAuditMarshaller(
writer,
uint16(config.GetInt("events.min")),
uint16(config.GetInt("events.max")),
Expand All @@ -353,13 +349,13 @@ func main() {
filters,
)

l.Printf("Started processing events in the range [%d, %d]\n", config.GetInt("events.min"), config.GetInt("events.max"))
audit.Std.Printf("Started processing events in the range [%d, %d]\n", config.GetInt("events.min"), config.GetInt("events.max"))

//Main loop. Get data from netlink and send it to the json lib for processing
for {
msg, err := nlClient.Receive()
if err != nil {
el.Printf("Error during message receive: %+v\n", err)
audit.Stderr.Printf("Error during message receive: %+v\n", err)
continue
}

Expand Down
Loading