diff --git a/cmd/explode.go b/cmd/explode.go index d78c27d..b190c75 100644 --- a/cmd/explode.go +++ b/cmd/explode.go @@ -6,51 +6,73 @@ import ( "strings" ) +const ( + porterrmsg = "Invalid port specification" +) + +func dashSplit(sp string, ports *[]int) error { + dp := strings.Split(sp, "-") + if len(dp) != 2 { + return errors.New(porterrmsg) + } + start, err := strconv.Atoi(dp[0]) + if err != nil { + return errors.New(porterrmsg) + } + end, err := strconv.Atoi(dp[1]) + if err != nil { + return errors.New(porterrmsg) + } + if start > end || start < 1 || end > 65535 { + return errors.New(porterrmsg) + } + for ; start <= end; start++ { + *ports = append(*ports, start) + } + return nil +} + +func convertAndAddPort(p string, ports *[]int) error { + i, err := strconv.Atoi(p) + if err != nil { + return errors.New(porterrmsg) + } + if i < 1 || i > 65535 { + return errors.New(porterrmsg) + } + *ports = append(*ports, i) + return nil +} + // Turns a string of ports separated by '-' or ',' and returns a slice of Ints. func explode(s string) ([]int, error) { - const errmsg = "Invalid port specification" ports := []int{} - switch { - case strings.Contains(s, "-"): - sp := strings.Split(s, "-") - if len(sp) != 2 { - return ports, errors.New(errmsg) - } - start, err := strconv.Atoi(sp[0]) - if err != nil { - return ports, errors.New(errmsg) - } - end, err := strconv.Atoi(sp[1]) - if err != nil { - return ports, errors.New(errmsg) - } - if start > end || start < 1 || end > 65535 { - return ports, errors.New(errmsg) - } - for ; start <= end; start++ { - ports = append(ports, start) - } - case strings.Contains(s, ","): + if strings.Contains(s, ",") && strings.Contains(s, "-") { sp := strings.Split(s, ",") for _, p := range sp { - i, err := strconv.Atoi(p) - if err != nil { - return ports, errors.New(errmsg) - } - if i < 1 || i > 65535 { - return ports, errors.New(errmsg) + if strings.Contains(p, "-") { + if err := dashSplit(p, &ports); err != nil { + return ports, err + } + } else { + if err := convertAndAddPort(p, &ports); err != nil { + return ports, err + } } - ports = append(ports, i) } - default: - i, err := strconv.Atoi(s) - if err != nil { - return ports, errors.New(errmsg) + } else if strings.Contains(s, ",") { + sp := strings.Split(s, ",") + for _, p := range sp { + convertAndAddPort(p, &ports) + } + } else if strings.Contains(s, "-") { + if err := dashSplit(s, &ports); err != nil { + return ports, err } - if i < 1 || i > 65535 { - return ports, errors.New(errmsg) + } else { + if err := convertAndAddPort(s, &ports); err != nil { + return ports, err } - ports = append(ports, i) } return ports, nil } diff --git a/cmd/explode_test.go b/cmd/explode_test.go new file mode 100644 index 0000000..9150c44 --- /dev/null +++ b/cmd/explode_test.go @@ -0,0 +1,43 @@ +package main + +import "testing" + +const ( + dashRange = "1-20" + singlePort = "4444" + dashAndComma = "1,2,3,4,5-10" +) + +func TestDashSplit(t *testing.T) { + ports := []int{} + err := dashSplit(dashRange, &ports) + if err != nil { + t.Error(err) + } + expected := 20 + if len(ports) != expected { + t.Errorf("Expected length of %d and got %d\n", expected, len(ports)) + } +} + +func TestConvertAndAddPort(t *testing.T) { + ports := []int{} + err := convertAndAddPort(singlePort, &ports) + if err != nil { + t.Error(err) + } + if ports[0] != 4444 { + t.Error("Expected 4444 and got", ports[0]) + } +} + +func TestExplode(t *testing.T) { + ports, err := explode(dashAndComma) + if err != nil { + t.Error(err) + } + expected := 10 + if len(ports) != expected { + t.Errorf("Expexted length of %d and got %d\n", expected, len(ports)) + } +} diff --git a/cmd/parse.go b/cmd/parse.go index a4747ef..f5066ad 100644 --- a/cmd/parse.go +++ b/cmd/parse.go @@ -48,7 +48,7 @@ type O struct { } func parse() *O { - args, err := docopt.Parse(usage, nil, true, "cookiescan 2.1.0", false) + args, err := docopt.Parse(usage, nil, true, "cookiescan 2.2.0", false) if err != nil { log.Fatalf("Error parsing usage. Error: %s\n", err.Error()) }