diff --git a/exec.go b/exec.go index 07eaef0..b5ca852 100644 --- a/exec.go +++ b/exec.go @@ -1,6 +1,7 @@ package main import ( + "errors" "log" "os" "os/exec" @@ -50,8 +51,16 @@ func runCmd(ctx context.Context, cancel context.CancelFunc, cmd string, args ... log.Println("Command finished successfully.") } else { log.Printf("Command exited with error: %s\n", err) - // OPTIMIZE: This could be cleaner - os.Exit(err.(*exec.ExitError).Sys().(syscall.WaitStatus).ExitStatus()) + + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { + os.Exit(status.ExitStatus()) + } + } + + // Fallback for non-ExitError types (e.g., *os.SyscallError) + os.Exit(1) } } diff --git a/exec_error_test.go b/exec_error_test.go new file mode 100644 index 0000000..be65473 --- /dev/null +++ b/exec_error_test.go @@ -0,0 +1,47 @@ +package main + +import ( + "errors" + "os/exec" + "syscall" + "testing" +) + +// helper function mirroring the exit code extraction logic +func extractExitCode(err error) int { + if err == nil { + return 0 + } + + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { + return status.ExitStatus() + } + } + + return 1 +} + +func TestExtractExitCode_FromExitError(t *testing.T) { + cmd := exec.Command("sh", "-c", "exit 42") + err := cmd.Run() + + if err == nil { + t.Fatalf("expected non-nil error") + } + + code := extractExitCode(err) + if code != 42 { + t.Fatalf("expected exit code 42, got %d", code) + } +} + +func TestExtractExitCode_FromGenericError(t *testing.T) { + err := errors.New("some generic error") + + code := extractExitCode(err) + if code != 1 { + t.Fatalf("expected fallback exit code 1, got %d", code) + } +}