diff --git a/backend/backend.go b/backend/backend.go index 4e2cb40..20bcf27 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -6,9 +6,10 @@ import ( ) type Backend struct { - url string - healthy atomic.Bool - weight int + url string + healthy atomic.Bool + weight int + activeConns atomic.Int64 } type BackendOptions struct { @@ -39,6 +40,18 @@ func (b *Backend) SetHealth(isHealthy bool) { b.healthy.Store(isHealthy) } +func (b *Backend) IncrActiveConns() { + b.activeConns.Add(1) +} + +func (b *Backend) DecrActiveConns() { + b.activeConns.Add(-1) +} + +func (b *Backend) GetActiveConns() int64 { + return b.activeConns.Load() +} + type BackendPool struct { bs []*Backend mu sync.RWMutex diff --git a/main.go b/main.go index f09b179..c15861e 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "load-balancer/listener" "load-balancer/proxy" "load-balancer/router" + "load-balancer/router/leastconnections" "load-balancer/router/roundrobin" "load-balancer/router/weightedroundrobin" "log" @@ -42,6 +43,8 @@ func main() { algo = roundrobin.NewRoundRobin(bp) case "weighted-round-robin": algo = weightedroundrobin.NewWeightedRoundRobin(bp) + case "least_connections": + algo = leastconnections.NewLeastConnections(bp) default: panic("unknown algorithm") } diff --git a/proxy/proxy.go b/proxy/proxy.go index cf90d3f..a39baee 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -40,11 +40,16 @@ func (p *Proxy) Handle(conn net.Conn) error { if err != nil { return fmt.Errorf("routing: %w", err) } + backendConn, err := net.DialTimeout("tcp", b.GetUrl(), p.dialTimeout) if err != nil { return fmt.Errorf("dial %s: %w", b.GetUrl(), err) } + + b.IncrActiveConns() + defer b.DecrActiveConns() defer backendConn.Close() + err = conn.SetDeadline(time.Now().Add(p.connTimeout)) if err != nil { return fmt.Errorf("set deadline %s: %w", conn.RemoteAddr(), err) @@ -53,6 +58,7 @@ func (p *Proxy) Handle(conn net.Conn) error { if err != nil { return fmt.Errorf("set deadline %s: %w", backendConn.RemoteAddr(), err) } + ch := make(chan error, 2) go func() { _, localErr := io.Copy(backendConn, conn) @@ -68,10 +74,12 @@ func (p *Proxy) Handle(conn net.Conn) error { } ch <- localErr }() + for range 2 { if err = <-ch; err != nil { return err } } + return nil } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 0ef81f5..0d3f8ea 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -128,3 +128,65 @@ func TestHandleBackendUnreachable(t *testing.T) { t.Error("expected error when backend is unreachable") } } + +func TestActiveConnsIncrementedDuringConnection(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + connEstablished := make(chan struct{}) + unblock := make(chan struct{}) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + close(connEstablished) + <-unblock + }() + + b := backend.NewBackend(backend.BackendOptions{Url: ln.Addr().String()}) + p := NewProxy(&stubRouter{b: b}, defaultOpts()) + + clientConn, proxyConn := net.Pipe() + defer clientConn.Close() + + done := make(chan struct{}) + go func() { + p.Handle(proxyConn) + close(done) + }() + + <-connEstablished + deadline := time.After(time.Second) + for b.GetActiveConns() != 1 { + select { + case <-deadline: + t.Fatalf("timed out waiting for active conn increment, got %d", b.GetActiveConns()) + default: + } + } + + close(unblock) + clientConn.Close() + <-done + + if b.GetActiveConns() != 0 { + t.Errorf("expected 0 active conns after Handle returned, got %d", b.GetActiveConns()) + } +} + +func TestActiveConnsNotIncrementedOnFailedDial(t *testing.T) { + b := backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:1"}) + p := NewProxy(&stubRouter{b: b}, defaultOpts()) + _, proxyConn := net.Pipe() + + p.Handle(proxyConn) + + if b.GetActiveConns() != 0 { + t.Errorf("expected 0 active conns after failed dial, got %d", b.GetActiveConns()) + } +} diff --git a/router/leastconnections/leastconnections.go b/router/leastconnections/leastconnections.go new file mode 100644 index 0000000..981b0b7 --- /dev/null +++ b/router/leastconnections/leastconnections.go @@ -0,0 +1,43 @@ +package leastconnections + +import ( + "errors" + "load-balancer/backend" + "math" +) + +var ErrNoBackends = errors.New("no backends available") + +type BackendPool interface { + GetPool() []*backend.Backend +} + +type LeastConnections struct { + bp BackendPool +} + +func NewLeastConnections(bp BackendPool) *LeastConnections { + if bp == nil { + panic("backend pool cannot be nil") + } + return &LeastConnections{bp: bp} +} + +func (ls *LeastConnections) GetBackend() (*backend.Backend, error) { + bp := ls.bp.GetPool() + var best *backend.Backend + bestConns := int64(math.MaxInt64) + for _, b := range bp { + if !b.IsHealthy() { + continue + } + if activeConns := b.GetActiveConns(); activeConns < bestConns { + bestConns = activeConns + best = b + } + } + if best == nil { + return nil, ErrNoBackends + } + return best, nil +} diff --git a/router/leastconnections/leastconnections_test.go b/router/leastconnections/leastconnections_test.go new file mode 100644 index 0000000..53ce8e6 --- /dev/null +++ b/router/leastconnections/leastconnections_test.go @@ -0,0 +1,131 @@ +package leastconnections + +import ( + "load-balancer/backend" + "sync" + "testing" +) + +type stubPool struct { + backends []*backend.Backend +} + +func (s *stubPool) GetPool() []*backend.Backend { + return s.backends +} + +func newHealthy(activeConns int64) *backend.Backend { + b := backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:0"}) + b.SetHealth(true) + for range activeConns { + b.IncrActiveConns() + } + return b +} + +func newUnhealthy(activeConns int64) *backend.Backend { + b := backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:0"}) + for range activeConns { + b.IncrActiveConns() + } + return b +} + +func TestPicksLeastConnectedBackend(t *testing.T) { + busy := newHealthy(5) + idle := newHealthy(1) + lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{busy, idle}}) + + got, err := lc.GetBackend() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != idle { + t.Error("expected backend with fewest connections") + } +} + +func TestSkipsUnhealthyBackend(t *testing.T) { + unhealthy := newUnhealthy(0) + healthy := newHealthy(5) + lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{unhealthy, healthy}}) + + got, err := lc.GetBackend() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != healthy { + t.Error("expected healthy backend even with higher conn count") + } +} + +func TestReturnsErrNoBackendsWhenAllUnhealthy(t *testing.T) { + lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{ + newUnhealthy(0), + newUnhealthy(2), + }}) + + _, err := lc.GetBackend() + if err != ErrNoBackends { + t.Errorf("expected ErrNoBackends, got %v", err) + } +} + +func TestReturnsErrNoBackendsOnEmptyPool(t *testing.T) { + lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{}}) + + _, err := lc.GetBackend() + if err != ErrNoBackends { + t.Errorf("expected ErrNoBackends, got %v", err) + } +} + +func TestPicksOnlyHealthyBackendRegardlessOfConnCount(t *testing.T) { + healthy := newHealthy(100) + lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{ + newUnhealthy(0), + newUnhealthy(0), + healthy, + }}) + + got, err := lc.GetBackend() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != healthy { + t.Error("expected the only healthy backend") + } +} + +func TestEqualConnectionsReturnsHealthyBackend(t *testing.T) { + a := newHealthy(3) + b := newHealthy(3) + lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{a, b}}) + + got, err := lc.GetBackend() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !got.IsHealthy() { + t.Error("expected a healthy backend") + } +} + +func TestConcurrentSafety(t *testing.T) { + backends := []*backend.Backend{ + newHealthy(0), + newHealthy(1), + newHealthy(2), + } + lc := NewLeastConnections(&stubPool{backends: backends}) + + var wg sync.WaitGroup + for range 100 { + wg.Add(1) + go func() { + defer wg.Done() + lc.GetBackend() + }() + } + wg.Wait() +}