Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions pipe/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,32 @@ type StageFunc func(ctx context.Context, env Env, stdin io.Reader, stdout io.Wri
// FunctionOption configures a Function stage.
type FunctionOption func(*goStage)

// WithStdinRequirement returns a FunctionOption declaring the stage's stdin
// requirement.
func WithStdinRequirement(requirement StreamRequirement) FunctionOption {
return func(s *goStage) {
s.requirements.Stdin = requirement
}
}

// WithStdoutRequirement returns a FunctionOption declaring the stage's stdout
// requirement.
func WithStdoutRequirement(requirement StreamRequirement) FunctionOption {
return func(s *goStage) {
s.requirements.Stdout = requirement
}
}

// ForbidStdin returns a FunctionOption declaring that the stage must not be
// connected to stdin.
func ForbidStdin() FunctionOption {
return func(s *goStage) {
s.requirements.Stdin = StreamForbidden
}
return WithStdinRequirement(StreamForbidden)
}

// ForbidStdout returns a FunctionOption declaring that the stage must not be
// connected to stdout.
func ForbidStdout() FunctionOption {
return func(s *goStage) {
s.requirements.Stdout = StreamForbidden
}
return WithStdoutRequirement(StreamForbidden)
}

// Function returns a pipeline `Stage` that will run a `StageFunc` in
Expand Down
33 changes: 31 additions & 2 deletions pipe/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ func WithDir(dir string) Option {
// WithStdin assigns stdin to the first command in the pipeline. The
// caller retains ownership of stdin; the pipeline will not close it,
// even if `Start()` returns an error.
//
// If the first stage is a `Command` and stdin is not an `*os.File`,
// `exec.Cmd` has to copy stdin through an internal goroutine, and
// `Cmd.Wait()` waits for that copy to finish. This is fine for bounded
// readers such as `strings.Reader` and `bytes.Reader`, and for
// `*os.File` values, which are passed to the command directly. But a
// borrowed, non-file reader that can block forever can also block the
// pipeline forever if the command exits without consuming all of its
// stdin. See `TestPipelineIOPipeStdinThatIsNeverClosed` for the known
// limitation.
func WithStdin(stdin io.Reader) Option {
return func(p *Pipeline) {
p.stdin = Input(stdin)
Expand Down Expand Up @@ -238,6 +248,12 @@ func (p *Pipeline) Start(ctx context.Context) error {

atomic.StoreUint32(&p.started, 1)
ctx, p.cancel = context.WithCancel(ctx)
startedOK := false
defer func() {
if !startedOK {
p.cancel()
}
Comment thread
znull marked this conversation as resolved.
}()

if len(p.stages) == 0 {
if p.stdout == nil {
Expand Down Expand Up @@ -290,11 +306,15 @@ func (p *Pipeline) Start(ctx context.Context) error {
requirements := s.Requirements()
if err := requirements.Stdin.Validate(); err != nil {
closePipes()
return fmt.Errorf("stdin: %w", err)
return fmt.Errorf(
"stage %q has invalid stdin requirement: %w", s.Name(), err,
)
}
if err := requirements.Stdout.Validate(); err != nil {
closePipes()
return fmt.Errorf("stdout: %w", err)
return fmt.Errorf(
"stage %q has invalid stdout requirement: %w", s.Name(), err,
)
}

stageJoiners[i].nextStage = s
Expand Down Expand Up @@ -362,10 +382,19 @@ func (p *Pipeline) Start(ctx context.Context) error {
}
}

startedOK = true
return nil
}

func (p *Pipeline) Output(ctx context.Context) ([]byte, error) {
if p.hasStarted() {
panic("attempt to get output from a pipeline that has already started")
}

if err := p.stdout.Close(); err != nil {
return nil, fmt.Errorf("closing previous stdout: %w", err)
}

Comment on lines +394 to +397

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISTM that trying to call Output() after having already set stdout would always be a programming error: never a sensible thing to do, likely to happen every time the code is run (i.e., high chance of detecting it in CI), and possibly tricky to figure out if this kind of caller error slipped through. What would you think about making this case panic() instead of silently covering up the mistake?

For that matter, using WithStdout() or WithStderr()/WithStderrCloser() more than once could also panic for the same reason.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd support that, yeah. Making things impossible to misuse is good.

var buf bytes.Buffer
p.stdout = Output(&buf)
err := p.Run(ctx)
Expand Down
50 changes: 47 additions & 3 deletions pipe/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ func TestPipelineEmptyOutput(t *testing.T) {
}
}

func TestPipelineOutputClosesConfiguredStdoutCloser(t *testing.T) {
t.Parallel()
ctx := context.Background()
stdout := &closeTrackingWriter{}
p := pipe.New(
pipe.WithStdin(strings.NewReader("hello world\n")),
pipe.WithStdoutCloser(stdout),
)

out, err := p.Output(ctx)
if assert.NoError(t, err) {
assert.Equal(t, "hello world\n", string(out))
assert.Equal(t, "", stdout.buf.String())
assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed")
}
}

func TestPipelineEmptyWithStdoutCloser(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down Expand Up @@ -951,6 +968,24 @@ func TestFunctionOptionsForbidStreams(t *testing.T) {
})
}

func TestFunctionOptionsSetStreamRequirements(t *testing.T) {
t.Parallel()

stage := pipe.Function(
"file-preferring",
func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error {
return nil
},
pipe.WithStdinRequirement(pipe.StreamPreferFile),
pipe.WithStdoutRequirement(pipe.StreamPreferFile),
)

assert.Equal(t, pipe.StageRequirements{
Stdin: pipe.StreamPreferFile,
Stdout: pipe.StreamPreferFile,
}, stage.Requirements())
}

func TestStreamForbiddenStdin(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down Expand Up @@ -1024,7 +1059,10 @@ func TestInvalidStreamRequirements(t *testing.T) {
Stdin: pipe.StreamRequirement(123),
},
})
require.ErrorContains(t, p.Run(ctx), `stdin: invalid stream requirement 123`)
require.ErrorContains(
t, p.Run(ctx),
`stage "source" has invalid stdin requirement: invalid stream requirement 123`,
)
assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed")
})

Expand All @@ -1038,7 +1076,10 @@ func TestInvalidStreamRequirements(t *testing.T) {
Stdout: pipe.StreamRequirement(123),
},
})
require.ErrorContains(t, p.Run(ctx), `stdout: invalid stream requirement 123`)
require.ErrorContains(
t, p.Run(ctx),
`stage "sink" has invalid stdout requirement: invalid stream requirement 123`,
)
assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed")
})
}
Expand Down Expand Up @@ -1071,7 +1112,10 @@ func TestInvalidStreamRequirement(t *testing.T) {
Stdin: pipe.StreamRequirement(99),
},
})
require.ErrorContains(t, p.Run(ctx), `stdin: invalid stream requirement 99`)
require.ErrorContains(
t, p.Run(ctx),
`stage "invalid" has invalid stdin requirement: invalid stream requirement 99`,
)
}

func TestFunctionNoInput(t *testing.T) {
Expand Down
Loading