diff --git a/pipe/pipeline.go b/pipe/pipeline.go index d237461..44fedd1 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -286,9 +286,6 @@ func (p *Pipeline) Start(ctx context.Context) error { // Store the stages in the joiners, and verify that the stages' // requirements are well-formed: for i, s := range p.stages { - stageJoiners[i].nextStage = s - stageJoiners[i+1].prevStage = s - // Make sure that the stage's requirements are well-formed: requirements := s.Requirements() if err := requirements.Stdin.Validate(); err != nil { @@ -297,21 +294,25 @@ func (p *Pipeline) Start(ctx context.Context) error { if err := requirements.Stdout.Validate(); err != nil { return fmt.Errorf("stdout: %w", err) } + + stageJoiners[i].nextStage = s + stageJoiners[i].nextStageReq = requirements + stageJoiners[i+1].prevStage = s + stageJoiners[i+1].prevStageReq = requirements } - // Create the "inner" pipes (i.e, all but the first and last - // `stageJoiners`): - for i := 1; i < len(stageJoiners)-1; i++ { - if err := stageJoiners[i].createPipe(); err != nil { + // Check that each of the stages' requirements are satisfiable: + for i := range stageJoiners { + if err := stageJoiners[i].validate(); err != nil { closePipes() return err } } - // Check that each of the stages' requirements are compatible with - // the pipes that we have created for them: - for i := range stageJoiners { - if err := stageJoiners[i].validate(); err != nil { + // Create the "inner" pipes (i.e, all but the first and last + // `stageJoiners`): + for i := 1; i < len(stageJoiners)-1; i++ { + if err := stageJoiners[i].createPipe(); err != nil { closePipes() return err } diff --git a/pipe/stage_joiner.go b/pipe/stage_joiner.go index 24fb789..03569cb 100644 --- a/pipe/stage_joiner.go +++ b/pipe/stage_joiner.go @@ -43,6 +43,11 @@ type stageJoiner struct { // prevStage holds the stage that needs to write to the pipe. prevStage Stage + // prevStageReq caches `prevStage.Requirements()` so that it + // doesn't have to be recomputed. It is the zero value if + // `prevStage` is nil. + prevStageReq StageRequirements + // prevStdout will be used as the stdout of `prevStage`. It is // usually the "write" end of the `(nextStdin, prevStdout)` pipe // pair, with the connected pipe ends in the same `stageJoiner` @@ -52,6 +57,11 @@ type stageJoiner struct { // nextStage holds the stage that needs to read from the pipe. nextStage Stage + // nextStageReq caches `nextStage.Requirements()` so that it + // doesn't have to be recomputed. It is the zero value if + // `nextStage` is nil. + nextStageReq StageRequirements + // nextStdin will be used as the stdin of `nextStage`. It is // usually the "read" end of the `(nextStdin, prevStdout)` pipe // pair. @@ -61,13 +71,8 @@ type stageJoiner struct { // needFilePipe returns `true` if the pipe that joins the two adjacent // stages should be an `os.Pipe()` rather than an `io.Pipe()`. func (sj *stageJoiner) needFilePipe() bool { - if sj.prevStage.Requirements().Stdout == StreamPreferFile { - return true - } - if sj.nextStage.Requirements().Stdin == StreamPreferFile { - return true - } - return false + return sj.prevStageReq.Stdout == StreamPreferFile || + sj.nextStageReq.Stdin == StreamPreferFile } func (sj *stageJoiner) createPipe() error { @@ -100,26 +105,28 @@ func (sj *stageJoiner) closePipe() error { ) } -// validate verifies that `sj.prevStdout` and `sj.nextStdin` are -// suitable for the adjacent stages, in particular that no pipe is -// created if the stage requirements are `StreamForbidden`. +// validate verifies that the adjacent stages' stream requirements are +// satisfiable, in particular that a stage that forbids its stdin or +// stdout is not connected to anything. func (sj *stageJoiner) validate() error { - if sj.prevStage != nil { - stdoutRequirements := sj.prevStage.Requirements().Stdout - if stdoutRequirements == StreamForbidden && sj.prevStdout != nil { - return fmt.Errorf( - "stage %q forbids stdout, but stdout is connected", sj.prevStage.Name(), - ) - } + // `prevStage`'s stdout is connected if there is a `nextStage` to + // consume it (in which case an inner pipe will be created) or if + // a stream (`p.stdout`) has already been stored in `prevStdout`. + if sj.prevStage != nil && sj.prevStageReq.Stdout == StreamForbidden && + (sj.nextStage != nil || sj.prevStdout != nil) { + return fmt.Errorf( + "stage %q forbids stdout, but stdout is connected", sj.prevStage.Name(), + ) } - if sj.nextStage != nil { - stdinRequirements := sj.nextStage.Requirements().Stdin - if stdinRequirements == StreamForbidden && sj.nextStdin != nil { - return fmt.Errorf( - "stage %q forbids stdin, but stdin is connected", sj.nextStage.Name(), - ) - } + // `nextStage`'s stdin is connected if there is a `prevStage` to + // produce it (in which case an inner pipe will be created) or if + // a stream (`p.stdin`) has already been stored in `nextStdin`. + if sj.nextStage != nil && sj.nextStageReq.Stdin == StreamForbidden && + (sj.prevStage != nil || sj.nextStdin != nil) { + return fmt.Errorf( + "stage %q forbids stdin, but stdin is connected", sj.nextStage.Name(), + ) } return nil