diff --git a/command/ssh/proxycommand.go b/command/ssh/proxycommand.go index fe35cf187..0a0293b38 100644 --- a/command/ssh/proxycommand.go +++ b/command/ssh/proxycommand.go @@ -6,7 +6,6 @@ import ( "net" "os" "strings" - "sync" "time" "github.com/pkg/errors" @@ -228,6 +227,10 @@ func getBastion(ctx *cli.Context, user, host string) (*api.SSHBastionResponse, e } func proxyDirect(host, port string) error { + return proxyDirectWithIO(host, port, os.Stdin, os.Stdout) +} + +func proxyDirectWithIO(host, port string, stdin io.Reader, stdout io.Writer) error { address := net.JoinHostPort(host, port) addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { @@ -238,22 +241,25 @@ func proxyDirect(host, port string) error { if err != nil { return errors.Wrapf(err, "error connecting to %s", address) } + defer conn.Close() - var wg sync.WaitGroup - wg.Add(1) + // Return as soon as either direction finishes. Waiting for both can + // deadlock when the server closes the connection while stdin stays open. + // See smallstep/cli#1641. Buffered so the slower goroutine never blocks + // sending after we've stopped receiving. + done := make(chan struct{}, 2) go func() { - io.Copy(conn, os.Stdin) + io.Copy(conn, stdin) conn.CloseWrite() - wg.Done() + done <- struct{}{} }() - wg.Add(1) go func() { - io.Copy(os.Stdout, conn) + io.Copy(stdout, conn) conn.CloseRead() - wg.Done() + done <- struct{}{} }() - wg.Wait() + <-done return nil } diff --git a/command/ssh/proxycommand_test.go b/command/ssh/proxycommand_test.go new file mode 100644 index 000000000..4484e930a --- /dev/null +++ b/command/ssh/proxycommand_test.go @@ -0,0 +1,55 @@ +package ssh + +import ( + "bytes" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// Test_proxyDirectWithIO_serverClosesBeforeStdin reproduces smallstep/cli#1641: +// when the server closes the connection before the client has closed stdin, the +// proxycommand must still return promptly. Previously it would block in +// wg.Wait() forever because the stdin->conn goroutine stayed blocked reading a +// stdin that never reaches EOF (the ssh client keeps it open until the +// proxycommand exits). +func Test_proxyDirectWithIO_serverClosesBeforeStdin(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + // Server sends some data and immediately closes the connection. + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + conn.Write([]byte("hello")) + conn.Close() + }() + + host, port, err := net.SplitHostPort(ln.Addr().String()) + require.NoError(t, err) + + // stdin that never reaches EOF, simulating the ssh client keeping the + // proxycommand's stdin open for the lifetime of the session. + stdinR, stdinW := io.Pipe() + defer stdinW.Close() // write end intentionally left open during the call + + var stdout bytes.Buffer + done := make(chan error, 1) + go func() { + done <- proxyDirectWithIO(host, port, stdinR, &stdout) + }() + + select { + case err := <-done: + require.NoError(t, err) + require.Equal(t, "hello", stdout.String()) + case <-time.After(5 * time.Second): + t.Fatal("proxyDirectWithIO did not return after the server closed the connection") + } +}