Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions listener/listener.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,81 @@
package listener

import (
"context"
"errors"
"fmt"
"net"
"sync"
)

type ProxyIO interface {
Handle(net.Conn) error
}

type Listener struct {
proxy ProxyIO
proxy ProxyIO
ln net.Listener
activeConns map[net.Conn]struct{}
mu sync.Mutex
wg sync.WaitGroup
}

func NewListener(px ProxyIO) *Listener {
return &Listener{proxy: px}
return &Listener{
proxy: px,
activeConns: make(map[net.Conn]struct{}),
}
}

func (l *Listener) Listen(port int64) {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
panic(err)
}
func (l *Listener) Listen(ln net.Listener) {
l.ln = ln
defer ln.Close()
for {
conn, err := ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
continue
}
l.mu.Lock()
l.activeConns[conn] = struct{}{}
l.mu.Unlock()
l.wg.Add(1)
go func() {
err = l.proxy.Handle(conn)
defer l.wg.Done()
defer conn.Close()
defer func() {
l.mu.Lock()
delete(l.activeConns, conn)
l.mu.Unlock()
}()
err := l.proxy.Handle(conn)
if err != nil {
fmt.Println(err)
}
}()
}
}

func (l *Listener) GracefulShutdown(ctx context.Context) {
l.ln.Close()

done := make(chan struct{})
go func() {
l.wg.Wait()
close(done)
}()

select {
case <-done:
return
case <-ctx.Done():
l.mu.Lock()
for conn := range l.activeConns {
conn.Close()
}
l.mu.Unlock()
<-done
}
}
158 changes: 158 additions & 0 deletions listener/listener_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package listener

import (
"context"
"net"
"sync"
"testing"
"time"
)

type stubProxy struct {
called chan struct{}
block chan struct{}
}

func newStubProxy() *stubProxy {
return &stubProxy{
called: make(chan struct{}, 1),
block: make(chan struct{}),
}
}

func (s *stubProxy) Handle(conn net.Conn) error {
s.called <- struct{}{}
<-s.block
return nil
}

func startListener(t *testing.T, l *Listener) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
go l.Listen(ln)
return ln.Addr().String()
}

func TestListenAcceptsConnection(t *testing.T) {
proxy := newStubProxy()
l := NewListener(proxy)
addr := startListener(t, l)
defer l.GracefulShutdown(context.Background())

conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
close(proxy.block)

select {
case <-proxy.called:
case <-time.After(time.Second):
t.Error("proxy Handle was not called")
}
}

func TestConcurrentListenAcceptsConnection(t *testing.T) {
count := 100
proxy := &stubProxy{
called: make(chan struct{}, count),
block: make(chan struct{}),
}
close(proxy.block)
l := NewListener(proxy)
addr := startListener(t, l)
defer l.GracefulShutdown(context.Background())

wg := sync.WaitGroup{}
for i := 0; i < count; i++ {
wg.Add(1)
go func() {
defer wg.Done()
conn, err := net.Dial("tcp", addr)
if err != nil {
return
}
conn.Close()
}()
}
wg.Wait()

for i := 0; i < 100; i++ {
select {
case <-proxy.called:
case <-time.After(time.Second):
t.Errorf("only %d of 100 connections were handled", i)
return
}
}
}

func TestGracefulShutdownWaitsForConnections(t *testing.T) {
proxy := newStubProxy()
l := NewListener(proxy)
addr := startListener(t, l)

conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
defer conn.Close()

<-proxy.called

done := make(chan struct{})
go func() {
l.GracefulShutdown(context.Background())
close(done)
}()

select {
case <-done:
t.Error("GracefulShutdown returned before connection finished")
case <-time.After(50 * time.Millisecond):
}

close(proxy.block)

select {
case <-done:
case <-time.After(time.Second):
t.Error("GracefulShutdown did not return after connection finished")
}
}

func TestGracefulShutdownForcesCloseOnDeadline(t *testing.T) {
proxy := newStubProxy()
l := NewListener(proxy)
addr := startListener(t, l)

conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
defer conn.Close()

<-proxy.called

ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()

done := make(chan struct{})
go func() {
l.GracefulShutdown(ctx)
close(done)
}()

<-ctx.Done()
close(proxy.block)

select {
case <-done:
case <-time.After(time.Second):
t.Error("GracefulShutdown did not force close on deadline")
}
}
22 changes: 21 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
package main

import (
"context"
"fmt"
"load-balancer/backend"
"load-balancer/listener"
"load-balancer/proxy"
"load-balancer/router"
"load-balancer/router/roundrobin"
"net"
"os"
"os/signal"
"syscall"
"time"
)

func main() {
port := 8080
host := "[::1]"
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
panic(err)
}
b := backend.NewBackend("localhost:80")
b1 := backend.NewBackend("localhost:8081")
bp := backend.NewBackendPool([]*backend.Backend{b, b1})
algo := roundrobin.NewRoundRobin(bp)
r := router.NewRouter(host+fmt.Sprintf(":%d", port), algo)
p := proxy.NewProxy(r)
l := listener.NewListener(p)
l.Listen(int64(port))

go l.Listen(ln)

sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
<-sigCh

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

l.GracefulShutdown(ctx)
}
1 change: 0 additions & 1 deletion proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func NewProxy(rt RouterIO) *Proxy {
}

func (p *Proxy) Handle(conn net.Conn) error {
defer conn.Close()
localAddr := conn.LocalAddr().String()
b := p.router.Route(localAddr)
if b == nil {
Expand Down
Loading