diff --git a/pipe/function.go b/pipe/function.go index a80947f..2970a12 100644 --- a/pipe/function.go +++ b/pipe/function.go @@ -33,20 +33,32 @@ type StageFunc func(ctx context.Context, env Env, stdin io.Reader, stdout io.Wri // FunctionOption configures a Function stage. type FunctionOption func(*goStage) +// WithStdinRequirement returns a FunctionOption declaring the stage's stdin +// requirement. +func WithStdinRequirement(requirement StreamRequirement) FunctionOption { + return func(s *goStage) { + s.requirements.Stdin = requirement + } +} + +// WithStdoutRequirement returns a FunctionOption declaring the stage's stdout +// requirement. +func WithStdoutRequirement(requirement StreamRequirement) FunctionOption { + return func(s *goStage) { + s.requirements.Stdout = requirement + } +} + // ForbidStdin returns a FunctionOption declaring that the stage must not be // connected to stdin. func ForbidStdin() FunctionOption { - return func(s *goStage) { - s.requirements.Stdin = StreamForbidden - } + return WithStdinRequirement(StreamForbidden) } // ForbidStdout returns a FunctionOption declaring that the stage must not be // connected to stdout. func ForbidStdout() FunctionOption { - return func(s *goStage) { - s.requirements.Stdout = StreamForbidden - } + return WithStdoutRequirement(StreamForbidden) } // Function returns a pipeline `Stage` that will run a `StageFunc` in diff --git a/pipe/pipeline.go b/pipe/pipeline.go index 878cdc6..3078c10 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -99,6 +99,16 @@ func WithDir(dir string) Option { // WithStdin assigns stdin to the first command in the pipeline. The // caller retains ownership of stdin; the pipeline will not close it, // even if `Start()` returns an error. +// +// If the first stage is a `Command` and stdin is not an `*os.File`, +// `exec.Cmd` has to copy stdin through an internal goroutine, and +// `Cmd.Wait()` waits for that copy to finish. This is fine for bounded +// readers such as `strings.Reader` and `bytes.Reader`, and for +// `*os.File` values, which are passed to the command directly. But a +// borrowed, non-file reader that can block forever can also block the +// pipeline forever if the command exits without consuming all of its +// stdin. See `TestPipelineIOPipeStdinThatIsNeverClosed` for the known +// limitation. func WithStdin(stdin io.Reader) Option { return func(p *Pipeline) { p.stdin = Input(stdin) @@ -238,6 +248,12 @@ func (p *Pipeline) Start(ctx context.Context) error { atomic.StoreUint32(&p.started, 1) ctx, p.cancel = context.WithCancel(ctx) + startedOK := false + defer func() { + if !startedOK { + p.cancel() + } + }() if len(p.stages) == 0 { if p.stdout == nil { @@ -290,11 +306,15 @@ func (p *Pipeline) Start(ctx context.Context) error { requirements := s.Requirements() if err := requirements.Stdin.Validate(); err != nil { closePipes() - return fmt.Errorf("stdin: %w", err) + return fmt.Errorf( + "stage %q has invalid stdin requirement: %w", s.Name(), err, + ) } if err := requirements.Stdout.Validate(); err != nil { closePipes() - return fmt.Errorf("stdout: %w", err) + return fmt.Errorf( + "stage %q has invalid stdout requirement: %w", s.Name(), err, + ) } stageJoiners[i].nextStage = s @@ -362,10 +382,19 @@ func (p *Pipeline) Start(ctx context.Context) error { } } + startedOK = true return nil } func (p *Pipeline) Output(ctx context.Context) ([]byte, error) { + if p.hasStarted() { + panic("attempt to get output from a pipeline that has already started") + } + + if err := p.stdout.Close(); err != nil { + return nil, fmt.Errorf("closing previous stdout: %w", err) + } + var buf bytes.Buffer p.stdout = Output(&buf) err := p.Run(ctx) diff --git a/pipe/pipeline_test.go b/pipe/pipeline_test.go index e956964..fe73eea 100644 --- a/pipe/pipeline_test.go +++ b/pipe/pipeline_test.go @@ -54,6 +54,23 @@ func TestPipelineEmptyOutput(t *testing.T) { } } +func TestPipelineOutputClosesConfiguredStdoutCloser(t *testing.T) { + t.Parallel() + ctx := context.Background() + stdout := &closeTrackingWriter{} + p := pipe.New( + pipe.WithStdin(strings.NewReader("hello world\n")), + pipe.WithStdoutCloser(stdout), + ) + + out, err := p.Output(ctx) + if assert.NoError(t, err) { + assert.Equal(t, "hello world\n", string(out)) + assert.Equal(t, "", stdout.buf.String()) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") + } +} + func TestPipelineEmptyWithStdoutCloser(t *testing.T) { t.Parallel() ctx := context.Background() @@ -951,6 +968,24 @@ func TestFunctionOptionsForbidStreams(t *testing.T) { }) } +func TestFunctionOptionsSetStreamRequirements(t *testing.T) { + t.Parallel() + + stage := pipe.Function( + "file-preferring", + func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { + return nil + }, + pipe.WithStdinRequirement(pipe.StreamPreferFile), + pipe.WithStdoutRequirement(pipe.StreamPreferFile), + ) + + assert.Equal(t, pipe.StageRequirements{ + Stdin: pipe.StreamPreferFile, + Stdout: pipe.StreamPreferFile, + }, stage.Requirements()) +} + func TestStreamForbiddenStdin(t *testing.T) { t.Parallel() ctx := context.Background() @@ -1024,7 +1059,10 @@ func TestInvalidStreamRequirements(t *testing.T) { Stdin: pipe.StreamRequirement(123), }, }) - require.ErrorContains(t, p.Run(ctx), `stdin: invalid stream requirement 123`) + require.ErrorContains( + t, p.Run(ctx), + `stage "source" has invalid stdin requirement: invalid stream requirement 123`, + ) assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") }) @@ -1038,7 +1076,10 @@ func TestInvalidStreamRequirements(t *testing.T) { Stdout: pipe.StreamRequirement(123), }, }) - require.ErrorContains(t, p.Run(ctx), `stdout: invalid stream requirement 123`) + require.ErrorContains( + t, p.Run(ctx), + `stage "sink" has invalid stdout requirement: invalid stream requirement 123`, + ) assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") }) } @@ -1071,7 +1112,10 @@ func TestInvalidStreamRequirement(t *testing.T) { Stdin: pipe.StreamRequirement(99), }, }) - require.ErrorContains(t, p.Run(ctx), `stdin: invalid stream requirement 99`) + require.ErrorContains( + t, p.Run(ctx), + `stage "invalid" has invalid stdin requirement: invalid stream requirement 99`, + ) } func TestFunctionNoInput(t *testing.T) {