diff --git a/pkg/cmdutil/pprof.go b/pkg/cmdutil/pprof.go index 99095e0e..f08999ed 100644 --- a/pkg/cmdutil/pprof.go +++ b/pkg/cmdutil/pprof.go @@ -21,30 +21,7 @@ func InitPProf(log *logging.Logger, mode string, addr string) func() { //nolint: switch mode { case "http": - go func() { - mux := http.NewServeMux() - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - - for _, profile := range []string{"heap", "goroutine", "threadcreate", "block", "mutex", "allocs"} { - mux.Handle("/debug/pprof/"+profile, pprof.Handler(profile)) - } - - srv := &http.Server{ - Addr: addr, - Handler: mux, - ReadHeaderTimeout: 5 * time.Second, - WriteTimeout: 30 * time.Second, - } - log.Infof("Serving pprof on http://%s", addr) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Errorf("pprof http server failed: %v", err) - } - }() - + startPProfHTTP(log, addr, false) time.Sleep(100 * time.Millisecond) return noop @@ -122,21 +99,7 @@ func InitPProf(log *logging.Logger, mode string, addr string) func() { //nolint: } case "trace": - go func() { - mux := http.NewServeMux() - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - srv := &http.Server{ - Addr: addr, - Handler: mux, - ReadHeaderTimeout: 5 * time.Second, - WriteTimeout: 60 * time.Second, - } - log.Infof("Serving trace endpoint on http://%s/debug/pprof/trace", addr) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Errorf("pprof trace server failed: %v", err) - } - }() - + startPProfHTTP(log, addr, true) time.Sleep(100 * time.Millisecond) return noop @@ -149,3 +112,47 @@ func InitPProf(log *logging.Logger, mode string, addr string) func() { //nolint: return noop } + +// startPProfHTTP starts a pprof HTTP server on a dedicated OS thread. +// Locking the goroutine to its own thread ensures the kernel scheduler +// gives it CPU time even when the Go runtime is saturated with goroutines, +// which is exactly when pprof is needed most. +func startPProfHTTP(log *logging.Logger, addr string, traceOnly bool) { + // Reserve an extra OS thread for the pprof server so it doesn't + // compete with application goroutines for GOMAXPROCS slots. + runtime.GOMAXPROCS(runtime.GOMAXPROCS(0) + 1) + + go func() { + // Pin this goroutine to a dedicated OS thread so the kernel + // scheduler guarantees it CPU time independent of Go's + // cooperative goroutine scheduler. + runtime.LockOSThread() + + mux := http.NewServeMux() + if traceOnly { + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + log.Infof("Serving trace endpoint on http://%s/debug/pprof/trace (dedicated thread)", addr) + } else { + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + + for _, profile := range []string{"heap", "goroutine", "threadcreate", "block", "mutex", "allocs"} { + mux.Handle("/debug/pprof/"+profile, pprof.Handler(profile)) + } + log.Infof("Serving pprof on http://%s (dedicated thread)", addr) + } + + srv := &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + WriteTimeout: 30 * time.Second, + } + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Errorf("pprof http server failed: %v", err) + } + }() +} diff --git a/pkg/dmsg/server.go b/pkg/dmsg/server.go index eef196ba..8cb22513 100644 --- a/pkg/dmsg/server.go +++ b/pkg/dmsg/server.go @@ -150,7 +150,6 @@ func (s *Server) Serve(lis net.Listener, addr string) error { return err } - // TODO(evanlinjin): Implement proper load-balancing. if s.SessionCount() >= s.maxSessions { s.log. WithField("max_sessions", s.maxSessions). diff --git a/pkg/dmsg/server_session.go b/pkg/dmsg/server_session.go index b2565578..800e8cdf 100644 --- a/pkg/dmsg/server_session.go +++ b/pkg/dmsg/server_session.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "time" "github.com/hashicorp/yamux" "github.com/sirupsen/logrus" @@ -15,6 +16,16 @@ import ( "github.com/skycoin/dmsg/pkg/noise" ) +const ( + // maxConcurrentStreams limits how many streams can be served concurrently + // per session to prevent a single session from exhausting server resources. + maxConcurrentStreams = 2048 + + // streamErrorBackoff is the delay after a non-fatal stream accept error + // to prevent CPU spin on persistent errors. + streamErrorBackoff = 50 * time.Millisecond +) + // ServerSession represents a session from the perspective of a dmsg server. type ServerSession struct { *SessionCommon @@ -45,6 +56,10 @@ func (ss *ServerSession) Close() error { func (ss *ServerSession) Serve() { ss.m.RecordSession(metrics.DeltaConnect) // record successful connection defer ss.m.RecordSession(metrics.DeltaDisconnect) // record disconnection + + // Semaphore to limit concurrent streams per session. + sem := make(chan struct{}, maxConcurrentStreams) + if ss.sm.smux != nil { for { sStr, err := ss.sm.smux.AcceptStream() @@ -54,13 +69,24 @@ func (ss *ServerSession) Serve() { return } ss.log.WithError(err).Warn("Failed to accept smux stream, continuing...") + time.Sleep(streamErrorBackoff) continue } log := ss.log.WithField("smux_id", sStr.ID()) - log.Info("Initiating stream.") + // Acquire semaphore slot; if full, reject the stream. + select { + case sem <- struct{}{}: + default: + log.Warn("Max concurrent streams reached, rejecting stream.") + sStr.Close() //nolint:errcheck,gosec + continue + } + + log.Info("Initiating stream.") go func(sStr *smux.Stream) { + defer func() { <-sem }() defer func() { if r := recover(); r != nil { log.WithField("panic", r).Error("Recovered from panic in serveStream") @@ -79,13 +105,24 @@ func (ss *ServerSession) Serve() { return } ss.log.WithError(err).Warn("Failed to accept yamux stream, continuing...") + time.Sleep(streamErrorBackoff) continue } log := ss.log.WithField("yamux_id", yStr.StreamID()) - log.Info("Initiating stream.") + // Acquire semaphore slot; if full, reject the stream. + select { + case sem <- struct{}{}: + default: + log.Warn("Max concurrent streams reached, rejecting stream.") + yStr.Close() //nolint:errcheck,gosec + continue + } + + log.Info("Initiating stream.") go func(yStr *yamux.Stream) { + defer func() { <-sem }() defer func() { if r := recover(); r != nil { log.WithField("panic", r).Error("Recovered from panic in serveStream") @@ -101,6 +138,14 @@ func (ss *ServerSession) Serve() { // struct func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr io.ReadWriteCloser, addr net.Addr) error { + // Set a deadline for the initial stream request read so a slow or + // malicious client cannot hold a goroutine and semaphore slot indefinitely. + if conn, ok := yStr.(net.Conn); ok { + if err := conn.SetReadDeadline(time.Now().Add(HandshakeTimeout)); err != nil { + return fmt.Errorf("set read deadline: %w", err) + } + } + readRequest := func() (StreamRequest, error) { obj, err := ss.readObject(yStr) if err != nil { @@ -183,6 +228,11 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr io.ReadWriteCl } log.Debug("Forwarded stream response.") + // Clear the read deadline before the long-lived bidirectional copy. + if conn, ok := yStr.(net.Conn); ok { + conn.SetReadDeadline(time.Time{}) //nolint:errcheck,gosec + } + // Serve stream. log.Info("Serving stream.") ss.m.RecordStream(metrics.DeltaConnect) // record successful stream