Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 57 additions & 25 deletions pipe/close_responsibility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ import (
// readCloseSpy records whether Close was called.
type readCloseSpy struct {
io.Reader
closed atomic.Bool
closeCount atomic.Uint32
}

func (r *readCloseSpy) Close() error {
r.closed.Store(true)
r.closeCount.Add(1)
return nil
}

// writeCloseSpy records whether Close was called.
type writeCloseSpy struct {
io.Writer
closed atomic.Bool
closeCount atomic.Uint32
}

func (w *writeCloseSpy) Close() error {
w.closed.Store(true)
w.closeCount.Add(1)
return nil
}

Expand Down Expand Up @@ -63,28 +63,52 @@ func TestGoStageHonorsStreamOwnership(t *testing.T) {
))
require.NoError(t, s.Wait())

assert.Equal(t, !tc.leaveIn, in.closed.Load(), "closing stdin=%v", !tc.leaveIn)
assert.Equal(t, !tc.leaveOut, out.closed.Load(), "closing stdout=%v", !tc.leaveOut)
if tc.leaveIn {
assert.EqualValues(t, 0, in.closeCount.Load(), "closing stdin=%v", !tc.leaveIn)
} else {
assert.EqualValues(t, 1, in.closeCount.Load(), "closing stdin=%v", !tc.leaveIn)
}
if tc.leaveOut {
assert.EqualValues(t, 0, out.closeCount.Load(), "closing stdout=%v", !tc.leaveOut)
} else {
assert.EqualValues(t, 1, out.closeCount.Load(), "closing stdout=%v", !tc.leaveOut)
}
})
}
}

func TestStreamConstructorsPreserveOwnershipAndDynamicType(t *testing.T) {
borrowedInput := strings.NewReader("borrowed")
assert.Same(t, borrowedInput, Input(borrowedInput).Reader())
assert.Nil(t, Input(borrowedInput).Closer())

ownedInput := &readCloseSpy{Reader: strings.NewReader("owned")}
assert.Same(t, ownedInput, ClosingInput(ownedInput).Reader())
assert.Same(t, ownedInput, ClosingInput(ownedInput).Closer())

borrowedOutput := &strings.Builder{}
assert.Same(t, borrowedOutput, Output(borrowedOutput).Writer())
assert.Nil(t, Output(borrowedOutput).Closer())

ownedOutput := &writeCloseSpy{Writer: io.Discard}
assert.Same(t, ownedOutput, ClosingOutput(ownedOutput).Writer())
assert.Same(t, ownedOutput, ClosingOutput(ownedOutput).Closer())
borrowedReader := &readCloseSpy{Reader: strings.NewReader("borrowed")}
borrowedInput := Input(borrowedReader)
assert.Same(t, borrowedReader, borrowedInput.Reader())
assert.NoError(t, borrowedInput.Close())
assert.EqualValues(t, 0, borrowedReader.closeCount.Load())
assert.NoError(t, borrowedInput.Close())
assert.EqualValues(t, 0, borrowedReader.closeCount.Load())

ownedReader := &readCloseSpy{Reader: strings.NewReader("owned")}
ownedInput := ClosingInput(ownedReader)
assert.Same(t, ownedReader, ownedInput.Reader())
assert.NoError(t, ownedInput.Close())
assert.EqualValues(t, 1, ownedReader.closeCount.Load())
assert.NoError(t, ownedInput.Close())
assert.EqualValues(t, 1, ownedReader.closeCount.Load())

borrowedWriter := &writeCloseSpy{Writer: &strings.Builder{}}
borrowedOutput := Output(borrowedWriter)
assert.Same(t, borrowedWriter, borrowedOutput.Writer())
assert.NoError(t, borrowedOutput.Close())
assert.EqualValues(t, 0, borrowedWriter.closeCount.Load())
assert.NoError(t, borrowedOutput.Close())
assert.EqualValues(t, 0, borrowedWriter.closeCount.Load())

ownedWriter := &writeCloseSpy{Writer: &writeCloseSpy{Writer: io.Discard}}
ownedOutput := ClosingOutput(ownedWriter)
assert.Same(t, ownedWriter, ownedOutput.Writer())
assert.NoError(t, ownedOutput.Close())
assert.EqualValues(t, 1, ownedWriter.closeCount.Load())
assert.NoError(t, ownedOutput.Close())
assert.EqualValues(t, 1, ownedWriter.closeCount.Load())
}

