Skip to content

Commit

Permalink
Allow workflow variables to be set using flags to daisy (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
adjackura authored Aug 7, 2017
1 parent 6576179 commit 0b82e5b
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 33 deletions.
12 changes: 12 additions & 0 deletions daisy/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package daisy
import (
"fmt"
"math/rand"
"os"
"os/user"
"path"
"reflect"
"regexp"
Expand Down Expand Up @@ -51,6 +53,16 @@ var (
gcsAPIBase = "https://storage.cloud.google.com"
)

func getUser() string {
if cu, err := user.Current(); err == nil {
return cu.Username
}
if hn, err := os.Hostname(); err == nil {
return hn
}
return "unknown"
}

func namedSubexp(re *regexp.Regexp, s string) map[string]string {
match := re.FindStringSubmatch(s)
if match == nil {
Expand Down
1 change: 1 addition & 0 deletions daisy/compute/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func NewClient(ctx context.Context, opts ...option.ClientOption) (Client, error)
}
c := &client{hc: hc, raw: rawService}
c.i = c

return c, nil
}

Expand Down
59 changes: 49 additions & 10 deletions daisy/daisy/daisy.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,29 @@ var (
se = flag.String("storage_endpoint_override", "", "API endpoint to override default")
)

func splitVariables(input string) map[string]string {
const (
flgDefValue = "flag generated for workflow variable"
varFlagPrefix = "var:"
)

func populateVars(input string) map[string]string {
varMap := map[string]string{}
if input == "" {
return varMap
}
for _, v := range strings.Split(input, ",") {
i := strings.Index(v, "=")
if i == -1 {
continue
if input != "" {
for _, v := range strings.Split(input, ",") {
i := strings.Index(v, "=")
if i == -1 {
continue
}
varMap[v[:i]] = v[i+1:]
}
varMap[v[:i]] = v[i+1:]
}

flag.Visit(func(flg *flag.Flag) {
if strings.HasPrefix(flg.Name, varFlagPrefix) {
varMap[strings.TrimPrefix(flg.Name, varFlagPrefix)] = flg.Value.String()
}
})

return varMap
}

Expand Down Expand Up @@ -108,15 +119,43 @@ func parseWorkflow(ctx context.Context, path string, varMap map[string]string, p
return w, nil
}

func addFlags(args []string) {
for _, arg := range args {
if len(arg) <= 1 || arg[0] != '-' {
continue
}

name := arg[1:]
if name[0] == '-' {
name = name[1:]
}

if !strings.HasPrefix(name, varFlagPrefix) {
continue
}

name = strings.SplitN(name, "=", 2)[0]

if flag.Lookup(name) != nil {
continue
}

flag.String(name, "", flgDefValue)
}
}

func main() {
addFlags(os.Args[1:])
flag.Parse()

if len(flag.Args()) == 0 {
log.Fatal("Not enough args, first arg needs to be the path to a workflow.")
}
ctx := context.Background()

var ws []*daisy.Workflow
varMap := splitVariables(*variables)
varMap := populateVars(*variables)

for _, path := range flag.Args() {
w, err := parseWorkflow(ctx, path, varMap, *project, *zone, *gcsPath, *oauth, *ce, *se)
if err != nil {
Expand Down
45 changes: 39 additions & 6 deletions daisy/daisy/daisy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,63 @@ package main

import (
"context"
"flag"
"fmt"
"reflect"
"runtime"
"testing"
)

func TestSplitVariables(t *testing.T) {
func TestPopulateVars(t *testing.T) {
var tests = []struct {
input string
want map[string]string
}{
{"", map[string]string{}},
{",", map[string]string{}},
{"key=var", map[string]string{"key": "var"}},
{"key1=var1,key2=var2", map[string]string{"key1": "var1", "key2": "var2"}},
{"", map[string]string{"test1": "value"}},
{",", map[string]string{"test1": "value"}},
{"key=var", map[string]string{"test1": "value", "key": "var"}},
{"key1=var1,key2=var2", map[string]string{"test1": "value", "key1": "var1", "key2": "var2"}},
}

// Add a generated var flag.
flag.String("var:test1", "", "")
flag.CommandLine.Parse([]string{"-var:test1", "value"})

for _, tt := range tests {
got := splitVariables(tt.input)
got := populateVars(tt.input)
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("splitVariables did not split %q as expected, want: %q, got: %q", tt.input, tt.want, got)
}
}
}

func TestAddFlags(t *testing.T) {
firstFlag := "var:first_var"
secondFlag := "var:second_var"
value := "value"

flag.Bool("var:test2", false, "")
flag.CommandLine.Parse([]string{"-validate", "-var:test2"})

args := []string{"-validate", "-var:test2", "-" + firstFlag, value, fmt.Sprintf("--%s=%s", secondFlag, value), "var:not_a_flag", "also_not_a_flag"}
before := flag.NFlag()
addFlags(args)
flag.CommandLine.Parse(args)
after := flag.NFlag()

want := before + 2
if after != want {
t.Errorf("number of flags after does not match expectation, want %d, got %d", want, after)
}

for _, fn := range []string{firstFlag, secondFlag} {
got := flag.Lookup(fn).Value.String()
if got != value {
t.Errorf("flag %q value %q!=%q", fn, got, value)
}
}
}

func TestParseWorkflows(t *testing.T) {
path := "../test_data/test.wf.json"
varMap := map[string]string{"key1": "var1", "key2": "var2"}
Expand Down
8 changes: 1 addition & 7 deletions daisy/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"io/ioutil"
"log"
"os"
"os/user"
"path"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -306,12 +305,7 @@ func (w *Workflow) populate(ctx context.Context) error {

w.id = randString(5)
now := time.Now().UTC()
cu, err := user.Current()
if err != nil {
w.username = "unknown"
} else {
w.username = cu.Username
}
w.username = getUser()

cwd, _ := os.Getwd()

Expand Down
14 changes: 4 additions & 10 deletions daisy/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"net/http"
"net/http/httptest"
"os"
"os/user"
"path/filepath"
"regexp"
"runtime"
Expand Down Expand Up @@ -490,11 +489,6 @@ func TestPopulate(t *testing.T) {
t.Fatalf("error creating temp file: %v", err)
}

cu, err := user.Current()
if err != nil {
t.Fatal(err)
}

called := false
var stepPopErr error
stepPop := func(ctx context.Context, s *Step) error {
Expand Down Expand Up @@ -555,7 +549,7 @@ func TestPopulate(t *testing.T) {
sourcesPath: fmt.Sprintf("%s/sources", got.scratchPath),
logsPath: fmt.Sprintf("%s/logs", got.scratchPath),
outsPath: fmt.Sprintf("%s/outs", got.scratchPath),
username: cu.Username,
username: got.username,
Steps: map[string]*Step{
"wf-name-step1": {
name: "wf-name-step1",
Expand Down Expand Up @@ -867,12 +861,12 @@ func TestWrite(t *testing.T) {
func TestRunStepTimeout(t *testing.T) {
w := testWorkflow()
s, _ := w.NewStep("test")
s.timeout = 1 * time.Microsecond
s.timeout = 1 * time.Nanosecond
s.testType = &mockStep{runImpl: func(ctx context.Context, s *Step) error {
time.Sleep(1 * time.Millisecond)
time.Sleep(1 * time.Second)
return nil
}}
want := `step "test" did not stop in specified timeout of 1µs`
want := `step "test" did not stop in specified timeout of 1ns`
if err := w.runStep(context.Background(), s); err == nil || err.Error() != want {
t.Errorf("did not get expected error, got: %q, want: %q", err.Error(), want)
}
Expand Down

0 comments on commit 0b82e5b

Please sign in to comment.