// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go) // Copyright (c) 2020-2024 Traefik Labs package http import ( "bufio" "fmt" "net" "net/http" ) type ( ModifyResponseFunc func(*http.Response) error ModifyResponseWriter struct { w http.ResponseWriter r *http.Request headerSent bool code int size int modifier ModifyResponseFunc modified bool modifierErr error } ) func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter { return &ModifyResponseWriter{ w: w, r: r, modifier: f, code: http.StatusOK, } } func (w *ModifyResponseWriter) Unwrap() http.ResponseWriter { return w.w } func (w *ModifyResponseWriter) StatusCode() int { return w.code } func (w *ModifyResponseWriter) Size() int { return w.size } func (w *ModifyResponseWriter) WriteHeader(code int) { if w.headerSent { return } if code >= http.StatusContinue && code < http.StatusOK { w.w.WriteHeader(code) } defer func() { w.headerSent = true w.code = code }() if w.modifier == nil || w.modified { w.w.WriteHeader(code) return } resp := http.Response{ StatusCode: code, Header: w.w.Header(), Request: w.r, ContentLength: int64(w.size), } if err := w.modifier(&resp); err != nil { w.modifierErr = fmt.Errorf("response modifier error: %w", err) resp.Status = w.modifierErr.Error() w.w.WriteHeader(http.StatusInternalServerError) return } w.modified = true w.w.WriteHeader(code) } func (w *ModifyResponseWriter) Header() http.Header { return w.w.Header() } func (w *ModifyResponseWriter) Write(b []byte) (int, error) { w.WriteHeader(w.code) if w.modifierErr != nil { return 0, w.modifierErr } n, err := w.w.Write(b) w.size += n return n, err } // Hijack hijacks the connection. func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if h, ok := w.w.(http.Hijacker); ok { return h.Hijack() } return nil, nil, fmt.Errorf("not a hijacker: %T", w.w) } // Flush sends any buffered data to the client. func (w *ModifyResponseWriter) Flush() { if flusher, ok := w.w.(http.Flusher); ok { flusher.Flush() } }