Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,29 @@ import (
type Backend struct {
url string
healthy atomic.Bool
weight int
}

type BackendOptions struct {
Url string
Url string
Weight int
}

func NewBackend(opts BackendOptions) *Backend {
return &Backend{url: opts.Url}
return &Backend{
url: opts.Url,
weight: opts.Weight,
}
}

func (b *Backend) GetUrl() string {
return b.url
}

func (b *Backend) GetWeight() int {
return b.weight
}

func (b *Backend) IsHealthy() bool {
return b.healthy.Load()
}
Expand Down
7 changes: 6 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"load-balancer/proxy"
"load-balancer/router"
"load-balancer/router/roundrobin"
"load-balancer/router/weightedroundrobin"
"log"
"net"
"os"
Expand All @@ -29,14 +30,18 @@ func main() {
backends := make([]*backend.Backend, 0, len(conf.Backends))
for _, b := range conf.Backends {
backends = append(backends, backend.NewBackend(backend.BackendOptions{
Url: b.Url}))
Url: b.Url,
Weight: b.Weight,
}))
}
bp := backend.NewBackendPool(backends)

var algo router.AlgoIO
switch conf.Algorithm {
case "round-robin":
algo = roundrobin.NewRoundRobin(bp)
case "weighted-round-robin":
algo = weightedroundrobin.NewWeightedRoundRobin(bp)
default:
panic("unknown algorithm")
}
Expand Down
59 changes: 59 additions & 0 deletions router/weightedroundrobin/weightedroundrobin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package weightedroundrobin

import (
"errors"
"load-balancer/backend"
"sync"
)

var ErrNoBackends = errors.New("no backends available")

type BackendPoolIO interface {
GetPool() []*backend.Backend
}

type WeightedRoundRobin struct {
bs []*backend.Backend
index int
weight int
mu sync.Mutex
}

func NewWeightedRoundRobin(bp BackendPoolIO) *WeightedRoundRobin {
if bp == nil {
panic("backend pool cannot be nil")
}
copiedBp := bp.GetPool()
bs := make([]*backend.Backend, 0, len(copiedBp))
for _, b := range copiedBp {
if b.GetWeight() > 0 {
bs = append(bs, b)
}
}
if len(bs) < 1 {
panic("total backend pool weight cannot be zero")
}
return &WeightedRoundRobin{
bs: bs,
weight: bs[0].GetWeight(),
}
}

func (wrr *WeightedRoundRobin) GetBackend() (*backend.Backend, error) {
wrr.mu.Lock()
defer wrr.mu.Unlock()
for range len(wrr.bs) {
b := wrr.bs[wrr.index]
if b.IsHealthy() {
wrr.weight--
if wrr.weight <= 0 {
wrr.index = (wrr.index + 1) % len(wrr.bs)
wrr.weight = wrr.bs[wrr.index].GetWeight()
}
return b, nil
}
wrr.index = (wrr.index + 1) % len(wrr.bs)
wrr.weight = wrr.bs[wrr.index].GetWeight()
}
return nil, ErrNoBackends
}
133 changes: 133 additions & 0 deletions router/weightedroundrobin/weightedroundrobin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package weightedroundrobin

import (
"load-balancer/backend"
"sync"
"testing"
)

type stubPool struct {
backends []*backend.Backend
}

func (s *stubPool) GetPool() []*backend.Backend {
return s.backends
}

func newHealthy(weight int) *backend.Backend {
b := backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:0", Weight: weight})
b.SetHealth(true)
return b
}

func newUnhealthy(weight int) *backend.Backend {
return backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:0", Weight: weight})
}

func TestWeightDistribution(t *testing.T) {
a := newHealthy(3)
b := newHealthy(1)
wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{a, b}})

want := []*backend.Backend{a, a, a, b, a, a, a, b}
for i, expected := range want {
got, err := wrr.GetBackend()
if err != nil {
t.Fatalf("call %d: unexpected error: %v", i, err)
}
if got != expected {
t.Errorf("call %d: got wrong backend", i)
}
}
}

func TestSkipsUnhealthyBackend(t *testing.T) {
healthy := newHealthy(2)
unhealthy := newUnhealthy(3)
wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{unhealthy, healthy}})

for i := range 4 {
got, err := wrr.GetBackend()
if err != nil {
t.Fatalf("call %d: unexpected error: %v", i, err)
}
if got != healthy {
t.Errorf("call %d: expected healthy backend", i)
}
}
}

func TestReturnsErrNoBackendsWhenAllUnhealthy(t *testing.T) {
wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{
newUnhealthy(2),
newUnhealthy(1),
}})

_, err := wrr.GetBackend()
if err != ErrNoBackends {
t.Errorf("expected ErrNoBackends, got %v", err)
}
}

func TestWrapAround(t *testing.T) {
a := newHealthy(1)
b := newHealthy(1)
wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{a, b}})

got0, _ := wrr.GetBackend()
got1, _ := wrr.GetBackend()
got2, _ := wrr.GetBackend()

if got0 != a {
t.Error("call 0: expected a")
}
if got1 != b {
t.Error("call 1: expected b")
}
if got2 != a {
t.Error("call 2: expected a (wrap-around)")
}
}

func TestZeroWeightBackendsAreExcluded(t *testing.T) {
zero := newHealthy(0)
nonzero := newHealthy(2)
wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{zero, nonzero}})

for i := range 4 {
got, err := wrr.GetBackend()
if err != nil {
t.Fatalf("call %d: unexpected error: %v", i, err)
}
if got != nonzero {
t.Errorf("call %d: zero-weight backend should be excluded", i)
}
}
}

func TestNewWeightedRoundRobinPanicsOnAllZeroWeights(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic when all backends have zero weight")
}
}()
NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{newHealthy(0)}})
}

func TestConcurrentSafety(t *testing.T) {
backends := make([]*backend.Backend, 3)
for i := range backends {
backends[i] = newHealthy(i + 1)
}
wrr := NewWeightedRoundRobin(&stubPool{backends: backends})

var wg sync.WaitGroup
for range 100 {
wg.Add(1)
go func() {
defer wg.Done()
wrr.GetBackend()
}()
}
wg.Wait()
}
Loading