Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 46 additions & 39 deletions pkg/cmdutil/pprof.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,7 @@ func InitPProf(log *logging.Logger, mode string, addr string) func() { //nolint:

switch mode {
case "http":
go func() {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)

for _, profile := range []string{"heap", "goroutine", "threadcreate", "block", "mutex", "allocs"} {
mux.Handle("/debug/pprof/"+profile, pprof.Handler(profile))
}

srv := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 30 * time.Second,
}
log.Infof("Serving pprof on http://%s", addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Errorf("pprof http server failed: %v", err)
}
}()

startPProfHTTP(log, addr, false)
time.Sleep(100 * time.Millisecond)
return noop

Expand Down Expand Up @@ -122,21 +99,7 @@ func InitPProf(log *logging.Logger, mode string, addr string) func() { //nolint:
}

case "trace":
go func() {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
srv := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 60 * time.Second,
}
log.Infof("Serving trace endpoint on http://%s/debug/pprof/trace", addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Errorf("pprof trace server failed: %v", err)
}
}()

startPProfHTTP(log, addr, true)
time.Sleep(100 * time.Millisecond)
return noop

Expand All @@ -149,3 +112,47 @@ func InitPProf(log *logging.Logger, mode string, addr string) func() { //nolint:

return noop
}

// startPProfHTTP starts a pprof HTTP server on a dedicated OS thread.
// Locking the goroutine to its own thread ensures the kernel scheduler
// gives it CPU time even when the Go runtime is saturated with goroutines,
// which is exactly when pprof is needed most.
func startPProfHTTP(log *logging.Logger, addr string, traceOnly bool) {
// Reserve an extra OS thread for the pprof server so it doesn't
// compete with application goroutines for GOMAXPROCS slots.
runtime.GOMAXPROCS(runtime.GOMAXPROCS(0) + 1)

go func() {
// Pin this goroutine to a dedicated OS thread so the kernel
// scheduler guarantees it CPU time independent of Go's
// cooperative goroutine scheduler.
runtime.LockOSThread()

mux := http.NewServeMux()
if traceOnly {
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
log.Infof("Serving trace endpoint on http://%s/debug/pprof/trace (dedicated thread)", addr)
} else {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)

for _, profile := range []string{"heap", "goroutine", "threadcreate", "block", "mutex", "allocs"} {
mux.Handle("/debug/pprof/"+profile, pprof.Handler(profile))
}
log.Infof("Serving pprof on http://%s (dedicated thread)", addr)
}

srv := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 30 * time.Second,
}
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Errorf("pprof http server failed: %v", err)
}
}()
}
1 change: 0 additions & 1 deletion pkg/dmsg/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ func (s *Server) Serve(lis net.Listener, addr string) error {
return err
}

// TODO(evanlinjin): Implement proper load-balancing.
if s.SessionCount() >= s.maxSessions {
s.log.
WithField("max_sessions", s.maxSessions).
Expand Down
54 changes: 52 additions & 2 deletions pkg/dmsg/server_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net"
"time"

"github.com/hashicorp/yamux"
"github.com/sirupsen/logrus"
Expand All @@ -15,6 +16,16 @@ import (
"github.com/skycoin/dmsg/pkg/noise"
)

const (
// maxConcurrentStreams limits how many streams can be served concurrently
// per session to prevent a single session from exhausting server resources.
maxConcurrentStreams = 2048

// streamErrorBackoff is the delay after a non-fatal stream accept error
// to prevent CPU spin on persistent errors.
streamErrorBackoff = 50 * time.Millisecond
)

// ServerSession represents a session from the perspective of a dmsg server.
type ServerSession struct {
*SessionCommon
Expand Down Expand Up @@ -45,6 +56,10 @@ func (ss *ServerSession) Close() error {
func (ss *ServerSession) Serve() {
ss.m.RecordSession(metrics.DeltaConnect) // record successful connection
defer ss.m.RecordSession(metrics.DeltaDisconnect) // record disconnection

// Semaphore to limit concurrent streams per session.
sem := make(chan struct{}, maxConcurrentStreams)

if ss.sm.smux != nil {
for {
sStr, err := ss.sm.smux.AcceptStream()
Expand All @@ -54,13 +69,24 @@ func (ss *ServerSession) Serve() {
return
}
ss.log.WithError(err).Warn("Failed to accept smux stream, continuing...")
time.Sleep(streamErrorBackoff)
continue
}

log := ss.log.WithField("smux_id", sStr.ID())
log.Info("Initiating stream.")

// Acquire semaphore slot; if full, reject the stream.
select {
case sem <- struct{}{}:
default:
log.Warn("Max concurrent streams reached, rejecting stream.")
sStr.Close() //nolint:errcheck,gosec
continue
}

log.Info("Initiating stream.")
go func(sStr *smux.Stream) {
defer func() { <-sem }()
defer func() {
if r := recover(); r != nil {
log.WithField("panic", r).Error("Recovered from panic in serveStream")
Expand All @@ -79,13 +105,24 @@ func (ss *ServerSession) Serve() {
return
}
ss.log.WithError(err).Warn("Failed to accept yamux stream, continuing...")
time.Sleep(streamErrorBackoff)
continue
}

log := ss.log.WithField("yamux_id", yStr.StreamID())
log.Info("Initiating stream.")

// Acquire semaphore slot; if full, reject the stream.
select {
case sem <- struct{}{}:
default:
log.Warn("Max concurrent streams reached, rejecting stream.")
yStr.Close() //nolint:errcheck,gosec
continue
}

log.Info("Initiating stream.")
go func(yStr *yamux.Stream) {
defer func() { <-sem }()
defer func() {
if r := recover(); r != nil {
log.WithField("panic", r).Error("Recovered from panic in serveStream")
Expand All @@ -101,6 +138,14 @@ func (ss *ServerSession) Serve() {
// struct

func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr io.ReadWriteCloser, addr net.Addr) error {
// Set a deadline for the initial stream request read so a slow or
// malicious client cannot hold a goroutine and semaphore slot indefinitely.
if conn, ok := yStr.(net.Conn); ok {
if err := conn.SetReadDeadline(time.Now().Add(HandshakeTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
}

readRequest := func() (StreamRequest, error) {
obj, err := ss.readObject(yStr)
if err != nil {
Expand Down Expand Up @@ -183,6 +228,11 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr io.ReadWriteCl
}
log.Debug("Forwarded stream response.")

// Clear the read deadline before the long-lived bidirectional copy.
if conn, ok := yStr.(net.Conn); ok {
conn.SetReadDeadline(time.Time{}) //nolint:errcheck,gosec
}

// Serve stream.
log.Info("Serving stream.")
ss.m.RecordStream(metrics.DeltaConnect) // record successful stream
Expand Down
Loading