Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions handler/accesslog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package handler

import (
"bufio"
"fmt"
"net"
"net/http"
"strconv"
"time"
)

type accessLogResponseWriter struct {
http.ResponseWriter
status int
}

func newAccessLogResponseWriter(w http.ResponseWriter) *accessLogResponseWriter {
return &accessLogResponseWriter{ResponseWriter: w}
}

func (w *accessLogResponseWriter) WriteHeader(status int) {
if w.status == 0 {
w.status = status
w.ResponseWriter.WriteHeader(status)
}
}

func (w *accessLogResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
return w.ResponseWriter.Write(p)
}

func (w *accessLogResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}

func (w *accessLogResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("connection doesn't support hijacking")
}
return hj.Hijack()
}

func markAccessLogStatus(w http.ResponseWriter, status int) {
if lw, ok := w.(*accessLogResponseWriter); ok && lw.status == 0 {
lw.status = status
}
}

func accessLogTarget(req *http.Request) string {
if req.Method == http.MethodConnect {
if req.RequestURI != "" {
return req.RequestURI
}
if req.URL != nil && req.URL.Host != "" {
return req.URL.Host
}
}
if req.URL == nil {
return ""
}
return req.URL.String()
}

func accessLogStatus(status int) string {
if status == 0 {
return "-"
}
text := http.StatusText(status)
if text == "" {
return strconv.Itoa(status)
}
return fmt.Sprintf("%d %s", status, text)
}

func accessLogDuration(start time.Time) time.Duration {
return time.Since(start).Round(time.Millisecond)
}
14 changes: 14 additions & 0 deletions handler/direct_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package handler

import (
"bytes"
"context"
"errors"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

derrors "github.com/SenseUnit/dumbproxy/dialer/errors"
clog "github.com/SenseUnit/dumbproxy/log"
)

type staticReject struct {
Expand Down Expand Up @@ -109,9 +113,11 @@ func TestDirectResponse(t *testing.T) {
}

func TestAccessReject(t *testing.T) {
var logBuf bytes.Buffer
proxy := NewProxyHandler(&Config{
Dialer: deniedDialer{},
AccessReject: staticReject{status: http.StatusTeapot, body: "access response"},
Logger: clog.NewCondLogger(log.New(&logBuf, "", 0), clog.INFO),
})
rr := httptest.NewRecorder()
req := &http.Request{
Expand All @@ -131,4 +137,12 @@ func TestAccessReject(t *testing.T) {
if rr.Body.String() != "access response" {
t.Fatalf("body = %q, want access response", rr.Body.String())
}

logOutput := logBuf.String()
if got := strings.Count(logOutput, "INFO Request:"); got != 1 {
t.Fatalf("INFO Request log count = %d, want 1\nlogs:\n%s", got, logOutput)
}
if !strings.Contains(logOutput, "CONNECT openrouter.ai:443 418 I'm a teapot dur=") {
t.Fatalf("access log is missing status or duration\nlogs:\n%s", logOutput)
}
}
16 changes: 14 additions & 2 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strconv"
"strings"
"sync"
"time"

"github.com/SenseUnit/dumbproxy/auth"
"github.com/SenseUnit/dumbproxy/dialer"
Expand Down Expand Up @@ -112,6 +113,7 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request, u
return
}
defer localconn.Close()
markAccessLogStatus(wr, http.StatusOK)

if buffered := rw.Reader.Buffered(); buffered > 0 {
s.logger.Debug("saving %d bytes buffered in bufio.ReadWriter", buffered)
Expand Down Expand Up @@ -191,7 +193,6 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request,
return
}
defer resp.Body.Close()
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
delHopHeaders(resp.Header)
copyHeader(wr.Header(), resp.Header)
wr.WriteHeader(resp.StatusCode)
Expand Down Expand Up @@ -238,10 +239,21 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
return
}

start := time.Now()
wr = newAccessLogResponseWriter(wr)
ctx := req.Context()
username, ok := s.auth.Validate(ctx, wr, req)
localAddr := getLocalAddr(req.Context())
s.logger.Info("Request: %v => %v %q %v %v %v", req.RemoteAddr, localAddr, username, req.Proto, req.Method, req.URL)
target := accessLogTarget(req)
defer func() {
status := 0
if lw, ok := wr.(*accessLogResponseWriter); ok {
status = lw.status
}
s.logger.Info("Request: %v => %v %q %v %v %v %s dur=%v",
req.RemoteAddr, localAddr, username, req.Proto, req.Method, target,
accessLogStatus(status), accessLogDuration(start))
}()

if !ok {
return
Expand Down