diff --git a/pipe/command.go b/pipe/command.go index 27d34c6..14a18d7 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -29,9 +29,12 @@ type commandStage struct { wg errgroup.Group stderr bytes.Buffer - // If the context expired, and we attempted to kill the command, - // `ctx.Err()` is stored here. - ctxErr atomic.Value + // If we attempted to kill the command, the first reason is stored here. + ctxErr atomic.Pointer[commandKillError] +} + +type commandKillError struct { + err error } var ( @@ -288,8 +291,8 @@ func (s *commandStage) filterCmdError(err error) error { return err } - ctxErr, ok := s.ctxErr.Load().(error) - if ok { + ctxErr := s.ctxErr.Load() + if ctxErr != nil { // If the process looks like it was killed by us, substitute // `ctxErr` for the process's own exit error. Note that this // doesn't do anything on Windows, where the `Signaled()` @@ -298,7 +301,7 @@ func (s *commandStage) filterCmdError(err error) error { ps, ok := eErr.Sys().(syscall.WaitStatus) if ok && ps.Signaled() && (ps.Signal() == syscall.SIGTERM || ps.Signal() == syscall.SIGKILL) { - return ctxErr + return ctxErr.err } } @@ -306,6 +309,12 @@ func (s *commandStage) filterCmdError(err error) error { return eErr } +func (s *commandStage) recordKillError(err error) { + if err != nil { + s.ctxErr.CompareAndSwap(nil, &commandKillError{err: err}) + } +} + func (s *commandStage) Wait() error { defer close(s.done) diff --git a/pipe/command_test.go b/pipe/command_test.go index 531f11f..39467c4 100644 --- a/pipe/command_test.go +++ b/pipe/command_test.go @@ -1,6 +1,8 @@ package pipe import ( + "context" + "errors" "testing" "github.com/stretchr/testify/assert" @@ -83,3 +85,16 @@ func TestCopyEnvWithOverride(t *testing.T) { }) } } + +func TestCommandStageRecordKillErrorAcceptsDifferentErrorTypesAndKeepsFirst(t *testing.T) { + errMemoryLimitExceeded := errors.New("memory limit exceeded") + var stage commandStage + + stage.recordKillError(errMemoryLimitExceeded) + stage.recordKillError(context.DeadlineExceeded) + + got := stage.ctxErr.Load() + if assert.NotNil(t, got, "expected ctxErr to store commandKillError") { + assert.ErrorIs(t, got.err, errMemoryLimitExceeded) + } +} diff --git a/pipe/command_unix.go b/pipe/command_unix.go index 1ccc7a2..ce07cfa 100644 --- a/pipe/command_unix.go +++ b/pipe/command_unix.go @@ -44,9 +44,8 @@ func (s *commandStage) Kill(err error) { default: } - // Record the `ctx.Err()`, which will be used as the error result - // for this stage. - s.ctxErr.Store(err) + // Record the kill reason, which will be used as the error result for this stage. + s.recordKillError(err) // First try to kill using a relatively gentle signal so that // the processes have a chance to clean up after themselves: diff --git a/pipe/command_windows.go b/pipe/command_windows.go index f8cdf3a..4c7abfd 100644 --- a/pipe/command_windows.go +++ b/pipe/command_windows.go @@ -21,9 +21,8 @@ func (s *commandStage) Kill(err error) { default: } - // Record the `ctx.Err()`, which will be used as the error result - // for this stage. - s.ctxErr.Store(err) + // Record the kill reason, which will be used as the error result for this stage. + s.recordKillError(err) s.cmd.Process.Kill() }