Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (
)

type Backend struct {
url string
healthy atomic.Bool
weight int
url string
healthy atomic.Bool
weight int
activeConns atomic.Int64
}

type BackendOptions struct {
Expand Down Expand Up @@ -39,6 +40,18 @@ func (b *Backend) SetHealth(isHealthy bool) {
b.healthy.Store(isHealthy)
}

func (b *Backend) IncrActiveConns() {
b.activeConns.Add(1)
}

func (b *Backend) DecrActiveConns() {
b.activeConns.Add(-1)
}

func (b *Backend) GetActiveConns() int64 {
return b.activeConns.Load()
}

type BackendPool struct {
bs []*Backend
mu sync.RWMutex
Expand Down
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"load-balancer/listener"
"load-balancer/proxy"
"load-balancer/router"
"load-balancer/router/leastconnections"
"load-balancer/router/roundrobin"
"load-balancer/router/weightedroundrobin"
"log"
Expand Down Expand Up @@ -42,6 +43,8 @@ func main() {
algo = roundrobin.NewRoundRobin(bp)
case "weighted-round-robin":
algo = weightedroundrobin.NewWeightedRoundRobin(bp)
case "least_connections":
algo = leastconnections.NewLeastConnections(bp)
default:
panic("unknown algorithm")
}
Expand Down
8 changes: 8 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ func (p *Proxy) Handle(conn net.Conn) error {
if err != nil {
return fmt.Errorf("routing: %w", err)
}

backendConn, err := net.DialTimeout("tcp", b.GetUrl(), p.dialTimeout)
if err != nil {
return fmt.Errorf("dial %s: %w", b.GetUrl(), err)
}

b.IncrActiveConns()
defer b.DecrActiveConns()
defer backendConn.Close()

err = conn.SetDeadline(time.Now().Add(p.connTimeout))
if err != nil {
return fmt.Errorf("set deadline %s: %w", conn.RemoteAddr(), err)
Expand All @@ -53,6 +58,7 @@ func (p *Proxy) Handle(conn net.Conn) error {
if err != nil {
return fmt.Errorf("set deadline %s: %w", backendConn.RemoteAddr(), err)
}

ch := make(chan error, 2)
go func() {
_, localErr := io.Copy(backendConn, conn)
Expand All @@ -68,10 +74,12 @@ func (p *Proxy) Handle(conn net.Conn) error {
}
ch <- localErr
}()

for range 2 {
if err = <-ch; err != nil {
return err
}
}

return nil
}
62 changes: 62 additions & 0 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,65 @@ func TestHandleBackendUnreachable(t *testing.T) {
t.Error("expected error when backend is unreachable")
}
}

func TestActiveConnsIncrementedDuringConnection(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()

connEstablished := make(chan struct{})
unblock := make(chan struct{})
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
close(connEstablished)
<-unblock
}()

b := backend.NewBackend(backend.BackendOptions{Url: ln.Addr().String()})
p := NewProxy(&stubRouter{b: b}, defaultOpts())

clientConn, proxyConn := net.Pipe()
defer clientConn.Close()

done := make(chan struct{})
go func() {
p.Handle(proxyConn)
close(done)
}()

<-connEstablished
deadline := time.After(time.Second)
for b.GetActiveConns() != 1 {
select {
case <-deadline:
t.Fatalf("timed out waiting for active conn increment, got %d", b.GetActiveConns())
default:
}
}

close(unblock)
clientConn.Close()
<-done

if b.GetActiveConns() != 0 {
t.Errorf("expected 0 active conns after Handle returned, got %d", b.GetActiveConns())
}
}

func TestActiveConnsNotIncrementedOnFailedDial(t *testing.T) {
b := backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:1"})
p := NewProxy(&stubRouter{b: b}, defaultOpts())
_, proxyConn := net.Pipe()

p.Handle(proxyConn)

if b.GetActiveConns() != 0 {
t.Errorf("expected 0 active conns after failed dial, got %d", b.GetActiveConns())
}
}
43 changes: 43 additions & 0 deletions router/leastconnections/leastconnections.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package leastconnections

import (
"errors"
"load-balancer/backend"
"math"
)

var ErrNoBackends = errors.New("no backends available")

type BackendPool interface {
GetPool() []*backend.Backend
}

type LeastConnections struct {
bp BackendPool
}

func NewLeastConnections(bp BackendPool) *LeastConnections {
if bp == nil {
panic("backend pool cannot be nil")
}
return &LeastConnections{bp: bp}
}

func (ls *LeastConnections) GetBackend() (*backend.Backend, error) {
bp := ls.bp.GetPool()
var best *backend.Backend
bestConns := int64(math.MaxInt64)
for _, b := range bp {
if !b.IsHealthy() {
continue
}
if activeConns := b.GetActiveConns(); activeConns < bestConns {
bestConns = activeConns
best = b
}
}
if best == nil {
return nil, ErrNoBackends
}
return best, nil
}
131 changes: 131 additions & 0 deletions router/leastconnections/leastconnections_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package leastconnections

import (
"load-balancer/backend"
"sync"
"testing"
)

type stubPool struct {
backends []*backend.Backend
}

func (s *stubPool) GetPool() []*backend.Backend {
return s.backends
}

func newHealthy(activeConns int64) *backend.Backend {
b := backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:0"})
b.SetHealth(true)
for range activeConns {
b.IncrActiveConns()
}
return b
}

func newUnhealthy(activeConns int64) *backend.Backend {
b := backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:0"})
for range activeConns {
b.IncrActiveConns()
}
return b
}

func TestPicksLeastConnectedBackend(t *testing.T) {
busy := newHealthy(5)
idle := newHealthy(1)
lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{busy, idle}})

got, err := lc.GetBackend()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != idle {
t.Error("expected backend with fewest connections")
}
}

func TestSkipsUnhealthyBackend(t *testing.T) {
unhealthy := newUnhealthy(0)
healthy := newHealthy(5)
lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{unhealthy, healthy}})

got, err := lc.GetBackend()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != healthy {
t.Error("expected healthy backend even with higher conn count")
}
}

func TestReturnsErrNoBackendsWhenAllUnhealthy(t *testing.T) {
lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{
newUnhealthy(0),
newUnhealthy(2),
}})

_, err := lc.GetBackend()
if err != ErrNoBackends {
t.Errorf("expected ErrNoBackends, got %v", err)
}
}

func TestReturnsErrNoBackendsOnEmptyPool(t *testing.T) {
lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{}})

_, err := lc.GetBackend()
if err != ErrNoBackends {
t.Errorf("expected ErrNoBackends, got %v", err)
}
}

func TestPicksOnlyHealthyBackendRegardlessOfConnCount(t *testing.T) {
healthy := newHealthy(100)
lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{
newUnhealthy(0),
newUnhealthy(0),
healthy,
}})

got, err := lc.GetBackend()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != healthy {
t.Error("expected the only healthy backend")
}
}

func TestEqualConnectionsReturnsHealthyBackend(t *testing.T) {
a := newHealthy(3)
b := newHealthy(3)
lc := NewLeastConnections(&stubPool{backends: []*backend.Backend{a, b}})

got, err := lc.GetBackend()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !got.IsHealthy() {
t.Error("expected a healthy backend")
}
}

func TestConcurrentSafety(t *testing.T) {
backends := []*backend.Backend{
newHealthy(0),
newHealthy(1),
newHealthy(2),
}
lc := NewLeastConnections(&stubPool{backends: backends})

var wg sync.WaitGroup
for range 100 {
wg.Add(1)
go func() {
defer wg.Done()
lc.GetBackend()
}()
}
wg.Wait()
}
Loading