diff --git a/daisy/common.go b/daisy/common.go index d321cb668..f005f8181 100644 --- a/daisy/common.go +++ b/daisy/common.go @@ -17,6 +17,8 @@ package daisy import ( "fmt" "math/rand" + "os" + "os/user" "path" "reflect" "regexp" @@ -51,6 +53,16 @@ var ( gcsAPIBase = "https://storage.cloud.google.com" ) +func getUser() string { + if cu, err := user.Current(); err == nil { + return cu.Username + } + if hn, err := os.Hostname(); err == nil { + return hn + } + return "unknown" +} + func namedSubexp(re *regexp.Regexp, s string) map[string]string { match := re.FindStringSubmatch(s) if match == nil { diff --git a/daisy/compute/compute.go b/daisy/compute/compute.go index 16f1d741a..24e03bc24 100644 --- a/daisy/compute/compute.go +++ b/daisy/compute/compute.go @@ -70,6 +70,7 @@ func NewClient(ctx context.Context, opts ...option.ClientOption) (Client, error) } c := &client{hc: hc, raw: rawService} c.i = c + return c, nil } diff --git a/daisy/daisy/daisy.go b/daisy/daisy/daisy.go index 78fcb4c57..390a069af 100644 --- a/daisy/daisy/daisy.go +++ b/daisy/daisy/daisy.go @@ -44,18 +44,29 @@ var ( se = flag.String("storage_endpoint_override", "", "API endpoint to override default") ) -func splitVariables(input string) map[string]string { +const ( + flgDefValue = "flag generated for workflow variable" + varFlagPrefix = "var:" +) + +func populateVars(input string) map[string]string { varMap := map[string]string{} - if input == "" { - return varMap - } - for _, v := range strings.Split(input, ",") { - i := strings.Index(v, "=") - if i == -1 { - continue + if input != "" { + for _, v := range strings.Split(input, ",") { + i := strings.Index(v, "=") + if i == -1 { + continue + } + varMap[v[:i]] = v[i+1:] } - varMap[v[:i]] = v[i+1:] } + + flag.Visit(func(flg *flag.Flag) { + if strings.HasPrefix(flg.Name, varFlagPrefix) { + varMap[strings.TrimPrefix(flg.Name, varFlagPrefix)] = flg.Value.String() + } + }) + return varMap } @@ -108,15 +119,43 @@ func parseWorkflow(ctx context.Context, path string, varMap map[string]string, p return w, nil } +func addFlags(args []string) { + for _, arg := range args { + if len(arg) <= 1 || arg[0] != '-' { + continue + } + + name := arg[1:] + if name[0] == '-' { + name = name[1:] + } + + if !strings.HasPrefix(name, varFlagPrefix) { + continue + } + + name = strings.SplitN(name, "=", 2)[0] + + if flag.Lookup(name) != nil { + continue + } + + flag.String(name, "", flgDefValue) + } +} + func main() { + addFlags(os.Args[1:]) flag.Parse() + if len(flag.Args()) == 0 { log.Fatal("Not enough args, first arg needs to be the path to a workflow.") } ctx := context.Background() var ws []*daisy.Workflow - varMap := splitVariables(*variables) + varMap := populateVars(*variables) + for _, path := range flag.Args() { w, err := parseWorkflow(ctx, path, varMap, *project, *zone, *gcsPath, *oauth, *ce, *se) if err != nil { diff --git a/daisy/daisy/daisy_test.go b/daisy/daisy/daisy_test.go index c8c5360c4..586dd4d25 100644 --- a/daisy/daisy/daisy_test.go +++ b/daisy/daisy/daisy_test.go @@ -16,30 +16,63 @@ package main import ( "context" + "flag" + "fmt" "reflect" "runtime" "testing" ) -func TestSplitVariables(t *testing.T) { +func TestPopulateVars(t *testing.T) { var tests = []struct { input string want map[string]string }{ - {"", map[string]string{}}, - {",", map[string]string{}}, - {"key=var", map[string]string{"key": "var"}}, - {"key1=var1,key2=var2", map[string]string{"key1": "var1", "key2": "var2"}}, + {"", map[string]string{"test1": "value"}}, + {",", map[string]string{"test1": "value"}}, + {"key=var", map[string]string{"test1": "value", "key": "var"}}, + {"key1=var1,key2=var2", map[string]string{"test1": "value", "key1": "var1", "key2": "var2"}}, } + // Add a generated var flag. + flag.String("var:test1", "", "") + flag.CommandLine.Parse([]string{"-var:test1", "value"}) + for _, tt := range tests { - got := splitVariables(tt.input) + got := populateVars(tt.input) if !reflect.DeepEqual(tt.want, got) { t.Errorf("splitVariables did not split %q as expected, want: %q, got: %q", tt.input, tt.want, got) } } } +func TestAddFlags(t *testing.T) { + firstFlag := "var:first_var" + secondFlag := "var:second_var" + value := "value" + + flag.Bool("var:test2", false, "") + flag.CommandLine.Parse([]string{"-validate", "-var:test2"}) + + args := []string{"-validate", "-var:test2", "-" + firstFlag, value, fmt.Sprintf("--%s=%s", secondFlag, value), "var:not_a_flag", "also_not_a_flag"} + before := flag.NFlag() + addFlags(args) + flag.CommandLine.Parse(args) + after := flag.NFlag() + + want := before + 2 + if after != want { + t.Errorf("number of flags after does not match expectation, want %d, got %d", want, after) + } + + for _, fn := range []string{firstFlag, secondFlag} { + got := flag.Lookup(fn).Value.String() + if got != value { + t.Errorf("flag %q value %q!=%q", fn, got, value) + } + } +} + func TestParseWorkflows(t *testing.T) { path := "../test_data/test.wf.json" varMap := map[string]string{"key1": "var1", "key2": "var2"} diff --git a/daisy/workflow.go b/daisy/workflow.go index 57b6a069b..ea4575752 100644 --- a/daisy/workflow.go +++ b/daisy/workflow.go @@ -25,7 +25,6 @@ import ( "io/ioutil" "log" "os" - "os/user" "path" "path/filepath" "reflect" @@ -306,12 +305,7 @@ func (w *Workflow) populate(ctx context.Context) error { w.id = randString(5) now := time.Now().UTC() - cu, err := user.Current() - if err != nil { - w.username = "unknown" - } else { - w.username = cu.Username - } + w.username = getUser() cwd, _ := os.Getwd() diff --git a/daisy/workflow_test.go b/daisy/workflow_test.go index 28fd9d425..ef982a296 100644 --- a/daisy/workflow_test.go +++ b/daisy/workflow_test.go @@ -25,7 +25,6 @@ import ( "net/http" "net/http/httptest" "os" - "os/user" "path/filepath" "regexp" "runtime" @@ -490,11 +489,6 @@ func TestPopulate(t *testing.T) { t.Fatalf("error creating temp file: %v", err) } - cu, err := user.Current() - if err != nil { - t.Fatal(err) - } - called := false var stepPopErr error stepPop := func(ctx context.Context, s *Step) error { @@ -555,7 +549,7 @@ func TestPopulate(t *testing.T) { sourcesPath: fmt.Sprintf("%s/sources", got.scratchPath), logsPath: fmt.Sprintf("%s/logs", got.scratchPath), outsPath: fmt.Sprintf("%s/outs", got.scratchPath), - username: cu.Username, + username: got.username, Steps: map[string]*Step{ "wf-name-step1": { name: "wf-name-step1", @@ -867,12 +861,12 @@ func TestWrite(t *testing.T) { func TestRunStepTimeout(t *testing.T) { w := testWorkflow() s, _ := w.NewStep("test") - s.timeout = 1 * time.Microsecond + s.timeout = 1 * time.Nanosecond s.testType = &mockStep{runImpl: func(ctx context.Context, s *Step) error { - time.Sleep(1 * time.Millisecond) + time.Sleep(1 * time.Second) return nil }} - want := `step "test" did not stop in specified timeout of 1µs` + want := `step "test" did not stop in specified timeout of 1ns` if err := w.runStep(context.Background(), s); err == nil || err.Error() != want { t.Errorf("did not get expected error, got: %q, want: %q", err.Error(), want) }