diff --git a/backend/backend.go b/backend/backend.go index 18d8655..4c62b9e 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -1,10 +1,13 @@ package backend -import "sync" +import ( + "sync" + "sync/atomic" +) type Backend struct { url string - healthy bool + healthy atomic.Bool } func NewBackend(url string) *Backend { @@ -16,7 +19,7 @@ func (b *Backend) GetUrl() string { } func (b *Backend) IsHealthy() bool { - return b.healthy + return b.healthy.Load() } type BackendPool struct { @@ -25,6 +28,9 @@ type BackendPool struct { } func NewBackendPool(bs []*Backend) *BackendPool { + if len(bs) < 1 { + panic("backend pool is empty") + } return &BackendPool{bs: bs} } diff --git a/listener/listener.go b/listener/listener.go index d33ce3e..27bbd57 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -3,9 +3,10 @@ package listener import ( "context" "errors" - "fmt" + "log" "net" "sync" + "time" ) type ProxyIO interface { @@ -21,6 +22,9 @@ type Listener struct { } func NewListener(px ProxyIO) *Listener { + if px == nil { + panic("proxy cannot be nil") + } return &Listener{ proxy: px, activeConns: make(map[net.Conn]struct{}), @@ -36,6 +40,8 @@ func (l *Listener) Listen(ln net.Listener) { if errors.Is(err, net.ErrClosed) { return } + log.Printf("accept error: %v", err) + time.Sleep(1 * time.Second) continue } l.mu.Lock() @@ -44,6 +50,12 @@ func (l *Listener) Listen(ln net.Listener) { l.wg.Add(1) go func() { defer l.wg.Done() + defer func() { + if r := recover(); r != nil { + log.Printf("panic handling connection %s: %v", + conn.RemoteAddr(), r) + } + }() defer conn.Close() defer func() { l.mu.Lock() @@ -52,7 +64,7 @@ func (l *Listener) Listen(ln net.Listener) { }() err := l.proxy.Handle(conn) if err != nil { - fmt.Println(err) + log.Printf("connection %s: %v", conn.RemoteAddr(), err) } }() } diff --git a/main.go b/main.go index 76da389..4114e32 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "load-balancer/proxy" "load-balancer/router" "load-balancer/router/roundrobin" + "log" "net" "os" "os/signal" @@ -20,7 +21,7 @@ func main() { host := "[::1]" ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { - panic(err) + log.Fatalf("failed to listen on port %d: %v", port, err) } b := backend.NewBackend("localhost:80") b1 := backend.NewBackend("localhost:8081") diff --git a/proxy/proxy.go b/proxy/proxy.go index 9648a3a..2ded4af 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -9,7 +9,7 @@ import ( ) type RouterIO interface { - Route(string) *backend.Backend + Route(string) (*backend.Backend, error) } type Proxy struct { @@ -19,6 +19,9 @@ type Proxy struct { } func NewProxy(rt RouterIO) *Proxy { + if rt == nil { + panic("router cannot be nil") + } return &Proxy{ router: rt, dialTimeout: 10 * time.Second, @@ -28,30 +31,36 @@ func NewProxy(rt RouterIO) *Proxy { func (p *Proxy) Handle(conn net.Conn) error { localAddr := conn.LocalAddr().String() - b := p.router.Route(localAddr) - if b == nil { - return fmt.Errorf("no available backend") + b, err := p.router.Route(localAddr) + if err != nil { + return fmt.Errorf("routing: %w", err) } backendConn, err := net.DialTimeout("tcp", b.GetUrl(), p.dialTimeout) if err != nil { - return err + return fmt.Errorf("dial %s: %w", b.GetUrl(), err) } defer backendConn.Close() err = conn.SetDeadline(time.Now().Add(p.connTimeout)) if err != nil { - return err + return fmt.Errorf("set deadline %s: %w", conn.RemoteAddr(), err) } err = backendConn.SetDeadline(time.Now().Add(p.connTimeout)) if err != nil { - return err + return fmt.Errorf("set deadline %s: %w", backendConn.RemoteAddr(), err) } ch := make(chan error, 2) go func() { _, localErr := io.Copy(backendConn, conn) + if localErr != nil { + localErr = fmt.Errorf("copy from %s to %s: %w", conn.RemoteAddr(), backendConn.RemoteAddr(), localErr) + } ch <- localErr }() go func() { _, localErr := io.Copy(conn, backendConn) + if localErr != nil { + localErr = fmt.Errorf("copy from %s to %s: %w", backendConn.RemoteAddr(), conn.RemoteAddr(), localErr) + } ch <- localErr }() for range 2 { diff --git a/router/roundrobin/round_robin.go b/router/roundrobin/round_robin.go index a4069ba..f526e94 100644 --- a/router/roundrobin/round_robin.go +++ b/router/roundrobin/round_robin.go @@ -1,10 +1,13 @@ package roundrobin import ( + "errors" "load-balancer/backend" "sync" ) +var ErrNoBackends = errors.New("no backends available") + type BackendPoolIO interface { GetPool() []*backend.Backend } @@ -16,14 +19,22 @@ type RoundRobin struct { } func NewRoundRobin(bs BackendPoolIO) *RoundRobin { + if bs == nil { + panic("backend pool cannot be nil") + } return &RoundRobin{bp: bs} } -func (rr *RoundRobin) GetBackend() *backend.Backend { +func (rr *RoundRobin) GetBackend() (*backend.Backend, error) { rr.mu.Lock() defer rr.mu.Unlock() bp := rr.bp.GetPool() - b := bp[rr.index] - rr.index = (rr.index + 1) % len(bp) - return b + for range len(bp) { + b := bp[rr.index] + rr.index = (rr.index + 1) % len(bp) + if b.IsHealthy() { + return b, nil + } + } + return nil, ErrNoBackends } diff --git a/router/router.go b/router/router.go index e74dcb9..c1ca71f 100644 --- a/router/router.go +++ b/router/router.go @@ -3,7 +3,7 @@ package router import "load-balancer/backend" type AlgoIO interface { - GetBackend() *backend.Backend + GetBackend() (*backend.Backend, error) } type Router struct { @@ -11,11 +11,14 @@ type Router struct { } func NewRouter(path string, be AlgoIO) *Router { + if be == nil { + panic("algorithm cannot be nil") + } router := make(map[string]AlgoIO) router[path] = be return &Router{router: router} } -func (r *Router) Route(path string) *backend.Backend { +func (r *Router) Route(path string) (*backend.Backend, error) { return r.router[path].GetBackend() }