From 5732b71150ba432e6bd9e5b47766eb02360f9982 Mon Sep 17 00:00:00 2001 From: Sergiy Kulanov Date: Sat, 7 Oct 2023 13:23:07 +0300 Subject: [PATCH] refactor: Add output format validator Signed-off-by: Sergiy Kulanov --- pkg/cli/prerun/prerun.go | 25 +++++++++++ pkg/cli/prerun/prerun_test.go | 33 +++++++++++++++ pkg/cmd/pipeline/graph.go | 4 ++ pkg/cmd/pipelinerun/graph.go | 4 ++ pkg/taskgraph/taskgraph_test.go | 74 +++++++++++++++++++++++++++++++++ 5 files changed, 140 insertions(+) create mode 100644 pkg/cli/prerun/prerun.go create mode 100644 pkg/cli/prerun/prerun_test.go diff --git a/pkg/cli/prerun/prerun.go b/pkg/cli/prerun/prerun.go new file mode 100644 index 0000000..64ae941 --- /dev/null +++ b/pkg/cli/prerun/prerun.go @@ -0,0 +1,25 @@ +package prerun + +import ( + "fmt" +) + +// Define the allowed output formats +var ValidOutputFormats = []string{"dot", "puml", "mmd"} + +func ValidateGraphPreRunE(outputFormat string) error { + if !contains(ValidOutputFormats, outputFormat) { + return fmt.Errorf("Invalid output format: %s. Allowed formats are: %v", outputFormat, ValidOutputFormats) + } + return nil +} + +// Helper function to check if a string is in a slice of strings +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/pkg/cli/prerun/prerun_test.go b/pkg/cli/prerun/prerun_test.go new file mode 100644 index 0000000..28718a5 --- /dev/null +++ b/pkg/cli/prerun/prerun_test.go @@ -0,0 +1,33 @@ +package prerun + +import ( + "testing" +) + +func TestValidateGraphPreRunE(t *testing.T) { + testCases := []struct { + name string + outputFormat string + wantErr bool + }{ + { + name: "Invalid output format", + outputFormat: "invalid", + wantErr: true, + }, + { + name: "Valid output format", + outputFormat: "dot", + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateGraphPreRunE(tc.outputFormat) + if (err != nil) != tc.wantErr { + t.Errorf("ValidateGraphPreRunE() error = %v, wantErr %v", err, tc.wantErr) + } + }) + } +} diff --git a/pkg/cmd/pipeline/graph.go b/pkg/cmd/pipeline/graph.go index f1d0634..aa88a0a 100644 --- a/pkg/cmd/pipeline/graph.go +++ b/pkg/cmd/pipeline/graph.go @@ -3,6 +3,7 @@ package pipeline import ( "fmt" + "github.com/sergk/tkn-graph/pkg/cli/prerun" pipelinepkg "github.com/sergk/tkn-graph/pkg/pipeline" "github.com/sergk/tkn-graph/pkg/taskgraph" "github.com/spf13/cobra" @@ -35,6 +36,9 @@ func graphCommand(p cli.Params) *cobra.Command { } return nil }, + PreRunE: func(cmd *cobra.Command, args []string) error { + return prerun.ValidateGraphPreRunE(opts.OutputFormat) + }, RunE: func(cmd *cobra.Command, args []string) error { cs, err := p.Clients() if err != nil { diff --git a/pkg/cmd/pipelinerun/graph.go b/pkg/cmd/pipelinerun/graph.go index 64fff5d..62d0ca9 100644 --- a/pkg/cmd/pipelinerun/graph.go +++ b/pkg/cmd/pipelinerun/graph.go @@ -3,6 +3,7 @@ package pipelinerun import ( "fmt" + "github.com/sergk/tkn-graph/pkg/cli/prerun" pipelinerunpkg "github.com/sergk/tkn-graph/pkg/pipelinerun" "github.com/sergk/tkn-graph/pkg/taskgraph" "github.com/spf13/cobra" @@ -35,6 +36,9 @@ func graphCommand(p cli.Params) *cobra.Command { } return nil }, + PreRunE: func(cmd *cobra.Command, args []string) error { + return prerun.ValidateGraphPreRunE(opts.OutputFormat) + }, RunE: func(cmd *cobra.Command, args []string) error { cs, err := p.Clients() if err != nil { diff --git a/pkg/taskgraph/taskgraph_test.go b/pkg/taskgraph/taskgraph_test.go index 0a61faa..7a090cc 100644 --- a/pkg/taskgraph/taskgraph_test.go +++ b/pkg/taskgraph/taskgraph_test.go @@ -1,6 +1,8 @@ package taskgraph import ( + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -255,3 +257,75 @@ func TestPrintAllGraphsWithUnsupportedFormat(t *testing.T) { // contains error message assert.Contains(t, err.Error(), "Invalid output format: FAIL") } + +func TestWriteAllGraphs(t *testing.T) { + // Create a temporary directory for testing + tempDir, err := os.MkdirTemp("", "test-output") + if err != nil { + t.Fatalf("Failed to create temporary directory: %v", err) + } + + defer func() { + // Remove the temporary directory and check for errors + if err = os.RemoveAll(tempDir); err != nil { + t.Errorf("Failed to remove temporary directory: %v", err) + } + }() + + // Create a test graph + testGraph := &TaskGraph{ + PipelineName: "test-pipeline", + Nodes: map[string]*TaskNode{ + "task1": { + Name: "task1", + TaskRefName: "taskRef1", + Dependencies: []*TaskNode{ + { + Name: "task2", + TaskRefName: "taskRef2", + }, + }, + }, + "task2": { + Name: "task2", + TaskRefName: "taskRef2", + Dependencies: []*TaskNode{ + { + Name: "task3", + TaskRefName: "taskRef3", + }, + }, + }, + "task3": { + Name: "task3", + TaskRefName: "taskRef3", + Dependencies: []*TaskNode{ + { + Name: "task4", + TaskRefName: "taskRef4", + }, + }, + }, + "task4": { + Name: "task4", + TaskRefName: "taskRef4", + }, + }, + } + + // Write the test graph to all supported formats + err = WriteAllGraphs([]*TaskGraph{testGraph}, "dot", tempDir, false) + assert.NoError(t, err) + err = WriteAllGraphs([]*TaskGraph{testGraph}, "puml", tempDir, false) + assert.NoError(t, err) + err = WriteAllGraphs([]*TaskGraph{testGraph}, "mmd", tempDir, false) + assert.NoError(t, err) + + // Check that the files were created + _, err = os.Stat(filepath.Join(tempDir, "test-pipeline.dot")) + assert.NoError(t, err) + _, err = os.Stat(filepath.Join(tempDir, "test-pipeline.puml")) + assert.NoError(t, err) + _, err = os.Stat(filepath.Join(tempDir, "test-pipeline.mmd")) + assert.NoError(t, err) +}