// TestCommandStageHonorsCloseStdin verifies that a command stage closes a
Expand All @@ -109,7 +133,11 @@ func TestCommandStageHonorsCloseStdin(t *testing.T) {
))
require.NoError(t, s.Wait())

assert.Equal(t, !leave, in.closed.Load(), "closing stdin=%v", !leave)
if leave {
assert.EqualValues(t, 0, in.closeCount.Load(), "closing stdin=%v", !leave)
} else {
assert.EqualValues(t, 1, in.closeCount.Load(), "closing stdin=%v", !leave)
}
})
}
}
Expand All @@ -136,19 +164,23 @@ func TestCommandStageHonorsCloseStdout(t *testing.T) {
))
require.NoError(t, s.Wait())

assert.Equal(t, !leave, out.closed.Load(), "closing stdout=%v", !leave)
if leave {
assert.EqualValues(t, 0, out.closeCount.Load(), "closing stdout=%v", !leave)
} else {
assert.EqualValues(t, 1, out.closeCount.Load(), "closing stdout=%v", !leave)
}
})
}
}

func inputForTest(r io.ReadCloser, closing bool) InputStream {
func inputForTest(r io.ReadCloser, closing bool) *InputStream {
if closing {
return ClosingInput(r)
}
return Input(r)
}

func outputForTest(w io.WriteCloser, closing bool) OutputStream {
func outputForTest(w io.WriteCloser, closing bool) *OutputStream {
if closing {
return ClosingOutput(w)
}
Expand Down
73 changes: 47 additions & 26 deletions pipe/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,35 +88,60 @@ func (s *commandStage) Requirements() StageRequirements {

func (s *commandStage) Start(
ctx context.Context, opts StageOptions,
ins InputStream, outs OutputStream,
stdin *InputStream, stdout *OutputStream,
) error {
stdin := ins.Reader()
stdinCloser := ins.Closer()
stdout := outs.Writer()
stdoutCloser := outs.Closer()
r := stdin.Reader()
w := stdout.Writer()

if s.cmd.Dir == "" {
s.cmd.Dir = opts.Dir
}

s.setupEnv(ctx, opts.Env)

// It is important that the streams that are used by a command be
// closed at the right time. When that is depends on the type of
// the stream.
//
// A subprocess ultimately needs its own copies of `*os.File` file
// descriptors for its stdin and stdout. The external command will
// "always" close those when it exits.
//
// (It's theoretically possible for a command to pass the open
// file descriptor to another, longer-lived process, in which case
// the file descriptor wouldn't necessarily get closed even when
// the command finishes. But that's ill-behaved in a command that
// is being used in a pipeline, so we'll ignore that possibility.)
//
// If a stream provided for use as stdin/stdout is an `*os.File`,
// then we set the corresponding field of `exec.Cmd` to that
// argument. This causes `exec.Cmd` to duplicate that file
// descriptor and passes the dup to the subprocess. Therefore, we
// want to close our own copy "early", namely as soon as the
// external command has started, because the external command will
// keep its own copy open as long as necessary (and no longer!).
//
// If a stdin/stdout stream is _not_ an `*os.File`, then
// `exec.Cmd` will take care of creating an `os.Pipe()`, copying
// from the provided stream into/out of the pipe, and eventually
// close both ends of the pipe. In that case, we must close the
// provided stream "late", namely only after the external command
// and the copy have finished.

// Things that have to be closed as soon as the command has started:
var earlyClosers []io.Closer

// See the type comment for `Stage` for the explanation of this closing behavior.
if stdin != nil {
s.cmd.Stdin = stdin
if r != nil {
s.cmd.Stdin = r
}

if stdinCloser != nil {
if _, ok := stdin.(*os.File); ok {
// We can close our copy as soon as the command has started
earlyClosers = append(earlyClosers, stdinCloser)
} else {
// We need to close `stdin`, but only after the command has finished
s.lateClosers = append(s.lateClosers, stdinCloser)
}
if _, ok := r.(*os.File); ok {
// We can close our copy as soon as the command has started
earlyClosers = append(earlyClosers, stdin)
} else {
// We need to close `stdin`, but only after the command has finished
s.lateClosers = append(s.lateClosers, stdin)
}

closeEarlyClosers := func() {
Expand All @@ -133,28 +158,24 @@ func (s *commandStage) Start(
_ = s.closeLateClosers()
}

if stdout != nil {
if f, ok := stdout.(*os.File); ok {
if w != nil {
if f, ok := w.(*os.File); ok {
s.cmd.Stdout = f
if stdoutCloser != nil {
earlyClosers = append(earlyClosers, stdoutCloser)
}
earlyClosers = append(earlyClosers, stdout)
} else {
if stdoutCloser != nil {
s.lateClosers = append(s.lateClosers, stdoutCloser)
}
s.lateClosers = append(s.lateClosers, stdout)
// Route the copy through our own pipe so we can use a
// pooled buffer rather than letting exec.Cmd allocate a
// fresh 32KB buffer for its internal io.Copy.
ec, err := s.setupPooledStdout(stdout)
ec, err := s.setupPooledStdout(w)
if err != nil {
cleanupOnStartFailure()
return err
}
earlyClosers = append(earlyClosers, ec)
}
} else if stdoutCloser != nil {
s.lateClosers = append(s.lateClosers, stdoutCloser)
} else {
s.lateClosers = append(s.lateClosers, stdout)
}

// If the caller hasn't arranged otherwise, read the command's
Expand Down
4 changes: 3 additions & 1 deletion pipe/command_stdout_fastpath_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ func TestCommandStageStdoutFastPath(t *testing.T) {
cmd := exec.Command("true")
s := CommandStage("true", cmd).(*commandStage)

stdout := OutputStream{writer: f}
var stdout *OutputStream
if tc.closingStdout {
stdout = ClosingOutput(f)
} else {
stdout = Output(f)
}

require.NoError(t, s.Start(ctx, StageOptions{}, Input(nil), stdout))
Expand Down
2 changes: 1 addition & 1 deletion pipe/env_stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (s *stageWithExtraEnv) Requirements() StageRequirements {

func (s *stageWithExtraEnv) Start(
ctx context.Context, opts StageOptions,
stdin InputStream, stdout OutputStream,
stdin *InputStream, stdout *OutputStream,
) error {
opts.Vars = append(opts.Vars[:len(opts.Vars):len(opts.Vars)], func(_ context.Context, vars []EnvVar) []EnvVar {
return append(vars, s.env...)
Expand Down
16 changes: 5 additions & 11 deletions pipe/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,15 @@ func (s *goStage) Requirements() StageRequirements {

func (s *goStage) Start(
ctx context.Context, opts StageOptions,
stdin InputStream, stdout OutputStream,
stdin *InputStream, stdout *OutputStream,
) error {
r := stdin.Reader()
stdinCloser := stdin.Closer()
if r == nil {
// treat nil as empty input.
r = strings.NewReader("")
}

w := stdout.Writer()
stdoutCloser := stdout.Closer()
if w == nil {
// treat nil output as /dev/null
w = io.Discard
Expand All @@ -110,15 +108,11 @@ func (s *goStage) Start(
s.err = opts.PanicHandler(p)
}
}
if stdoutCloser != nil {
if err := stdoutCloser.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
}
if err := stdout.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
}
if stdinCloser != nil {
if err := stdinCloser.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
if err := stdin.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
close(s.done)
}()
Expand Down
6 changes: 3 additions & 3 deletions pipe/pipe_matching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ func (s *pipeSniffingStage) Requirements() pipe.StageRequirements {

func (s *pipeSniffingStage) Start(
_ context.Context, _ pipe.StageOptions,
stdin pipe.InputStream, stdout pipe.OutputStream,
stdin *pipe.InputStream, stdout *pipe.OutputStream,
) error {
s.stdin = stdin.Reader()
stdin.Close()
_ = stdin.Close()
s.stdout = stdout.Writer()
stdout.Close()
_ = stdout.Close()
return nil
}

Expand Down
Loading
Loading