diff --git a/backend/backend.go b/backend/backend.go index 20bcf27..2d3b40c 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -3,6 +3,7 @@ package backend import ( "sync" "sync/atomic" + "time" ) type Backend struct { @@ -10,17 +11,40 @@ type Backend struct { healthy atomic.Bool weight int activeConns atomic.Int64 + cb CircuitBreaker } type BackendOptions struct { - Url string - Weight int + Url string + Weight int + CircuitWindowSize int + CircuitFailureThreshold int + CircuitHalfOpenThreshold int + CircuitCooldownTimeout time.Duration } func NewBackend(opts BackendOptions) *Backend { + if opts.CircuitWindowSize == 0 { + opts.CircuitWindowSize = 10 + } + if opts.CircuitFailureThreshold == 0 { + opts.CircuitFailureThreshold = 50 + } + if opts.CircuitHalfOpenThreshold == 0 { + opts.CircuitHalfOpenThreshold = 3 + } + if opts.CircuitCooldownTimeout == 0 { + opts.CircuitCooldownTimeout = 30 * time.Second + } return &Backend{ url: opts.Url, weight: opts.Weight, + cb: NewCircuitBreaker(CircuitBreakerOptions{ + WindowSize: opts.CircuitWindowSize, + FailureThreshold: opts.CircuitFailureThreshold, + HalfOpenThreshold: opts.CircuitHalfOpenThreshold, + CooldownTimeout: opts.CircuitCooldownTimeout, + }), } } @@ -52,6 +76,18 @@ func (b *Backend) GetActiveConns() int64 { return b.activeConns.Load() } +func (b *Backend) IsOpen() bool { + return b.cb.IsOpen() +} + +func (b *Backend) RecordSuccess() { + b.cb.RecordSuccess() +} + +func (b *Backend) RecordFailure() { + b.cb.RecordFailure() +} + type BackendPool struct { bs []*Backend mu sync.RWMutex diff --git a/backend/circuit_breaker.go b/backend/circuit_breaker.go new file mode 100644 index 0000000..0a45925 --- /dev/null +++ b/backend/circuit_breaker.go @@ -0,0 +1,116 @@ +package backend + +import ( + "sync" + "time" +) + +type CircuitState int + +const ( + Closed CircuitState = iota + Open + HalfOpen +) + +type CircuitBreaker struct { + state CircuitState + window []byte + windowSize int + successes int + failures int + failureThreshold int + halfOpenThreshold int + cooldownTimeout time.Duration + cooldownStartedAt time.Time + mu sync.Mutex +} + +type CircuitBreakerOptions struct { + WindowSize int + FailureThreshold int + HalfOpenThreshold int + CooldownTimeout time.Duration +} + +func NewCircuitBreaker(opts CircuitBreakerOptions) CircuitBreaker { + return CircuitBreaker{ + state: Closed, + window: make([]byte, 0, opts.WindowSize), + windowSize: opts.WindowSize, + failureThreshold: opts.FailureThreshold, + halfOpenThreshold: opts.HalfOpenThreshold, + cooldownTimeout: opts.CooldownTimeout, + } +} + +func (b *CircuitBreaker) IsOpen() bool { + b.mu.Lock() + defer b.mu.Unlock() + if b.state != Open { + return false + } + if time.Now().After(b.cooldownStartedAt.Add(b.cooldownTimeout)) { + b.state = HalfOpen + return false + } + return true +} + +func (b *CircuitBreaker) RecordFailure() { + b.mu.Lock() + defer b.mu.Unlock() + if b.state == Open { + return + } + b.failures++ + if b.state == Closed { + if len(b.window) >= b.windowSize { + if b.window[0] == 'S' { + b.successes-- + } + if b.window[0] == 'F' { + b.failures-- + } + copy(b.window, b.window[1:]) + b.window[len(b.window)-1] = 'F' + } else { + b.window = append(b.window, 'F') + } + } + if b.state == HalfOpen || b.failures*100/b.windowSize >= b.failureThreshold { + b.state = Open + b.cooldownStartedAt = time.Now() + b.successes, b.failures = 0, 0 + b.window = make([]byte, 0, b.windowSize) + } +} + +func (b *CircuitBreaker) RecordSuccess() { + b.mu.Lock() + defer b.mu.Unlock() + if b.state == Open { + return + } + b.successes++ + if b.state == Closed { + if len(b.window) >= b.windowSize { + if b.window[0] == 'S' { + b.successes-- + } + if b.window[0] == 'F' { + b.failures-- + } + copy(b.window, b.window[1:]) + b.window[len(b.window)-1] = 'S' + } else { + b.window = append(b.window, 'S') + } + } + if b.state == HalfOpen { + if b.successes >= b.halfOpenThreshold { + b.state = Closed + b.successes = 0 + } + } +} diff --git a/backend/circuit_breaker_test.go b/backend/circuit_breaker_test.go new file mode 100644 index 0000000..54c9343 --- /dev/null +++ b/backend/circuit_breaker_test.go @@ -0,0 +1,134 @@ +package backend + +import ( + "sync" + "testing" + "time" +) + +func newTestCB() CircuitBreaker { + return CircuitBreaker{ + state: Closed, + window: make([]byte, 0, 4), + windowSize: 4, + failureThreshold: 50, + halfOpenThreshold: 2, + cooldownTimeout: time.Second, + } +} + +func TestClosedToOpenOnFailureThreshold(t *testing.T) { + cb := newTestCB() + + // 2 successes + 2 failures = 50% failure rate in window of 4 + cb.RecordSuccess() + cb.RecordSuccess() + cb.RecordFailure() + cb.RecordFailure() + + if !cb.IsOpen() { + t.Error("expected circuit to be open after failure threshold exceeded") + } +} + +func TestStaysClosedBelowThreshold(t *testing.T) { + cb := newTestCB() + + cb.RecordSuccess() + cb.RecordSuccess() + cb.RecordSuccess() + cb.RecordFailure() + + if cb.IsOpen() { + t.Error("expected circuit to stay closed below failure threshold") + } +} + +func TestOpenToHalfOpenAfterCooldown(t *testing.T) { + cb := newTestCB() + + cb.RecordSuccess() + cb.RecordSuccess() + cb.RecordFailure() + cb.RecordFailure() + + // manually expire the cooldown + cb.mu.Lock() + cb.cooldownStartedAt = time.Now().Add(-2 * time.Second) + cb.mu.Unlock() + + if cb.IsOpen() { + t.Error("expected circuit to transition to half-open after cooldown") + } + cb.mu.Lock() + state := cb.state + cb.mu.Unlock() + if state != HalfOpen { + t.Errorf("expected HalfOpen state, got %v", state) + } +} + +func TestHalfOpenToClosedOnSuccesses(t *testing.T) { + cb := newTestCB() + cb.mu.Lock() + cb.state = HalfOpen + cb.mu.Unlock() + + cb.RecordSuccess() + cb.RecordSuccess() + + if cb.IsOpen() { + t.Error("expected circuit to close after half-open threshold successes") + } + cb.mu.Lock() + state := cb.state + cb.mu.Unlock() + if state != Closed { + t.Errorf("expected Closed state, got %v", state) + } +} + +func TestHalfOpenToOpenOnFailure(t *testing.T) { + cb := newTestCB() + cb.mu.Lock() + cb.state = HalfOpen + cb.mu.Unlock() + + cb.RecordFailure() + + if !cb.IsOpen() { + t.Error("expected circuit to reopen on failure in half-open state") + } +} + +func TestRecordsIgnoredWhenOpen(t *testing.T) { + cb := newTestCB() + cb.mu.Lock() + cb.state = Open + cb.cooldownStartedAt = time.Now() + cb.mu.Unlock() + + cb.RecordFailure() + cb.RecordSuccess() + + cb.mu.Lock() + failures := cb.failures + successes := cb.successes + cb.mu.Unlock() + + if failures != 0 || successes != 0 { + t.Errorf("expected no records in open state, got failures=%d successes=%d", failures, successes) + } +} + +func TestConcurrentSafety(t *testing.T) { + cb := newTestCB() + var wg sync.WaitGroup + for range 100 { + wg.Add(3) + go func() { defer wg.Done(); cb.RecordFailure() }() + go func() { defer wg.Done(); cb.RecordSuccess() }() + go func() { defer wg.Done(); cb.IsOpen() }() + } + wg.Wait() +} diff --git a/proxy/proxy.go b/proxy/proxy.go index a39baee..7b489a1 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -43,6 +43,7 @@ func (p *Proxy) Handle(conn net.Conn) error { backendConn, err := net.DialTimeout("tcp", b.GetUrl(), p.dialTimeout) if err != nil { + b.RecordFailure() return fmt.Errorf("dial %s: %w", b.GetUrl(), err) } @@ -56,6 +57,7 @@ func (p *Proxy) Handle(conn net.Conn) error { } err = backendConn.SetDeadline(time.Now().Add(p.connTimeout)) if err != nil { + b.RecordFailure() return fmt.Errorf("set deadline %s: %w", backendConn.RemoteAddr(), err) } @@ -77,9 +79,12 @@ func (p *Proxy) Handle(conn net.Conn) error { for range 2 { if err = <-ch; err != nil { + b.RecordFailure() return err } } + b.RecordSuccess() + return nil } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 0d3f8ea..acb9732 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -190,3 +190,26 @@ func TestActiveConnsNotIncrementedOnFailedDial(t *testing.T) { t.Errorf("expected 0 active conns after failed dial, got %d", b.GetActiveConns()) } } + +func TestCircuitOpensAfterRepeatedDialFailures(t *testing.T) { + b := backend.NewBackend(backend.BackendOptions{ + Url: "127.0.0.1:1", + CircuitWindowSize: 4, + CircuitFailureThreshold: 50, + CircuitCooldownTimeout: time.Minute, + }) + p := NewProxy(&stubRouter{b: b}, ProxyOptions{ + DialTimeout: 100 * time.Millisecond, + ConnTimeout: 30 * time.Second, + }) + + for range 2 { + _, proxyConn := net.Pipe() + p.Handle(proxyConn) + proxyConn.Close() + } + + if !b.IsOpen() { + t.Error("expected circuit to be open after repeated dial failures") + } +} diff --git a/router/leastconnections/leastconnections.go b/router/leastconnections/leastconnections.go index 981b0b7..58dd1dc 100644 --- a/router/leastconnections/leastconnections.go +++ b/router/leastconnections/leastconnections.go @@ -28,7 +28,7 @@ func (ls *LeastConnections) GetBackend() (*backend.Backend, error) { var best *backend.Backend bestConns := int64(math.MaxInt64) for _, b := range bp { - if !b.IsHealthy() { + if !b.IsHealthy() || b.IsOpen() { continue } if activeConns := b.GetActiveConns(); activeConns < bestConns { diff --git a/router/roundrobin/round_robin.go b/router/roundrobin/round_robin.go index 581f32c..b70668a 100644 --- a/router/roundrobin/round_robin.go +++ b/router/roundrobin/round_robin.go @@ -32,7 +32,7 @@ func (rr *RoundRobin) GetBackend() (*backend.Backend, error) { for range len(bp) { b := bp[rr.index] rr.index = (rr.index + 1) % len(bp) - if b.IsHealthy() { + if b.IsHealthy() && !b.IsOpen() { return b, nil } } diff --git a/router/weightedroundrobin/weightedroundrobin.go b/router/weightedroundrobin/weightedroundrobin.go index 577e49e..7ddb7ef 100644 --- a/router/weightedroundrobin/weightedroundrobin.go +++ b/router/weightedroundrobin/weightedroundrobin.go @@ -44,7 +44,7 @@ func (wrr *WeightedRoundRobin) GetBackend() (*backend.Backend, error) { defer wrr.mu.Unlock() for range len(wrr.bs) { b := wrr.bs[wrr.index] - if b.IsHealthy() { + if b.IsHealthy() && !b.IsOpen() { wrr.weight-- if wrr.weight <= 0 { wrr.index = (wrr.index + 1) % len(wrr.bs)