Skip to content

Commit

Permalink
feat: add context for interrupt handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mrcljx committed Sep 19, 2023
1 parent 58e8769 commit 7258254
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 35 deletions.
11 changes: 10 additions & 1 deletion internal/lefthook/run.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package lefthook

import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"path/filepath"
"slices"
"strings"
Expand Down Expand Up @@ -143,8 +145,11 @@ Run 'lefthook install' manually.`,
)
}

ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
defer stop()

go func() {
runner.RunAll(sourceDirs)
runner.RunAll(ctx, sourceDirs)
close(resultChan)
}()

Expand All @@ -153,6 +158,10 @@ Run 'lefthook install' manually.`,
results = append(results, res)
}

if ctx.Err() != nil {
return errors.New("Interrupted")
}

if !logSettings.SkipSummary() {
printSummary(time.Since(startTime), results, logSettings)
}
Expand Down
13 changes: 7 additions & 6 deletions internal/lefthook/run/exec/execute_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package exec

import (
"context"
"fmt"
"io"
"os"
Expand All @@ -27,7 +28,7 @@ type executeArgs struct {
interactive, useStdin bool
}

func (e CommandExecutor) Execute(opts Options, out io.Writer) error {
func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Writer) error {
in := os.Stdin
if opts.Interactive && !isatty.IsTerminal(os.Stdin.Fd()) {
tty, err := os.Open("/dev/tty")
Expand Down Expand Up @@ -60,25 +61,25 @@ func (e CommandExecutor) Execute(opts Options, out io.Writer) error {
// We can have one command split into separate to fit into shell command max length.
// In this case we execute those commands one by one.
for _, command := range opts.Commands {
if err := e.execute(command, args); err != nil {
if err := e.execute(ctx, command, args); err != nil {
return err
}
}

return nil
}

func (e CommandExecutor) RawExecute(command []string, out io.Writer) error {
cmd := exec.Command(command[0], command[1:]...)
func (e CommandExecutor) RawExecute(ctx context.Context, command []string, out io.Writer) error {
cmd := exec.CommandContext(ctx, command[0], command[1:]...)

cmd.Stdout = out
cmd.Stderr = os.Stderr

return cmd.Run()
}

func (e CommandExecutor) execute(cmdstr string, args *executeArgs) error {
command := exec.Command("sh", "-c", cmdstr)
func (e CommandExecutor) execute(ctx context.Context, cmdstr string, args *executeArgs) error {
command := exec.CommandContext(ctx, "sh", "-c", cmdstr)
command.Dir = args.root
command.Env = append(os.Environ(), args.envs...)

Expand Down
5 changes: 3 additions & 2 deletions internal/lefthook/run/exec/execute_windows.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package exec

import (
"context"
"fmt"
"io"
"os"
Expand All @@ -18,7 +19,7 @@ type executeArgs struct {
root string
}

func (e CommandExecutor) Execute(opts Options, out io.Writer) error {
func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Writer) error {
root, _ := filepath.Abs(opts.Root)
envs := make([]string, len(opts.Env))
for name, value := range opts.Env {
Expand All @@ -44,7 +45,7 @@ func (e CommandExecutor) Execute(opts Options, out io.Writer) error {
return nil
}

func (e CommandExecutor) RawExecute(command []string, out io.Writer) error {
func (e CommandExecutor) RawExecute(ctx context.Context, command []string, out io.Writer) error {
cmd := exec.Command(command[0], command[1:]...)

cmd.Stdout = out
Expand Down
5 changes: 3 additions & 2 deletions internal/lefthook/run/exec/executor.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package exec

import (
"context"
"io"
)

Expand All @@ -15,6 +16,6 @@ type Options struct {
// Executor provides an interface for command execution.
// It is used here for testing purpose mostly.
type Executor interface {
Execute(opts Options, out io.Writer) error
RawExecute(command []string, out io.Writer) error
Execute(ctx context.Context, opts Options, out io.Writer) error
RawExecute(ctx context.Context, command []string, out io.Writer) error
}
51 changes: 30 additions & 21 deletions internal/lefthook/run/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package run

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -63,8 +64,8 @@ func NewRunner(opts Options) *Runner {

// RunAll runs scripts and commands.
// LFS hook is executed at first if needed.
func (r *Runner) RunAll(sourceDirs []string) {
if err := r.runLFSHook(); err != nil {
func (r *Runner) RunAll(ctx context.Context, sourceDirs []string) {
if err := r.runLFSHook(ctx); err != nil {
log.Error(err)
}

Expand All @@ -88,10 +89,10 @@ func (r *Runner) RunAll(sourceDirs []string) {
r.preHook()

for _, dir := range scriptDirs {
r.runScripts(dir)
r.runScripts(ctx, dir)
}

r.runCommands()
r.runCommands(ctx)

r.postHook()
}
Expand All @@ -105,7 +106,7 @@ func (r *Runner) success(name string) {
r.ResultChan <- resultSuccess(name)
}

func (r *Runner) runLFSHook() error {
func (r *Runner) runLFSHook(ctx context.Context) error {
if !git.IsLFSHook(r.HookName) {
return nil
}
Expand All @@ -128,6 +129,7 @@ func (r *Runner) runLFSHook() error {
)
out := bytes.NewBuffer(make([]byte, 0))
err := r.executor.RawExecute(
ctx,
append(
[]string{"git", "lfs", r.HookName},
r.GitArgs...,
Expand Down Expand Up @@ -223,7 +225,7 @@ func (r *Runner) postHook() {
}
}

func (r *Runner) runScripts(dir string) {
func (r *Runner) runScripts(ctx context.Context, dir string) {
files, err := afero.ReadDir(r.Repo.Fs, dir) // ReadDir already sorts files by .Name()
if err != nil || len(files) == 0 {
return
Expand All @@ -233,6 +235,10 @@ func (r *Runner) runScripts(dir string) {
var wg sync.WaitGroup

for _, file := range files {
if ctx.Err() != nil {
return
}

script, ok := r.Hook.Scripts[file.Name()]
if !ok {
r.logSkip(file.Name(), "not specified in config file")
Expand All @@ -255,29 +261,32 @@ func (r *Runner) runScripts(dir string) {
wg.Add(1)
go func(script *config.Script, path string, file os.FileInfo) {
defer wg.Done()
r.runScript(script, path, file)
r.runScript(ctx, script, path, file)
}(script, path, file)
} else {
r.runScript(script, path, file)
r.runScript(ctx, script, path, file)
}
}

wg.Wait()

for _, file := range interactiveScripts {
if ctx.Err() != nil {
return
}

script := r.Hook.Scripts[file.Name()]
if r.failed.Load() {
r.logSkip(file.Name(), "non-interactive scripts failed")
continue
}

path := filepath.Join(dir, file.Name())

r.runScript(script, path, file)
r.runScript(ctx, script, path, file)
}
}

func (r *Runner) runScript(script *config.Script, path string, file os.FileInfo) {
func (r *Runner) runScript(ctx context.Context, script *config.Script, path string, file os.FileInfo) {
command, err := r.prepareScript(script, path, file)
if err != nil {
r.logSkip(file.Name(), err.Error())
Expand All @@ -289,7 +298,7 @@ func (r *Runner) runScript(script *config.Script, path string, file os.FileInfo)
defer log.StartSpinner()
}

finished := r.run(exec.Options{
finished := r.run(ctx, exec.Options{
Name: file.Name(),
Root: r.Repo.RootPath,
Commands: []string{command},
Expand All @@ -310,7 +319,7 @@ func (r *Runner) runScript(script *config.Script, path string, file os.FileInfo)
}
}

func (r *Runner) runCommands() {
func (r *Runner) runCommands(ctx context.Context) {
commands := make([]string, 0, len(r.Hook.Commands))
for name := range r.Hook.Commands {
if len(r.RunOnlyCommands) == 0 || slices.Contains(r.RunOnlyCommands, name) {
Expand Down Expand Up @@ -338,10 +347,10 @@ func (r *Runner) runCommands() {
wg.Add(1)
go func(name string, command *config.Command) {
defer wg.Done()
r.runCommand(name, command)
r.runCommand(ctx, name, command)
}(name, r.Hook.Commands[name])
} else {
r.runCommand(name, r.Hook.Commands[name])
r.runCommand(ctx, name, r.Hook.Commands[name])
}
}

Expand All @@ -353,11 +362,11 @@ func (r *Runner) runCommands() {
continue
}

r.runCommand(name, r.Hook.Commands[name])
r.runCommand(ctx, name, r.Hook.Commands[name])
}
}

func (r *Runner) runCommand(name string, command *config.Command) {
func (r *Runner) runCommand(ctx context.Context, name string, command *config.Command) {
run, err := r.prepareCommand(name, command)
if err != nil {
r.logSkip(name, err.Error())
Expand All @@ -369,7 +378,7 @@ func (r *Runner) runCommand(name string, command *config.Command) {
defer log.StartSpinner()
}

finished := r.run(exec.Options{
finished := r.run(ctx, exec.Options{
Name: name,
Root: filepath.Join(r.Repo.RootPath, command.Root),
Commands: run.commands,
Expand Down Expand Up @@ -409,7 +418,7 @@ func (r *Runner) addStagedFiles(files []string) {
}
}

func (r *Runner) run(opts exec.Options, follow bool) bool {
func (r *Runner) run(ctx context.Context, opts exec.Options, follow bool) bool {
log.SetName(opts.Name)
defer log.UnsetName(opts.Name)

Expand All @@ -423,7 +432,7 @@ func (r *Runner) run(opts exec.Options, follow bool) bool {
out = os.Stdout
}

err := r.executor.Execute(opts, out)
err := r.executor.Execute(ctx, opts, out)
if err != nil {
r.fail(opts.Name, errors.New(opts.FailText))
} else {
Expand All @@ -434,7 +443,7 @@ func (r *Runner) run(opts exec.Options, follow bool) bool {
}

out := bytes.NewBuffer(make([]byte, 0))
err := r.executor.Execute(opts, out)
err := r.executor.Execute(ctx, opts, out)

if err != nil {
r.fail(opts.Name, errors.New(opts.FailText))
Expand Down
7 changes: 4 additions & 3 deletions internal/lefthook/run/runner_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package run

import (
"context"
"errors"
"fmt"
"io"
Expand All @@ -19,7 +20,7 @@ import (

type TestExecutor struct{}

func (e TestExecutor) Execute(opts exec.Options, _out io.Writer) (err error) {
func (e TestExecutor) Execute(_ctx context.Context, opts exec.Options, _out io.Writer) (err error) {
if strings.HasPrefix(opts.Commands[0], "success") {
err = nil
} else {
Expand All @@ -29,7 +30,7 @@ func (e TestExecutor) Execute(opts exec.Options, _out io.Writer) (err error) {
return
}

func (e TestExecutor) RawExecute(_command []string, _out io.Writer) error {
func (e TestExecutor) RawExecute(_ctx context.Context, _command []string, _out io.Writer) error {
return nil
}

Expand Down Expand Up @@ -755,7 +756,7 @@ func TestRunAll(t *testing.T) {
}

t.Run(fmt.Sprintf("%d: %s", i, tt.name), func(t *testing.T) {
runner.RunAll(tt.sourceDirs)
runner.RunAll(context.Background(), tt.sourceDirs)
close(resultChan)

var success, fail []Result
Expand Down
Loading

0 comments on commit 7258254

Please sign in to comment.