Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,48 @@ package backend
import (
"sync"
"sync/atomic"
"time"
)

type Backend struct {
url string
healthy atomic.Bool
weight int
activeConns atomic.Int64
cb CircuitBreaker
}

type BackendOptions struct {
Url string
Weight int
Url string
Weight int
CircuitWindowSize int
CircuitFailureThreshold int
CircuitHalfOpenThreshold int
CircuitCooldownTimeout time.Duration
}

func NewBackend(opts BackendOptions) *Backend {
if opts.CircuitWindowSize == 0 {
opts.CircuitWindowSize = 10
}
if opts.CircuitFailureThreshold == 0 {
opts.CircuitFailureThreshold = 50
}
if opts.CircuitHalfOpenThreshold == 0 {
opts.CircuitHalfOpenThreshold = 3
}
if opts.CircuitCooldownTimeout == 0 {
opts.CircuitCooldownTimeout = 30 * time.Second
}
return &Backend{
url: opts.Url,
weight: opts.Weight,
cb: NewCircuitBreaker(CircuitBreakerOptions{
WindowSize: opts.CircuitWindowSize,
FailureThreshold: opts.CircuitFailureThreshold,
HalfOpenThreshold: opts.CircuitHalfOpenThreshold,
CooldownTimeout: opts.CircuitCooldownTimeout,
}),
}
}

Expand Down Expand Up @@ -52,6 +76,18 @@ func (b *Backend) GetActiveConns() int64 {
return b.activeConns.Load()
}

func (b *Backend) IsOpen() bool {
return b.cb.IsOpen()
}

func (b *Backend) RecordSuccess() {
b.cb.RecordSuccess()
}

func (b *Backend) RecordFailure() {
b.cb.RecordFailure()
}

