diff --git a/main.go b/main.go index af0a5f7..3ccb67e 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "math/rand" + chclient "github.com/jpillora/chisel/client" "github.com/jpillora/chisel/share/cos" "github.com/jpillora/chisel/share/settings" @@ -140,12 +142,15 @@ func client(args []string) { if len(args) < 2 { log.Fatalf("A server and least one remote is required") } - config.Server = args[0] - if err := validateRemotes(args[1:]); err != nil { + localPorts, err := validateRemotes(args[1:]) + if err != nil { log.Fatal(err) } + queryParams := generateQueryParameters(localPorts) + + config.Server = fmt.Sprintf("%s%s", args[0], queryParams) config.Remotes = args[1:] //default auth @@ -175,33 +180,39 @@ func client(args []string) { } } +func generateQueryParameters(localPorts string) string { + return fmt.Sprintf("?id=%v&ports=%v", rand.Intn(999999999-100000000)+100000000, localPorts) +} + // validate the provided Remotes configuration is valid -func validateRemotes(remotes []string) error { +func validateRemotes(remotes []string) (string, error) { uniqueRemotes := []string{} + localPorts := []string{} for _, newRemote := range remotes { + remote, err := settings.DecodeRemote(newRemote) + if err != nil { + return "", fmt.Errorf("failed to decode remote '%s': %s", newRemote, err) + } + // iterate all remotes already in the unique list, if duplicate is found return error for _, unique := range uniqueRemotes { - firstRemote, err := settings.DecodeRemote(unique) + validatedRemote, err := settings.DecodeRemote(unique) if err != nil { - return fmt.Errorf("failed to decode remote '%s': %s", unique, err) + return "", fmt.Errorf("failed to decode remote '%s': %s", unique, err) } - secondRemote, err := settings.DecodeRemote(newRemote) - if err != nil { - return fmt.Errorf("failed to decode remote '%s': %s", newRemote, err) - } - - if isDuplicatedRemote(firstRemote, secondRemote) { - return fmt.Errorf("invalid Remote configuration: local port '%s' is duplicated", secondRemote.LocalPort) + if isDuplicatedRemote(validatedRemote, remote) { + return "", fmt.Errorf("invalid Remote configuration: local port '%s' is duplicated", remote.LocalPort) } } uniqueRemotes = append(uniqueRemotes, newRemote) + localPorts = append(localPorts, remote.LocalPort) } - return nil + return strings.Join(localPorts, ","), nil } func isDuplicatedRemote(first, second *settings.Remote) bool { diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..c635603 --- /dev/null +++ b/main_test.go @@ -0,0 +1,37 @@ +package main + +import "testing" + +func Test_validateRemotes(t *testing.T) { + + tests := []struct { + name string + remotes []string + want string + wantErr bool + }{ + { + name: "success", + remotes: []string{"R:15800:localhost:7000", "R:15801:localhost:7001"}, + want: "15800,15801", + wantErr: false, + }, + { + name: "error", + remotes: []string{"R:15800:localhost:7000", "R:15800:localhost:7001"}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := validateRemotes(tt.remotes) + if (err != nil) != tt.wantErr { + t.Errorf("validateRemotes() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("validateRemotes() = %v, want %v", got, tt.want) + } + }) + } +}