diff --git a/listener/listener.go b/listener/listener.go index adf0a23..5252152 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -29,7 +29,6 @@ func (l *Listener) Listen(port int64) { continue } go func() { - defer conn.Close() err = l.proxy.Handle(conn) if err != nil { fmt.Println(err) diff --git a/proxy/proxy.go b/proxy/proxy.go index a2b6131..3f44c57 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -5,6 +5,7 @@ import ( "io" "load-balancer/backend" "net" + "time" ) type RouterIO interface { @@ -12,26 +13,52 @@ type RouterIO interface { } type Proxy struct { - router RouterIO + router RouterIO + dialTimeout time.Duration + connTimeout time.Duration } func NewProxy(rt RouterIO) *Proxy { - return &Proxy{router: rt} + return &Proxy{ + router: rt, + dialTimeout: 10 * time.Second, + connTimeout: 20 * time.Second, + } } func (p *Proxy) Handle(conn net.Conn) error { + defer conn.Close() localAddr := conn.LocalAddr().String() b := p.router.Route(localAddr) if b == nil { return fmt.Errorf("no available backend") } - fmt.Println(b.GetUrl()) - backendConn, err := net.Dial("tcp", b.GetUrl()) + backendConn, err := net.DialTimeout("tcp", b.GetUrl(), p.dialTimeout) if err != nil { return err } defer backendConn.Close() - go io.Copy(backendConn, conn) - io.Copy(conn, backendConn) + err = conn.SetDeadline(time.Now().Add(p.connTimeout)) + if err != nil { + return err + } + err = backendConn.SetDeadline(time.Now().Add(p.connTimeout)) + if err != nil { + return err + } + ch := make(chan error, 2) + go func() { + _, localErr := io.Copy(backendConn, conn) + ch <- localErr + }() + go func() { + _, localErr := io.Copy(conn, backendConn) + ch <- localErr + }() + for range 2 { + if err = <-ch; err != nil { + return err + } + } return nil } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go new file mode 100644 index 0000000..69c0632 --- /dev/null +++ b/proxy/proxy_test.go @@ -0,0 +1,118 @@ +package proxy + +import ( + "errors" + "io" + "load-balancer/backend" + "net" + "testing" + "time" +) + +type stubRouter struct { + b *backend.Backend +} + +func (s *stubRouter) Route(_ string) *backend.Backend { + return s.b +} + +func startEchoServer(t *testing.T) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { ln.Close() }) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + io.Copy(conn, conn) + }() + return ln.Addr().String() +} + +func TestHandleForwardsData(t *testing.T) { + addr := startEchoServer(t) + p := NewProxy(&stubRouter{b: backend.NewBackend(addr)}) + + clientConn, proxyConn := net.Pipe() + defer clientConn.Close() + go p.Handle(proxyConn) + + if _, err := clientConn.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + + buf := make([]byte, 5) + if _, err := io.ReadFull(clientConn, buf); err != nil { + t.Fatal(err) + } + if string(buf) != "hello" { + t.Errorf("got %q, want %q", string(buf), "hello") + } +} + +func TestHandleNoBackend(t *testing.T) { + p := NewProxy(&stubRouter{b: nil}) + _, proxyConn := net.Pipe() + + err := p.Handle(proxyConn) + if err == nil { + t.Error("expected error when no backend available") + } +} + +func TestHandleDeadlineExceeded(t *testing.T) { + // backend accepts but never sends anything — deadline fires first + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + // block forever — never read or write + select {} + }() + + p := &Proxy{ + router: &stubRouter{b: backend.NewBackend(ln.Addr().String())}, + dialTimeout: 10 * time.Second, + connTimeout: 5 * time.Millisecond, + } + + clientConn, proxyConn := net.Pipe() + defer clientConn.Close() + + err = p.Handle(proxyConn) + + var netErr net.Error + if !errors.As(err, &netErr) || !netErr.Timeout() { + t.Errorf("expected timeout error, got %v", err) + } + + // proxyConn was closed by defer conn.Close() inside Handle + // so writing to clientConn should now fail + _, writeErr := clientConn.Write([]byte("x")) + if writeErr == nil { + t.Error("expected clientConn write to fail after Handle returned") + } +} + +func TestHandleBackendUnreachable(t *testing.T) { + p := NewProxy(&stubRouter{b: backend.NewBackend("127.0.0.1:1")}) + _, proxyConn := net.Pipe() + + err := p.Handle(proxyConn) + if err == nil { + t.Error("expected error when backend is unreachable") + } +}