type BackendPool struct {
bs []*Backend
mu sync.RWMutex
Expand Down
116 changes: 116 additions & 0 deletions backend/circuit_breaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package backend

import (
"sync"
"time"
)

type CircuitState int

const (
Closed CircuitState = iota
Open
HalfOpen
)

type CircuitBreaker struct {
state CircuitState
window []byte
windowSize int
successes int
failures int
failureThreshold int
halfOpenThreshold int
cooldownTimeout time.Duration
cooldownStartedAt time.Time
mu sync.Mutex
}

type CircuitBreakerOptions struct {
WindowSize int
FailureThreshold int
HalfOpenThreshold int
CooldownTimeout time.Duration
}

func NewCircuitBreaker(opts CircuitBreakerOptions) CircuitBreaker {
return CircuitBreaker{
state: Closed,
window: make([]byte, 0, opts.WindowSize),
windowSize: opts.WindowSize,
failureThreshold: opts.FailureThreshold,
halfOpenThreshold: opts.HalfOpenThreshold,
cooldownTimeout: opts.CooldownTimeout,
}
}

func (b *CircuitBreaker) IsOpen() bool {
b.mu.Lock()
defer b.mu.Unlock()
if b.state != Open {
return false
}
if time.Now().After(b.cooldownStartedAt.Add(b.cooldownTimeout)) {
b.state = HalfOpen
return false
}
return true
}

func (b *CircuitBreaker) RecordFailure() {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == Open {
return
}
b.failures++
if b.state == Closed {
if len(b.window) >= b.windowSize {
if b.window[0] == 'S' {
b.successes--
}
if b.window[0] == 'F' {
b.failures--
}
copy(b.window, b.window[1:])
b.window[len(b.window)-1] = 'F'
} else {
b.window = append(b.window, 'F')
}
}
if b.state == HalfOpen || b.failures*100/b.windowSize >= b.failureThreshold {
b.state = Open
b.cooldownStartedAt = time.Now()
b.successes, b.failures = 0, 0
b.window = make([]byte, 0, b.windowSize)
}
}

func (b *CircuitBreaker) RecordSuccess() {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == Open {
return
}
b.successes++
if b.state == Closed {
if len(b.window) >= b.windowSize {
if b.window[0] == 'S' {
b.successes--
}
if b.window[0] == 'F' {
b.failures--
}
copy(b.window, b.window[1:])
b.window[len(b.window)-1] = 'S'
} else {
b.window = append(b.window, 'S')
}
}
if b.state == HalfOpen {
if b.successes >= b.halfOpenThreshold {
b.state = Closed
b.successes = 0
}
}
}
134 changes: 134 additions & 0 deletions backend/circuit_breaker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package backend

import (
"sync"
"testing"
"time"
)

func newTestCB() CircuitBreaker {
return CircuitBreaker{
state: Closed,
window: make([]byte, 0, 4),
windowSize: 4,
failureThreshold: 50,
halfOpenThreshold: 2,
cooldownTimeout: time.Second,
}
}

func TestClosedToOpenOnFailureThreshold(t *testing.T) {
cb := newTestCB()

// 2 successes + 2 failures = 50% failure rate in window of 4
cb.RecordSuccess()
cb.RecordSuccess()
cb.RecordFailure()
cb.RecordFailure()

if !cb.IsOpen() {
t.Error("expected circuit to be open after failure threshold exceeded")
}
}

func TestStaysClosedBelowThreshold(t *testing.T) {
cb := newTestCB()

cb.RecordSuccess()
cb.RecordSuccess()
cb.RecordSuccess()
cb.RecordFailure()

if cb.IsOpen() {
t.Error("expected circuit to stay closed below failure threshold")
}
}

func TestOpenToHalfOpenAfterCooldown(t *testing.T) {
cb := newTestCB()

cb.RecordSuccess()
cb.RecordSuccess()
cb.RecordFailure()
cb.RecordFailure()

// manually expire the cooldown
cb.mu.Lock()
cb.cooldownStartedAt = time.Now().Add(-2 * time.Second)
cb.mu.Unlock()

if cb.IsOpen() {
t.Error("expected circuit to transition to half-open after cooldown")
}
cb.mu.Lock()
state := cb.state
cb.mu.Unlock()
if state != HalfOpen {
t.Errorf("expected HalfOpen state, got %v", state)
}
}

func TestHalfOpenToClosedOnSuccesses(t *testing.T) {
cb := newTestCB()
cb.mu.Lock()
cb.state = HalfOpen
cb.mu.Unlock()

cb.RecordSuccess()
cb.RecordSuccess()

if cb.IsOpen() {
t.Error("expected circuit to close after half-open threshold successes")
}
cb.mu.Lock()
state := cb.state
cb.mu.Unlock()
if state != Closed {
t.Errorf("expected Closed state, got %v", state)
}
}

func TestHalfOpenToOpenOnFailure(t *testing.T) {
cb := newTestCB()
cb.mu.Lock()
cb.state = HalfOpen
cb.mu.Unlock()

cb.RecordFailure()

if !cb.IsOpen() {
t.Error("expected circuit to reopen on failure in half-open state")
}
}

func TestRecordsIgnoredWhenOpen(t *testing.T) {
cb := newTestCB()
cb.mu.Lock()
cb.state = Open
cb.cooldownStartedAt = time.Now()
cb.mu.Unlock()

cb.RecordFailure()
cb.RecordSuccess()

cb.mu.Lock()
failures := cb.failures
successes := cb.successes
cb.mu.Unlock()

if failures != 0 || successes != 0 {
t.Errorf("expected no records in open state, got failures=%d successes=%d", failures, successes)
}
}

func TestConcurrentSafety(t *testing.T) {
cb := newTestCB()
var wg sync.WaitGroup
for range 100 {
wg.Add(3)
go func() { defer wg.Done(); cb.RecordFailure() }()
go func() { defer wg.Done(); cb.RecordSuccess() }()
go func() { defer wg.Done(); cb.IsOpen() }()
}
wg.Wait()
}
5 changes: 5 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func (p *Proxy) Handle(conn net.Conn) error {

backendConn, err := net.DialTimeout("tcp", b.GetUrl(), p.dialTimeout)
if err != nil {
b.RecordFailure()
return fmt.Errorf("dial %s: %w", b.GetUrl(), err)
}

Expand All @@ -56,6 +57,7 @@ func (p *Proxy) Handle(conn net.Conn) error {
}
err = backendConn.SetDeadline(time.Now().Add(p.connTimeout))
if err != nil {
b.RecordFailure()
return fmt.Errorf("set deadline %s: %w", backendConn.RemoteAddr(), err)
}

Expand All @@ -77,9 +79,12 @@ func (p *Proxy) Handle(conn net.Conn) error {

for range 2 {
if err = <-ch; err != nil {
b.RecordFailure()
return err
}
}

b.RecordSuccess()

return nil
}
23 changes: 23 additions & 0 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,26 @@ func TestActiveConnsNotIncrementedOnFailedDial(t *testing.T) {
t.Errorf("expected 0 active conns after failed dial, got %d", b.GetActiveConns())
}
}

func TestCircuitOpensAfterRepeatedDialFailures(t *testing.T) {
b := backend.NewBackend(backend.BackendOptions{
Url: "127.0.0.1:1",
CircuitWindowSize: 4,
CircuitFailureThreshold: 50,
CircuitCooldownTimeout: time.Minute,
})
p := NewProxy(&stubRouter{b: b}, ProxyOptions{
DialTimeout: 100 * time.Millisecond,
ConnTimeout: 30 * time.Second,
})

for range 2 {
_, proxyConn := net.Pipe()
p.Handle(proxyConn)
proxyConn.Close()
}

if !b.IsOpen() {
t.Error("expected circuit to be open after repeated dial failures")
}
}
2 changes: 1 addition & 1 deletion router/leastconnections/leastconnections.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (ls *LeastConnections) GetBackend() (*backend.Backend, error) {
var best *backend.Backend
bestConns := int64(math.MaxInt64)
for _, b := range bp {
if !b.IsHealthy() {
if !b.IsHealthy() || b.IsOpen() {
continue
}
if activeConns := b.GetActiveConns(); activeConns < bestConns {
Expand Down
2 changes: 1 addition & 1 deletion router/roundrobin/round_robin.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (rr *RoundRobin) GetBackend() (*backend.Backend, error) {
for range len(bp) {
b := bp[rr.index]
rr.index = (rr.index + 1) % len(bp)
if b.IsHealthy() {
if b.IsHealthy() && !b.IsOpen() {
return b, nil
}
}
Expand Down
Loading
Loading