Skip to content

Commit

Permalink
Fix concurrency issues. Add posture check verification.
Browse files Browse the repository at this point in the history
  • Loading branch information
plorenz committed Jan 6, 2025
1 parent 1f963c1 commit 5b35d72
Show file tree
Hide file tree
Showing 11 changed files with 438 additions and 95 deletions.
26 changes: 8 additions & 18 deletions common/event_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"github.com/openziti/ziti/common/pb/edge_ctrl_pb"
"sync"
"sync/atomic"
)

type OnStoreSuccess func(index uint64, event *edge_ctrl_pb.DataState_ChangeSet)
Expand Down Expand Up @@ -51,17 +52,15 @@ type EventCache interface {
// when replaying events is not expected (i.e. in routers)
type ForgetfulEventCache struct {
lock sync.Mutex
index *uint64
index uint64
}

func NewForgetfulEventCache() *ForgetfulEventCache {
return &ForgetfulEventCache{}
}

func (cache *ForgetfulEventCache) SetCurrentIndex(index uint64) {
cache.lock.Lock()
defer cache.lock.Unlock()
cache.index = &index
cache.index = index
}

func (cache *ForgetfulEventCache) WhileLocked(callback func(uint64, bool)) {
Expand All @@ -82,16 +81,14 @@ func (cache *ForgetfulEventCache) Store(event *edge_ctrl_pb.DataState_ChangeSet,
return nil
}

if cache.index != nil {
if *cache.index >= event.Index {
return fmt.Errorf("out of order event detected, currentIndex: %d, receivedIndex: %d, type :%T", *cache.index, event.Index, cache)
}
if cache.index > 0 && cache.index >= event.Index {
return fmt.Errorf("out of order event detected, currentIndex: %d, receivedIndex: %d, type :%T", cache.index, event.Index, cache)
}

cache.index = &event.Index
cache.index = event.Index

if onSuccess != nil {
onSuccess(*cache.index, event)
onSuccess(cache.index, event)
}

return nil
Expand All @@ -102,18 +99,11 @@ func (cache *ForgetfulEventCache) ReplayFrom(_ uint64) ([]*edge_ctrl_pb.DataStat
}

func (cache *ForgetfulEventCache) CurrentIndex() (uint64, bool) {
cache.lock.Lock()
defer cache.lock.Unlock()

return cache.currentIndex()
}

func (cache *ForgetfulEventCache) currentIndex() (uint64, bool) {
if cache.index == nil {
return 0, false
}

return *cache.index, true
return atomic.LoadUint64(&cache.index), true
}

// LoggingEventCache stores events in order to support replaying (i.e. in controllers).
Expand Down
107 changes: 57 additions & 50 deletions common/router_data_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type DataStateIdentity = edge_ctrl_pb.DataState_Identity

type Identity struct {
*DataStateIdentity
ServicePolicies map[string]struct{} `json:"servicePolicies"`
ServicePolicies *concurrenz.SyncSet[string] `json:"servicePolicies"`
identityIndex uint64
serviceSetIndex uint64
}
Expand Down Expand Up @@ -94,8 +94,8 @@ type DataStateServicePolicy = edge_ctrl_pb.DataState_ServicePolicy

type ServicePolicy struct {
*DataStateServicePolicy
Services map[string]struct{} `json:"services"`
PostureChecks map[string]struct{} `json:"postureChecks"`
Services *concurrenz.SyncSet[string] `json:"services"`
PostureChecks *concurrenz.SyncSet[string] `json:"postureChecks"`
}

// RouterDataModel represents a sub-set of a controller's data model. Enough to validate an identities access to dial/bind
Expand Down Expand Up @@ -202,6 +202,8 @@ func NewReceiverRouterDataModelFromExisting(existing *RouterDataModel, listenerB
closeNotify: closeNotify,
stopNotify: make(chan struct{}),
}
currentIndex, _ := existing.CurrentIndex()
result.SetCurrentIndex(currentIndex)
go result.processSubscriberEvents()
return result
}
Expand Down Expand Up @@ -232,6 +234,7 @@ func NewReceiverRouterDataModelFromFile(path string, listenerBufferSize uint, cl

err = json.Unmarshal(data, rdmContents)
if err != nil {
rdmContents.RouterDataModel.Stop()
return nil, err
}

Expand Down Expand Up @@ -371,7 +374,7 @@ func (rdm *RouterDataModel) HandleIdentityEvent(index uint64, event *edge_ctrl_p
if valueInMap == nil {
identity = &Identity{
DataStateIdentity: model.Identity,
ServicePolicies: map[string]struct{}{},
ServicePolicies: concurrenz.NewSyncSet[string](),
identityIndex: index,
}
} else {
Expand Down Expand Up @@ -400,7 +403,7 @@ func (rdm *RouterDataModel) HandleServiceEvent(index uint64, event *edge_ctrl_pb
if event.Action == edge_ctrl_pb.DataState_Delete {
rdm.Services.Remove(model.Service.Id)
rdm.ServicePolicies.IterCb(func(key string, v *ServicePolicy) {
delete(v.Services, model.Service.Id)
v.Services.Remove(model.Service.Id)
})
} else {
rdm.Services.Set(model.Service.Id, &Service{
Expand Down Expand Up @@ -444,8 +447,8 @@ func (rdm *RouterDataModel) applyUpdateServicePolicyEvent(model *edge_ctrl_pb.Da
if valueInMap == nil {
return &ServicePolicy{
DataStateServicePolicy: servicePolicy,
Services: map[string]struct{}{},
PostureChecks: map[string]struct{}{},
Services: concurrenz.NewSyncSet[string](),
PostureChecks: concurrenz.NewSyncSet[string](),
}
} else {
return &ServicePolicy{
Expand Down Expand Up @@ -526,9 +529,9 @@ func (rdm *RouterDataModel) HandleServicePolicyChange(index uint64, model *edge_
rdm.Identities.Upsert(identityId, nil, func(exist bool, valueInMap *Identity, newValue *Identity) *Identity {
if valueInMap != nil {
if model.Add {
valueInMap.ServicePolicies[model.PolicyId] = struct{}{}
valueInMap.ServicePolicies.Add(model.PolicyId)
} else {
delete(valueInMap.ServicePolicies, model.PolicyId)
valueInMap.ServicePolicies.Remove(model.PolicyId)
}
valueInMap.serviceSetIndex = index
}
Expand All @@ -551,21 +554,21 @@ func (rdm *RouterDataModel) HandleServicePolicyChange(index uint64, model *edge_
case edge_ctrl_pb.ServicePolicyRelatedEntityType_RelatedService:
if model.Add {
for _, serviceId := range model.RelatedEntityIds {
valueInMap.Services[serviceId] = struct{}{}
valueInMap.Services.Add(serviceId)
}
} else {
for _, serviceId := range model.RelatedEntityIds {
delete(valueInMap.Services, serviceId)
valueInMap.Services.Remove(serviceId)
}
}
case edge_ctrl_pb.ServicePolicyRelatedEntityType_RelatedPostureCheck:
if model.Add {
for _, postureCheckId := range model.RelatedEntityIds {
valueInMap.PostureChecks[postureCheckId] = struct{}{}
valueInMap.PostureChecks.Add(postureCheckId)
}
} else {
for _, postureCheckId := range model.RelatedEntityIds {
delete(valueInMap.PostureChecks, postureCheckId)
valueInMap.PostureChecks.Remove(postureCheckId)
}
}
}
Expand Down Expand Up @@ -648,7 +651,7 @@ func (rdm *RouterDataModel) GetDataState() *edge_ctrl_pb.DataState {
}
events = append(events, newEvent)

for policyId := range v.ServicePolicies {
v.ServicePolicies.RangeAll(func(policyId string) {
change := servicePolicyIdentities[policyId]
if change == nil {
change = &edge_ctrl_pb.DataState_ServicePolicyChange{
Expand All @@ -659,7 +662,7 @@ func (rdm *RouterDataModel) GetDataState() *edge_ctrl_pb.DataState {
servicePolicyIdentities[policyId] = change
}
change.RelatedEntityIds = append(change.RelatedEntityIds, v.Id)
}
})
})

rdm.Services.IterCb(func(key string, v *Service) {
Expand Down Expand Up @@ -696,9 +699,9 @@ func (rdm *RouterDataModel) GetDataState() *edge_ctrl_pb.DataState {
RelatedEntityType: edge_ctrl_pb.ServicePolicyRelatedEntityType_RelatedService,
Add: true,
}
for serviceId := range v.Services {
v.Services.RangeAll(func(serviceId string) {
addServicesChange.RelatedEntityIds = append(addServicesChange.RelatedEntityIds, serviceId)
}
})
events = append(events, &edge_ctrl_pb.DataState_Event{
Model: &edge_ctrl_pb.DataState_Event_ServicePolicyChange{
ServicePolicyChange: addServicesChange,
Expand All @@ -710,9 +713,9 @@ func (rdm *RouterDataModel) GetDataState() *edge_ctrl_pb.DataState {
RelatedEntityType: edge_ctrl_pb.ServicePolicyRelatedEntityType_RelatedPostureCheck,
Add: true,
}
for postureCheckId := range v.PostureChecks {
v.PostureChecks.RangeAll(func(postureCheckId string) {
addPostureCheckChanges.RelatedEntityIds = append(addPostureCheckChanges.RelatedEntityIds, postureCheckId)
}
})
events = append(events, &edge_ctrl_pb.DataState_Event{
Model: &edge_ctrl_pb.DataState_Event_ServicePolicyChange{
ServicePolicyChange: addPostureCheckChanges,
Expand Down Expand Up @@ -818,28 +821,22 @@ func (rdm *RouterDataModel) GetServiceAccessPolicies(identityId string, serviceI

postureChecks := map[string]*edge_ctrl_pb.DataState_PostureCheck{}

for servicePolicyId := range identity.ServicePolicies {
identity.ServicePolicies.RangeAll(func(servicePolicyId string) {
servicePolicy, ok := rdm.ServicePolicies.Get(servicePolicyId)

if !ok {
continue
}

if servicePolicy.PolicyType != policyType {
continue
}

policies = append(policies, servicePolicy)
if ok && servicePolicy.PolicyType != policyType {
policies = append(policies, servicePolicy)

for postureCheckId := range servicePolicy.PostureChecks {
if _, ok := postureChecks[postureCheckId]; !ok {
//ignore ok, if !ok postureCheck == nil which will trigger
//failure during evaluation
postureCheck, _ := rdm.PostureChecks.Get(postureCheckId)
postureChecks[postureCheckId] = postureCheck.DataStatePostureCheck
}
servicePolicy.PostureChecks.RangeAll(func(postureCheckId string) {
if _, ok := postureChecks[postureCheckId]; !ok {
//ignore ok, if !ok postureCheck == nil which will trigger
//failure during evaluation
postureCheck, _ := rdm.PostureChecks.Get(postureCheckId)
postureChecks[postureCheckId] = postureCheck.DataStatePostureCheck
}
})
}
}
})

return &AccessPolicies{
Identity: identity,
Expand Down Expand Up @@ -895,20 +892,20 @@ func (rdm *RouterDataModel) buildServiceList(sub *IdentitySubscription) (map[str
services := map[string]*IdentityService{}
postureChecks := map[string]*PostureCheck{}

for policyId := range sub.Identity.ServicePolicies {
sub.Identity.ServicePolicies.RangeAll(func(policyId string) {
policy, ok := rdm.ServicePolicies.Get(policyId)
if !ok {
log.WithField("policyId", policyId).Error("could not find service policy")
continue
return
}

for serviceId := range policy.Services {
policy.Services.RangeAll(func(serviceId string) {
service, ok := rdm.Services.Get(serviceId)
if !ok {
log.WithField("policyId", policyId).
WithField("serviceId", serviceId).
Error("could not find service")
continue
return
}

identityService, ok := services[serviceId]
Expand All @@ -920,16 +917,16 @@ func (rdm *RouterDataModel) buildServiceList(sub *IdentitySubscription) (map[str
}
services[serviceId] = identityService
rdm.loadServiceConfigs(sub.Identity, identityService)
rdm.loadServicePostureChecks(sub.Identity, policy, identityService, postureChecks)
}
rdm.loadServicePostureChecks(sub.Identity, policy, identityService, postureChecks)

if policy.PolicyType == edge_ctrl_pb.PolicyType_BindPolicy {
identityService.BindAllowed = true
} else if policy.PolicyType == edge_ctrl_pb.PolicyType_DialPolicy {
identityService.DialAllowed = true
}
}
}
})
})

return services, postureChecks
}
Expand All @@ -940,15 +937,15 @@ func (rdm *RouterDataModel) loadServicePostureChecks(identity *Identity, policy
WithField("serviceId", svc.Service.Id).
WithField("policyId", policy.Id)

for postureCheckId := range policy.PostureChecks {
policy.PostureChecks.RangeAll(func(postureCheckId string) {
check, ok := rdm.PostureChecks.Get(postureCheckId)
if !ok {
log.WithField("postureCheckId", postureCheckId).Error("could not find posture check")
} else {
svc.Checks[postureCheckId] = struct{}{}
checks[postureCheckId] = check
}
}
})
}

func (rdm *RouterDataModel) loadServiceConfigs(identity *Identity, svc *IdentityService) {
Expand Down Expand Up @@ -1026,7 +1023,7 @@ type DiffSink func(entityType string, id string, diffType DiffType, detail strin
func (rdm *RouterDataModel) Validate(correct *RouterDataModel, sink DiffSink) {
correct.Diff(rdm, sink)
rdm.subscriptions.IterCb(func(key string, v *IdentitySubscription) {
v.Diff(correct, sink)
v.Diff(rdm, sink)
})
}

Expand All @@ -1053,12 +1050,19 @@ func (rdm *RouterDataModel) Diff(o *RouterDataModel, sink DiffSink) {
diffType("identity", rdm.Identities, o.Identities, sink, Identity{}, DataStateIdentity{})
diffType("service", rdm.Services, o.Services, sink, Service{}, DataStateService{})
diffType("service-policy", rdm.ServicePolicies, o.ServicePolicies, sink, ServicePolicy{}, DataStateServicePolicy{})
diffType("posture-check", rdm.PostureChecks, o.PostureChecks, sink, PostureCheck{}, DataStatePostureCheck{})
diffType("posture-check", rdm.PostureChecks, o.PostureChecks, sink,
PostureCheck{}, DataStatePostureCheck{},
edge_ctrl_pb.DataState_PostureCheck_Domains_{}, edge_ctrl_pb.DataState_PostureCheck_Domains{},
edge_ctrl_pb.DataState_PostureCheck_Mac_{}, edge_ctrl_pb.DataState_PostureCheck_Mac{},
edge_ctrl_pb.DataState_PostureCheck_Mfa_{}, edge_ctrl_pb.DataState_PostureCheck_Mfa{},
edge_ctrl_pb.DataState_PostureCheck_OsList_{}, edge_ctrl_pb.DataState_PostureCheck_OsList{}, edge_ctrl_pb.DataState_PostureCheck_Os{},
edge_ctrl_pb.DataState_PostureCheck_Process_{}, edge_ctrl_pb.DataState_PostureCheck_Process{},
edge_ctrl_pb.DataState_PostureCheck_ProcessMulti_{}, edge_ctrl_pb.DataState_PostureCheck_ProcessMulti{})
diffType("public-keys", rdm.PublicKeys, o.PublicKeys, sink, edge_ctrl_pb.DataState_PublicKey{})
diffType("revocations", rdm.Revocations, o.Revocations, sink, edge_ctrl_pb.DataState_Revocation{})
diffMaps("cached-public-keys", rdm.getPublicKeysAsCmap(), o.getPublicKeysAsCmap(), sink, func(a, b crypto.PublicKey) []string {
if a == nil || b == nil {
return []string{fmt.Sprintf("cached public key is nil: orig: %v, dest: %v", a, a)}
return []string{fmt.Sprintf("cached public key is nil: orig: %v, dest: %v", a, b)}
}
return nil
})
Expand Down Expand Up @@ -1098,14 +1102,17 @@ func diffType[P any, T *P](entityType string, m1 cmap.ConcurrentMap[string, T],

hasMissing := false
adapter := cmp.Reporter(diffReporter)
syncSetT := cmp.Transformer("syncSetToMap", func(s *concurrenz.SyncSet[string]) map[string]struct{} {
return s.ToMap()
})
m1.IterCb(func(key string, v T) {
v2, exists := m2.Get(key)
if !exists {
sink(entityType, key, DiffTypeSub, "entity missing")
hasMissing = true
} else {
diffReporter.key = key
cmp.Diff(v, v2, cmpopts.IgnoreUnexported(ignoreTypes...), adapter)
cmp.Diff(v, v2, syncSetT, cmpopts.IgnoreUnexported(ignoreTypes...), adapter)
}
})

Expand Down
Loading

0 comments on commit 5b35d72

Please sign in to comment.