Skip to content

Commit

Permalink
add txns support
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanulit committed Dec 2, 2024
1 parent 90edbde commit 89255a5
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 79 deletions.
34 changes: 21 additions & 13 deletions service/cmd/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,29 @@ var (
panic(fmt.Errorf("could not load config: %w", err))
}

res := dbClient.AttrFqnReindex(context.Background())
cmd.Print("Namespace FQNs reindexed:\n")
for _, r := range res.Namespaces {
cmd.Printf("\t%s: %s\n", r.ID, r.Fqn)
}
ctx := context.Background()

cmd.Print("Attribute FQNs reindexed:\n")
for _, r := range res.Attributes {
cmd.Printf("\t%s: %s\n", r.ID, r.Fqn)
}
// ignore error as dbClient.AttrFqnReindex will panic on error
_ = dbClient.RunInTx(ctx, func(txClient *policydb.PolicyDBClient) error {
res := txClient.AttrFqnReindex(ctx)

cmd.Print("Attribute Value FQNs reindexed:\n")
for _, r := range res.Values {
cmd.Printf("\t%s: %s\n", r.ID, r.Fqn)
}
cmd.Print("Namespace FQNs reindexed:\n")
for _, r := range res.Namespaces {
cmd.Printf("\t%s: %s\n", r.ID, r.Fqn)
}

cmd.Print("Attribute FQNs reindexed:\n")
for _, r := range res.Attributes {
cmd.Printf("\t%s: %s\n", r.ID, r.Fqn)
}

cmd.Print("Attribute Value FQNs reindexed:\n")
for _, r := range res.Values {
cmd.Printf("\t%s: %s\n", r.ID, r.Fqn)
}

return nil
})
},
}
)
Expand Down
24 changes: 16 additions & 8 deletions service/policy/attributes/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,25 @@ func (s *AttributesService) CreateAttributeValue(ctx context.Context, req *conne
ActionType: audit.ActionTypeCreate,
}

item, err := s.dbClient.CreateAttributeValue(ctx, req.Msg.GetAttributeId(), req.Msg)
err := s.dbClient.RunInTx(ctx, func(txClient *policydb.PolicyDBClient) error {
item, err := txClient.CreateAttributeValue(ctx, req.Msg.GetAttributeId(), req.Msg)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return err
}

auditParams.ObjectID = item.GetId()
auditParams.Original = item
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Value = item

return nil
})
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("attributeId", req.Msg.GetAttributeId()), slog.String("value", req.Msg.String()))
return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("value", req.Msg.String()))
}

auditParams.ObjectID = item.GetId()
auditParams.Original = item
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Value = item
return connect.NewResponse(rsp), nil
}

Expand Down
27 changes: 17 additions & 10 deletions service/policy/namespaces/namespaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,26 @@ func (ns NamespacesService) CreateNamespace(ctx context.Context, req *connect.Re
}
rsp := &namespaces.CreateNamespaceResponse{}

n, err := ns.dbClient.CreateNamespace(ctx, req.Msg)
err := ns.dbClient.RunInTx(ctx, func(txClient *policydb.PolicyDBClient) error {
n, err := txClient.CreateNamespace(ctx, req.Msg)
if err != nil {
ns.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return err
}

auditParams.ObjectID = n.GetId()
auditParams.Original = n
ns.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

ns.logger.Debug("created new namespace", slog.String("name", req.Msg.GetName()))
rsp.Namespace = n

return nil
})
if err != nil {
ns.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("name", req.Msg.GetName()))
return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("namespace", req.Msg.String()))
}

auditParams.ObjectID = n.GetId()
auditParams.Original = n
ns.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

ns.logger.Debug("created new namespace", slog.String("name", req.Msg.GetName()))
rsp.Namespace = n

return connect.NewResponse(rsp), nil
}

Expand Down
136 changes: 88 additions & 48 deletions service/policy/unsafe/unsafe.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,38 @@ func (s *UnsafeService) UnsafeUpdateNamespace(ctx context.Context, req *connect.
ObjectID: id,
}

original, err := s.dbClient.GetNamespace(ctx, id)
var (
original *policy.Namespace
updated *policy.Namespace
err error
)

err = s.dbClient.RunInTx(ctx, func(txClient *policydb.PolicyDBClient) error {
original, err = txClient.GetNamespace(ctx, id)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return err
}

updated, err = txClient.UnsafeUpdateNamespace(ctx, id, name)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return err
}

auditParams.Original = original
auditParams.Updated = updated

s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Namespace = &policy.Namespace{
Id: id,
}

return nil
})
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.String("id", id))
}

updated, err := s.dbClient.UnsafeUpdateNamespace(ctx, id, name)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextUpdateFailed, slog.String("id", id), slog.String("namespace", name))
}

auditParams.Original = original
auditParams.Updated = updated

s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Namespace = &policy.Namespace{
Id: id,
return nil, db.StatusifyError(err, db.ErrTextUpdateFailed, slog.String("namespace", req.Msg.String()))
}

return connect.NewResponse(rsp), nil
Expand Down Expand Up @@ -163,25 +176,38 @@ func (s *UnsafeService) UnsafeUpdateAttribute(ctx context.Context, req *connect.
ObjectID: id,
}

original, err := s.dbClient.GetAttribute(ctx, id)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.String("id", id))
}
var (
original *policy.Attribute
updated *policy.Attribute
err error
)

updated, err := s.dbClient.UnsafeUpdateAttribute(ctx, req.Msg)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextUpdateFailed, slog.String("id", id), slog.String("attribute", req.Msg.String()))
}
err = s.dbClient.RunInTx(ctx, func(txClient *policydb.PolicyDBClient) error {
original, err = txClient.GetAttribute(ctx, id)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return err
}

auditParams.Original = original
auditParams.Updated = updated
updated, err = txClient.UnsafeUpdateAttribute(ctx, req.Msg)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return err
}

s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)
auditParams.Original = original
auditParams.Updated = updated

rsp.Attribute = &policy.Attribute{
Id: id,
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Attribute = &policy.Attribute{
Id: id,
}

return nil
})
if err != nil {
return nil, db.StatusifyError(err, db.ErrTextUpdateFailed, slog.String("attribute", req.Msg.String()))
}

return connect.NewResponse(rsp), nil
Expand Down Expand Up @@ -269,26 +295,40 @@ func (s *UnsafeService) UnsafeUpdateAttributeValue(ctx context.Context, req *con
ObjectID: id,
}

original, err := s.dbClient.GetAttributeValue(ctx, id)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.String("id", id))
}
var (
original *policy.Value
updated *policy.Value
err error
)

updated, err := s.dbClient.UnsafeUpdateAttributeValue(ctx, req.Msg)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextUpdateFailed, slog.String("id", id), slog.String("attribute_value", req.Msg.String()))
}
err = s.dbClient.RunInTx(ctx, func(txClient *policydb.PolicyDBClient) error {
original, err = txClient.GetAttributeValue(ctx, id)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return err
}

auditParams.Original = original
auditParams.Updated = updated
updated, err = txClient.UnsafeUpdateAttributeValue(ctx, req.Msg)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return err
}

s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)
auditParams.Original = original
auditParams.Updated = updated

rsp.Value = &policy.Value{
Id: id,
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Value = &policy.Value{
Id: id,
}

return nil
})
if err != nil {
return nil, db.StatusifyError(err, db.ErrTextUpdateFailed, slog.String("attribute_value", req.Msg.String()))
}

return connect.NewResponse(rsp), nil
}

Expand Down

0 comments on commit 89255a5

Please sign in to comment.