diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index aa35b9c..1b3cec0 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -8,7 +8,6 @@ import ( "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/logging/accesslog" - gphttp "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage" "github.com/yusing/go-proxy/internal/route/routes" @@ -69,19 +68,17 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request } func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if ep.accessLogger != nil { + w = accesslog.NewResponseRecorder(w) + defer ep.accessLogger.Log(r, w.(*accesslog.ResponseRecorder).Response()) + } mux, err := ep.findRouteFunc(r.Host) if err == nil { - if ep.accessLogger != nil { - w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { - ep.accessLogger.Log(r, resp) - return nil - }) - } if ep.middleware != nil { ep.middleware.ServeHTTP(mux.ServeHTTP, w, routes.WithRouteContext(r, mux)) - return + } else { + mux.ServeHTTP(w, r) } - mux.ServeHTTP(w, r) return } // Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway? diff --git a/internal/logging/accesslog/response_recorder.go b/internal/logging/accesslog/response_recorder.go new file mode 100644 index 0000000..4a3b96e --- /dev/null +++ b/internal/logging/accesslog/response_recorder.go @@ -0,0 +1,67 @@ +package accesslog + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +type ResponseRecorder struct { + w http.ResponseWriter + + resp http.Response +} + +func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder { + return &ResponseRecorder{ + w: w, + resp: http.Response{ + StatusCode: http.StatusOK, + Header: w.Header(), + }, + } +} + +func (w *ResponseRecorder) Unwrap() http.ResponseWriter { + return w.w +} + +func (w *ResponseRecorder) Response() *http.Response { + return &w.resp +} + +func (w *ResponseRecorder) Header() http.Header { + return w.w.Header() +} + +func (w *ResponseRecorder) Write(b []byte) (int, error) { + n, err := w.w.Write(b) + w.resp.ContentLength += int64(n) + return n, err +} + +func (w *ResponseRecorder) WriteHeader(code int) { + w.w.WriteHeader(code) + + if code >= http.StatusContinue && code < http.StatusOK { + return + } + w.resp.StatusCode = code +} + +// Hijack hijacks the connection. +func (w *ResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := w.w.(http.Hijacker); ok { + return h.Hijack() + } + + return nil, nil, fmt.Errorf("not a hijacker: %T", w.w) +} + +// Flush sends any buffered data to the client. +func (w *ResponseRecorder) Flush() { + if flusher, ok := w.w.(http.Flusher); ok { + flusher.Flush() + } +}