Skip to content

Commit

Permalink
fix: share STDIN across different commands on pre-push hook (#732)
Browse files Browse the repository at this point in the history
* fix: error and bail if multiple commands useStdin

Since Stdin from Git is not intercepted and forwared to each commands/scripts,
this could lead to hang.

* fix: CommandExecutor.RawExecute forward os.Stdin

This is only used by LFS pre-push Hook which require Stdin.

* fix: use cached reader for stdin when use_stdin is specified

* chore: remove special checks for pre-push hook

* chore: fix typos

* fix: fail the hook if git-lfs command failed

* chore: add more verbose error

---------

Co-authored-by: Thomas Desveaux <[email protected]>
  • Loading branch information
mrexox and tdesveaux authored May 30, 2024
1 parent 556201d commit afc1125
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 36 deletions.
5 changes: 4 additions & 1 deletion internal/lefthook/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ Run 'lefthook install' manually.`,
)

startTime := time.Now()
results := r.RunAll(ctx, sourceDirs)
results, runErr := r.RunAll(ctx, sourceDirs)
if runErr != nil {
return fmt.Errorf("failed to run the hook: %w", runErr)
}

if ctx.Err() != nil {
return errors.New("Interrupted")
Expand Down
49 changes: 49 additions & 0 deletions internal/lefthook/runner/cached_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package runner

import (
"bytes"
"io"
)

// cachedReader reads from the provided `io.Reader` until `io.EOF` and saves
// the read content into the inner buffer.
//
// After `io.EOF` it will be providing the read data again and again.
type cachedReader struct {
in io.Reader
useBuffer bool
buf []byte
reader *bytes.Reader
}

func NewCachedReader(in io.Reader) *cachedReader {
return &cachedReader{
in: in,
buf: []byte{},
reader: bytes.NewReader([]byte{}),
}
}

func (r *cachedReader) Read(p []byte) (int, error) {
if r.useBuffer {
n, err := r.reader.Read(p)
if err == io.EOF {
_, seekErr := r.reader.Seek(0, io.SeekStart)
if seekErr != nil {
panic(seekErr)
}

return n, err
}

return n, err
}

n, err := r.in.Read(p)
r.buf = append(r.buf, p[:n]...)
if err == io.EOF {
r.useBuffer = true
r.reader = bytes.NewReader(r.buf)
}
return n, err
}
24 changes: 24 additions & 0 deletions internal/lefthook/runner/cached_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package runner

import (
"bytes"
"io"
"testing"
)

func TestCachedReader(t *testing.T) {
testSlice := []byte("Some example string\nMultiline")

cachedReader := NewCachedReader(bytes.NewReader(testSlice))

for range 5 {
res, err := io.ReadAll(cachedReader)
if err != nil {
t.Errorf("unexpected err: %s", err)
}

if !bytes.Equal(res, testSlice) {
t.Errorf("expected %v to be equal to %v", res, testSlice)
}
}
}
9 changes: 3 additions & 6 deletions internal/lefthook/runner/exec/execute_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ type executeArgs struct {
interactive, useStdin bool
}

func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Writer) error {
var in io.Reader = nullReader{}
if opts.UseStdin {
in = os.Stdin
}
func (e CommandExecutor) Execute(ctx context.Context, opts Options, in io.Reader, out io.Writer) error {
if opts.Interactive && !isatty.IsTerminal(os.Stdin.Fd()) {
tty, err := os.Open("/dev/tty")
if err == nil {
Expand Down Expand Up @@ -72,9 +68,10 @@ func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Write
return nil
}

func (e CommandExecutor) RawExecute(ctx context.Context, command []string, out io.Writer) error {
func (e CommandExecutor) RawExecute(ctx context.Context, command []string, in io.Reader, out io.Writer) error {
cmd := exec.CommandContext(ctx, command[0], command[1:]...)

cmd.Stdin = in
cmd.Stdout = out
cmd.Stderr = os.Stderr

Expand Down
11 changes: 4 additions & 7 deletions internal/lefthook/runner/exec/execute_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ type executeArgs struct {
root string
}

func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Writer) error {
var in io.Reader = nullReader{}
if opts.UseStdin {
in = os.Stdin
}
func (e CommandExecutor) Execute(ctx context.Context, opts Options, in io.Reader, out io.Writer) error {
if opts.Interactive && !isatty.IsTerminal(os.Stdin.Fd()) {
tty, err := tty.Open()
if err == nil {
Expand Down Expand Up @@ -63,9 +59,10 @@ func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Write
return nil
}

func (e CommandExecutor) RawExecute(ctx context.Context, command []string, out io.Writer) error {
cmd := exec.Command(command[0], command[1:]...)
func (e CommandExecutor) RawExecute(ctx context.Context, command []string, in io.Reader, out io.Writer) error {
cmd := exec.CommandContext(ctx, command[0], command[1:]...)

cmd.Stdin = in
cmd.Stdout = out
cmd.Stderr = os.Stderr

Expand Down
4 changes: 2 additions & 2 deletions internal/lefthook/runner/exec/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ type Options struct {
// Executor provides an interface for command execution.
// It is used here for testing purpose mostly.
type Executor interface {
Execute(ctx context.Context, opts Options, out io.Writer) error
RawExecute(ctx context.Context, command []string, out io.Writer) error
Execute(ctx context.Context, opts Options, in io.Reader, out io.Writer) error
RawExecute(ctx context.Context, command []string, in io.Reader, out io.Writer) error
}
10 changes: 0 additions & 10 deletions internal/lefthook/runner/exec/nullReader.go

This file was deleted.

15 changes: 15 additions & 0 deletions internal/lefthook/runner/null_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package runner

import "io"

// nullReader always returns `io.EOF`.
type nullReader struct{}

func NewNullReader() io.Reader {
return nullReader{}
}

// Implements io.Reader interface.
func (nullReader) Read(b []byte) (int, error) {
return 0, io.EOF
}
20 changes: 20 additions & 0 deletions internal/lefthook/runner/null_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package runner

import (
"bytes"
"io"
"testing"
)

func TestNullReader(t *testing.T) {
nullReader := NewNullReader()

res, err := io.ReadAll(nullReader)
if err != nil {
t.Errorf("unexpected err: %s", err)
}

if !bytes.Equal(res, []byte{}) {
t.Errorf("expected %v to be equal to %v", res, []byte{})
}
}
28 changes: 21 additions & 7 deletions internal/lefthook/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,19 @@ type Options struct {
type Runner struct {
Options

stdin io.Reader
partiallyStagedFiles []string
failed atomic.Bool
executor exec.Executor
}

func New(opts Options) *Runner {
return &Runner{
Options: opts,
Options: opts,

// Some hooks use STDIN for parsing data from Git. To allow multiple commands
// and scripts access the same Git data STDIN is cached via cachedReader.
stdin: NewCachedReader(os.Stdin),
executor: exec.CommandExecutor{},
}
}
Expand All @@ -79,16 +84,16 @@ type executable interface {

// RunAll runs scripts and commands.
// LFS hook is executed at first if needed.
func (r *Runner) RunAll(ctx context.Context, sourceDirs []string) []Result {
func (r *Runner) RunAll(ctx context.Context, sourceDirs []string) ([]Result, error) {
results := make([]Result, 0, len(r.Hook.Commands)+len(r.Hook.Scripts))

if err := r.runLFSHook(ctx); err != nil {
log.Error(err)
return results, err
}

if r.Hook.DoSkip(r.Repo.State()) {
r.logSkip(r.HookName, "hook setting")
return results
return results, nil
}

if !r.DisableTTY && !r.Hook.Follow {
Expand All @@ -113,7 +118,7 @@ func (r *Runner) RunAll(ctx context.Context, sourceDirs []string) []Result {

r.postHook()

return results
return results, nil
}

func (r *Runner) runLFSHook(ctx context.Context) error {
Expand Down Expand Up @@ -144,6 +149,7 @@ func (r *Runner) runLFSHook(ctx context.Context) error {
[]string{"git", "lfs", r.HookName},
r.GitArgs...,
),
r.stdin,
out,
)

Expand Down Expand Up @@ -490,6 +496,12 @@ func (r *Runner) run(ctx context.Context, opts exec.Options, follow bool) bool {
log.SetName(opts.Name)
defer log.UnsetName(opts.Name)

// If the command does not explicitly `use_stdin` no input will be provided.
var in io.Reader = NewNullReader()
if opts.UseStdin {
in = r.stdin
}

if (follow || opts.Interactive) && r.LogSettings.LogExecution() {
r.logExecute(opts.Name, nil, nil)

Expand All @@ -500,12 +512,14 @@ func (r *Runner) run(ctx context.Context, opts exec.Options, follow bool) bool {
out = io.Discard
}

err := r.executor.Execute(ctx, opts, out)
err := r.executor.Execute(ctx, opts, in, out)

return err == nil
}

out := bytes.NewBuffer(make([]byte, 0))
err := r.executor.Execute(ctx, opts, out)

err := r.executor.Execute(ctx, opts, in, out)

r.logExecute(opts.Name, err, out)

Expand Down
9 changes: 6 additions & 3 deletions internal/lefthook/runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (

type TestExecutor struct{}

func (e TestExecutor) Execute(_ctx context.Context, opts exec.Options, _out io.Writer) (err error) {
func (e TestExecutor) Execute(_ctx context.Context, opts exec.Options, _in io.Reader, _out io.Writer) (err error) {
if strings.HasPrefix(opts.Commands[0], "success") {
err = nil
} else {
Expand All @@ -31,7 +31,7 @@ func (e TestExecutor) Execute(_ctx context.Context, opts exec.Options, _out io.W
return
}

func (e TestExecutor) RawExecute(_ctx context.Context, _command []string, _out io.Writer) error {
func (e TestExecutor) RawExecute(_ctx context.Context, _command []string, _in io.Reader, _out io.Writer) error {
return nil
}

Expand Down Expand Up @@ -766,7 +766,10 @@ func TestRunAll(t *testing.T) {
}

t.Run(fmt.Sprintf("%d: %s", i, tt.name), func(t *testing.T) {
results := runner.RunAll(context.Background(), tt.sourceDirs)
results, err := runner.RunAll(context.Background(), tt.sourceDirs)
if err != nil {
t.Errorf("unexpected error %s", err)
}

var success, fail []Result
for _, result := range results {
Expand Down

0 comments on commit afc1125

Please sign in to comment.