diff --git a/common/event_cache.go b/common/event_cache.go index db42c0e25..0d68eb457 100644 --- a/common/event_cache.go +++ b/common/event_cache.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/openziti/ziti/common/pb/edge_ctrl_pb" "sync" + "sync/atomic" ) type OnStoreSuccess func(index uint64, event *edge_ctrl_pb.DataState_ChangeSet) @@ -51,7 +52,7 @@ type EventCache interface { // when replaying events is not expected (i.e. in routers) type ForgetfulEventCache struct { lock sync.Mutex - index *uint64 + index uint64 } func NewForgetfulEventCache() *ForgetfulEventCache { @@ -59,9 +60,7 @@ func NewForgetfulEventCache() *ForgetfulEventCache { } func (cache *ForgetfulEventCache) SetCurrentIndex(index uint64) { - cache.lock.Lock() - defer cache.lock.Unlock() - cache.index = &index + cache.index = index } func (cache *ForgetfulEventCache) WhileLocked(callback func(uint64, bool)) { @@ -82,16 +81,14 @@ func (cache *ForgetfulEventCache) Store(event *edge_ctrl_pb.DataState_ChangeSet, return nil } - if cache.index != nil { - if *cache.index >= event.Index { - return fmt.Errorf("out of order event detected, currentIndex: %d, receivedIndex: %d, type :%T", *cache.index, event.Index, cache) - } + if cache.index > 0 && cache.index >= event.Index { + return fmt.Errorf("out of order event detected, currentIndex: %d, receivedIndex: %d, type :%T", cache.index, event.Index, cache) } - cache.index = &event.Index + cache.index = event.Index if onSuccess != nil { - onSuccess(*cache.index, event) + onSuccess(cache.index, event) } return nil @@ -102,18 +99,11 @@ func (cache *ForgetfulEventCache) ReplayFrom(_ uint64) ([]*edge_ctrl_pb.DataStat } func (cache *ForgetfulEventCache) CurrentIndex() (uint64, bool) { - cache.lock.Lock() - defer cache.lock.Unlock() - return cache.currentIndex() } func (cache *ForgetfulEventCache) currentIndex() (uint64, bool) { - if cache.index == nil { - return 0, false - } - - return *cache.index, true + return atomic.LoadUint64(&cache.index), true } // LoggingEventCache stores events in order to support replaying (i.e. in controllers). diff --git a/common/router_data_model.go b/common/router_data_model.go index 0eddb8971..7055faa53 100644 --- a/common/router_data_model.go +++ b/common/router_data_model.go @@ -57,7 +57,7 @@ type DataStateIdentity = edge_ctrl_pb.DataState_Identity type Identity struct { *DataStateIdentity - ServicePolicies map[string]struct{} `json:"servicePolicies"` + ServicePolicies *concurrenz.SyncSet[string] `json:"servicePolicies"` identityIndex uint64 serviceSetIndex uint64 } @@ -94,8 +94,8 @@ type DataStateServicePolicy = edge_ctrl_pb.DataState_ServicePolicy type ServicePolicy struct { *DataStateServicePolicy - Services map[string]struct{} `json:"services"` - PostureChecks map[string]struct{} `json:"postureChecks"` + Services *concurrenz.SyncSet[string] `json:"services"` + PostureChecks *concurrenz.SyncSet[string] `json:"postureChecks"` } // RouterDataModel represents a sub-set of a controller's data model. Enough to validate an identities access to dial/bind @@ -202,6 +202,8 @@ func NewReceiverRouterDataModelFromExisting(existing *RouterDataModel, listenerB closeNotify: closeNotify, stopNotify: make(chan struct{}), } + currentIndex, _ := existing.CurrentIndex() + result.SetCurrentIndex(currentIndex) go result.processSubscriberEvents() return result } @@ -232,6 +234,7 @@ func NewReceiverRouterDataModelFromFile(path string, listenerBufferSize uint, cl err = json.Unmarshal(data, rdmContents) if err != nil { + rdmContents.RouterDataModel.Stop() return nil, err } @@ -371,7 +374,7 @@ func (rdm *RouterDataModel) HandleIdentityEvent(index uint64, event *edge_ctrl_p if valueInMap == nil { identity = &Identity{ DataStateIdentity: model.Identity, - ServicePolicies: map[string]struct{}{}, + ServicePolicies: concurrenz.NewSyncSet[string](), identityIndex: index, } } else { @@ -400,7 +403,7 @@ func (rdm *RouterDataModel) HandleServiceEvent(index uint64, event *edge_ctrl_pb if event.Action == edge_ctrl_pb.DataState_Delete { rdm.Services.Remove(model.Service.Id) rdm.ServicePolicies.IterCb(func(key string, v *ServicePolicy) { - delete(v.Services, model.Service.Id) + v.Services.Remove(model.Service.Id) }) } else { rdm.Services.Set(model.Service.Id, &Service{ @@ -444,8 +447,8 @@ func (rdm *RouterDataModel) applyUpdateServicePolicyEvent(model *edge_ctrl_pb.Da if valueInMap == nil { return &ServicePolicy{ DataStateServicePolicy: servicePolicy, - Services: map[string]struct{}{}, - PostureChecks: map[string]struct{}{}, + Services: concurrenz.NewSyncSet[string](), + PostureChecks: concurrenz.NewSyncSet[string](), } } else { return &ServicePolicy{ @@ -526,9 +529,9 @@ func (rdm *RouterDataModel) HandleServicePolicyChange(index uint64, model *edge_ rdm.Identities.Upsert(identityId, nil, func(exist bool, valueInMap *Identity, newValue *Identity) *Identity { if valueInMap != nil { if model.Add { - valueInMap.ServicePolicies[model.PolicyId] = struct{}{} + valueInMap.ServicePolicies.Add(model.PolicyId) } else { - delete(valueInMap.ServicePolicies, model.PolicyId) + valueInMap.ServicePolicies.Remove(model.PolicyId) } valueInMap.serviceSetIndex = index } @@ -551,21 +554,21 @@ func (rdm *RouterDataModel) HandleServicePolicyChange(index uint64, model *edge_ case edge_ctrl_pb.ServicePolicyRelatedEntityType_RelatedService: if model.Add { for _, serviceId := range model.RelatedEntityIds { - valueInMap.Services[serviceId] = struct{}{} + valueInMap.Services.Add(serviceId) } } else { for _, serviceId := range model.RelatedEntityIds { - delete(valueInMap.Services, serviceId) + valueInMap.Services.Remove(serviceId) } } case edge_ctrl_pb.ServicePolicyRelatedEntityType_RelatedPostureCheck: if model.Add { for _, postureCheckId := range model.RelatedEntityIds { - valueInMap.PostureChecks[postureCheckId] = struct{}{} + valueInMap.PostureChecks.Add(postureCheckId) } } else { for _, postureCheckId := range model.RelatedEntityIds { - delete(valueInMap.PostureChecks, postureCheckId) + valueInMap.PostureChecks.Remove(postureCheckId) } } } @@ -648,7 +651,7 @@ func (rdm *RouterDataModel) GetDataState() *edge_ctrl_pb.DataState { } events = append(events, newEvent) - for policyId := range v.ServicePolicies { + v.ServicePolicies.RangeAll(func(policyId string) { change := servicePolicyIdentities[policyId] if change == nil { change = &edge_ctrl_pb.DataState_ServicePolicyChange{ @@ -659,7 +662,7 @@ func (rdm *RouterDataModel) GetDataState() *edge_ctrl_pb.DataState { servicePolicyIdentities[policyId] = change } change.RelatedEntityIds = append(change.RelatedEntityIds, v.Id) - } + }) }) rdm.Services.IterCb(func(key string, v *Service) { @@ -696,9 +699,9 @@ func (rdm *RouterDataModel) GetDataState() *edge_ctrl_pb.DataState { RelatedEntityType: edge_ctrl_pb.ServicePolicyRelatedEntityType_RelatedService, Add: true, } - for serviceId := range v.Services { + v.Services.RangeAll(func(serviceId string) { addServicesChange.RelatedEntityIds = append(addServicesChange.RelatedEntityIds, serviceId) - } + }) events = append(events, &edge_ctrl_pb.DataState_Event{ Model: &edge_ctrl_pb.DataState_Event_ServicePolicyChange{ ServicePolicyChange: addServicesChange, @@ -710,9 +713,9 @@ func (rdm *RouterDataModel) GetDataState() *edge_ctrl_pb.DataState { RelatedEntityType: edge_ctrl_pb.ServicePolicyRelatedEntityType_RelatedPostureCheck, Add: true, } - for postureCheckId := range v.PostureChecks { + v.PostureChecks.RangeAll(func(postureCheckId string) { addPostureCheckChanges.RelatedEntityIds = append(addPostureCheckChanges.RelatedEntityIds, postureCheckId) - } + }) events = append(events, &edge_ctrl_pb.DataState_Event{ Model: &edge_ctrl_pb.DataState_Event_ServicePolicyChange{ ServicePolicyChange: addPostureCheckChanges, @@ -818,28 +821,22 @@ func (rdm *RouterDataModel) GetServiceAccessPolicies(identityId string, serviceI postureChecks := map[string]*edge_ctrl_pb.DataState_PostureCheck{} - for servicePolicyId := range identity.ServicePolicies { + identity.ServicePolicies.RangeAll(func(servicePolicyId string) { servicePolicy, ok := rdm.ServicePolicies.Get(servicePolicyId) - if !ok { - continue - } - - if servicePolicy.PolicyType != policyType { - continue - } - - policies = append(policies, servicePolicy) + if ok && servicePolicy.PolicyType != policyType { + policies = append(policies, servicePolicy) - for postureCheckId := range servicePolicy.PostureChecks { - if _, ok := postureChecks[postureCheckId]; !ok { - //ignore ok, if !ok postureCheck == nil which will trigger - //failure during evaluation - postureCheck, _ := rdm.PostureChecks.Get(postureCheckId) - postureChecks[postureCheckId] = postureCheck.DataStatePostureCheck - } + servicePolicy.PostureChecks.RangeAll(func(postureCheckId string) { + if _, ok := postureChecks[postureCheckId]; !ok { + //ignore ok, if !ok postureCheck == nil which will trigger + //failure during evaluation + postureCheck, _ := rdm.PostureChecks.Get(postureCheckId) + postureChecks[postureCheckId] = postureCheck.DataStatePostureCheck + } + }) } - } + }) return &AccessPolicies{ Identity: identity, @@ -895,20 +892,20 @@ func (rdm *RouterDataModel) buildServiceList(sub *IdentitySubscription) (map[str services := map[string]*IdentityService{} postureChecks := map[string]*PostureCheck{} - for policyId := range sub.Identity.ServicePolicies { + sub.Identity.ServicePolicies.RangeAll(func(policyId string) { policy, ok := rdm.ServicePolicies.Get(policyId) if !ok { log.WithField("policyId", policyId).Error("could not find service policy") - continue + return } - for serviceId := range policy.Services { + policy.Services.RangeAll(func(serviceId string) { service, ok := rdm.Services.Get(serviceId) if !ok { log.WithField("policyId", policyId). WithField("serviceId", serviceId). Error("could not find service") - continue + return } identityService, ok := services[serviceId] @@ -920,16 +917,16 @@ func (rdm *RouterDataModel) buildServiceList(sub *IdentitySubscription) (map[str } services[serviceId] = identityService rdm.loadServiceConfigs(sub.Identity, identityService) - rdm.loadServicePostureChecks(sub.Identity, policy, identityService, postureChecks) } + rdm.loadServicePostureChecks(sub.Identity, policy, identityService, postureChecks) if policy.PolicyType == edge_ctrl_pb.PolicyType_BindPolicy { identityService.BindAllowed = true } else if policy.PolicyType == edge_ctrl_pb.PolicyType_DialPolicy { identityService.DialAllowed = true } - } - } + }) + }) return services, postureChecks } @@ -940,7 +937,7 @@ func (rdm *RouterDataModel) loadServicePostureChecks(identity *Identity, policy WithField("serviceId", svc.Service.Id). WithField("policyId", policy.Id) - for postureCheckId := range policy.PostureChecks { + policy.PostureChecks.RangeAll(func(postureCheckId string) { check, ok := rdm.PostureChecks.Get(postureCheckId) if !ok { log.WithField("postureCheckId", postureCheckId).Error("could not find posture check") @@ -948,7 +945,7 @@ func (rdm *RouterDataModel) loadServicePostureChecks(identity *Identity, policy svc.Checks[postureCheckId] = struct{}{} checks[postureCheckId] = check } - } + }) } func (rdm *RouterDataModel) loadServiceConfigs(identity *Identity, svc *IdentityService) { @@ -1026,7 +1023,7 @@ type DiffSink func(entityType string, id string, diffType DiffType, detail strin func (rdm *RouterDataModel) Validate(correct *RouterDataModel, sink DiffSink) { correct.Diff(rdm, sink) rdm.subscriptions.IterCb(func(key string, v *IdentitySubscription) { - v.Diff(correct, sink) + v.Diff(rdm, sink) }) } @@ -1053,12 +1050,19 @@ func (rdm *RouterDataModel) Diff(o *RouterDataModel, sink DiffSink) { diffType("identity", rdm.Identities, o.Identities, sink, Identity{}, DataStateIdentity{}) diffType("service", rdm.Services, o.Services, sink, Service{}, DataStateService{}) diffType("service-policy", rdm.ServicePolicies, o.ServicePolicies, sink, ServicePolicy{}, DataStateServicePolicy{}) - diffType("posture-check", rdm.PostureChecks, o.PostureChecks, sink, PostureCheck{}, DataStatePostureCheck{}) + diffType("posture-check", rdm.PostureChecks, o.PostureChecks, sink, + PostureCheck{}, DataStatePostureCheck{}, + edge_ctrl_pb.DataState_PostureCheck_Domains_{}, edge_ctrl_pb.DataState_PostureCheck_Domains{}, + edge_ctrl_pb.DataState_PostureCheck_Mac_{}, edge_ctrl_pb.DataState_PostureCheck_Mac{}, + edge_ctrl_pb.DataState_PostureCheck_Mfa_{}, edge_ctrl_pb.DataState_PostureCheck_Mfa{}, + edge_ctrl_pb.DataState_PostureCheck_OsList_{}, edge_ctrl_pb.DataState_PostureCheck_OsList{}, edge_ctrl_pb.DataState_PostureCheck_Os{}, + edge_ctrl_pb.DataState_PostureCheck_Process_{}, edge_ctrl_pb.DataState_PostureCheck_Process{}, + edge_ctrl_pb.DataState_PostureCheck_ProcessMulti_{}, edge_ctrl_pb.DataState_PostureCheck_ProcessMulti{}) diffType("public-keys", rdm.PublicKeys, o.PublicKeys, sink, edge_ctrl_pb.DataState_PublicKey{}) diffType("revocations", rdm.Revocations, o.Revocations, sink, edge_ctrl_pb.DataState_Revocation{}) diffMaps("cached-public-keys", rdm.getPublicKeysAsCmap(), o.getPublicKeysAsCmap(), sink, func(a, b crypto.PublicKey) []string { if a == nil || b == nil { - return []string{fmt.Sprintf("cached public key is nil: orig: %v, dest: %v", a, a)} + return []string{fmt.Sprintf("cached public key is nil: orig: %v, dest: %v", a, b)} } return nil }) @@ -1098,6 +1102,9 @@ func diffType[P any, T *P](entityType string, m1 cmap.ConcurrentMap[string, T], hasMissing := false adapter := cmp.Reporter(diffReporter) + syncSetT := cmp.Transformer("syncSetToMap", func(s *concurrenz.SyncSet[string]) map[string]struct{} { + return s.ToMap() + }) m1.IterCb(func(key string, v T) { v2, exists := m2.Get(key) if !exists { @@ -1105,7 +1112,7 @@ func diffType[P any, T *P](entityType string, m1 cmap.ConcurrentMap[string, T], hasMissing = true } else { diffReporter.key = key - cmp.Diff(v, v2, cmpopts.IgnoreUnexported(ignoreTypes...), adapter) + cmp.Diff(v, v2, syncSetT, cmpopts.IgnoreUnexported(ignoreTypes...), adapter) } }) diff --git a/common/subscriber.go b/common/subscriber.go index ae4f6126d..b0a7a79c0 100644 --- a/common/subscriber.go +++ b/common/subscriber.go @@ -21,6 +21,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/michaelquigley/pfxlog" "github.com/openziti/foundation/v2/concurrenz" + "github.com/openziti/ziti/common/pb/edge_ctrl_pb" "sync" ) @@ -106,14 +107,22 @@ func (self *IdentitySubscription) Diff(rdm *RouterDataModel, sink DiffSink) { } adapter := cmp.Reporter(diffReporter) - cmp.Diff(currentState, self, cmpopts.IgnoreUnexported( + syncSetT := cmp.Transformer("syncSetToMap", func(s *concurrenz.SyncSet[string]) map[string]struct{} { + return s.ToMap() + }) + cmp.Diff(currentState, self, syncSetT, cmpopts.IgnoreUnexported( sync.Mutex{}, IdentitySubscription{}, IdentityService{}, Config{}, ConfigType{}, DataStateConfig{}, DataStateConfigType{}, Identity{}, DataStateIdentity{}, Service{}, DataStateService{}, ServicePolicy{}, DataStateServicePolicy{}, - PostureCheck{}, DataStatePostureCheck{}, + PostureCheck{}, DataStatePostureCheck{}, edge_ctrl_pb.DataState_PostureCheck_Domains_{}, edge_ctrl_pb.DataState_PostureCheck_Domains{}, + edge_ctrl_pb.DataState_PostureCheck_Mac_{}, edge_ctrl_pb.DataState_PostureCheck_Mac{}, + edge_ctrl_pb.DataState_PostureCheck_Mfa_{}, edge_ctrl_pb.DataState_PostureCheck_Mfa{}, + edge_ctrl_pb.DataState_PostureCheck_OsList_{}, edge_ctrl_pb.DataState_PostureCheck_OsList{}, edge_ctrl_pb.DataState_PostureCheck_Os{}, + edge_ctrl_pb.DataState_PostureCheck_Process_{}, edge_ctrl_pb.DataState_PostureCheck_Process{}, + edge_ctrl_pb.DataState_PostureCheck_ProcessMulti_{}, edge_ctrl_pb.DataState_PostureCheck_ProcessMulti{}, ), adapter) } @@ -184,6 +193,11 @@ func (self *IdentitySubscription) initialize(rdm *RouterDataModel, identity *Ide } func (self *IdentitySubscription) checkForChanges(rdm *RouterDataModel) { + idx, _ := rdm.CurrentIndex() + log := pfxlog.Logger(). + WithField("index", idx). + WithField("identity", self.IdentityId) + self.Lock() newIdentity, ok := rdm.Identities.Get(self.IdentityId) notifyRemoved := !ok && self.Identity != nil @@ -197,6 +211,7 @@ func (self *IdentitySubscription) checkForChanges(rdm *RouterDataModel) { newServices := self.Services newChecks := self.Checks self.Unlock() + log.Debugf("identity subscriber updated. identities old: %p new: %p, rdm: %p", oldIdentity, newIdentity, rdm) if notifyRemoved { state := &IdentityState{ diff --git a/controller/sync_strats/sync_instant.go b/controller/sync_strats/sync_instant.go index 7009f8120..93e1fecae 100644 --- a/controller/sync_strats/sync_instant.go +++ b/controller/sync_strats/sync_instant.go @@ -1058,7 +1058,10 @@ func (strategy *InstantStrategy) ValidateIdentities(tx *bbolt.Tx, rdm *common.Ro policyList := strategy.ae.GetStores().Identity.GetRelatedEntitiesIdList(tx, t.Id, db.EntityTypeServicePolicies) policySet := genext.SliceToSet(policyList) - result = diffSets("identity", t.Id, "service policy", policySet, v.ServicePolicies, result) + + v.ServicePolicies.WithReadLock(func(m map[string]struct{}) { + result = diffSets("identity", t.Id, "service policy", policySet, m, result) + }) return result }) @@ -1158,11 +1161,16 @@ func (strategy *InstantStrategy) ValidateServicePolicies(tx *bbolt.Tx, rdm *comm policyList := strategy.ae.GetStores().ServicePolicy.GetRelatedEntitiesIdList(tx, t.Id, db.EntityTypeServices) policySet := genext.SliceToSet(policyList) - result = diffSets("service policy", t.Id, "service", policySet, v.Services, result) + v.Services.WithReadLock(func(m map[string]struct{}) { + result = diffSets("service policy", t.Id, "service", policySet, m, result) + }) policyList = strategy.ae.GetStores().ServicePolicy.GetRelatedEntitiesIdList(tx, t.Id, db.EntityTypePostureChecks) policySet = genext.SliceToSet(policyList) - result = diffSets("service policy", t.Id, "posture check", policySet, v.PostureChecks, result) + + v.PostureChecks.WithReadLock(func(m map[string]struct{}) { + result = diffSets("service policy", t.Id, "posture check", policySet, m, result) + }) return result }) diff --git a/router/posture/access.go b/router/posture/access.go index 39aab7519..58b087344 100644 --- a/router/posture/access.go +++ b/router/posture/access.go @@ -43,18 +43,18 @@ func IsPassing(accessPolicies *common.AccessPolicies, cache *Cache) (*common.Ser Errors: []error{}, } - for postureCheckId := range policy.PostureChecks { + policy.PostureChecks.RangeAll(func(postureCheckId string) { postureCheck, ok := accessPolicies.PostureChecks[postureCheckId] if !ok || postureCheck == nil { policyErr.Errors = append(policyErr.Errors, fmt.Errorf("posture check id %s not found model", postureCheckId)) - continue + return } if err := EvaluatePostureCheck(postureCheck, cache); err != nil { policyErr.Errors = append(policyErr.Errors, err) } - } + }) if len(policyErr.Errors) == 0 { return policy, nil diff --git a/router/state/dataState.go b/router/state/dataState.go index dd45b1b8d..7f9177bcb 100644 --- a/router/state/dataState.go +++ b/router/state/dataState.go @@ -43,11 +43,13 @@ func (self *DataStateHandler) HandleReceive(msg *channel.Message, ch channel.Cha model := common.NewReceiverRouterDataModel(RouterDataModelListerBufferSize, self.state.GetEnv().GetCloseNotify()) logger.WithField("index", newState.EndIndex).Info("received full router data model state") - for _, event := range newState.Events { - model.Handle(newState.EndIndex, event) - } + model.WhileLocked(func(u uint64, b bool) { + for _, event := range newState.Events { + model.Handle(newState.EndIndex, event) + } + model.SetCurrentIndex(newState.EndIndex) + }) - model.SetCurrentIndex(newState.EndIndex) self.state.SetRouterDataModel(model) logger.WithField("index", newState.EndIndex).Info("finished processing full router data model state") }) diff --git a/router/state/manager.go b/router/state/manager.go index ce3a03758..af78626e8 100644 --- a/router/state/manager.go +++ b/router/state/manager.go @@ -470,7 +470,7 @@ func (sm *ManagerImpl) SetRouterDataModel(model *common.RouterDataModel) { logger = logger.WithField("existingIndex", existingIndex) } model.SyncAllSubscribers() - logger.Info("router data model replacement complete") + logger.Infof("router data model replacement complete, old: %p, new: %p", existing, model) } func (sm *ManagerImpl) MarkSyncInProgress(trackerId string) { diff --git a/router/state/validate.go b/router/state/validate.go index d5a4c6ca0..7a4f10185 100644 --- a/router/state/validate.go +++ b/router/state/validate.go @@ -37,12 +37,13 @@ func (self *ValidateDataStateRequestHandler) HandleReceive(msg *channel.Message, newState := request.State model := common.NewBareRouterDataModel() + model.WhileLocked(func(u uint64, b bool) { + for _, event := range newState.Events { + model.Handle(newState.EndIndex, event) + } + model.SetCurrentIndex(newState.EndIndex) + }) - for _, event := range newState.Events { - model.Handle(newState.EndIndex, event) - } - - model.SetCurrentIndex(newState.EndIndex) current := self.state.RouterDataModel() response := &edge_ctrl_pb.RouterDataModelValidateResponse{ diff --git a/zititest/models/router-data-model-test/main.go b/zititest/models/router-data-model-test/main.go index fa0b6de36..c3005255c 100644 --- a/zititest/models/router-data-model-test/main.go +++ b/zititest/models/router-data-model-test/main.go @@ -240,13 +240,6 @@ var m = &model.Model{ workflow.AddAction(zitilibActions.Edge("create", "edge-router-policy", "all", "--edge-router-roles", "#all", "--identity-roles", "#all")) workflow.AddAction(zitilibActions.Edge("create", "service-edge-router-policy", "all", "--service-roles", "#all", "--edge-router-roles", "#all")) - workflow.AddAction(zitilibActions.Edge("create", "config", "host-config", "host.v1", ` - { - "address" : "localhost", - "port" : 8080, - "protocol" : "tcp" - }`)) - workflow.AddAction(model.ActionFunc(func(run model.Run) error { ctrls := &CtrlClients{} if err := ctrls.init(run, "#ctrl1"); err != nil { diff --git a/zititest/models/router-data-model-test/validation.go b/zititest/models/router-data-model-test/validation.go index ad4950e6e..fce2aae5f 100644 --- a/zititest/models/router-data-model-test/validation.go +++ b/zititest/models/router-data-model-test/validation.go @@ -28,6 +28,7 @@ import ( "github.com/openziti/fablab/kernel/lib/parallel" "github.com/openziti/fablab/kernel/model" "github.com/openziti/foundation/v2/errorz" + ptrutil "github.com/openziti/foundation/v2/util" "github.com/openziti/ziti/common/pb/mgmt_pb" "github.com/openziti/ziti/ziti/util" "github.com/openziti/ziti/zitirest" @@ -121,6 +122,7 @@ func sowChaos(run model.Run) error { applyTasks(getRestartTasks) applyTasks(getIdentityChaosTasks) applyTasks(getServicePolicyChaosTasks) + applyTasks(getPostureTasks) if err != nil { return err @@ -137,6 +139,10 @@ func sowChaos(run model.Run) error { if apiErr.GetPayload().Error.Code == errorz.NotFoundCode { return parallel.ErrActionIgnore } + } else if strings.HasPrefix(task.Type(), "create.") && attempt > 1 { + if apiErr.GetPayload().Error.Code == errorz.CouldNotValidateCode { + return parallel.ErrActionIgnore + } } msg = apiErr.GetPayload().Error.Message } @@ -406,6 +412,63 @@ func (self *taskGenerationContext) generateConfigTasks() { } } +func (self *taskGenerationContext) generatePostureCheckTasks() { + if self.err != nil { + return + } + + if scenarioCounter%3 == 0 && len(self.configs) > 2 { // only delete configs every third iteration + for i := 0; i < 2; i++ { + entityId := *self.configs[i].ID + self.lastTasks = append(self.lastTasks, parallel.TaskWithLabel("delete.config", fmt.Sprintf("delete config %s", entityId), func() error { + return models.DeleteConfig(self.ctrls.getRandomCtrl(), entityId, 15*time.Second) + })) + } + } + + // delete any configs used by config types to be deleted + if len(self.configTypesDeleted) > 0 { + for _, config := range self.configs { + if _, deleted := self.configsDeleted[*config.ID]; deleted { + continue + } + if _, deleted := self.configTypesDeleted[*config.ConfigTypeID]; deleted { + entityId := *config.ID + self.configsDeleted[entityId] = struct{}{} + self.lastTasks = append(self.lastTasks, parallel.TaskWithLabel("delete.config", fmt.Sprintf("delete config %s", entityId), func() error { + return models.DeleteConfig(self.ctrls.getRandomCtrl(), entityId, 15*time.Second) + })) + } + } + } + + for i := 2; i < min(7, len(self.configs)); i++ { + entityId := *self.configs[i].ID + self.tasks = append(self.tasks, parallel.TaskWithLabel("modify.config", fmt.Sprintf("modify config %s", entityId), func() error { + entity := self.configs[i] + entity.Name = newId() + entity.Data = map[string]interface{}{ + "hostname": fmt.Sprintf("https://%s.com", uuid.NewString()), + "protocol": func() string { + if rand.Int()%2 == 0 { + return "tcp" + } + return "udp" + }(), + "port": rand.Intn(32000), + } + return models.UpdateConfigFromDetail(self.ctrls.getRandomCtrl(), entity, 15*time.Second) + })) + } + + if len(self.configTypes) > 0 { + createConfigCount := 25 - (len(self.configs) - len(self.configsDeleted)) // target 25 configs available + for i := 0; i < createConfigCount; i++ { + self.tasks = append(self.tasks, createNewConfig(self.ctrls.getRandomCtrl(), self.getConfigTypeId())) + } + } +} + func (self *taskGenerationContext) generateServiceTasks() { if self.err != nil { return @@ -472,6 +535,7 @@ func getServiceAndConfigChaosTasks(_ model.Run, ctrls *CtrlClients) ([]parallel. ctx.loadEntities() ctx.generateConfigTypeTasks() ctx.generateConfigTasks() + ctx.generatePostureCheckTasks() ctx.generateServiceTasks() return ctx.getResults() @@ -517,6 +581,178 @@ func getIdentityChaosTasks(r model.Run, ctrls *CtrlClients) ([]parallel.LabeledT return result, nil } +func getPostureTasks(r model.Run, ctrls *CtrlClients) ([]parallel.LabeledTask, error) { + entities, err := models.ListPostureChecks(ctrls.getRandomCtrl(), "limit none", 15*time.Second) + if err != nil { + return nil, err + } + chaos.Randomize(entities) + + var result []parallel.LabeledTask + + var i int + for len(result) < 5+(len(entities)-100) { + entityId := *entities[i].ID() + result = append(result, parallel.TaskWithLabel("delete.posture-check", fmt.Sprintf("delete posture check %s", entityId), func() error { + return models.DeletePostureCheck(ctrls.getRandomCtrl(), entityId, 15*time.Second) + })) + i++ + } + + for len(result) < min(10, len(entities)) { + entity := entities[i] + entity.SetName(newId()) + entity.SetRoleAttributes(getRoleAttributesAsAttrPtr(3)) + + switch p := entity.(type) { + case *rest_model.PostureCheckDomainDetail: + p.Domains = []string{uuid.NewString(), uuid.NewString()} + case *rest_model.PostureCheckMacAddressDetail: + p.MacAddresses = []string{uuid.NewString(), uuid.NewString()} + case *rest_model.PostureCheckMfaDetail: + p.IgnoreLegacyEndpoints = *newBoolPtr() + p.PromptOnUnlock = *newBoolPtr() + p.PromptOnWake = *newBoolPtr() + p.TimeoutSeconds = int64(rand.Intn(1000)) + case *rest_model.PostureCheckOperatingSystemDetail: + p.OperatingSystems = getRandomOperatingSystems() + case *rest_model.PostureCheckProcessDetail: + p.Process = getRandomProcess() + case *rest_model.PostureCheckProcessMultiDetail: + p.Semantic = ptrutil.Ptr(getRandomSemantic()) + p.Processes = getRandomProcessMultis() + default: + return nil, fmt.Errorf("unhandled posture check type: %T", p) + } + + result = append(result, parallel.TaskWithLabel("modify.posture-check", fmt.Sprintf("modify %s posture-check %s", entity.TypeID(), *entity.ID()), func() error { + return models.UpdatePostureCheckFromDetail(ctrls.getRandomCtrl(), entity, 15*time.Second) + })) + i++ + } + + for i := 0; i < 55-len(entities); i++ { + result = append(result, createNewPostureCheck(ctrls.getRandomCtrl())) + } + + return result, nil +} + +func getRandomSemantic() rest_model.Semantic { + if rand.Int()%2 == 0 { + return rest_model.SemanticAnyOf + } + return rest_model.SemanticAllOf +} + +func getRandomOperatingSystems() []*rest_model.OperatingSystem { + return getRandom(1, 3, getRandomOperatingSystem) +} + +func getRandomOperatingSystem() *rest_model.OperatingSystem { + return &rest_model.OperatingSystem{ + Type: ptrutil.Ptr(getRandomOsType()), + Versions: getRandom(1, 3, getRandomVersion), + } +} + +func getRandomVersion() string { + return fmt.Sprintf("%d.%d.%d", rand.Intn(100), rand.Intn(100), rand.Intn(100)) +} + +func getRandomProcessMultis() []*rest_model.ProcessMulti { + return getRandom(1, 3, getRandomProcessMulti) +} + +func getRandomProcessMulti() *rest_model.ProcessMulti { + return &rest_model.ProcessMulti{ + Hashes: []string{uuid.NewString(), uuid.NewString()}, + OsType: ptrutil.Ptr(getRandomOsType()), + Path: ptrutil.Ptr(uuid.NewString()), + SignerFingerprints: []string{uuid.NewString(), uuid.NewString()}, + } +} + +func getRandomProcess() *rest_model.Process { + return &rest_model.Process{ + Hashes: []string{uuid.NewString(), uuid.NewString()}, + OsType: ptrutil.Ptr(getRandomOsType()), + Path: ptrutil.Ptr(uuid.NewString()), + SignerFingerprint: uuid.NewString(), + } +} + +func getRandom[T any](min, max int, f func() T) []T { + var result []T + count := min + if max > min { + min += rand.Intn(max - min) + } + for i := 0; i < count; i++ { + result = append(result, f()) + } + return result +} + +var osTypes = []rest_model.OsType{ + rest_model.OsTypeLinux, + rest_model.OsTypeWindows, + rest_model.OsTypeMacOS, + rest_model.OsTypeIOS, + rest_model.OsTypeAndroid, + rest_model.OsTypeWindowsServer, +} + +func getRandomOsType() rest_model.OsType { + return osTypes[rand.Intn(len(osTypes))] +} + +func createNewPostureCheck(ctrl *zitirest.Clients) parallel.LabeledTask { + var create rest_model.PostureCheckCreate + + switch rand.Intn(6) { + case 0: + create = &rest_model.PostureCheckDomainCreate{ + Domains: []string{uuid.NewString(), uuid.NewString()}, + } + + case 1: + create = &rest_model.PostureCheckMacAddressCreate{ + MacAddresses: getRandom(1, 3, uuid.NewString), + } + + case 2: + mfaCreate := &rest_model.PostureCheckMfaCreate{} + mfaCreate.IgnoreLegacyEndpoints = *newBoolPtr() + mfaCreate.PromptOnUnlock = *newBoolPtr() + mfaCreate.PromptOnWake = *newBoolPtr() + mfaCreate.TimeoutSeconds = int64(rand.Intn(1000)) + create = mfaCreate + case 3: + create = &rest_model.PostureCheckOperatingSystemCreate{ + OperatingSystems: getRandomOperatingSystems(), + } + case 4: + create = &rest_model.PostureCheckProcessCreate{ + Process: getRandomProcess(), + } + case 5: + create = &rest_model.PostureCheckProcessMultiCreate{ + Semantic: ptrutil.Ptr(getRandomSemantic()), + Processes: getRandomProcessMultis(), + } + default: + panic("programming error") + } + + create.SetName(newId()) + create.SetRoleAttributes(getRoleAttributesAsAttrPtr(3)) + + return parallel.TaskWithLabel("create.posture-check", fmt.Sprintf("create %s posture check", create.TypeID()), func() error { + return models.CreatePostureCheck(ctrl, create, 15*time.Second) + }) +} + func getServicePolicyChaosTasks(_ model.Run, ctrls *CtrlClients) ([]parallel.LabeledTask, error) { entities, err := models.ListServicePolicies(ctrls.getRandomCtrl(), "limit none", 15*time.Second) if err != nil { @@ -642,7 +878,6 @@ func createNewIdentity(ctrl *zitirest.Clients) parallel.LabeledTask { func createNewServicePolicy(ctrl *zitirest.Clients) parallel.LabeledTask { return parallel.TaskWithLabel("create.service-policy", "create new service policy", func() error { - anyOf := rest_model.SemanticAnyOf policyType := rest_model.DialBindDial if rand.Int()%2 == 0 { policyType = rest_model.DialBindBind @@ -651,7 +886,7 @@ func createNewServicePolicy(ctrl *zitirest.Clients) parallel.LabeledTask { Name: newId(), IdentityRoles: getRoles(3), PostureCheckRoles: getRoles(3), - Semantic: &anyOf, + Semantic: ptrutil.Ptr(getRandomSemantic()), ServiceRoles: getRoles(3), Type: &policyType, } diff --git a/zititest/zitilab/models/api.go b/zititest/zitilab/models/api.go index 1a1a77c88..61f93f39e 100644 --- a/zititest/zitilab/models/api.go +++ b/zititest/zitilab/models/api.go @@ -2,8 +2,10 @@ package models import ( "context" + "fmt" "github.com/openziti/edge-api/rest_management_api_client/config" "github.com/openziti/edge-api/rest_management_api_client/identity" + "github.com/openziti/edge-api/rest_management_api_client/posture_checks" "github.com/openziti/edge-api/rest_management_api_client/service" "github.com/openziti/edge-api/rest_management_api_client/service_policy" "github.com/openziti/edge-api/rest_model" @@ -349,3 +351,93 @@ func UpdateConfigType(clients *zitirest.Clients, id string, entity *rest_model.C return err } + +func ListPostureChecks(clients *zitirest.Clients, filter string, timeout time.Duration) ([]rest_model.PostureCheckDetail, error) { + ctx, cancelF := context.WithTimeout(context.Background(), timeout) + defer cancelF() + + result, err := clients.Edge.PostureChecks.ListPostureChecks(&posture_checks.ListPostureChecksParams{ + Filter: &filter, + Context: ctx, + }, nil) + + if err != nil { + return nil, err + } + return result.Payload.Data(), nil +} + +func CreatePostureCheck(clients *zitirest.Clients, entity rest_model.PostureCheckCreate, timeout time.Duration) error { + ctx, cancelF := context.WithTimeout(context.Background(), timeout) + defer cancelF() + + _, err := clients.Edge.PostureChecks.CreatePostureCheck(&posture_checks.CreatePostureCheckParams{ + Context: ctx, + PostureCheck: entity, + }, nil) + + return util.WrapIfApiError(err) +} + +func DeletePostureCheck(clients *zitirest.Clients, id string, timeout time.Duration) error { + ctx, cancelF := context.WithTimeout(context.Background(), timeout) + defer cancelF() + + _, err := clients.Edge.PostureChecks.DeletePostureCheck(&posture_checks.DeletePostureCheckParams{ + Context: ctx, + ID: id, + }, nil) + + return err +} + +func UpdatePostureCheckFromDetail(clients *zitirest.Clients, entity rest_model.PostureCheckDetail, timeout time.Duration) error { + var update rest_model.PostureCheckUpdate + switch p := entity.(type) { + case *rest_model.PostureCheckDomainDetail: + update = &rest_model.PostureCheckDomainUpdate{ + Domains: p.Domains, + } + case *rest_model.PostureCheckMacAddressDetail: + update = &rest_model.PostureCheckMacAddressUpdate{ + MacAddresses: p.MacAddresses, + } + case *rest_model.PostureCheckMfaDetail: + update = &rest_model.PostureCheckMfaUpdate{ + PostureCheckMfaProperties: p.PostureCheckMfaProperties, + } + case *rest_model.PostureCheckOperatingSystemDetail: + update = &rest_model.PostureCheckOperatingSystemUpdate{ + OperatingSystems: p.OperatingSystems, + } + case *rest_model.PostureCheckProcessDetail: + update = &rest_model.PostureCheckProcessUpdate{ + Process: p.Process, + } + case *rest_model.PostureCheckProcessMultiDetail: + update = &rest_model.PostureCheckProcessMultiUpdate{ + Semantic: p.Semantic, + Processes: p.Processes, + } + default: + return fmt.Errorf("unhandled posture check type %T", p) + } + + update.SetName(entity.Name()) + update.SetRoleAttributes(entity.RoleAttributes()) + + return UpdatePostureCheck(clients, *entity.ID(), update, timeout) +} + +func UpdatePostureCheck(clients *zitirest.Clients, id string, entity rest_model.PostureCheckUpdate, timeout time.Duration) error { + ctx, cancelF := context.WithTimeout(context.Background(), timeout) + defer cancelF() + + _, err := clients.Edge.PostureChecks.UpdatePostureCheck(&posture_checks.UpdatePostureCheckParams{ + Context: ctx, + ID: id, + PostureCheck: entity, + }, nil) + + return err +}