From f3f6ef1d8ba3a5ed731d450b3512bab261e4e6ba Mon Sep 17 00:00:00 2001 From: Marcel Jackwerth Date: Mon, 18 Sep 2023 09:52:17 +0000 Subject: [PATCH] feat: add context for interrupt handling --- internal/lefthook/run.go | 11 +++- internal/lefthook/run/exec/execute_unix.go | 13 ++--- internal/lefthook/run/exec/execute_windows.go | 5 +- internal/lefthook/run/exec/executor.go | 5 +- internal/lefthook/run/runner.go | 51 ++++++++++-------- internal/lefthook/run/runner_test.go | 7 +-- testdata/run_interrupt.txt | 52 +++++++++++++++++++ 7 files changed, 109 insertions(+), 35 deletions(-) create mode 100644 testdata/run_interrupt.txt diff --git a/internal/lefthook/run.go b/internal/lefthook/run.go index f18f76f7..d51a01f1 100644 --- a/internal/lefthook/run.go +++ b/internal/lefthook/run.go @@ -1,9 +1,11 @@ package lefthook import ( + "context" "errors" "fmt" "os" + "os/signal" "path/filepath" "slices" "strings" @@ -143,8 +145,11 @@ Run 'lefthook install' manually.`, ) } + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + go func() { - runner.RunAll(sourceDirs) + runner.RunAll(ctx, sourceDirs) close(resultChan) }() @@ -153,6 +158,10 @@ Run 'lefthook install' manually.`, results = append(results, res) } + if ctx.Err() != nil { + return errors.New("Interrupted") + } + if !logSettings.SkipSummary() { printSummary(time.Since(startTime), results, logSettings) } diff --git a/internal/lefthook/run/exec/execute_unix.go b/internal/lefthook/run/exec/execute_unix.go index ad0ceff2..da72f8af 100644 --- a/internal/lefthook/run/exec/execute_unix.go +++ b/internal/lefthook/run/exec/execute_unix.go @@ -4,6 +4,7 @@ package exec import ( + "context" "fmt" "io" "os" @@ -27,7 +28,7 @@ type executeArgs struct { interactive, useStdin bool } -func (e CommandExecutor) Execute(opts Options, out io.Writer) error { +func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Writer) error { in := os.Stdin if opts.Interactive && !isatty.IsTerminal(os.Stdin.Fd()) { tty, err := os.Open("/dev/tty") @@ -60,7 +61,7 @@ func (e CommandExecutor) Execute(opts Options, out io.Writer) error { // We can have one command split into separate to fit into shell command max length. // In this case we execute those commands one by one. for _, command := range opts.Commands { - if err := e.execute(command, args); err != nil { + if err := e.execute(ctx, command, args); err != nil { return err } } @@ -68,8 +69,8 @@ func (e CommandExecutor) Execute(opts Options, out io.Writer) error { return nil } -func (e CommandExecutor) RawExecute(command []string, out io.Writer) error { - cmd := exec.Command(command[0], command[1:]...) +func (e CommandExecutor) RawExecute(ctx context.Context, command []string, out io.Writer) error { + cmd := exec.CommandContext(ctx, command[0], command[1:]...) cmd.Stdout = out cmd.Stderr = os.Stderr @@ -77,8 +78,8 @@ func (e CommandExecutor) RawExecute(command []string, out io.Writer) error { return cmd.Run() } -func (e CommandExecutor) execute(cmdstr string, args *executeArgs) error { - command := exec.Command("sh", "-c", cmdstr) +func (e CommandExecutor) execute(ctx context.Context, cmdstr string, args *executeArgs) error { + command := exec.CommandContext(ctx, "sh", "-c", cmdstr) command.Dir = args.root command.Env = append(os.Environ(), args.envs...) diff --git a/internal/lefthook/run/exec/execute_windows.go b/internal/lefthook/run/exec/execute_windows.go index 506f1b7d..4c88c2c0 100644 --- a/internal/lefthook/run/exec/execute_windows.go +++ b/internal/lefthook/run/exec/execute_windows.go @@ -1,6 +1,7 @@ package exec import ( + "context" "fmt" "io" "os" @@ -18,7 +19,7 @@ type executeArgs struct { root string } -func (e CommandExecutor) Execute(opts Options, out io.Writer) error { +func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Writer) error { root, _ := filepath.Abs(opts.Root) envs := make([]string, len(opts.Env)) for name, value := range opts.Env { @@ -44,7 +45,7 @@ func (e CommandExecutor) Execute(opts Options, out io.Writer) error { return nil } -func (e CommandExecutor) RawExecute(command []string, out io.Writer) error { +func (e CommandExecutor) RawExecute(ctx context.Context, command []string, out io.Writer) error { cmd := exec.Command(command[0], command[1:]...) cmd.Stdout = out diff --git a/internal/lefthook/run/exec/executor.go b/internal/lefthook/run/exec/executor.go index 96cf4099..713fadb1 100644 --- a/internal/lefthook/run/exec/executor.go +++ b/internal/lefthook/run/exec/executor.go @@ -1,6 +1,7 @@ package exec import ( + "context" "io" ) @@ -15,6 +16,6 @@ type Options struct { // Executor provides an interface for command execution. // It is used here for testing purpose mostly. type Executor interface { - Execute(opts Options, out io.Writer) error - RawExecute(command []string, out io.Writer) error + Execute(ctx context.Context, opts Options, out io.Writer) error + RawExecute(ctx context.Context, command []string, out io.Writer) error } diff --git a/internal/lefthook/run/runner.go b/internal/lefthook/run/runner.go index 6ca98b1f..15b9befd 100644 --- a/internal/lefthook/run/runner.go +++ b/internal/lefthook/run/runner.go @@ -2,6 +2,7 @@ package run import ( "bytes" + "context" "errors" "fmt" "io" @@ -63,8 +64,8 @@ func NewRunner(opts Options) *Runner { // RunAll runs scripts and commands. // LFS hook is executed at first if needed. -func (r *Runner) RunAll(sourceDirs []string) { - if err := r.runLFSHook(); err != nil { +func (r *Runner) RunAll(ctx context.Context, sourceDirs []string) { + if err := r.runLFSHook(ctx); err != nil { log.Error(err) } @@ -88,10 +89,10 @@ func (r *Runner) RunAll(sourceDirs []string) { r.preHook() for _, dir := range scriptDirs { - r.runScripts(dir) + r.runScripts(ctx, dir) } - r.runCommands() + r.runCommands(ctx) r.postHook() } @@ -105,7 +106,7 @@ func (r *Runner) success(name string) { r.ResultChan <- resultSuccess(name) } -func (r *Runner) runLFSHook() error { +func (r *Runner) runLFSHook(ctx context.Context) error { if !git.IsLFSHook(r.HookName) { return nil } @@ -128,6 +129,7 @@ func (r *Runner) runLFSHook() error { ) out := bytes.NewBuffer(make([]byte, 0)) err := r.executor.RawExecute( + ctx, append( []string{"git", "lfs", r.HookName}, r.GitArgs..., @@ -223,7 +225,7 @@ func (r *Runner) postHook() { } } -func (r *Runner) runScripts(dir string) { +func (r *Runner) runScripts(ctx context.Context, dir string) { files, err := afero.ReadDir(r.Repo.Fs, dir) // ReadDir already sorts files by .Name() if err != nil || len(files) == 0 { return @@ -233,6 +235,10 @@ func (r *Runner) runScripts(dir string) { var wg sync.WaitGroup for _, file := range files { + if ctx.Err() != nil { + return + } + script, ok := r.Hook.Scripts[file.Name()] if !ok { r.logSkip(file.Name(), "not specified in config file") @@ -255,16 +261,20 @@ func (r *Runner) runScripts(dir string) { wg.Add(1) go func(script *config.Script, path string, file os.FileInfo) { defer wg.Done() - r.runScript(script, path, file) + r.runScript(ctx, script, path, file) }(script, path, file) } else { - r.runScript(script, path, file) + r.runScript(ctx, script, path, file) } } wg.Wait() for _, file := range interactiveScripts { + if ctx.Err() != nil { + return + } + script := r.Hook.Scripts[file.Name()] if r.failed.Load() { r.logSkip(file.Name(), "non-interactive scripts failed") @@ -272,12 +282,11 @@ func (r *Runner) runScripts(dir string) { } path := filepath.Join(dir, file.Name()) - - r.runScript(script, path, file) + r.runScript(ctx, script, path, file) } } -func (r *Runner) runScript(script *config.Script, path string, file os.FileInfo) { +func (r *Runner) runScript(ctx context.Context, script *config.Script, path string, file os.FileInfo) { command, err := r.prepareScript(script, path, file) if err != nil { r.logSkip(file.Name(), err.Error()) @@ -289,7 +298,7 @@ func (r *Runner) runScript(script *config.Script, path string, file os.FileInfo) defer log.StartSpinner() } - finished := r.run(exec.Options{ + finished := r.run(ctx, exec.Options{ Name: file.Name(), Root: r.Repo.RootPath, Commands: []string{command}, @@ -310,7 +319,7 @@ func (r *Runner) runScript(script *config.Script, path string, file os.FileInfo) } } -func (r *Runner) runCommands() { +func (r *Runner) runCommands(ctx context.Context) { commands := make([]string, 0, len(r.Hook.Commands)) for name := range r.Hook.Commands { if len(r.RunOnlyCommands) == 0 || slices.Contains(r.RunOnlyCommands, name) { @@ -338,10 +347,10 @@ func (r *Runner) runCommands() { wg.Add(1) go func(name string, command *config.Command) { defer wg.Done() - r.runCommand(name, command) + r.runCommand(ctx, name, command) }(name, r.Hook.Commands[name]) } else { - r.runCommand(name, r.Hook.Commands[name]) + r.runCommand(ctx, name, r.Hook.Commands[name]) } } @@ -353,11 +362,11 @@ func (r *Runner) runCommands() { continue } - r.runCommand(name, r.Hook.Commands[name]) + r.runCommand(ctx, name, r.Hook.Commands[name]) } } -func (r *Runner) runCommand(name string, command *config.Command) { +func (r *Runner) runCommand(ctx context.Context, name string, command *config.Command) { run, err := r.prepareCommand(name, command) if err != nil { r.logSkip(name, err.Error()) @@ -369,7 +378,7 @@ func (r *Runner) runCommand(name string, command *config.Command) { defer log.StartSpinner() } - finished := r.run(exec.Options{ + finished := r.run(ctx, exec.Options{ Name: name, Root: filepath.Join(r.Repo.RootPath, command.Root), Commands: run.commands, @@ -409,7 +418,7 @@ func (r *Runner) addStagedFiles(files []string) { } } -func (r *Runner) run(opts exec.Options, follow bool) bool { +func (r *Runner) run(ctx context.Context, opts exec.Options, follow bool) bool { log.SetName(opts.Name) defer log.UnsetName(opts.Name) @@ -423,7 +432,7 @@ func (r *Runner) run(opts exec.Options, follow bool) bool { out = os.Stdout } - err := r.executor.Execute(opts, out) + err := r.executor.Execute(ctx, opts, out) if err != nil { r.fail(opts.Name, errors.New(opts.FailText)) } else { @@ -434,7 +443,7 @@ func (r *Runner) run(opts exec.Options, follow bool) bool { } out := bytes.NewBuffer(make([]byte, 0)) - err := r.executor.Execute(opts, out) + err := r.executor.Execute(ctx, opts, out) if err != nil { r.fail(opts.Name, errors.New(opts.FailText)) diff --git a/internal/lefthook/run/runner_test.go b/internal/lefthook/run/runner_test.go index 95977f0c..8c05e514 100644 --- a/internal/lefthook/run/runner_test.go +++ b/internal/lefthook/run/runner_test.go @@ -1,6 +1,7 @@ package run import ( + "context" "errors" "fmt" "io" @@ -19,7 +20,7 @@ import ( type TestExecutor struct{} -func (e TestExecutor) Execute(opts exec.Options, _out io.Writer) (err error) { +func (e TestExecutor) Execute(_ctx context.Context, opts exec.Options, _out io.Writer) (err error) { if strings.HasPrefix(opts.Commands[0], "success") { err = nil } else { @@ -29,7 +30,7 @@ func (e TestExecutor) Execute(opts exec.Options, _out io.Writer) (err error) { return } -func (e TestExecutor) RawExecute(_command []string, _out io.Writer) error { +func (e TestExecutor) RawExecute(_ctx context.Context, _command []string, _out io.Writer) error { return nil } @@ -755,7 +756,7 @@ func TestRunAll(t *testing.T) { } t.Run(fmt.Sprintf("%d: %s", i, tt.name), func(t *testing.T) { - runner.RunAll(tt.sourceDirs) + runner.RunAll(context.Background(), tt.sourceDirs) close(resultChan) var success, fail []Result diff --git a/testdata/run_interrupt.txt b/testdata/run_interrupt.txt new file mode 100644 index 00000000..9103d014 --- /dev/null +++ b/testdata/run_interrupt.txt @@ -0,0 +1,52 @@ +chmod 0700 hook.sh +chmod 0700 commit-with-interrupt.sh +exec git init +exec git config user.email "you@example.com" +exec git config user.name "Your Name" +exec lefthook install +exec git add -A + +exec git commit -m 'init' +stderr 'hook-done' + +exec ./commit-with-interrupt.sh +stderr 'script-done' +! stderr 'hook-done' +stderr 'signal: killed' +stderr 'Error: Interrupted' +grep unstaged newfile.txt +exec git stash list +! stdout 'lefthook auto backup' + +-- lefthook.yml -- +pre-commit: + commands: + slow_job: + run: ./hook.sh + +-- hook.sh -- +#!/usr/bin/env bash + +sleep 2 +>&2 echo hook-done + +-- newfile.txt -- +staged + +-- commit-with-interrupt.sh -- +#!/usr/bin/env bash + +echo staged >> newfile.txt +git add newfile.txt +echo unstaged >> newfile.txt + +# ctrl-c is emulated by sending SIGINT to a process group +# so we first need to emulate being a terminal and enable +# job monitoring so that new PGIDs are assigned. +set -m +nohup git commit -m test & +pgid=$! +sleep 1 +kill -SIGINT -$pgid +wait +>&2 echo 'script-done'