diff --git a/cmd/main.go b/cmd/main.go index d01c2a4..7683f5a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -28,8 +28,7 @@ func main() { client := git.NewGitClient() model := gpt.NewGPTModel() reader := func() (string, error) { - reader := bufio.NewReader(os.Stdin) - input, err := reader.ReadString('\n') + input, err := bufio.NewReader(os.Stdin).ReadString('\n') return strings.TrimSpace(input), err } options := commands.GenerateOptions{Ctx: ctx, Client: client, Model: model} diff --git a/cmd/main_test.go b/cmd/main_test.go index fa9a2c1..7830529 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -1,55 +1,55 @@ package main import ( - "bytes" - "io" - "os" - "testing" + "bytes" + "io" + "os" + "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) func TestMainRun(t *testing.T) { - origArgs := os.Args - defer func() { os.Args = origArgs }() - - tests := []struct { - name string - args []string - output string - }{ - { - name: "Help command", - args: []string{"--help"}, - output: "Usage:", - }, - { - name: "Version command", - args: []string{"--version"}, - output: "cmt 0.2.0\n", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - os.Args = append([]string{"cmd"}, tt.args...) - - r, w, _ := os.Pipe() - origStdout := os.Stdout - os.Stdout = w - defer func() { os.Stdout = origStdout }() - - main() - - w.Close() - var buf bytes.Buffer - _, err := io.Copy(&buf, r) - if err != nil { - return - } - result := buf.String() - - assert.Contains(t, result, tt.output) - }) - } + origArgs := os.Args + defer func() { os.Args = origArgs }() + + tests := []struct { + name string + args []string + output string + }{ + { + name: "Help command", + args: []string{"--help"}, + output: "Usage:", + }, + { + name: "Version command", + args: []string{"--version"}, + output: "cmt 0.3.0\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Args = append([]string{"cmd"}, tt.args...) + + r, w, _ := os.Pipe() + origStdout := os.Stdout + os.Stdout = w + defer func() { os.Stdout = origStdout }() + + main() + + w.Close() + var buf bytes.Buffer + _, err := io.Copy(&buf, r) + if err != nil { + return + } + result := buf.String() + + assert.Contains(t, result, tt.output) + }) + } } diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 4c9a7c4..fb38596 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -4,7 +4,7 @@ import ( "fmt" ) -const VERSION = "0.2.0" +const VERSION = "0.3.0" func Help() { fmt.Println("Usage:") diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 99ee92c..2041d42 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -49,7 +49,7 @@ func TestVersion(t *testing.T) { }{ { name: "Version", - output: "cmt 0.2.0\n", + output: "cmt 0.3.0\n", }, }