Skip to content

Commit

Permalink
run: Allow mocking by passing mock function through context
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidGamba committed Feb 1, 2024
1 parent e6d6166 commit 6dc3d3a
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 39 deletions.
45 changes: 44 additions & 1 deletion run/README.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ Import: `github.com/DavidGamba/dgtools/run`
defer cancel()
out, err := run.CMD("./command", "arg1", "arg2").Ctx(ctx).CombinedOutput()
----
+
Or:
+
[source, go]
----
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
out, err := run.CMDCtx(ctx, "./command", "arg1", "arg2").CombinedOutput()
----

.Run a command and pass a custom io.Writer to run:
[source, go]
Expand Down Expand Up @@ -107,11 +116,45 @@ Import: `github.com/DavidGamba/dgtools/run`
log.Printf("Failed with exit code: %d, full error output: %s\n", exitErr.ExitCode(), string(errOutput))
----

== Testing

A mocking function can be stored in the context and retrieved automatically:

. Store the mock function in the context:
+
[source, go]
----
ctx := context.Background()
mockR := run.CMD().Mock(func(r *run.RunInfo) error {
r.Stdout.Write([]byte("hello world\n"))
r.Stderr.Write([]byte("hola mundo\n"))
return nil
})
ctx = run.ContextWithRunInfo(ctx, mockR)
----

. Automatically run the mock function if it exists in the context:
+
[source, go]
----
r := run.CMDCtx(ctx, "ls", "./run")
out, err := r.CombinedOutput()
if err != nil {
t.Errorf("unexpected error")
}
if string(out) != "hello world\nhola mundo\n" {
t.Errorf("wrong output: %s\n", out)
}
----

NOTE: Must use `run.CMDCtx` to automatically run the mock function if it exists in the context.
If the function doesn't exist it runs the command as usual.

== LICENSE

This file is part of run.

Copyright (C) 2020-2021 David Gamba Rios
Copyright (C) 2020-2024 David Gamba Rios

This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
Expand Down
2 changes: 1 addition & 1 deletion run/go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/DavidGamba/dgtools/run

go 1.14
go 1.18
114 changes: 77 additions & 37 deletions run/run.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of run.
//
// Copyright (C) 2020-2021 David Gamba Rios
// Copyright (C) 2020-2024 David Gamba Rios
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand All @@ -27,28 +27,49 @@ var osStdout io.Writer = os.Stdout
var osStderr io.Writer = os.Stderr

type RunInfo struct {
cmd []string
Cmd []string // exposed for mocking purposes only
debug bool
env []string
dir string
stdout io.Writer
stderr io.Writer
Stdout io.Writer // exposed for mocking purposes only
Stderr io.Writer // exposed for mocking purposes only
stdin io.Reader
saveErr bool
printErr bool
ctx context.Context
mockFn MockFn
}

type runInfoContextKey string

func ContextWithRunInfo(ctx context.Context, value *RunInfo) context.Context {
return context.WithValue(ctx, runInfoContextKey("runInfo"), value)
}

// CMD - Normal constructor.
func CMD(cmd ...string) *RunInfo {
r := &RunInfo{cmd: cmd}
r := &RunInfo{Cmd: cmd}
r.env = os.Environ()
r.stdout = nil
r.stderr = nil
r.Stdout = nil
r.Stderr = nil
r.ctx = context.Background()
r.printErr = true
return r
}

// CMD - Pulls RunInfo from context if it exists and if not it initializes a new one.
// Useful when loading a RunInfo from context to ease testing.
func CMDCtx(ctx context.Context, cmd ...string) *RunInfo {
v, ok := ctx.Value(runInfoContextKey("runInfo")).(*RunInfo)
if ok {
v.Cmd = cmd
return v
}
r := CMD(cmd...)
r.ctx = ctx
return r
}

