diff --git a/listener/listener.go b/listener/listener.go index 5252152..d33ce3e 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -1,8 +1,11 @@ package listener import ( + "context" + "errors" "fmt" "net" + "sync" ) type ProxyIO interface { @@ -10,29 +13,69 @@ type ProxyIO interface { } type Listener struct { - proxy ProxyIO + proxy ProxyIO + ln net.Listener + activeConns map[net.Conn]struct{} + mu sync.Mutex + wg sync.WaitGroup } func NewListener(px ProxyIO) *Listener { - return &Listener{proxy: px} + return &Listener{ + proxy: px, + activeConns: make(map[net.Conn]struct{}), + } } -func (l *Listener) Listen(port int64) { - ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) - if err != nil { - panic(err) - } +func (l *Listener) Listen(ln net.Listener) { + l.ln = ln defer ln.Close() for { conn, err := ln.Accept() if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } continue } + l.mu.Lock() + l.activeConns[conn] = struct{}{} + l.mu.Unlock() + l.wg.Add(1) go func() { - err = l.proxy.Handle(conn) + defer l.wg.Done() + defer conn.Close() + defer func() { + l.mu.Lock() + delete(l.activeConns, conn) + l.mu.Unlock() + }() + err := l.proxy.Handle(conn) if err != nil { fmt.Println(err) } }() } } + +func (l *Listener) GracefulShutdown(ctx context.Context) { + l.ln.Close() + + done := make(chan struct{}) + go func() { + l.wg.Wait() + close(done) + }() + + select { + case <-done: + return + case <-ctx.Done(): + l.mu.Lock() + for conn := range l.activeConns { + conn.Close() + } + l.mu.Unlock() + <-done + } +} diff --git a/listener/listener_test.go b/listener/listener_test.go new file mode 100644 index 0000000..88166fd --- /dev/null +++ b/listener/listener_test.go @@ -0,0 +1,158 @@ +package listener + +import ( + "context" + "net" + "sync" + "testing" + "time" +) + +type stubProxy struct { + called chan struct{} + block chan struct{} +} + +func newStubProxy() *stubProxy { + return &stubProxy{ + called: make(chan struct{}, 1), + block: make(chan struct{}), + } +} + +func (s *stubProxy) Handle(conn net.Conn) error { + s.called <- struct{}{} + <-s.block + return nil +} + +func startListener(t *testing.T, l *Listener) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + go l.Listen(ln) + return ln.Addr().String() +} + +func TestListenAcceptsConnection(t *testing.T) { + proxy := newStubProxy() + l := NewListener(proxy) + addr := startListener(t, l) + defer l.GracefulShutdown(context.Background()) + + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + close(proxy.block) + + select { + case <-proxy.called: + case <-time.After(time.Second): + t.Error("proxy Handle was not called") + } +} + +func TestConcurrentListenAcceptsConnection(t *testing.T) { + count := 100 + proxy := &stubProxy{ + called: make(chan struct{}, count), + block: make(chan struct{}), + } + close(proxy.block) + l := NewListener(proxy) + addr := startListener(t, l) + defer l.GracefulShutdown(context.Background()) + + wg := sync.WaitGroup{} + for i := 0; i < count; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := net.Dial("tcp", addr) + if err != nil { + return + } + conn.Close() + }() + } + wg.Wait() + + for i := 0; i < 100; i++ { + select { + case <-proxy.called: + case <-time.After(time.Second): + t.Errorf("only %d of 100 connections were handled", i) + return + } + } +} + +func TestGracefulShutdownWaitsForConnections(t *testing.T) { + proxy := newStubProxy() + l := NewListener(proxy) + addr := startListener(t, l) + + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + <-proxy.called + + done := make(chan struct{}) + go func() { + l.GracefulShutdown(context.Background()) + close(done) + }() + + select { + case <-done: + t.Error("GracefulShutdown returned before connection finished") + case <-time.After(50 * time.Millisecond): + } + + close(proxy.block) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("GracefulShutdown did not return after connection finished") + } +} + +func TestGracefulShutdownForcesCloseOnDeadline(t *testing.T) { + proxy := newStubProxy() + l := NewListener(proxy) + addr := startListener(t, l) + + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + <-proxy.called + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + done := make(chan struct{}) + go func() { + l.GracefulShutdown(ctx) + close(done) + }() + + <-ctx.Done() + close(proxy.block) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("GracefulShutdown did not force close on deadline") + } +} diff --git a/main.go b/main.go index d056418..76da389 100644 --- a/main.go +++ b/main.go @@ -1,17 +1,27 @@ package main import ( + "context" "fmt" "load-balancer/backend" "load-balancer/listener" "load-balancer/proxy" "load-balancer/router" "load-balancer/router/roundrobin" + "net" + "os" + "os/signal" + "syscall" + "time" ) func main() { port := 8080 host := "[::1]" + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + panic(err) + } b := backend.NewBackend("localhost:80") b1 := backend.NewBackend("localhost:8081") bp := backend.NewBackendPool([]*backend.Backend{b, b1}) @@ -19,5 +29,15 @@ func main() { r := router.NewRouter(host+fmt.Sprintf(":%d", port), algo) p := proxy.NewProxy(r) l := listener.NewListener(p) - l.Listen(int64(port)) + + go l.Listen(ln) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + <-sigCh + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + l.GracefulShutdown(ctx) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 3f44c57..9648a3a 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -27,7 +27,6 @@ func NewProxy(rt RouterIO) *Proxy { } func (p *Proxy) Handle(conn net.Conn) error { - defer conn.Close() localAddr := conn.LocalAddr().String() b := p.router.Route(localAddr) if b == nil {