diff --git a/backend/backend.go b/backend/backend.go index c5a36d4..4e2cb40 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -8,20 +8,29 @@ import ( type Backend struct { url string healthy atomic.Bool + weight int } type BackendOptions struct { - Url string + Url string + Weight int } func NewBackend(opts BackendOptions) *Backend { - return &Backend{url: opts.Url} + return &Backend{ + url: opts.Url, + weight: opts.Weight, + } } func (b *Backend) GetUrl() string { return b.url } +func (b *Backend) GetWeight() int { + return b.weight +} + func (b *Backend) IsHealthy() bool { return b.healthy.Load() } diff --git a/main.go b/main.go index 42571f1..f09b179 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "load-balancer/proxy" "load-balancer/router" "load-balancer/router/roundrobin" + "load-balancer/router/weightedroundrobin" "log" "net" "os" @@ -29,7 +30,9 @@ func main() { backends := make([]*backend.Backend, 0, len(conf.Backends)) for _, b := range conf.Backends { backends = append(backends, backend.NewBackend(backend.BackendOptions{ - Url: b.Url})) + Url: b.Url, + Weight: b.Weight, + })) } bp := backend.NewBackendPool(backends) @@ -37,6 +40,8 @@ func main() { switch conf.Algorithm { case "round-robin": algo = roundrobin.NewRoundRobin(bp) + case "weighted-round-robin": + algo = weightedroundrobin.NewWeightedRoundRobin(bp) default: panic("unknown algorithm") } diff --git a/router/weightedroundrobin/weightedroundrobin.go b/router/weightedroundrobin/weightedroundrobin.go new file mode 100644 index 0000000..577e49e --- /dev/null +++ b/router/weightedroundrobin/weightedroundrobin.go @@ -0,0 +1,59 @@ +package weightedroundrobin + +import ( + "errors" + "load-balancer/backend" + "sync" +) + +var ErrNoBackends = errors.New("no backends available") + +type BackendPoolIO interface { + GetPool() []*backend.Backend +} + +type WeightedRoundRobin struct { + bs []*backend.Backend + index int + weight int + mu sync.Mutex +} + +func NewWeightedRoundRobin(bp BackendPoolIO) *WeightedRoundRobin { + if bp == nil { + panic("backend pool cannot be nil") + } + copiedBp := bp.GetPool() + bs := make([]*backend.Backend, 0, len(copiedBp)) + for _, b := range copiedBp { + if b.GetWeight() > 0 { + bs = append(bs, b) + } + } + if len(bs) < 1 { + panic("total backend pool weight cannot be zero") + } + return &WeightedRoundRobin{ + bs: bs, + weight: bs[0].GetWeight(), + } +} + +func (wrr *WeightedRoundRobin) GetBackend() (*backend.Backend, error) { + wrr.mu.Lock() + defer wrr.mu.Unlock() + for range len(wrr.bs) { + b := wrr.bs[wrr.index] + if b.IsHealthy() { + wrr.weight-- + if wrr.weight <= 0 { + wrr.index = (wrr.index + 1) % len(wrr.bs) + wrr.weight = wrr.bs[wrr.index].GetWeight() + } + return b, nil + } + wrr.index = (wrr.index + 1) % len(wrr.bs) + wrr.weight = wrr.bs[wrr.index].GetWeight() + } + return nil, ErrNoBackends +} diff --git a/router/weightedroundrobin/weightedroundrobin_test.go b/router/weightedroundrobin/weightedroundrobin_test.go new file mode 100644 index 0000000..00ce3d9 --- /dev/null +++ b/router/weightedroundrobin/weightedroundrobin_test.go @@ -0,0 +1,133 @@ +package weightedroundrobin + +import ( + "load-balancer/backend" + "sync" + "testing" +) + +type stubPool struct { + backends []*backend.Backend +} + +func (s *stubPool) GetPool() []*backend.Backend { + return s.backends +} + +func newHealthy(weight int) *backend.Backend { + b := backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:0", Weight: weight}) + b.SetHealth(true) + return b +} + +func newUnhealthy(weight int) *backend.Backend { + return backend.NewBackend(backend.BackendOptions{Url: "127.0.0.1:0", Weight: weight}) +} + +func TestWeightDistribution(t *testing.T) { + a := newHealthy(3) + b := newHealthy(1) + wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{a, b}}) + + want := []*backend.Backend{a, a, a, b, a, a, a, b} + for i, expected := range want { + got, err := wrr.GetBackend() + if err != nil { + t.Fatalf("call %d: unexpected error: %v", i, err) + } + if got != expected { + t.Errorf("call %d: got wrong backend", i) + } + } +} + +func TestSkipsUnhealthyBackend(t *testing.T) { + healthy := newHealthy(2) + unhealthy := newUnhealthy(3) + wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{unhealthy, healthy}}) + + for i := range 4 { + got, err := wrr.GetBackend() + if err != nil { + t.Fatalf("call %d: unexpected error: %v", i, err) + } + if got != healthy { + t.Errorf("call %d: expected healthy backend", i) + } + } +} + +func TestReturnsErrNoBackendsWhenAllUnhealthy(t *testing.T) { + wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{ + newUnhealthy(2), + newUnhealthy(1), + }}) + + _, err := wrr.GetBackend() + if err != ErrNoBackends { + t.Errorf("expected ErrNoBackends, got %v", err) + } +} + +func TestWrapAround(t *testing.T) { + a := newHealthy(1) + b := newHealthy(1) + wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{a, b}}) + + got0, _ := wrr.GetBackend() + got1, _ := wrr.GetBackend() + got2, _ := wrr.GetBackend() + + if got0 != a { + t.Error("call 0: expected a") + } + if got1 != b { + t.Error("call 1: expected b") + } + if got2 != a { + t.Error("call 2: expected a (wrap-around)") + } +} + +func TestZeroWeightBackendsAreExcluded(t *testing.T) { + zero := newHealthy(0) + nonzero := newHealthy(2) + wrr := NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{zero, nonzero}}) + + for i := range 4 { + got, err := wrr.GetBackend() + if err != nil { + t.Fatalf("call %d: unexpected error: %v", i, err) + } + if got != nonzero { + t.Errorf("call %d: zero-weight backend should be excluded", i) + } + } +} + +func TestNewWeightedRoundRobinPanicsOnAllZeroWeights(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic when all backends have zero weight") + } + }() + NewWeightedRoundRobin(&stubPool{backends: []*backend.Backend{newHealthy(0)}}) +} + +func TestConcurrentSafety(t *testing.T) { + backends := make([]*backend.Backend, 3) + for i := range backends { + backends[i] = newHealthy(i + 1) + } + wrr := NewWeightedRoundRobin(&stubPool{backends: backends}) + + var wg sync.WaitGroup + for range 100 { + wg.Add(1) + go func() { + defer wg.Done() + wrr.GetBackend() + }() + } + wg.Wait() +}