func (r *RunInfo) Log() *RunInfo {
r.debug = true
return r
Expand Down Expand Up @@ -90,11 +111,11 @@ func (r *RunInfo) Ctx(ctx context.Context) *RunInfo {
//
// Retrieval can be done as shown below:
//
// err := run.CMD("./command", "arg").SaveErr().Run() // or .STDOutOutput() or .CombinedOutput()
// if err != nil {
// var exitErr *exec.ExitError
// if errors.As(err, &exitErr) {
// errOutput := exitErr.Stderr
// err := run.CMD("./command", "arg").SaveErr().Run() // or .STDOutOutput() or .CombinedOutput()
// if err != nil {
// var exitErr *exec.ExitError
// if errors.As(err, &exitErr) {
// errOutput := exitErr.Stderr
func (r *RunInfo) SaveErr() *RunInfo {
r.saveErr = true
return r
Expand All @@ -117,8 +138,8 @@ func (r *RunInfo) DiscardErr() *RunInfo {
// CombinedOutput - Runs given CMD and returns STDOut and STDErr combined.
func (r *RunInfo) CombinedOutput() ([]byte, error) {
var b bytes.Buffer
r.stdout = &b
r.stderr = &b
r.Stdout = &b
r.Stderr = &b
err := r.Run()
return b.Bytes(), err
}
Expand All @@ -128,11 +149,18 @@ func (r *RunInfo) CombinedOutput() ([]byte, error) {
// Stderr output is discarded unless a call to SaveErr() or PrintErr() was made.
func (r *RunInfo) STDOutOutput() ([]byte, error) {
var b bytes.Buffer
r.stdout = &b
r.Stdout = &b
err := r.Run()
return b.Bytes(), err
}

type MockFn func(*RunInfo) error

func (r *RunInfo) Mock(fn MockFn) *RunInfo {
r.mockFn = fn
return r
}

// Run - wrapper around os/exec CMD.Run()
//
// Run starts the specified command and waits for it to complete.
Expand All @@ -145,49 +173,61 @@ func (r *RunInfo) STDOutOutput() ([]byte, error) {
//
// Examples:
//
// Run() // Output goes to os.Stdout and os.Stderr
// Run(out) // Sets the command's os.Stdout and os.Stderr to out.
// Run(out, outErr) // Sets the command's os.Stdout to out and os.Stderr to outErr.
// Run() // Output goes to os.Stdout and os.Stderr
// Run(out) // Sets the command's os.Stdout and os.Stderr to out.
// Run(out, outErr) // Sets the command's os.Stdout to out and os.Stderr to outErr.
func (r *RunInfo) Run(w ...io.Writer) error {
if r.debug {
msg := fmt.Sprintf("run %v", r.cmd)
msg := fmt.Sprintf("run %v", r.Cmd)
if r.dir != "" {
msg += fmt.Sprintf(" on %s", r.dir)
}
Logger.Println(msg)
}
c := exec.CommandContext(r.ctx, r.cmd[0], r.cmd[1:]...)
c.Dir = r.dir
c.Env = r.env
if len(w) == 0 {
if r.stdout == nil {
r.stdout = osStdout
if r.Stdout == nil {
r.Stdout = osStdout
}
c.Stdout = r.stdout
c.Stderr = r.stderr
} else if len(w) == 1 {
c.Stdout = w[0]
c.Stderr = w[0]
r.Stdout = w[0]
r.Stderr = w[0]
} else if len(w) > 1 {
c.Stdout = w[0]
c.Stderr = w[1]
r.Stdout = w[0]
r.Stderr = w[1]
}
if r.printErr {
if c.Stderr == nil {
c.Stderr = osStderr
} else if c.Stderr != osStderr {
c.Stderr = io.MultiWriter(c.Stderr, osStderr)
if r.Stderr == nil {
r.Stderr = osStderr
} else if r.Stderr != osStderr {
r.Stderr = io.MultiWriter(r.Stderr, osStderr)
}
}
var b bytes.Buffer
if r.saveErr {
if c.Stderr == nil {
c.Stderr = &b
if r.Stderr == nil {
r.Stderr = &b
} else {
c.Stderr = io.MultiWriter(c.Stderr, &b)
r.Stderr = io.MultiWriter(r.Stderr, &b)
}
}

if r.mockFn != nil {
err := r.mockFn(r)
if err != nil && r.saveErr {
if exitErr, ok := err.(*exec.ExitError); ok {
exitErr.Stderr = b.Bytes()
}
}
return err
}

c := exec.CommandContext(r.ctx, r.Cmd[0], r.Cmd[1:]...)
c.Dir = r.dir
c.Env = r.env
c.Stdout = r.Stdout
c.Stderr = r.Stderr
c.Stdin = r.stdin

err := c.Run()
if err != nil && r.saveErr {
if exitErr, ok := err.(*exec.ExitError); ok {
Expand Down
85 changes: 85 additions & 0 deletions run/run_ext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package run_test

import (
"context"
"fmt"
"slices"
"testing"

"github.com/DavidGamba/dgtools/run"
)

func TestRunWithMocks(t *testing.T) {
t.Run("mock", func(t *testing.T) {
r := run.CMD("ls", "./run")
r.Mock(func(r *run.RunInfo) error {
r.Stdout.Write([]byte("hello world\n"))
r.Stderr.Write([]byte("hola mundo\n"))
return nil
})
out, err := r.CombinedOutput()
if err != nil {
t.Errorf("unexpected error")
}
if string(out) != "hello world\nhola mundo\n" {
t.Errorf("wrong output: %s\n", out)
}
})

t.Run("mock with context", func(t *testing.T) {
ctx := context.Background()
mockR := run.CMD().Mock(func(r *run.RunInfo) error {
r.Stdout.Write([]byte("hello world\n"))
r.Stderr.Write([]byte("hola mundo\n"))
return nil
})
ctx = run.ContextWithRunInfo(ctx, mockR)

r := run.CMDCtx(ctx, "ls", "./run")
out, err := r.CombinedOutput()
if err != nil {
t.Errorf("unexpected error")
}
if string(out) != "hello world\nhola mundo\n" {
t.Errorf("wrong output: %s\n", out)
}
})

t.Run("mock with context switch", func(t *testing.T) {
ctx := context.Background()
mockR := run.CMD().Mock(func(r *run.RunInfo) error {
cmd := r.Cmd

switch {
case slices.Compare(cmd, []string{"ls", "./run"}) == 0:
r.Stdout.Write([]byte("hello world\n"))
r.Stderr.Write([]byte("hola mundo\n"))
return nil
case slices.Compare(cmd, []string{"ls", "x"}) == 0:
r.Stderr.Write([]byte("not found x\n"))
return fmt.Errorf("not found x")
default:
return fmt.Errorf("unexpected command: %s", cmd)
}
})
ctx = run.ContextWithRunInfo(ctx, mockR)

r := run.CMDCtx(ctx, "ls", "./run")
out, err := r.CombinedOutput()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if string(out) != "hello world\nhola mundo\n" {
t.Errorf("wrong output: %s\n", out)
}

r = run.CMDCtx(ctx, "ls", "x")
out, err = r.CombinedOutput()
if err == nil {
t.Errorf("expected error")
}
if string(out) != "not found x\n" {
t.Errorf("wrong output: %s\n", out)
}
})
}

0 comments on commit 6dc3d3a

Please sign in to comment.