Skip to content

Commit

Permalink
Added wrappers for the windows AD objects and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rnishtala-sumo committed Oct 20, 2023
1 parent bbe5c4b commit e6f3980
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 58 deletions.
114 changes: 61 additions & 53 deletions pkg/receiver/activedirectoryinvreceiver/adinvreceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package activedirectoryinvreceiver
import (
"context"
"fmt"
"log"
"sync"
"time"

Expand All @@ -29,19 +28,32 @@ import (
"go.uber.org/zap"
)

// Client is an interface for an Active Directory client
type Client interface {
Open(path string) (interface{}, error)
Open(path string, resourceLogs *plog.ResourceLogs) (Container, error)
}

// ADSIClient is a wrapper for an Active Directory client
type ADSIClient struct{}

func (c ADSIClient) Open(path string) (interface{}, error) {
// Open an Active Directory container
func (c ADSIClient) Open(path string, resourceLogs *plog.ResourceLogs) (Container, error) {
client, err := adsi.NewClient()
if err != nil {
return nil, err
}
ldapPath := fmt.Sprintf("LDAP://%s", path)
return client.Open(ldapPath)
root, err := client.Open(ldapPath)
if err != nil {
return nil, err
}
rootContainer, err := root.ToContainer()
if err != nil {
return nil, err
}
windowsContainer := ADSIContainer{rootContainer}
defer rootContainer.Close()
return windowsContainer, nil
}

type ADReceiver struct {
Expand All @@ -53,6 +65,7 @@ type ADReceiver struct {
doneChan chan bool
}

// newLogsReceiver creates a new Active Directory Inventory receiver
func newLogsReceiver(cfg *ADConfig, logger *zap.Logger, client Client, consumer consumer.Logs) *ADReceiver {

return &ADReceiver{
Expand All @@ -65,20 +78,23 @@ func newLogsReceiver(cfg *ADConfig, logger *zap.Logger, client Client, consumer
}
}

// Start the logs receiver
func (l *ADReceiver) Start(ctx context.Context, _ component.Host) error {
l.logger.Debug("starting to poll for active directory inventory records")
l.logger.Debug("Starting to poll for active directory inventory records")
l.wg.Add(1)
go l.startPolling(ctx)
return nil
}

// Shutdown the logs receiver
func (l *ADReceiver) Shutdown(_ context.Context) error {
l.logger.Debug("shutting down logs receiver")
l.logger.Debug("Shutting down logs receiver")
close(l.doneChan)
l.wg.Wait()
return nil
}

// Start polling for Active Directory inventory records
func (l *ADReceiver) startPolling(ctx context.Context) {
defer l.wg.Done()
t := time.NewTicker(l.config.PollInterval * time.Second)
Expand All @@ -97,40 +113,55 @@ func (l *ADReceiver) startPolling(ctx context.Context) {
}
}

func (r *ADReceiver) poll(ctx context.Context) error {
go func() {
root, err := r.client.Open(r.config.DN)
if err != nil {
r.logger.Error("Failed to open root object:", zap.Error(err))
return
}
rootObject := root.(*adsi.Object)
rootContainer, err := rootObject.ToContainer()
// Traverse the Active Directory tree and set user attributes to log records
func (r *ADReceiver) traverse(node Container, attrs []string, resourceLogs *plog.ResourceLogs) {
nodeObject, err := node.ToObject()
if err != nil {
r.logger.Error("Failed to convert container to object", zap.Error(err))
return
}
setUserAttributes(nodeObject, attrs, resourceLogs)
children, err := node.Children()
if err != nil {
r.logger.Error("Failed to retrieve children", zap.Error(err))
return
}
for child, err := children.Next(); err == nil; child, err = children.Next() {
windowsChildContainer, err := child.ToContainer()
if err != nil {
r.logger.Error("Failed to open root object:", zap.Error(err))
r.logger.Error("Failed to convert child object to container", zap.Error(err))
return
}
defer rootContainer.Close()
logs := plog.NewLogs()
rl := logs.ResourceLogs().AppendEmpty()
resourceLogs := &rl
_ = resourceLogs.ScopeLogs().AppendEmpty()
r.traverse(rootContainer, resourceLogs)
err = r.consumer.ConsumeLogs(ctx, logs)
if err != nil {
r.logger.Error("Error consuming log", zap.Error(err))
}
}()
childContainer := ADSIContainer{windowsChildContainer}
r.traverse(childContainer, attrs, resourceLogs)
}
children.Close()
}

// Poll for Active Directory inventory records
func (r *ADReceiver) poll(ctx context.Context) error {
logs := plog.NewLogs()
rl := logs.ResourceLogs().AppendEmpty()
resourceLogs := &rl
_ = resourceLogs.ScopeLogs().AppendEmpty()
root, err := r.client.Open(r.config.DN, resourceLogs)
r.traverse(root, r.config.Attributes, resourceLogs)
if err != nil {
return err
}
err = r.consumer.ConsumeLogs(ctx, logs)
if err != nil {
r.logger.Error("Error consuming log", zap.Error(err))
}
return nil
}

func (l *ADReceiver) printAttrs(user *adsi.Object, resourceLogs *plog.ResourceLogs) {
attrs := l.config.Attributes
// Set user attributes to a log record body
func setUserAttributes(user Object, attrs []string, resourceLogs *plog.ResourceLogs) {
observedTime := pcommon.NewTimestampFromTime(time.Now())
attributes := ""
for _, attr := range attrs {
values, err := user.Attr(attr)
values, err := user.Attrs(attr)
if err == nil && len(values) > 0 {
attributes += fmt.Sprintf("%s: %v\n", attr, values)
}
Expand All @@ -140,26 +171,3 @@ func (l *ADReceiver) printAttrs(user *adsi.Object, resourceLogs *plog.ResourceLo
logRecord.SetTimestamp(observedTime)
logRecord.Body().SetStr(attributes)
}

func (l *ADReceiver) traverse(node *adsi.Container, resourceLogs *plog.ResourceLogs) {
nodeObject, err := node.ToObject()
if err != nil {
log.Printf("Error creating objects: %v\n", err)
return
}
l.printAttrs(nodeObject, resourceLogs)
children, err := node.Children()
if err != nil {
log.Printf("Error retrieving children: %v\n", err)
return
}
for child, err := children.Next(); err == nil; child, err = children.Next() {
childContainer, err := child.ToContainer()
if err != nil {
log.Println("Failed to traverse child object:", err)
return
}
l.traverse(childContainer, resourceLogs)
}
children.Close()
}
99 changes: 94 additions & 5 deletions pkg/receiver/activedirectoryinvreceiver/adinvreceiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,70 @@ package activedirectoryinvreceiver

import (
"context"
"testing"

"fmt"
adsi "github.com/go-adsi/adsi"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/collector/component/componenttest"
"go.opentelemetry.io/collector/consumer/consumertest"
"go.opentelemetry.io/collector/pdata/plog"
"go.uber.org/zap"
"testing"
"time"
)

type MockClient struct{}
type MockClient struct {
mock.Mock
}

func (mc MockClient) Open(path string, resourceLogs *plog.ResourceLogs) (Container, error) {
args := mc.Called(path, resourceLogs)
return args.Get(0).(Container), args.Error(1)
}

type MockContainer struct {
mock.Mock
}

func (mc MockContainer) ToObject() (Object, error) {
args := mc.Called()
return args.Get(0).(Object), args.Error(1)
}

func (mc MockContainer) Close() {
mc.Called()
}

func (mc MockContainer) Children() (ObjectIter, error) {
args := mc.Called()
return args.Get(0).(ObjectIter), args.Error(1)
}

type MockObject struct {
mock.Mock
}

func (mo MockObject) Attrs(key string) ([]interface{}, error) {
args := mo.Called(key)
return args.Get(0).([]interface{}), args.Error(1)
}

func (mo MockObject) ToContainer() (Container, error) {
args := mo.Called()
return args.Get(0).(Container), args.Error(1)
}

type MockObjectIter struct {
mock.Mock
}

func (mo MockObjectIter) Next() (*adsi.Object, error) {
args := mo.Called()
return args.Get(0).(*adsi.Object), args.Error(1)
}

func (c MockClient) Open(path string) (interface{}, error) {
return nil, nil
func (mo MockObjectIter) Close() {
mo.Called()
}

func TestStart(t *testing.T) {
Expand All @@ -45,3 +97,40 @@ func TestStart(t *testing.T) {
err = logsRcvr.Shutdown(context.Background())
require.NoError(t, err)
}

func TestPoll(t *testing.T) {
cfg := CreateDefaultConfig().(*ADConfig)
cfg.DN = "CN=Guest,CN=Users,DC=exampledomain,DC=com"
cfg.PollInterval = 1
cfg.Attributes = []string{"name"}

sink := &consumertest.LogsSink{}
mockClient := defaultMockClient()

logsRcvr := newLogsReceiver(cfg, zap.NewNop(), mockClient, sink)

err := logsRcvr.Start(context.Background(), componenttest.NewNopHost())
require.NoError(t, err)

require.Eventually(t, func() bool {
return sink.LogRecordCount() > 0
}, 2*time.Second, 10*time.Millisecond)

err = logsRcvr.Shutdown(context.Background())
require.NoError(t, err)
}

func defaultMockClient() Client {
mockClient := MockClient{}
mockContainer := MockContainer{}
mockObject := MockObject{}
mockObjectIter := MockObjectIter{}
attrs := []interface{}{"Guest", "test"}
mockContainer.On("ToObject").Return(&mockObject, nil)
mockContainer.On("Children").Return(mockObjectIter, fmt.Errorf("no children"))
mockContainer.On("Close").Return(nil)
mockObject.On("Attrs", mock.Anything).Return(attrs, nil)
mockObject.On("ToContainer").Return(&mockContainer, nil)
mockClient.On("Open", mock.Anything, mock.Anything).Return(&mockContainer, nil)
return mockClient
}
Loading

0 comments on commit e6f3980

Please sign in to comment.