From 8f3d7a872801eaf759309ae5ee2ddaa3385c605d Mon Sep 17 00:00:00 2001 From: Barnaby Gray Date: Thu, 23 Feb 2023 17:13:36 +0000 Subject: [PATCH] Add `--timeout` option --- commands.go | 99 ++++++++++++++------------- instances.go | 9 +-- internal/features/step_definitions.go | 13 ++-- main.go | 65 ++++++++++++++---- util.go | 11 +-- 5 files changed, 121 insertions(+), 76 deletions(-) diff --git a/commands.go b/commands.go index f3858a4a..97348c67 100644 --- a/commands.go +++ b/commands.go @@ -1,6 +1,7 @@ package cli53 import ( + "context" "fmt" "io" "os" @@ -14,7 +15,7 @@ import ( const ChangeBatchSize = 100 -func createZone(name, comment, vpcId, vpcRegion, delegationSetId string) { +func createZone(ctx context.Context, name, comment, vpcId, vpcRegion, delegationSetId string) { callerReference := uniqueReference() req := route53.CreateHostedZoneInput{ CallerReference: &callerReference, @@ -34,12 +35,12 @@ func createZone(name, comment, vpcId, vpcRegion, delegationSetId string) { delegationSetId = strings.Replace(delegationSetId, "/delegationset/", "", 1) req.DelegationSetId = aws.String(delegationSetId) } - resp, err := r53.CreateHostedZone(&req) + resp, err := r53.CreateHostedZoneWithContext(ctx, &req) fatalIfErr(err) fmt.Printf("Created zone: '%s' ID: '%s'\n", *resp.HostedZone.Name, *resp.HostedZone.Id) } -func createReusableDelegationSet(zoneId string) { +func createReusableDelegationSet(ctx context.Context, zoneId string) { callerReference := uniqueReference() req := route53.CreateReusableDelegationSetInput{ CallerReference: &callerReference, @@ -47,7 +48,7 @@ func createReusableDelegationSet(zoneId string) { if zoneId != "" { req.HostedZoneId = &zoneId } - resp, err := r53.CreateReusableDelegationSet(&req) + resp, err := r53.CreateReusableDelegationSetWithContext(ctx, &req) fatalIfErr(err) ds := resp.DelegationSet fmt.Printf("Created reusable delegation set ID: '%s'\n", *ds.Id) @@ -56,9 +57,9 @@ func createReusableDelegationSet(zoneId string) { } } -func listReusableDelegationSets() { +func listReusableDelegationSets(ctx context.Context) { req := route53.ListReusableDelegationSetsInput{} - resp, err := r53.ListReusableDelegationSets(&req) + resp, err := r53.ListReusableDelegationSetsWithContext(ctx, &req) fatalIfErr(err) fmt.Printf("Reusable delegation sets:\n") if len(resp.DelegationSets) == 0 { @@ -74,19 +75,19 @@ func listReusableDelegationSets() { } } -func deleteReusableDelegationSet(id string) { +func deleteReusableDelegationSet(ctx context.Context, id string) { if !strings.HasPrefix(id, "/delegationset/") { id = "/delegationset/" + id } req := route53.DeleteReusableDelegationSetInput{ Id: &id, } - _, err := r53.DeleteReusableDelegationSet(&req) + _, err := r53.DeleteReusableDelegationSetWithContext(ctx, &req) fatalIfErr(err) fmt.Printf("Deleted reusable delegation set\n") } -func deleteRecordSets(zone *route53.HostedZone, rrsets []*route53.ResourceRecordSet, wait bool) (int, error) { +func deleteRecordSets(ctx context.Context, zone *route53.HostedZone, rrsets []*route53.ResourceRecordSet, wait bool) (int, error) { // delete all non-default SOA/NS records changes := []*route53.Change{} for _, rrset := range rrsets { @@ -106,21 +107,21 @@ func deleteRecordSets(zone *route53.HostedZone, rrsets []*route53.ResourceRecord Changes: changes, }, } - resp, err := r53.ChangeResourceRecordSets(&req) + resp, err := r53.ChangeResourceRecordSetsWithContext(ctx, &req) if err != nil { return 0, err } if wait { - waitForChange(resp.ChangeInfo) + waitForChange(ctx, resp.ChangeInfo) } } return len(changes), nil } -func purgeZoneRecords(zone *route53.HostedZone, wait bool) { +func purgeZoneRecords(ctx context.Context, zone *route53.HostedZone, wait bool) { total := 0 - err := batchListAllRecordSets(r53, *zone.Id, func(rrsets []*route53.ResourceRecordSet) { - n, err := deleteRecordSets(zone, rrsets, wait) + err := batchListAllRecordSets(ctx, r53, *zone.Id, func(rrsets []*route53.ResourceRecordSet) { + n, err := deleteRecordSets(ctx, zone, rrsets, wait) fatalIfErr(err) total += n }) @@ -129,24 +130,24 @@ func purgeZoneRecords(zone *route53.HostedZone, wait bool) { fmt.Printf("%d record sets deleted\n", total) } -func deleteZone(name string, purge bool) { - zone := lookupZone(name) +func deleteZone(ctx context.Context, name string, purge bool) { + zone := lookupZone(ctx, name) if purge { - purgeZoneRecords(zone, false) + purgeZoneRecords(ctx, zone, false) } req := route53.DeleteHostedZoneInput{Id: zone.Id} - _, err := r53.DeleteHostedZone(&req) + _, err := r53.DeleteHostedZoneWithContext(ctx, &req) fatalIfErr(err) fmt.Printf("Deleted zone: '%s' ID: '%s'\n", *zone.Name, *zone.Id) } -func listZones(formatter Formatter) { +func listZones(ctx context.Context, formatter Formatter) { zones := make(chan *route53.HostedZone) go func() { req := route53.ListHostedZonesInput{} for { // paginated - resp, err := r53.ListHostedZones(&req) + resp, err := r53.ListHostedZonesWithContext(ctx, &req) fatalIfErr(err) for _, zone := range resp.HostedZones { zones <- zone @@ -276,8 +277,8 @@ func validateBindFile(args importArgs) { parseBindFile(reader, args.file, "validate.test") } -func importBind(args importArgs) { - zone := lookupZone(args.name) +func importBind(ctx context.Context, args importArgs) { + zone := lookupZone(ctx, args.name) var reader io.Reader if args.file == "-" { @@ -295,7 +296,7 @@ func importBind(args importArgs) { grouped := groupRecords(records) existing := map[string]*route53.ResourceRecordSet{} if args.replace || args.upsert { - rrsets, err := ListAllRecordSets(r53, *zone.Id) + rrsets, err := ListAllRecordSets(ctx, r53, *zone.Id) fatalIfErr(err) for _, rrset := range rrsets { if args.editauth || !isAuthRecord(zone, rrset) { @@ -363,16 +364,16 @@ func importBind(args importArgs) { } } } else { - resp := batchChanges(additions, deletions, zone) + resp := batchChanges(ctx, additions, deletions, zone) fmt.Printf("%d records imported (%d changes / %d additions / %d deletions)\n", len(records), len(additions)+len(deletions), len(additions), len(deletions)) if args.wait && resp != nil { - waitForChange(resp.ChangeInfo) + waitForChange(ctx, resp.ChangeInfo) } } } -func batchChanges(additions, deletions []*route53.Change, zone *route53.HostedZone) *route53.ChangeResourceRecordSetsOutput { +func batchChanges(ctx context.Context, additions, deletions []*route53.Change, zone *route53.HostedZone) *route53.ChangeResourceRecordSetsOutput { // sort additions so aliases are last sort.Sort(changeSorter{additions}) @@ -392,7 +393,7 @@ func batchChanges(additions, deletions []*route53.Change, zone *route53.HostedZo ChangeBatch: &batch, } var err error - resp, err = r53.ChangeResourceRecordSets(&req) + resp, err = r53.ChangeResourceRecordSetsWithContext(ctx, &req) fatalIfErr(err) } return resp @@ -416,9 +417,9 @@ func UnexpandSelfAliases(records []dns.RR, zone *route53.HostedZone, full bool) } } -func exportBind(name string, full bool, writer io.Writer) { - zone := lookupZone(name) - ExportBindToWriter(r53, zone, full, writer) +func exportBind(ctx context.Context, name string, full bool, writer io.Writer) { + zone := lookupZone(ctx, name) + ExportBindToWriter(ctx, r53, zone, full, writer) } type exportSorter struct { @@ -450,8 +451,8 @@ func (r exportSorter) Less(i, j int) bool { return *r.rrsets[i].Name < *r.rrsets[j].Name } -func ExportBindToWriter(r53 *route53.Route53, zone *route53.HostedZone, full bool, out io.Writer) { - rrsets, err := ListAllRecordSets(r53, *zone.Id) +func ExportBindToWriter(ctx context.Context, r53 *route53.Route53, zone *route53.HostedZone, full bool, out io.Writer) { + rrsets, err := ListAllRecordSets(ctx, r53, *zone.Id) fatalIfErr(err) sort.Sort(exportSorter{rrsets, *zone.Name}) @@ -607,8 +608,8 @@ func parseRecordList(args []string, zone *route53.HostedZone) []dns.RR { return records } -func createRecords(args createArgs) { - zone := lookupZone(args.name) +func createRecords(ctx context.Context, args createArgs) { + zone := lookupZone(ctx, args.name) records := parseRecordList(args.records, zone) expandSelfAliases(records, zone) @@ -617,7 +618,7 @@ func createRecords(args createArgs) { var existing []*route53.ResourceRecordSet if args.replace || args.append { var err error - existing, err = ListAllRecordSets(r53, *zone.Id) + existing, err = ListAllRecordSets(ctx, r53, *zone.Id) fatalIfErr(err) } @@ -654,7 +655,7 @@ func createRecords(args createArgs) { } } - resp := batchChanges(additions, deletions, zone) + resp := batchChanges(ctx, additions, deletions, zone) for _, record := range records { txt := strings.Replace(record.String(), "\t", " ", -1) @@ -662,17 +663,17 @@ func createRecords(args createArgs) { } if args.wait { - waitForChange(resp.ChangeInfo) + waitForChange(ctx, resp.ChangeInfo) } } -func batchListAllRecordSets(r53 *route53.Route53, id string, callback func(rrsets []*route53.ResourceRecordSet)) error { +func batchListAllRecordSets(ctx context.Context, r53 *route53.Route53, id string, callback func(rrsets []*route53.ResourceRecordSet)) error { req := route53.ListResourceRecordSetsInput{ HostedZoneId: &id, } for { - resp, err := r53.ListResourceRecordSets(&req) + resp, err := r53.ListResourceRecordSetsWithContext(ctx, &req) if err != nil { return err } else { @@ -690,8 +691,8 @@ func batchListAllRecordSets(r53 *route53.Route53, id string, callback func(rrset } // Paginate request to get all record sets. -func ListAllRecordSets(r53 *route53.Route53, id string) (rrsets []*route53.ResourceRecordSet, err error) { - err = batchListAllRecordSets(r53, id, func(results []*route53.ResourceRecordSet) { +func ListAllRecordSets(ctx context.Context, r53 *route53.Route53, id string) (rrsets []*route53.ResourceRecordSet, err error) { + err = batchListAllRecordSets(ctx, r53, id, func(results []*route53.ResourceRecordSet) { rrsets = append(rrsets, results...) }) @@ -703,9 +704,9 @@ func ListAllRecordSets(r53 *route53.Route53, id string) (rrsets []*route53.Resou return } -func deleteRecord(name string, match string, rtype string, wait bool, identifier string) { - zone := lookupZone(name) - rrsets, err := ListAllRecordSets(r53, *zone.Id) +func deleteRecord(ctx context.Context, name string, match string, rtype string, wait bool, identifier string) { + zone := lookupZone(ctx, name) + rrsets, err := ListAllRecordSets(ctx, r53, *zone.Id) fatalIfErr(err) match = qualifyName(match, *zone.Name) @@ -727,18 +728,18 @@ func deleteRecord(name string, match string, rtype string, wait bool, identifier Changes: changes, }, } - resp, err := r53.ChangeResourceRecordSets(&req2) + resp, err := r53.ChangeResourceRecordSetsWithContext(ctx, &req2) fatalIfErr(err) fmt.Printf("%d record sets deleted\n", len(changes)) if wait { - waitForChange(resp.ChangeInfo) + waitForChange(ctx, resp.ChangeInfo) } } else { fmt.Println("Warning: no records matched - nothing deleted") } } -func purgeRecords(name string, wait bool) { - zone := lookupZone(name) - purgeZoneRecords(zone, wait) +func purgeRecords(ctx context.Context, name string, wait bool) { + zone := lookupZone(ctx, name) + purgeZoneRecords(ctx, zone, wait) } diff --git a/instances.go b/instances.go index b608f9b4..731c4606 100644 --- a/instances.go +++ b/instances.go @@ -1,6 +1,7 @@ package cli53 import ( + "context" "fmt" "regexp" "strings" @@ -28,8 +29,8 @@ type InstanceRecord struct { value string } -func instances(args instancesArgs, config *aws.Config) { - zone := lookupZone(args.name) +func instances(ctx context.Context, args instancesArgs, config *aws.Config) { + zone := lookupZone(ctx, args.name) fmt.Println("Getting DNS records") describeInstancesInput := ec2.DescribeInstancesInput{} @@ -140,11 +141,11 @@ func instances(args instancesArgs, config *aws.Config) { fmt.Printf("+ %s %s %v\n", *rr.Name, *rr.Type, *rr.ResourceRecords[0].Value) } } else { - resp := batchChanges(upserts, []*route53.Change{}, zone) + resp := batchChanges(ctx, upserts, []*route53.Change{}, zone) fmt.Printf("%d records upserted\n", len(upserts)) if args.wait && resp != nil { - waitForChange(resp.ChangeInfo) + waitForChange(ctx, resp.ChangeInfo) } } } diff --git a/internal/features/step_definitions.go b/internal/features/step_definitions.go index 5c5fa7f5..55eb128d 100644 --- a/internal/features/step_definitions.go +++ b/internal/features/step_definitions.go @@ -2,6 +2,7 @@ package features import ( "bytes" + "context" "fmt" "io/ioutil" "log" @@ -90,7 +91,8 @@ func uniqueReference() string { func cleanupDomain(r53 *route53.Route53, id string) { // delete all non-default SOA/NS records - rrsets, err := cli53.ListAllRecordSets(r53, id) + ctx := context.Background() + rrsets, err := cli53.ListAllRecordSets(ctx, r53, id) fatalIfErr(err) changes := []*route53.Change{} for _, rrset := range rrsets { @@ -309,7 +311,8 @@ func init() { name = domain(name) r53 := getService() id := domainId(name) - rrsets, err := cli53.ListAllRecordSets(r53, id) + ctx := context.Background() + rrsets, err := cli53.ListAllRecordSets(ctx, r53, id) fatalIfErr(err) actual := len(rrsets) if expected != actual { @@ -338,7 +341,8 @@ func init() { r53 := getService() zone := domainZone(name) out := new(bytes.Buffer) - cli53.ExportBindToWriter(r53, zone, false, out) + ctx := context.Background() + cli53.ExportBindToWriter(ctx, r53, zone, false, out) actual := out.Bytes() rfile, err := os.Open(filename) fatalIfErr(err) @@ -414,7 +418,8 @@ func init() { func hasRecord(name, record string) bool { r53 := getService() zone := domainZone(name) - rrsets, err := cli53.ListAllRecordSets(r53, *zone.Id) + ctx := context.Background() + rrsets, err := cli53.ListAllRecordSets(ctx, r53, *zone.Id) fatalIfErr(err) for _, rrset := range rrsets { diff --git a/main.go b/main.go index 4bf41041..68cf5273 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,9 @@ package cli53 import ( + "context" "os" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/route53" @@ -11,6 +13,14 @@ import ( var r53 *route53.Route53 var version = "main" +func theContext(c *cli.Context) (context.Context, func()) { + if c.IsSet("timeout") { + timeout := time.Second * time.Duration(c.Float64("timeout")) + return context.WithTimeout(context.Background(), timeout) + } + return context.Background(), func() {} +} + // Main entry point for cli53 application func Main(args []string) int { commonFlags := []cli.Flag{ @@ -31,6 +41,10 @@ func Main(args []string) int { Name: "endpoint-url", Usage: "override Route53 endpoint (hostname or fully qualified URI)", }, + &cli.Float64Flag{ + Name: "timeout", + Usage: "timeout in seconds", + }, } app := cli.NewApp() @@ -64,7 +78,9 @@ func Main(args []string) int { if formatter == nil { return cli.NewExitError("Unknown format", 1) } - listZones(formatter) + ctx, cancel := theContext(c) + defer cancel() + listZones(ctx, formatter) return nil }, }, @@ -103,7 +119,9 @@ func Main(args []string) int { cli.ShowCommandHelp(c, "create") return cli.NewExitError("Expected exactly 1 parameter", 1) } - createZone(c.Args().First(), c.String("comment"), c.String("vpc-id"), c.String("vpc-region"), c.String("delegation-set-id")) + ctx, cancel := theContext(c) + defer cancel() + createZone(ctx, c.Args().First(), c.String("comment"), c.String("vpc-id"), c.String("vpc-region"), c.String("delegation-set-id")) return nil }, }, @@ -127,7 +145,9 @@ func Main(args []string) int { return cli.NewExitError("Expected exactly 1 parameter", 1) } domain := c.Args().First() - deleteZone(domain, c.Bool("purge")) + ctx, cancel := theContext(c) + defer cancel() + deleteZone(ctx, domain, c.Bool("purge")) return nil }, }, @@ -209,7 +229,9 @@ func Main(args []string) int { upsert: c.Bool("upsert"), dryrun: c.Bool("dry-run"), } - importBind(args) + ctx, cancel := theContext(c) + defer cancel() + importBind(ctx, args) return nil }, }, @@ -283,7 +305,9 @@ func Main(args []string) int { aRecord: c.Bool("a-record"), dryRun: c.Bool("dry-run"), } - instances(args, config) + ctx, cancel := theContext(c) + defer cancel() + instances(ctx, args, config) return nil }, }, @@ -321,7 +345,9 @@ func Main(args []string) int { } defer writer.Close() } - exportBind(c.Args().First(), c.Bool("full"), writer) + ctx, cancel := theContext(c) + defer cancel() + exportBind(ctx, c.Args().First(), c.Bool("full"), writer) return nil }, }, @@ -410,11 +436,12 @@ func Main(args []string) int { subdivisionCode: c.String("subdivision-code"), multivalue: c.Bool("multivalue"), } - if args.validate() { - createRecords(args) - } else { + if !args.validate() { return cli.NewExitError("Validation error", 1) } + ctx, cancel := theContext(c) + defer cancel() + createRecords(ctx, args) return nil }, }, @@ -443,7 +470,9 @@ func Main(args []string) int { cli.ShowCommandHelp(c, "rrdelete") return cli.NewExitError("Expected exactly 3 parameters", 1) } - deleteRecord(c.Args().Get(0), c.Args().Get(1), c.Args().Get(2), c.Bool("wait"), c.String("identifier")) + ctx, cancel := theContext(c) + defer cancel() + deleteRecord(ctx, c.Args().Get(0), c.Args().Get(1), c.Args().Get(2), c.Bool("wait"), c.String("identifier")) return nil }, }, @@ -473,7 +502,9 @@ func Main(args []string) int { if !c.Bool("confirm") { return cli.NewExitError("You must --confirm this action", 1) } - purgeRecords(c.Args().First(), c.Bool("wait")) + ctx, cancel := theContext(c) + defer cancel() + purgeRecords(ctx, c.Args().First(), c.Bool("wait")) return nil }, }, @@ -486,7 +517,9 @@ func Main(args []string) int { if err != nil { return err } - listReusableDelegationSets() + ctx, cancel := theContext(c) + defer cancel() + listReusableDelegationSets(ctx) return nil }, }, @@ -505,7 +538,9 @@ func Main(args []string) int { if err != nil { return err } - createReusableDelegationSet(c.String("zone-id")) + ctx, cancel := theContext(c) + defer cancel() + createReusableDelegationSet(ctx, c.String("zone-id")) return nil }, }, @@ -523,7 +558,9 @@ func Main(args []string) int { cli.ShowCommandHelp(c, "dsdelete") return cli.NewExitError("Expected exactly 1 parameter", 1) } - deleteReusableDelegationSet(c.Args().First()) + ctx, cancel := theContext(c) + defer cancel() + deleteReusableDelegationSet(ctx, c.Args().First()) return nil }, }, diff --git a/util.go b/util.go index b5df9ba0..44815223 100644 --- a/util.go +++ b/util.go @@ -1,6 +1,7 @@ package cli53 import ( + "context" "fmt" "math/rand" "os" @@ -114,7 +115,7 @@ func isZoneId(s string) bool { return reZoneId.MatchString(s) } -func lookupZone(nameOrId string) *route53.HostedZone { +func lookupZone(ctx context.Context, nameOrId string) *route53.HostedZone { if isZoneId(nameOrId) { // lookup by id id := nameOrId @@ -124,7 +125,7 @@ func lookupZone(nameOrId string) *route53.HostedZone { req := route53.GetHostedZoneInput{ Id: aws.String(id), } - resp, err := r53.GetHostedZone(&req) + resp, err := r53.GetHostedZoneWithContext(ctx, &req) if err, ok := err.(awserr.Error); ok && err.Code() == "NoSuchHostedZone" { errorAndExit(fmt.Sprintf("Zone '%s' not found", nameOrId)) } @@ -136,7 +137,7 @@ func lookupZone(nameOrId string) *route53.HostedZone { req := route53.ListHostedZonesByNameInput{ DNSName: aws.String(nameOrId), } - resp, err := r53.ListHostedZonesByName(&req) + resp, err := r53.ListHostedZonesByNameWithContext(ctx, &req) fatalIfErr(err) for _, zone := range resp.HostedZones { if zoneName(*zone.Name) == zoneName(nameOrId) { @@ -155,11 +156,11 @@ func lookupZone(nameOrId string) *route53.HostedZone { return nil } -func waitForChange(change *route53.ChangeInfo) { +func waitForChange(ctx context.Context, change *route53.ChangeInfo) { fmt.Printf("Waiting for sync") for { req := route53.GetChangeInput{Id: change.Id} - resp, err := r53.GetChange(&req) + resp, err := r53.GetChangeWithContext(ctx, &req) fatalIfErr(err) if *resp.ChangeInfo.Status == "INSYNC" { fmt.Println("\nCompleted")