Skip to content

Commit

Permalink
Move executable to cmd directory
Browse files Browse the repository at this point in the history
This patch also allow this repository to be imported in another project.
Only the execution part and config loading is moved to the cmd directory
allowing the use of the core features in other projects.
  • Loading branch information
Maxime Vidori committed Nov 6, 2018
1 parent c160a22 commit bec1116
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 121 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
bin:
govendor sync
go build
go build cmd/*.go

test:
govendor sync
go test -v
go test -v ./...

test-cov-html:
go test -coverprofile=coverage.out
go tool cover -html=coverage.out

bench:
go test -bench=.
go test -bench=. ./...

bench-cpu:
go test -bench=. -benchtime=5s -cpuprofile=cpu.pprof
Expand Down
10 changes: 5 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package main
package audit

import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"sync/atomic"
"syscall"
"time"
"fmt"
)

// Endianness is an alias for what we assume is the current machine endianness
Expand Down Expand Up @@ -63,13 +63,13 @@ func NewNetlinkClient(recvSize int) (*NetlinkClient, error) {
// Set the buffer size if we were asked
if recvSize > 0 {
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, recvSize); err != nil {
el.Println("Failed to set receive buffer size")
Stderr.Println("Failed to set receive buffer size")
}
}

// Print the current receive buffer size
if v, err := syscall.GetsockoptInt(n.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF); err == nil {
l.Println("Socket receive buffer size:", v)
Std.Println("Socket receive buffer size:", v)
}

go func() {
Expand Down Expand Up @@ -151,6 +151,6 @@ func (n *NetlinkClient) KeepConnection() {

err := n.Send(packet, payload)
if err != nil {
el.Println("Error occurred while trying to keep the connection:", err)
Stderr.Println("Error occurred while trying to keep the connection:", err)
}
}
31 changes: 8 additions & 23 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package main
package audit

import (
"bytes"
"encoding/binary"
"github.com/stretchr/testify/assert"
"os"
"syscall"
"testing"

"github.com/slackhq/go-audit/internal/test"
"github.com/stretchr/testify/assert"
)

func TestNetlinkClient_KeepConnection(t *testing.T) {
Expand All @@ -29,8 +30,8 @@ func TestNetlinkClient_KeepConnection(t *testing.T) {
assert.EqualValues(t, msg.Data[:40], expectedData, "data was wrong")

// Make sure we get errors printed
lb, elb := hookLogger()
defer resetLogger()
lb, elb := test.HookLogger(Std, Stderr)
defer test.ResetLogger(Std, Stderr)
syscall.Close(n.fd)
n.KeepConnection()
assert.Equal(t, "", lb.String(), "Got some log lines we did not expect")
Expand Down Expand Up @@ -88,8 +89,8 @@ func TestNetlinkClient_SendReceive(t *testing.T) {
}

func TestNewNetlinkClient(t *testing.T) {
lb, elb := hookLogger()
defer resetLogger()
lb, elb := test.HookLogger(Std, Stderr)
defer test.ResetLogger(Std, Stderr)

n, err := NewNetlinkClient(1024)

Expand Down Expand Up @@ -143,19 +144,3 @@ func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload

return msg
}

// Resets global loggers
func resetLogger() {
l.SetOutput(os.Stdout)
el.SetOutput(os.Stderr)
}

// Hooks the global loggers writers so you can assert their contents
func hookLogger() (lb *bytes.Buffer, elb *bytes.Buffer) {
lb = &bytes.Buffer{}
l.SetOutput(lb)

elb = &bytes.Buffer{}
el.SetOutput(elb)
return
}
84 changes: 40 additions & 44 deletions audit.go → cmd/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"flag"
"fmt"
"log"
"log/syslog"
"os"
"os/exec"
Expand All @@ -15,12 +14,10 @@ import (
"strings"
"syscall"

"github.com/slackhq/go-audit"
"github.com/spf13/viper"
)

var l = log.New(os.Stdout, "", 0)
var el = log.New(os.Stderr, "", 0)

type executor func(string, ...string) error

func lExec(s string, a ...string) error {
Expand All @@ -46,8 +43,8 @@ func loadConfig(configFile string) (*viper.Viper, error) {
return nil, err
}

l.SetFlags(config.GetInt("log.flags"))
el.SetFlags(config.GetInt("log.flags"))
audit.Std.SetFlags(config.GetInt("log.flags"))
audit.Stderr.SetFlags(config.GetInt("log.flags"))

return config, nil
}
Expand All @@ -58,7 +55,7 @@ func setRules(config *viper.Viper, e executor) error {
return fmt.Errorf("Failed to flush existing audit rules. Error: %s", err)
}

l.Println("Flushed existing audit rules")
audit.Std.Println("Flushed existing audit rules")

// Add ours in
if rules := config.GetStringSlice("rules"); len(rules) != 0 {
Expand All @@ -72,7 +69,7 @@ func setRules(config *viper.Viper, e executor) error {
return fmt.Errorf("Failed to add rule #%d. Error: %s", i+1, err)
}

l.Printf("Added audit rule #%d\n", i+1)
audit.Std.Printf("Added audit rule #%d\n", i+1)
}
} else {
return errors.New("No audit rules found")
Expand All @@ -81,8 +78,8 @@ func setRules(config *viper.Viper, e executor) error {
return nil
}

func createOutput(config *viper.Viper) (*AuditWriter, error) {
var writer *AuditWriter
func createOutput(config *viper.Viper) (*audit.AuditWriter, error) {
var writer *audit.AuditWriter
var err error
i := 0

Expand Down Expand Up @@ -123,7 +120,7 @@ func createOutput(config *viper.Viper) (*AuditWriter, error) {
return writer, nil
}

func createSyslogOutput(config *viper.Viper) (*AuditWriter, error) {
func createSyslogOutput(config *viper.Viper) (*audit.AuditWriter, error) {
attempts := config.GetInt("output.syslog.attempts")
if attempts < 1 {
return nil, fmt.Errorf("Output attempts for syslog must be at least 1, %v provided", attempts)
Expand All @@ -140,10 +137,10 @@ func createSyslogOutput(config *viper.Viper) (*AuditWriter, error) {
return nil, fmt.Errorf("Failed to open syslog writer. Error: %v", err)
}

return NewAuditWriter(syslogWriter, attempts), nil
return audit.NewAuditWriter(syslogWriter, attempts), nil
}

func createFileOutput(config *viper.Viper) (*AuditWriter, error) {
func createFileOutput(config *viper.Viper) (*audit.AuditWriter, error) {
attempts := config.GetInt("output.file.attempts")
if attempts < 1 {
return nil, fmt.Errorf("Output attempts for file must be at least 1, %v provided", attempts)
Expand Down Expand Up @@ -193,10 +190,10 @@ func createFileOutput(config *viper.Viper) (*AuditWriter, error) {
return nil, fmt.Errorf("Could not chown output file. Error: %s", err)
}

return NewAuditWriter(f, attempts), nil
return audit.NewAuditWriter(f, attempts), nil
}

func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
func handleLogRotation(config *viper.Viper, writer *audit.AuditWriter) {
// Re-open our log file. This is triggered by a USR1 signal and is meant to be used upon log rotation

sigc := make(chan os.Signal, 1)
Expand All @@ -205,38 +202,37 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
for range sigc {
newWriter, err := createFileOutput(config)
if err != nil {
el.Fatalln("Error re-opening log file. Exiting.")
audit.Stderr.Fatalln("Error re-opening log file. Exiting.")
}

oldFile := writer.w.(*os.File)
writer.w = newWriter.w
writer.e = newWriter.e
oldFile := writer.IOWriter().(*os.File)
writer.SetIOWriter(newWriter.IOWriter())

err = oldFile.Close()
if err != nil {
el.Printf("Error closing old log file: %+v\n", err)
audit.Stderr.Printf("Error closing old log file: %+v\n", err)
}
}
}

func createStdOutOutput(config *viper.Viper) (*AuditWriter, error) {
func createStdOutOutput(config *viper.Viper) (*audit.AuditWriter, error) {
attempts := config.GetInt("output.stdout.attempts")
if attempts < 1 {
return nil, fmt.Errorf("Output attempts for stdout must be at least 1, %v provided", attempts)
}

// l logger is no longer stdout
l.SetOutput(os.Stderr)
audit.Std.SetOutput(os.Stderr)

return NewAuditWriter(os.Stdout, attempts), nil
return audit.NewAuditWriter(os.Stdout, attempts), nil
}

func createFilters(config *viper.Viper) ([]AuditFilter, error) {
func createFilters(config *viper.Viper) ([]audit.AuditFilter, error) {
var err error
var ok bool

fs := config.Get("filters")
filters := []AuditFilter{}
filters := []audit.AuditFilter{}

if fs == nil {
return filters, nil
Expand All @@ -253,7 +249,7 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
return filters, fmt.Errorf("Could not parse filter %d; '%+v'", i+1, f)
}

af := AuditFilter{}
af := audit.AuditFilter{}
for k, v := range f2 {
switch k {
case "message_type":
Expand All @@ -262,10 +258,10 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
if err != nil {
return filters, fmt.Errorf("`message_type` in filter %d could not be parsed; Value: `%+v`; Error: %s", i+1, v, err)
}
af.messageType = uint16(fv)
af.MessageType = uint16(fv)

} else if ev, ok := v.(int); ok {
af.messageType = uint16(ev)
af.MessageType = uint16(ev)

} else {
return filters, fmt.Errorf("`message_type` in filter %d could not be parsed; Value: `%+v`", i+1, v)
Expand All @@ -277,31 +273,31 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
return filters, fmt.Errorf("`regex` in filter %d could not be parsed; Value: `%+v`", i+1, v)
}

if af.regex, err = regexp.Compile(re); err != nil {
if af.Regex, err = regexp.Compile(re); err != nil {
return filters, fmt.Errorf("`regex` in filter %d could not be parsed; Value: `%+v`; Error: %s", i+1, v, err)
}

case "syscall":
if af.syscall, ok = v.(string); ok {
if af.Syscall, ok = v.(string); ok {
// All is good
} else if ev, ok := v.(int); ok {
af.syscall = strconv.Itoa(ev)
af.Syscall = strconv.Itoa(ev)
} else {
return filters, fmt.Errorf("`syscall` in filter %d could not be parsed; Value: `%+v`", i+1, v)
}
}
}

if af.regex == nil {
if af.Regex == nil {
return filters, fmt.Errorf("Filter %d is missing the `regex` entry", i+1)
}

if af.messageType == 0 {
if af.MessageType == 0 {
return filters, fmt.Errorf("Filter %d is missing the `message_type` entry", i+1)
}

filters = append(filters, af)
l.Printf("Ignoring syscall `%v` containing message type `%v` matching string `%s`\n", af.syscall, af.messageType, af.regex.String())
audit.Std.Printf("Ignoring syscall `%v` containing message type `%v` matching string `%s`\n", af.Syscall, af.MessageType, af.Regex.String())
}

return filters, nil
Expand All @@ -313,37 +309,37 @@ func main() {
flag.Parse()

if *configFile == "" {
el.Println("A config file must be provided")
audit.Stderr.Println("A config file must be provided")
flag.Usage()
os.Exit(1)
}

config, err := loadConfig(*configFile)
if err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

// output needs to be created before anything that write to stdout
writer, err := createOutput(config)
if err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

if err := setRules(config, lExec); err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

filters, err := createFilters(config)
if err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

nlClient, err := NewNetlinkClient(config.GetInt("socket_buffer.receive"))
nlClient, err := audit.NewNetlinkClient(config.GetInt("socket_buffer.receive"))
if err != nil {
el.Fatal(err)
audit.Stderr.Fatal(err)
}

marshaller := NewAuditMarshaller(
marshaller := audit.NewAuditMarshaller(
writer,
uint16(config.GetInt("events.min")),
uint16(config.GetInt("events.max")),
Expand All @@ -353,13 +349,13 @@ func main() {
filters,
)

l.Printf("Started processing events in the range [%d, %d]\n", config.GetInt("events.min"), config.GetInt("events.max"))
audit.Std.Printf("Started processing events in the range [%d, %d]\n", config.GetInt("events.min"), config.GetInt("events.max"))

//Main loop. Get data from netlink and send it to the json lib for processing
for {
msg, err := nlClient.Receive()
if err != nil {
el.Printf("Error during message receive: %+v\n", err)
audit.Stderr.Printf("Error during message receive: %+v\n", err)
continue
}

Expand Down
Loading

0 comments on commit bec1116

Please sign in to comment.