diff --git a/README.md b/README.md index 995ee86..aabd484 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,31 @@ A package used to easily build command pipelines in your Go applications # Important We have not thoroughly tested this package on OSs other than Linux, especially Windows. At this time, using this package on Windows based systems is considered experimental and will be supported only on a best effort basis. +# Migrating to v2 + +It's normal for pipelines to stop before all input has been consumed[^1]. If an earlier stage continues writing after that happens, the write side of the pipe can fail with `EPIPE`, `SIGPIPE`, or `io.ErrClosedPipe`. + +In go-pipe v1 it was possible to get away without handling this case, because a command stage's stdin was connected in a way that often (but not necessarily!) drained the write side and hid the error from the previous stage feeding it. That was an implementation detail, not a guarantee. In go-pipe v2, producer stages are more likely to be connected directly to a command's stdin, and thus see the error themselves. + +Fortunately, this is easily handled by wrapping the stage with `pipe.IgnoreError(stage, IsPipeError)`. If the producer only writes output and is otherwise stateless, that's the only thing needed. + +If the producer also updates state, metrics, cursors, or has other side effects, in a way that depends on how much of the output was produced, then in addition to using `pipe.IgnoreError`, you must also ensure producer-owned state is brought to a consistent point before returning the error. + +For example, if a stateful producer function must process its entire input for correctness regardless of whether it was read by the consumer, it should use a pattern like: + +```go +var writeErr error +for _, item := range items { + updateState(item) + if writeErr == nil { + _, writeErr = fmt.Fprintln(stdout, item) + } +} +return writeErr +``` + # Links * [Docs](https://pkg.go.dev/github.com/github/go-pipe/v2) + +[^1]: In `cat foo | head | grep -q`, for example, either `head` or `grep` could exit before its input is fully consumed. diff --git a/pipe/close_responsibility_test.go b/pipe/close_responsibility_test.go index 23f6107..20a5a28 100644 --- a/pipe/close_responsibility_test.go +++ b/pipe/close_responsibility_test.go @@ -15,28 +15,28 @@ import ( // readCloseSpy records whether Close was called. type readCloseSpy struct { io.Reader - closed atomic.Bool + closeCount atomic.Uint32 } func (r *readCloseSpy) Close() error { - r.closed.Store(true) + r.closeCount.Add(1) return nil } // writeCloseSpy records whether Close was called. type writeCloseSpy struct { io.Writer - closed atomic.Bool + closeCount atomic.Uint32 } func (w *writeCloseSpy) Close() error { - w.closed.Store(true) + w.closeCount.Add(1) return nil } -// TestGoStageHonorsCloseFlags verifies that a Function stage closes -// stdin/stdout iff the corresponding close flag is true. -func TestGoStageHonorsCloseFlags(t *testing.T) { +// TestGoStageHonorsStreamOwnership verifies that a Function stage closes +// stdin/stdout iff the corresponding stream is closing. +func TestGoStageHonorsStreamOwnership(t *testing.T) { cases := []struct { name string leaveIn, leaveOut bool @@ -58,34 +58,62 @@ func TestGoStageHonorsCloseFlags(t *testing.T) { require.NoError(t, s.Start( context.Background(), StageOptions{}, - in, !tc.leaveIn, - out, !tc.leaveOut, + inputForTest(in, !tc.leaveIn), + outputForTest(out, !tc.leaveOut), )) require.NoError(t, s.Wait()) - assert.Equal(t, !tc.leaveIn, in.closed.Load(), "closeStdin=%v", !tc.leaveIn) - assert.Equal(t, !tc.leaveOut, out.closed.Load(), "closeStdout=%v", !tc.leaveOut) + if tc.leaveIn { + assert.EqualValues(t, 0, in.closeCount.Load(), "closing stdin=%v", !tc.leaveIn) + } else { + assert.EqualValues(t, 1, in.closeCount.Load(), "closing stdin=%v", !tc.leaveIn) + } + if tc.leaveOut { + assert.EqualValues(t, 0, out.closeCount.Load(), "closing stdout=%v", !tc.leaveOut) + } else { + assert.EqualValues(t, 1, out.closeCount.Load(), "closing stdout=%v", !tc.leaveOut) + } }) } } -func TestStagePanicsWhenOwnedStreamIsNotCloseable(t *testing.T) { - s := Function("f", func(_ context.Context, _ Env, _ io.Reader, _ io.Writer) error { - return nil - }) - - assert.PanicsWithValue(t, "stage asked to close *strings.Reader, which does not implement io.Closer", func() { - _ = s.Start( - context.Background(), StageOptions{}, - strings.NewReader("not closeable"), true, - nil, false, - ) - }) +func TestStreamConstructorsPreserveOwnershipAndDynamicType(t *testing.T) { + borrowedReader := &readCloseSpy{Reader: strings.NewReader("borrowed")} + borrowedInput := Input(borrowedReader) + assert.Same(t, borrowedReader, borrowedInput.Reader()) + assert.NoError(t, borrowedInput.Close()) + assert.EqualValues(t, 0, borrowedReader.closeCount.Load()) + assert.NoError(t, borrowedInput.Close()) + assert.EqualValues(t, 0, borrowedReader.closeCount.Load()) + + ownedReader := &readCloseSpy{Reader: strings.NewReader("owned")} + ownedInput := ClosingInput(ownedReader) + assert.Same(t, ownedReader, ownedInput.Reader()) + assert.NoError(t, ownedInput.Close()) + assert.EqualValues(t, 1, ownedReader.closeCount.Load()) + assert.NoError(t, ownedInput.Close()) + assert.EqualValues(t, 1, ownedReader.closeCount.Load()) + + borrowedWriter := &writeCloseSpy{Writer: &strings.Builder{}} + borrowedOutput := Output(borrowedWriter) + assert.Same(t, borrowedWriter, borrowedOutput.Writer()) + assert.NoError(t, borrowedOutput.Close()) + assert.EqualValues(t, 0, borrowedWriter.closeCount.Load()) + assert.NoError(t, borrowedOutput.Close()) + assert.EqualValues(t, 0, borrowedWriter.closeCount.Load()) + + ownedWriter := &writeCloseSpy{Writer: &writeCloseSpy{Writer: io.Discard}} + ownedOutput := ClosingOutput(ownedWriter) + assert.Same(t, ownedWriter, ownedOutput.Writer()) + assert.NoError(t, ownedOutput.Close()) + assert.EqualValues(t, 1, ownedWriter.closeCount.Load()) + assert.NoError(t, ownedOutput.Close()) + assert.EqualValues(t, 1, ownedWriter.closeCount.Load()) } // TestCommandStageHonorsCloseStdin verifies that a command stage closes a -// non-file stdin (a "late" closer) iff closeStdin is true. An empty -// reader is used so exec.Cmd's input-copy goroutine sees EOF promptly. +// non-file stdin (a "late" closer) iff the input stream is closing. An +// empty reader is used so exec.Cmd's input-copy goroutine sees EOF promptly. func TestCommandStageHonorsCloseStdin(t *testing.T) { for _, leave := range []bool{false, true} { name := "owns stdin" @@ -100,19 +128,23 @@ func TestCommandStageHonorsCloseStdin(t *testing.T) { require.NoError(t, s.Start( context.Background(), StageOptions{}, - in, !leave, - nil, false, + inputForTest(in, !leave), + Output(nil), )) require.NoError(t, s.Wait()) - assert.Equal(t, !leave, in.closed.Load(), "closeStdin=%v", !leave) + if leave { + assert.EqualValues(t, 0, in.closeCount.Load(), "closing stdin=%v", !leave) + } else { + assert.EqualValues(t, 1, in.closeCount.Load(), "closing stdin=%v", !leave) + } }) } } // TestCommandStageHonorsCloseStdout verifies the stdout counterpart: a // non-file stdout (routed through the pooled-copy path) is closed iff -// closeStdout is true. +// the output stream is closing. func TestCommandStageHonorsCloseStdout(t *testing.T) { for _, leave := range []bool{false, true} { name := "owns stdout" @@ -127,12 +159,30 @@ func TestCommandStageHonorsCloseStdout(t *testing.T) { require.NoError(t, s.Start( context.Background(), StageOptions{}, - nil, false, - out, !leave, + Input(nil), + outputForTest(out, !leave), )) require.NoError(t, s.Wait()) - assert.Equal(t, !leave, out.closed.Load(), "closeStdout=%v", !leave) + if leave { + assert.EqualValues(t, 0, out.closeCount.Load(), "closing stdout=%v", !leave) + } else { + assert.EqualValues(t, 1, out.closeCount.Load(), "closing stdout=%v", !leave) + } }) } } + +func inputForTest(r io.ReadCloser, closing bool) *InputStream { + if closing { + return ClosingInput(r) + } + return Input(r) +} + +func outputForTest(w io.WriteCloser, closing bool) *OutputStream { + if closing { + return ClosingOutput(w) + } + return Output(w) +} diff --git a/pipe/command.go b/pipe/command.go index 9cfd9c3..27d34c6 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -81,18 +81,17 @@ func (s *commandStage) Process() *os.Process { func (s *commandStage) Requirements() StageRequirements { return StageRequirements{ - StdinNeedsFile: true, - StdoutNeedsFile: true, + Stdin: StreamPreferFile, + Stdout: StreamPreferFile, } } func (s *commandStage) Start( ctx context.Context, opts StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, + stdin *InputStream, stdout *OutputStream, ) error { - stdinCloser := ownedCloser(stdin, closeStdin) - stdoutCloser := ownedCloser(stdout, closeStdout) + r := stdin.Reader() + w := stdout.Writer() if s.cmd.Dir == "" { s.cmd.Dir = opts.Dir @@ -100,45 +99,49 @@ func (s *commandStage) Start( s.setupEnv(ctx, opts.Env) + // It is important that the streams that are used by a command be + // closed at the right time. When that is depends on the type of + // the stream. + // + // A subprocess ultimately needs its own copies of `*os.File` file + // descriptors for its stdin and stdout. The external command will + // "always" close those when it exits. + // + // (It's theoretically possible for a command to pass the open + // file descriptor to another, longer-lived process, in which case + // the file descriptor wouldn't necessarily get closed even when + // the command finishes. But that's ill-behaved in a command that + // is being used in a pipeline, so we'll ignore that possibility.) + // + // If a stream provided for use as stdin/stdout is an `*os.File`, + // then we set the corresponding field of `exec.Cmd` to that + // argument. This causes `exec.Cmd` to duplicate that file + // descriptor and passes the dup to the subprocess. Therefore, we + // want to close our own copy "early", namely as soon as the + // external command has started, because the external command will + // keep its own copy open as long as necessary (and no longer!). + // + // If a stdin/stdout stream is _not_ an `*os.File`, then + // `exec.Cmd` will take care of creating an `os.Pipe()`, copying + // from the provided stream into/out of the pipe, and eventually + // close both ends of the pipe. In that case, we must close the + // provided stream "late", namely only after the external command + // and the copy have finished. + // Things that have to be closed as soon as the command has started: var earlyClosers []io.Closer // See the type comment for `Stage` for the explanation of this closing behavior. - if stdin != nil { - s.cmd.Stdin = stdin + if r != nil { + s.cmd.Stdin = r } - if stdinCloser != nil { - if _, ok := stdin.(*os.File); ok { - // We can close our copy as soon as the command has started - earlyClosers = append(earlyClosers, stdinCloser) - } else { - // We need to close `stdin`, but only after the command has finished - s.lateClosers = append(s.lateClosers, stdinCloser) - } - } - - if stdout != nil { - if f, ok := stdout.(*os.File); ok { - s.cmd.Stdout = f - if stdoutCloser != nil { - earlyClosers = append(earlyClosers, stdoutCloser) - } - } else { - // Route the copy through our own pipe so we can use a - // pooled buffer rather than letting exec.Cmd allocate a - // fresh 32KB buffer for its internal io.Copy. - ec, err := s.setupPooledStdout(stdout) - if err != nil { - return err - } - earlyClosers = append(earlyClosers, ec) - if stdoutCloser != nil { - s.lateClosers = append(s.lateClosers, stdoutCloser) - } - } - } else if stdoutCloser != nil { - s.lateClosers = append(s.lateClosers, stdoutCloser) + if _, ok := r.(*os.File); ok { + // We can close our copy as soon as the command has started + earlyClosers = append(earlyClosers, stdin) + } else { + // We need to close `stdin`, but only after the command has finished + s.lateClosers = append(s.lateClosers, stdin) } closeEarlyClosers := func() { @@ -155,6 +158,26 @@ func (s *commandStage) Start( _ = s.closeLateClosers() } + if w != nil { + if f, ok := w.(*os.File); ok { + s.cmd.Stdout = f + earlyClosers = append(earlyClosers, stdout) + } else { + s.lateClosers = append(s.lateClosers, stdout) + // Route the copy through our own pipe so we can use a + // pooled buffer rather than letting exec.Cmd allocate a + // fresh 32KB buffer for its internal io.Copy. + ec, err := s.setupPooledStdout(w) + if err != nil { + cleanupOnStartFailure() + return err + } + earlyClosers = append(earlyClosers, ec) + } + } else { + s.lateClosers = append(s.lateClosers, stdout) + } + // If the caller hasn't arranged otherwise, read the command's // standard error into our `stderr` field: if s.cmd.Stderr == nil { diff --git a/pipe/command_stdout_fastpath_test.go b/pipe/command_stdout_fastpath_test.go index dc1c854..0de3f51 100644 --- a/pipe/command_stdout_fastpath_test.go +++ b/pipe/command_stdout_fastpath_test.go @@ -18,15 +18,15 @@ import ( // subprocess can detect when that fd is closed. func TestCommandStageStdoutFastPath(t *testing.T) { cases := []struct { - name string - closeStdout bool + name string + closingStdout bool }{ { - name: "raw *os.File with closeStdout", - closeStdout: true, + name: "raw *os.File with closing stdout", + closingStdout: true, }, { - name: "raw *os.File without closeStdout", + name: "raw *os.File with non-closing stdout", }, } for _, tc := range cases { @@ -42,7 +42,14 @@ func TestCommandStageStdoutFastPath(t *testing.T) { cmd := exec.Command("true") s := CommandStage("true", cmd).(*commandStage) - require.NoError(t, s.Start(ctx, StageOptions{}, nil, false, f, tc.closeStdout)) + var stdout *OutputStream + if tc.closingStdout { + stdout = ClosingOutput(f) + } else { + stdout = Output(f) + } + + require.NoError(t, s.Start(ctx, StageOptions{}, Input(nil), stdout)) t.Cleanup(func() { _ = s.Wait() }) gotFile, ok := s.cmd.Stdout.(*os.File) diff --git a/pipe/env_stage.go b/pipe/env_stage.go index 64dab22..3db2296 100644 --- a/pipe/env_stage.go +++ b/pipe/env_stage.go @@ -1,9 +1,6 @@ package pipe -import ( - "context" - "io" -) +import "context" // WithExtraEnv returns a Stage that adds env to the environment seen by inner. func WithExtraEnv(inner Stage, env []EnvVar) Stage { @@ -40,13 +37,12 @@ func (s *stageWithExtraEnv) Requirements() StageRequirements { func (s *stageWithExtraEnv) Start( ctx context.Context, opts StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, + stdin *InputStream, stdout *OutputStream, ) error { opts.Vars = append(opts.Vars[:len(opts.Vars):len(opts.Vars)], func(_ context.Context, vars []EnvVar) []EnvVar { return append(vars, s.env...) }) - return s.inner.Start(ctx, opts, stdin, closeStdin, stdout, closeStdout) + return s.inner.Start(ctx, opts, stdin, stdout) } func (s *stageWithExtraEnv) Wait() error { diff --git a/pipe/filter-error.go b/pipe/filter-error.go index 654796a..9bdee27 100644 --- a/pipe/filter-error.go +++ b/pipe/filter-error.go @@ -48,6 +48,12 @@ type ErrorMatcher func(err error) bool // the functions from the standard library that has the same signature // (e.g., `os.IsTimeout`), or some combination of these (e.g., // `AnyError(IsSIGPIPE, os.IsTimeout)`). +// +// `IgnoreError` only suppresses the error returned by the wrapped +// stage. If a producer ignores pipe errors because a later stage can +// stop reading early, the producer is still responsible for keeping any +// producer-owned state, metrics, cursors, or other side effects +// consistent before returning the ignored error. func IgnoreError(s Stage, em ErrorMatcher) Stage { return FilterError(s, func(err error) error { @@ -128,7 +134,11 @@ var ( // IsPipeError is an `ErrorMatcher` that matches a few different // errors that typically result if a stage writes to a subsequent - // stage that has stopped reading from its stdin. Use like + // stage that has stopped reading from its stdin. This is commonly + // useful with `IgnoreError` for stateless producer stages whose only + // job is writing output. Stateful producers should continue any + // producer-owned state updates needed for consistency before + // returning the pipe error for `IgnoreError` to suppress. Use like // // p.Add(IgnoreError(someStage, IsPipeError)) IsPipeError = AnyError(IsSIGPIPE, IsEPIPE, IsErrClosedPipe) diff --git a/pipe/function.go b/pipe/function.go index e5422f2..a80947f 100644 --- a/pipe/function.go +++ b/pipe/function.go @@ -17,6 +17,15 @@ import ( // Neither `stdin` nor `stdout` are necessarily buffered. If the // `StageFunc` requires buffering, it needs to arrange that itself. // +// A later stage can stop reading before this function has written all +// of its output. In that case, writes to `stdout` can fail with an +// error matched by `IsPipeError`. If the function only writes output +// and is otherwise stateless, callers can usually wrap the stage with +// `IgnoreError(stage, IsPipeError)`. If the function also updates +// producer-owned state, metrics, cursors, or other side effects that +// depend on how much output was produced, it should bring those side +// effects to a consistent point before returning the write error. +// // A `StageFunc` is run in a separate goroutine, so it must be careful // to synchronize any data access aside from reading and writing. type StageFunc func(ctx context.Context, env Env, stdin io.Reader, stdout io.Writer) error @@ -78,19 +87,15 @@ func (s *goStage) Requirements() StageRequirements { func (s *goStage) Start( ctx context.Context, opts StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, + stdin *InputStream, stdout *OutputStream, ) error { - stdinCloser := ownedCloser(stdin, closeStdin) - stdoutCloser := ownedCloser(stdout, closeStdout) - - r := stdin + r := stdin.Reader() if r == nil { // treat nil as empty input. r = strings.NewReader("") } - w := stdout + w := stdout.Writer() if w == nil { // treat nil output as /dev/null w = io.Discard @@ -103,15 +108,11 @@ func (s *goStage) Start( s.err = opts.PanicHandler(p) } } - if stdoutCloser != nil { - if err := stdoutCloser.Close(); err != nil && s.err == nil { - s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err) - } + if err := stdout.Close(); err != nil && s.err == nil { + s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err) } - if stdinCloser != nil { - if err := stdinCloser.Close(); err != nil && s.err == nil { - s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err) - } + if err := stdin.Close(); err != nil && s.err == nil { + s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err) } close(s.done) }() diff --git a/pipe/pipe_matching_test.go b/pipe/pipe_matching_test.go index 542e857..89c5ae7 100644 --- a/pipe/pipe_matching_test.go +++ b/pipe/pipe_matching_test.go @@ -49,14 +49,10 @@ func writeCloser() io.WriteCloser { } func newPipeSniffingStage( - stdinNeedsFile bool, stdinExpectation ioExpectation, - stdoutNeedsFile bool, stdoutExpectation ioExpectation, + req pipe.StageRequirements, stdinExpectation, stdoutExpectation ioExpectation, ) *pipeSniffingStage { return &pipeSniffingStage{ - requirements: pipe.StageRequirements{ - StdinNeedsFile: stdinNeedsFile, - StdoutNeedsFile: stdoutNeedsFile, - }, + requirements: req, expect: pipeExpectations{ stdin: stdinExpectation, stdout: stdoutExpectation, @@ -68,8 +64,11 @@ func newPipeSniffingFunc( stdinExpectation, stdoutExpectation ioExpectation, ) *pipeSniffingStage { return newPipeSniffingStage( - false, stdinExpectation, - false, stdoutExpectation, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + stdinExpectation, stdoutExpectation, ) } @@ -77,8 +76,11 @@ func newPipeSniffingCmd( stdinExpectation, stdoutExpectation ioExpectation, ) *pipeSniffingStage { return newPipeSniffingStage( - true, stdinExpectation, - true, stdoutExpectation, + pipe.StageRequirements{ + Stdin: pipe.StreamPreferFile, + Stdout: pipe.StreamPreferFile, + }, + stdinExpectation, stdoutExpectation, ) } @@ -104,17 +106,12 @@ func (s *pipeSniffingStage) Requirements() pipe.StageRequirements { func (s *pipeSniffingStage) Start( _ context.Context, _ pipe.StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, + stdin *pipe.InputStream, stdout *pipe.OutputStream, ) error { - s.stdin = stdin - if closeStdin { - _ = stdin.(io.Closer).Close() - } - s.stdout = stdout - if closeStdout { - _ = stdout.(io.Closer).Close() - } + s.stdin = stdin.Reader() + _ = stdin.Close() + s.stdout = stdout.Writer() + _ = stdout.Close() return nil } @@ -330,16 +327,25 @@ func TestPipeTypes(t *testing.T) { opts: []pipe.Option{}, stages: []pipe.Stage{ newPipeSniffingStage( - false, expectNil, - false, expectOther, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectNil, expectOther, ), newPipeSniffingStage( - false, expectOther, - true, expectFile, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamPreferFile, + }, + expectOther, expectFile, ), newPipeSniffingStage( - false, expectFile, - false, expectNil, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectFile, expectNil, ), }, }, @@ -348,16 +354,25 @@ func TestPipeTypes(t *testing.T) { opts: []pipe.Option{}, stages: []pipe.Stage{ newPipeSniffingStage( - false, expectNil, - false, expectFile, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectNil, expectFile, ), newPipeSniffingStage( - true, expectFile, - false, expectOther, + pipe.StageRequirements{ + Stdin: pipe.StreamPreferFile, + Stdout: pipe.StreamAcceptAny, + }, + expectFile, expectOther, ), newPipeSniffingStage( - false, expectOther, - false, expectNil, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectOther, expectNil, ), }, }, diff --git a/pipe/pipeline.go b/pipe/pipeline.go index 80db20a..44fedd1 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "os" "sync/atomic" ) @@ -55,12 +54,10 @@ type ContextValuesFunc func(context.Context) []EnvVar type Pipeline struct { env Env - stdin io.Reader - stdinCloser io.Closer - stdout io.Writer - stdoutCloser io.Closer - stages []Stage - cancel func() + stdin *InputStream + stdout *OutputStream + stages []Stage + cancel func() // Atomically written and read value, nonzero if the pipeline has // been started. This is only used for lifecycle sanity checks but @@ -99,28 +96,31 @@ func WithDir(dir string) Option { } } -// WithStdin assigns stdin to the first command in the pipeline. +// WithStdin assigns stdin to the first command in the pipeline. The +// caller retains ownership of stdin; the pipeline will not close it, +// even if `Start()` returns an error. func WithStdin(stdin io.Reader) Option { return func(p *Pipeline) { - p.stdin = stdin - p.stdinCloser = nil + p.stdin = Input(stdin) } } -// WithStdout assigns stdout to the last command in the pipeline. +// WithStdout assigns stdout to the last command in the pipeline. The +// caller retains ownership of stdout; the pipeline will not close it, +// even if `Start()` returns an error. func WithStdout(stdout io.Writer) Option { return func(p *Pipeline) { - p.stdout = stdout - p.stdoutCloser = nil + p.stdout = Output(stdout) } } // WithStdoutCloser assigns stdout to the last command in the -// pipeline, and closes stdout when it's done. +// pipeline, and closes stdout when the pipeline is done with it. The +// pipeline is responsible for closing stdout even if `Start()` returns +// an error. func WithStdoutCloser(stdout io.WriteCloser) Option { return func(p *Pipeline) { - p.stdout = stdout - p.stdoutCloser = stdout + p.stdout = ClosingOutput(stdout) } } @@ -217,54 +217,6 @@ func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) { } } -type stageStarter struct { - requirements StageRequirements - stdin io.Reader - stdinCloser io.Closer - stdout io.Writer - stdoutCloser io.Closer -} - -func (requirement StreamRequirement) validate() error { - switch requirement { - case StreamOptional, StreamForbidden: - return nil - default: - return fmt.Errorf("invalid stream requirement %d", requirement) - } -} - -func (requirements StageRequirements) validate(s Stage, stdinConnected, stdoutConnected bool) error { - if err := requirements.Stdin.validate(); err != nil { - return fmt.Errorf("stdin: %w", err) - } - if err := requirements.Stdout.validate(); err != nil { - return fmt.Errorf("stdout: %w", err) - } - if requirements.Stdin == StreamForbidden && stdinConnected { - return fmt.Errorf("stage %q forbids stdin, but stdin is connected", s.Name()) - } - if requirements.Stdout == StreamForbidden && stdoutConnected { - return fmt.Errorf("stage %q forbids stdout, but stdout is connected", s.Name()) - } - return nil -} - -func (p *Pipeline) abortBeforeStart(s Stage, err error) error { - if p.stdoutCloser != nil { - _ = p.stdoutCloser.Close() - } - p.cancel() - p.eventHandler(&Event{ - Command: s.Name(), - Msg: "failed to start pipeline stage", - Err: err, - }) - return fmt.Errorf( - "starting pipeline stage %q: %w", s.Name(), err, - ) -} - func (p *Pipeline) stageOptions() StageOptions { return StageOptions{Env: p.env, PanicHandler: p.panicHandler} } @@ -272,6 +224,13 @@ func (p *Pipeline) stageOptions() StageOptions { // Start starts the commands in the pipeline. If `Start()` exits // without an error, `Wait()` must also be called, to allow all // resources to be freed. +// +// If `Start()` returns an error, `Wait()` must not be called. Before +// returning an error, `Start()` cancels and waits for any stages that +// were started, closes any inter-stage pipes that the pipeline owns, +// and closes stdout if it was supplied with `WithStdoutCloser()`. +// Streams supplied with `WithStdin()` or `WithStdout()` remain owned by +// the caller and are not closed by the pipeline. func (p *Pipeline) Start(ctx context.Context) error { if p.hasStarted() { panic("attempt to start a pipeline that has already started") @@ -305,46 +264,68 @@ func (p *Pipeline) Start(ctx context.Context) error { // We need to decide how to start the stages, especially what // pipes to use to connect adjacent stages (`os.Pipe()` vs. // `io.Pipe()`) based on the two stages' requirements. - stageStarters := make([]stageStarter, len(p.stages)) + stageJoiners := make([]stageJoiner, len(p.stages)+1) + + // Arrange for the input of the 0th stage to come from `p.stdin`: + stageJoiners[0].nextStdin = p.stdin + + // Arrange for the output of the last stage to go to `p.stdout`: + stageJoiners[len(p.stages)].prevStdout = p.stdout + + // closePipes closes all of the streams that are currently stored + // in the joiners. This should be called if startup fails. As we + // call `Stage.Start()` and pass that method streams, we clear + // them from the corresponding joiners to avoid closing them + // twice. + closePipes := func() { + for _, sj := range stageJoiners { + _ = sj.closePipe() + } + } - // Collect information about each stage's type and requirements: + // Store the stages in the joiners, and verify that the stages' + // requirements are well-formed: for i, s := range p.stages { - stageStarters[i].requirements = s.Requirements() - - err := stageStarters[i].requirements.validate( - s, - i > 0 || p.stdin != nil, - i < len(p.stages)-1 || p.stdout != nil, - ) - if err != nil { - return p.abortBeforeStart(s, err) + // Make sure that the stage's requirements are well-formed: + requirements := s.Requirements() + if err := requirements.Stdin.Validate(); err != nil { + return fmt.Errorf("stdin: %w", err) } + if err := requirements.Stdout.Validate(); err != nil { + return fmt.Errorf("stdout: %w", err) + } + + stageJoiners[i].nextStage = s + stageJoiners[i].nextStageReq = requirements + stageJoiners[i+1].prevStage = s + stageJoiners[i+1].prevStageReq = requirements } - if p.stdin != nil { - // Arrange for the input of the 0th stage to come from - // `p.stdin`: - stageStarters[0].stdin = p.stdin - stageStarters[0].stdinCloser = p.stdinCloser + // Check that each of the stages' requirements are satisfiable: + for i := range stageJoiners { + if err := stageJoiners[i].validate(); err != nil { + closePipes() + return err + } } - if p.stdout != nil { - i := len(p.stages) - 1 - ss := &stageStarters[i] - ss.stdout = p.stdout - ss.stdoutCloser = p.stdoutCloser + // Create the "inner" pipes (i.e, all but the first and last + // `stageJoiners`): + for i := 1; i < len(stageJoiners)-1; i++ { + if err := stageJoiners[i].createPipe(); err != nil { + closePipes() + return err + } } - // Clean up any processes and pipes that have been created. `i` is - // the index of the stage that failed to start (whose output pipe - // has already been cleaned up if necessary). + // We're about to start up the stages, one by one. If something + // goes wrong during that process, this function should be called + // to kill any stages that have already been started and to close + // any pipes that have not yet been passed to a stage. `i` is the + // index of the stage that failed to start. If the stage already + // received its streams, it is responsible for closing them. abort := func(i int, err error) error { - // Close the pipe that the previous stage was writing to. - // That should cause it to exit even if it's not minding - // its context. - if stageStarters[i].stdinCloser != nil { - _ = stageStarters[i].stdinCloser.Close() - } + closePipes() // Kill and wait for any stages that have been started // already to finish: @@ -362,57 +343,19 @@ func (p *Pipeline) Start(ctx context.Context) error { ) } - // Loop over all but the last stage, starting them. By the time we - // get to a stage, its stdin will have already been determined, - // but we still need to figure out its stdout and set the stdin - // that will be used for the subsequent stage. - for i, s := range p.stages[:len(p.stages)-1] { - ss := &stageStarters[i] - nextSS := &stageStarters[i+1] - - // We need to generate a pipe pair for this stage to use - // to communicate with its successor: - if ss.requirements.StdoutNeedsFile || nextSS.requirements.StdinNeedsFile { - // Use an OS-level pipe for the communication: - nextStdin, stdout, err := os.Pipe() - if err != nil { - return abort(i, err) - } - nextSS.stdin = nextStdin - nextSS.stdinCloser = nextStdin - ss.stdout = stdout - ss.stdoutCloser = stdout - } else { - nextStdin, stdout := io.Pipe() - nextSS.stdin = nextStdin - nextSS.stdinCloser = nextStdin - ss.stdout = stdout - ss.stdoutCloser = stdout - } - if err := s.Start( - ctx, p.stageOptions(), - ss.stdin, ss.stdinCloser != nil, - ss.stdout, ss.stdoutCloser != nil, - ); err != nil { - nextSS.stdinCloser.Close() - ss.stdoutCloser.Close() - return abort(i, err) - } - } + // Loop over all of the stages, starting them in order. + for i, s := range p.stages { + prevSJ := &stageJoiners[i] + nextSJ := &stageJoiners[i+1] - // The last stage needs special handling, because its stdout - // doesn't need to flow into another stage (it's already set in - // `ss.stdout` if it's needed). - { - i := len(p.stages) - 1 - s := p.stages[i] - ss := &stageStarters[i] + err := s.Start(ctx, p.stageOptions(), prevSJ.nextStdin, nextSJ.prevStdout) - if err := s.Start( - ctx, p.stageOptions(), - ss.stdin, ss.stdinCloser != nil, - ss.stdout, ss.stdoutCloser != nil, - ); err != nil { + // Even if that stage failed to start, we are no longer + // responsible for closing its streams: + prevSJ.nextStdin = nil + nextSJ.prevStdout = nil + + if err != nil { return abort(i, err) } } @@ -422,8 +365,7 @@ func (p *Pipeline) Start(ctx context.Context) error { func (p *Pipeline) Output(ctx context.Context) ([]byte, error) { var buf bytes.Buffer - p.stdout = &buf - p.stdoutCloser = nil + p.stdout = Output(&buf) err := p.Run(ctx) return buf.Bytes(), err } @@ -518,7 +460,9 @@ func (p *Pipeline) Wait() error { return nil } -// Run starts and waits for the commands in the pipeline. +// Run starts and waits for the commands in the pipeline. If startup +// fails, it returns the `Start()` error after `Start()` has performed +// its failure cleanup. func (p *Pipeline) Run(ctx context.Context) error { if err := p.Start(ctx); err != nil { return err diff --git a/pipe/pipeline_test.go b/pipe/pipeline_test.go index be54b5f..82dc2cd 100644 --- a/pipe/pipeline_test.go +++ b/pipe/pipeline_test.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "os" - "runtime" "strconv" "strings" "testing" @@ -84,13 +83,30 @@ func TestPipelineFirstStageFailsToStart(t *testing.T) { t.Parallel() ctx := context.Background() startErr := errors.New("foo") + stdout := &closeTrackingWriter{} - p := pipe.New() + p := pipe.New(pipe.WithStdoutCloser(stdout)) p.Add( ErrorStartingStage{startErr}, ErrorStartingStage{errors.New("this error should never happen")}, ) assert.ErrorIs(t, p.Run(ctx), startErr) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") +} + +func TestPipelineFirstStageFailsToStartClosesStdoutCloser(t *testing.T) { + t.Parallel() + ctx := context.Background() + startErr := errors.New("foo") + stdout := &closeTrackingWriter{} + + p := pipe.New(pipe.WithStdoutCloser(stdout)) + p.Add( + ErrorStartingStage{startErr}, + pipe.Command("this-stage-should-not-start"), + ) + assert.ErrorIs(t, p.Run(ctx), startErr) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") } func TestPipelineSecondStageFailsToStart(t *testing.T) { @@ -106,6 +122,22 @@ func TestPipelineSecondStageFailsToStart(t *testing.T) { assert.ErrorIs(t, p.Run(ctx), startErr) } +func TestPipelineMiddleStageFailsToStartClosesUnstartedStdoutCloser(t *testing.T) { + t.Parallel() + ctx := context.Background() + startErr := errors.New("foo") + stdout := &closeTrackingWriter{} + + p := pipe.New(pipe.WithStdoutCloser(stdout)) + p.Add( + seqFunction(20000), + ErrorStartingStage{startErr}, + ErrorStartingStage{errors.New("this error should never happen")}, + ) + assert.ErrorIs(t, p.Run(ctx), startErr) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") +} + func TestPipelineSingleCommandOutput(t *testing.T) { t.Parallel() ctx := context.Background() @@ -271,10 +303,6 @@ func TestIOPipePipelineReadFromSlowly(t *testing.T) { } func TestPipelineReadFromSlowly2(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'seq' unavailable") - } - t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -370,10 +398,6 @@ func TestPipelineStderr(t *testing.T) { } func TestPipelineInterrupted(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'sleep' unavailable") - } - t.Parallel() stdout := &bytes.Buffer{} @@ -392,10 +416,6 @@ func TestPipelineInterrupted(t *testing.T) { } func TestPipelineCanceled(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'sleep' unavailable") - } - t.Parallel() stdout := &bytes.Buffer{} @@ -419,10 +439,6 @@ func TestPipelineCanceled(t *testing.T) { // unread output in this case *does fit* within the OS-level pipe // buffer. func TestLittleEPIPE(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'sleep' unavailable") - } - t.Parallel() p := pipe.New() @@ -442,10 +458,6 @@ func TestLittleEPIPE(t *testing.T) { // amount of unread output in this case *does not fit* within the // OS-level pipe buffer. func TestBigEPIPE(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'seq' unavailable") - } - t.Parallel() p := pipe.New() @@ -465,10 +477,6 @@ func TestBigEPIPE(t *testing.T) { // amount of unread output in this case *does not fit* within the // OS-level pipe buffer. func TestIgnoredSIGPIPE(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'seq' unavailable") - } - t.Parallel() p := pipe.New() @@ -484,6 +492,71 @@ func TestIgnoredSIGPIPE(t *testing.T) { assert.EqualValues(t, "foo\n", out) } +func TestGoProducerSeesPipeErrorWhenCommandStopsReading(t *testing.T) { + t.Parallel() + + p := pipe.New() + p.Add( + pipe.Function( + "write-to-closed-command", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + w := bufio.NewWriter(stdout) + for i := 0; i < 100000; i++ { + if _, err := fmt.Fprintln(w, i); err != nil { + return err + } + } + return w.Flush() + }, + ), + pipe.Command("true"), + ) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + err := p.Run(ctx) + require.Error(t, err) + assert.True(t, pipe.IsPipeError(err), "expected a pipe error, got %v", err) +} + +func TestIgnoredPipeErrorStillAllowsStatefulProducerToFinish(t *testing.T) { + t.Parallel() + + const total = 100000 + processed := 0 + p := pipe.New() + p.Add( + pipe.IgnoreError( + pipe.Function( + "stateful-producer", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + w := bufio.NewWriter(stdout) + var writeErr error + for i := 0; i < total; i++ { + processed++ + if writeErr == nil { + if _, err := fmt.Fprintln(w, i); err != nil { + writeErr = err + } + } + } + if writeErr == nil { + writeErr = w.Flush() + } + return writeErr + }, + ), + pipe.IsPipeError, + ), + pipe.Command("true"), + ) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + require.NoError(t, p.Run(ctx)) + assert.Equal(t, total, processed) +} + func TestFunction(t *testing.T) { t.Parallel() ctx := context.Background() @@ -600,15 +673,10 @@ func (s ErrorStartingStage) Requirements() pipe.StageRequirements { func (s ErrorStartingStage) Start( _ context.Context, _ pipe.StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, + stdin *pipe.InputStream, stdout *pipe.OutputStream, ) error { - if closeStdin { - _ = stdin.(io.Closer).Close() - } - if closeStdout { - _ = stdout.(io.Closer).Close() - } + _ = stdin.Close() + _ = stdout.Close() return s.err } @@ -632,18 +700,13 @@ func (s requirementStage) Requirements() pipe.StageRequirements { func (s requirementStage) Start( _ context.Context, _ pipe.StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, + stdin *pipe.InputStream, stdout *pipe.OutputStream, ) error { if s.started != nil { *s.started = true } - if closeStdin { - _ = stdin.(io.Closer).Close() - } - if closeStdout { - _ = stdout.(io.Closer).Close() - } + _ = stdin.Close() + _ = stdout.Close() return nil } diff --git a/pipe/stage.go b/pipe/stage.go index e35f428..131a861 100644 --- a/pipe/stage.go +++ b/pipe/stage.go @@ -2,80 +2,36 @@ package pipe import ( "context" - "fmt" - "io" ) // Stage is an element of a `Pipeline`. It reads from standard input // and writes to standard output. // -// Who closes stdin and stdout? +// # Who closes stdin and stdout? // -// A `Stage` as a whole is responsible for closing its end of stdin -// and stdout (assuming that `Start()` returns successfully) if the -// corresponding close flag passed to `Start()` is true. Its doing so -// tells the previous/next stage that it is done reading/writing data, -// which can affect their behavior. Therefore, it should close each -// one as soon as it is done with it. If the caller wants to suppress -// the closing of stdin/stdout, it passes a false close flag. +// A `Stage` is responsible for calling `Close()` on the +// `InputStream`/`OutputStream` that represent its stdin and stdout as +// soon as it doesn't need them anymore. That responsibility begins as +// soon as the stage's `Start()` method is called, and applies +// regardless of whether `Start()` returns an error. It must close the +// streams before its `Wait()` method returns. The caller must not +// close the streams after calling `Start()`. // -// How this should be done depends on whether stdin/stdout are of type -// `*os.File`. -// -// If a stage is an external command, then the subprocess ultimately -// needs its own copies of `*os.File` file descriptors for its stdin -// and stdout. The external command will "always" [1] close those when -// it exits. -// -// If the stage is an external command and one of the arguments is an -// `*os.File`, then it can set the corresponding field of `exec.Cmd` -// to that argument directly. This has the result that `exec.Cmd` -// duplicates that file descriptor and passes the dup to the -// subprocess. Therefore, the stage must close its copy of that -// argument as soon as the external command has started, because the -// external command will keep its own copy open as long as necessary -// (and no longer!). It should use roughly the following sequence: -// -// cmd.Stdin = f // Similarly for stdout -// cmd.Start(…) -// f.Close() // close our copy -// cmd.Wait() -// -// If the stage is an external command and one of its arguments is not -// an `*os.File`, then `exec.Cmd` will take care of creating an -// `os.Pipe()`, copying from the provided argument in/out of the pipe, -// and eventually closing both ends of the pipe. The stage must close -// the argument itself, but only _after_ the external command has -// finished, like so: -// -// cmd.Stdin = r // Similarly for stdout -// cmd.Start(…) -// cmd.Wait() -// r.Close() -// -// If the stage is a Go function, then it holds the only copy of -// stdin/stdout, so it must wait until the function is done before -// closing them (regardless of their underlying type, like so: -// -// go func() { -// f(…, stdin, stdout) -// stdin.Close() -// stdout.Close() -// }() +// Closing stdin/stdout tells the previous/next stage that this stage +// is done reading/writing data, which can affect their behavior. +// Therefore, it is important for a stage to close each one as soon as +// it is done with it. // // From the point of view of the pipeline as a whole, if stdin is // provided by the user (`WithStdin()`), then we don't want the first -// stage to close it at all, whether it's an `*os.File` or not. The -// pipeline communicates this by passing closeStdin=false when it -// starts that stage. For stdout, it depends on whether the user -// supplied it using `WithStdout()` or `WithStdoutCloser()`. -// -// [1] It's theoretically possible for a command to pass the open file -// descriptor to another, longer-lived process, in which case the -// file descriptor wouldn't necessarily get closed when the -// command finishes. But that's ill-behaved in a command that is -// being used in a pipeline, so we'll ignore that possibility. - +// stage to close it at all. This is arranged by by passing a +// non-closing `InputStream` when it starts that stage. For stdout, it +// depends on whether the user supplied it using `WithStdout()` or +// `WithStdoutCloser()`, and in the former case provides the last +// stage with a non-closing `OutputStream`. Calling `Close()` on a +// non-closing stream (or even on a nil stream) is a NOP, so the +// `Stage` can always call `Close()` and doesn't have to worry about +// whether a stdin/stdout stream is non-closing. type Stage interface { // Name returns the name of the stage. Name() string @@ -86,20 +42,20 @@ type Stage interface { // Start starts the stage in the background, in the environment // described by `opts.Env`, using `stdin` to provide its input and - // `stdout` to collect its output. (`stdin`/`stdout` might be set - // to `nil` if the stage is to receive no input, which might be the - // case for the first/last stage in a pipeline.) If `closeStdin` or - // `closeStdout` is true, the stage is responsible for closing the - // corresponding stream. A stream with a true close flag must - // implement `io.Closer`. See the `Stage` type comment for more + // `stdout` to collect its output. (`stdin.Reader()` or + // `stdout.Writer()` might be `nil` if the stage is to receive no + // input or produce no output, which might be the case for the + // first/last stage in a pipeline.) The stage is responsible for + // calling `stdin.Close()` and `stdout.Close()`, even if `Start()` + // returns an error. See the `Stage` type comment for more // information about responsibility for closing stdin and stdout. // // If `Start()` returns without an error, `Wait()` must also be - // called, to allow all resources to be freed. + // called, to allow all resources to be freed. If `Start()` returns + // an error, `Wait()` must not be called. Start( ctx context.Context, opts StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, + stdin *InputStream, stdout *OutputStream, ) error // Wait waits for the stage to be done, either because it has @@ -108,17 +64,6 @@ type Stage interface { Wait() error } -func ownedCloser(stream any, owned bool) io.Closer { - if !owned { - return nil - } - closer, ok := stream.(io.Closer) - if !ok { - panic(fmt.Sprintf("stage asked to close %T, which does not implement io.Closer", stream)) - } - return closer -} - // StageOptions carries everything (other than `ctx`, `stdin`, and // `stdout`) that a pipeline passes to `Stage.Start`. type StageOptions struct { @@ -136,25 +81,11 @@ type StageOptions struct { // StagePanicHandler is a function that handles panics in a pipeline's stages. type StagePanicHandler func(p any) error -type StreamRequirement int - -const ( - // StreamOptional means the stream may be connected or nil. - StreamOptional StreamRequirement = iota - - // StreamForbidden means the stream must be nil. - StreamForbidden -) - -// StageRequirements describes what a Stage needs from the streams connected to -// its stdin and stdout. The zero value is correct for stages that are happy -// with arbitrary io.Reader/io.Writer streams, such as Function stages. +// StageRequirements describes what a Stage needs from the streams +// connected to its stdin and stdout. The zero value is correct for +// stages that are happy with arbitrary io.Reader/io.Writer streams, +// such as Function stages. type StageRequirements struct { Stdin StreamRequirement Stdout StreamRequirement - - // {Stdin,Stdout}NeedsFile indicate that, if stdio is connected, the - // stage requires it to be backed by an *os.File (a real file descriptor) - StdinNeedsFile bool - StdoutNeedsFile bool } diff --git a/pipe/stage_joiner.go b/pipe/stage_joiner.go new file mode 100644 index 0000000..03569cb --- /dev/null +++ b/pipe/stage_joiner.go @@ -0,0 +1,133 @@ +package pipe + +import ( + "errors" + "fmt" + "io" + "os" +) + +// stageJoiner is a helper type that helps join two adjacent stages +// together. stageJoiners[i] tells how to connect stage `i-1` to stage +// `i`. From the point of view of stages, `stageJoiners[i].nextStdin` +// and `stageJoiners[i+1].prevStdout` are the input and output +// streams, respectively, of `stage[i]`. The first and last elements +// of `stageJoiners` manage `p.stdin` and `p.stdout`, respectively. +// Schematically, the data flows through like this: +// +// p.stdin == stageJoiners[0].nextStdin → +// stage[0] → +// stageJoiners[1].prevStdout → stageJoiners[1].nextStdin → +// stage[1] → +// stageJoiners[2].prevStdout → stageJoiners[2].nextStdin → +// stage[2] → +// ... → +// stageJoiners[i].prevStdout → stageJoiners[i].nextStdin → +// stage[i] → +// stageJoiners[i+1].prevStdout → stageJoiners[i+1].nextStdin → +// ... → +// stageJoiners[len(stages)-1].prevStdout → stageJoiners[len(stages)-1].nextStdin → +// stage[len(stages)-1] → +// stageJoiners[len(stages)].prevStdout == p.stdout +// +// In pseudo-Shell notation, the stages are run like this: +// +// stage[0] stageJoiners[1].prevStdout +// stage[1] stageJoiners[2].prevStdout +// stage[2] stageJoiners[3].prevStdout +// ... +// stage[i] stageJoiners[i].prevStdout +// ... +// stage[len(stages)-1] p.stdout +type stageJoiner struct { + // prevStage holds the stage that needs to write to the pipe. + prevStage Stage + + // prevStageReq caches `prevStage.Requirements()` so that it + // doesn't have to be recomputed. It is the zero value if + // `prevStage` is nil. + prevStageReq StageRequirements + + // prevStdout will be used as the stdout of `prevStage`. It is + // usually the "write" end of the `(nextStdin, prevStdout)` pipe + // pair, with the connected pipe ends in the same `stageJoiner` + // instance. + prevStdout *OutputStream + + // nextStage holds the stage that needs to read from the pipe. + nextStage Stage + + // nextStageReq caches `nextStage.Requirements()` so that it + // doesn't have to be recomputed. It is the zero value if + // `nextStage` is nil. + nextStageReq StageRequirements + + // nextStdin will be used as the stdin of `nextStage`. It is + // usually the "read" end of the `(nextStdin, prevStdout)` pipe + // pair. + nextStdin *InputStream +} + +// needFilePipe returns `true` if the pipe that joins the two adjacent +// stages should be an `os.Pipe()` rather than an `io.Pipe()`. +func (sj *stageJoiner) needFilePipe() bool { + return sj.prevStageReq.Stdout == StreamPreferFile || + sj.nextStageReq.Stdin == StreamPreferFile +} + +func (sj *stageJoiner) createPipe() error { + var r io.ReadCloser + var w io.WriteCloser + if sj.needFilePipe() { + var err error + r, w, err = os.Pipe() + if err != nil { + return fmt.Errorf("creating os.Pipe: %w", err) + } + } else { + r, w = io.Pipe() + } + + sj.prevStdout = ClosingOutput(w) + sj.nextStdin = ClosingInput(r) + + return nil +} + +// closePipe closes both ends of the pipe that was allocated by +// `createPipe()`. This should only be called if the corresponding +// stage's `Start()` method was never called (otherwise the stage is +// responsible for closing its stdin and stdout). +func (sj *stageJoiner) closePipe() error { + return errors.Join( + sj.prevStdout.Close(), + sj.nextStdin.Close(), + ) +} + +// validate verifies that the adjacent stages' stream requirements are +// satisfiable, in particular that a stage that forbids its stdin or +// stdout is not connected to anything. +func (sj *stageJoiner) validate() error { + // `prevStage`'s stdout is connected if there is a `nextStage` to + // consume it (in which case an inner pipe will be created) or if + // a stream (`p.stdout`) has already been stored in `prevStdout`. + if sj.prevStage != nil && sj.prevStageReq.Stdout == StreamForbidden && + (sj.nextStage != nil || sj.prevStdout != nil) { + return fmt.Errorf( + "stage %q forbids stdout, but stdout is connected", sj.prevStage.Name(), + ) + } + + // `nextStage`'s stdin is connected if there is a `prevStage` to + // produce it (in which case an inner pipe will be created) or if + // a stream (`p.stdin`) has already been stored in `nextStdin`. + if sj.nextStage != nil && sj.nextStageReq.Stdin == StreamForbidden && + (sj.prevStage != nil || sj.nextStdin != nil) { + return fmt.Errorf( + "stage %q forbids stdin, but stdin is connected", sj.nextStage.Name(), + ) + } + + return nil +} diff --git a/pipe/stream_requirement.go b/pipe/stream_requirement.go new file mode 100644 index 0000000..ddff829 --- /dev/null +++ b/pipe/stream_requirement.go @@ -0,0 +1,38 @@ +package pipe + +import "fmt" + +// StreamRequirement describes a `Stage`'s requirement for its stdin +// or stdout, namely whether it can be anything, whether it should +// preferably be an `*os.File`, or whether it must be `nil`. The zero +// value `StreamAcceptAny` is a valid value that indicates that the +// stage has no particular requirements or preferences for its +// stdin/stdout, such as a typical `Function` stage. +type StreamRequirement int + +const ( + // StreamAcceptAny indicates that the stage hasn't declared what + // kind of stream it requires, maybe even `nil`. + StreamAcceptAny StreamRequirement = iota + + // StreamPreferFile indicates that the stage prefers the + // corresponding stream to be backed by an `*os.File` (a real file + // descriptor), but it can work with any io.Reader/io.Writer. + StreamPreferFile + + // StreamForbidden indicates that the stage requires the + // corresponding stream to be nil. It won't read/write the stream + // or close it. + StreamForbidden +) + +// Validate checks that `req` has a valid value and returns an error +// otherwise. +func (requirement StreamRequirement) Validate() error { + switch requirement { + case StreamAcceptAny, StreamPreferFile, StreamForbidden: + return nil + default: + return fmt.Errorf("invalid stream requirement %d", requirement) + } +} diff --git a/pipe/streams.go b/pipe/streams.go new file mode 100644 index 0000000..8ddd578 --- /dev/null +++ b/pipe/streams.go @@ -0,0 +1,142 @@ +package pipe + +import ( + "io" + "sync" +) + +// InputStream represents `stdin` for a stage, which might or might +// not need to be closed when the stage is done with it. It usually +// holds an `io.Reader`, which can be retrieved using `Reader()`. Its +// `Close()` method closes the reader if necessary (i.e., if the +// `InputStream` was constructed using `ClosingInput()`. The +// `Close()` method is idempotent. +// +// A nil `*InputStream` is a valid value. Its `Reader()` method +// returns `nil` and `Close()` does nothing successfully. +// +// It might seem like `InputStream` should implement `io.Reader` +// itself. But we want to avoid hiding the dynamic type of the +// `io.Reader` that is being used as the stdin of a pipeline. That +// object might be of a type that is subject to optimizations that +// aren't available for a generic `io.Reader`. For example, it might +// be an `*os.File` (which can be passed directly to subcommands or to +// `splice(2)`), or it might implement `io.WriterTo`. +type InputStream struct { + reader io.Reader + + // once is used to ensure that `Close()` is only called once. + once sync.Once + + // closer is set to `nil` after the first call to `Close()`. + closer io.Closer + + // closeErr is set to the error returned by the first call to + // `Close()`, and returned from that and any subsequent calls to + // `Close()`. + closeErr error +} + +// The stage may read from r but must not close it. +func Input(r io.Reader) *InputStream { + return &InputStream{reader: r} +} + +// The stage is responsible for closing r. +func ClosingInput(r io.ReadCloser) *InputStream { + return &InputStream{reader: r, closer: r} +} + +func (s *InputStream) Reader() io.Reader { + if s == nil { + return nil + } + return s.reader +} + +// Close closes the underlying reader if necessary. If `s` was +// constructed using `ClosingInput()`, then close the `io.ReadCloser` +// that was passed to that function. If `s` is `nil` or was +// constructed using `Input()`, then do nothing successfully. +func (s *InputStream) Close() error { + if s == nil { + return nil + } + + s.once.Do(func() { + if s.closer != nil { + s.closeErr = s.closer.Close() + s.closer = nil + } + }) + + return s.closeErr +} + +// OutputStream represents `stdout` for a stage, which might or might +// not need to be closed when the stage is done with it. It usually +// holds an `io.Writer`, which can be retrieved using `Writer()`. Its +// `Close()` method closes the writer if necessary (i.e., if the +// `OutputStream` was constructed using `ClosingOutput()`. The +// `Close()` method is idempotent. +// +// A nil `*OutputStream` is a valid value. Its `Writer()` method +// returns `nil` and `Close()` does nothing successfully. +// +// It might seem like `OutputStream` should implement `io.Writer` +// itself. But we want to avoid hiding the dynamic type of the +// `io.Writer` that is being used as the stdout of a pipeline. That +// object might be of a type that is subject to optimizations that +// aren't available for a generic `io.Writer`. For example, it might +// be an `*os.File` (which can be passed directly to subcommands or to +// `splice(2)`), or it might implement `io.ReaderFrom`. +type OutputStream struct { + writer io.Writer + + // once is used to ensure that `Close()` is only called once. + once sync.Once + + // closer is set to `nil` after the first call to `Close()`. + closer io.Closer + + // closeErr is set to the error returned by the first call to + // `Close()`, and returned from that and any subsequent calls to + // `Close()`. + closeErr error +} + +// The stage may write to w but must not close it. +func Output(w io.Writer) *OutputStream { + return &OutputStream{writer: w} +} + +// The stage is responsible for closing w. +func ClosingOutput(w io.WriteCloser) *OutputStream { + return &OutputStream{writer: w, closer: w} +} + +func (s *OutputStream) Writer() io.Writer { + if s == nil { + return nil + } + return s.writer +} + +// Close closes the underlying writer if necessary. If `s` was +// constructed using `ClosingOutput()`, then close the +// `io.WriteCloser` that was passed to that function. If `s` is `nil` +// or was constructed using `Output()`, then do nothing successfully. +func (s *OutputStream) Close() error { + if s == nil { + return nil + } + + s.once.Do(func() { + if s.closer != nil { + s.closeErr = s.closer.Close() + s.closer = nil + } + }) + + return s.closeErr +}