refactor: remove Tracer from middleware implementations and related debugging functionality

This commit is contained in:
yusing 2025-05-29 20:27:25 +08:00
parent 24ba4c2a46
commit 72923b8cfa
11 changed files with 3 additions and 214 deletions

View file

@ -14,7 +14,6 @@ import (
type (
cidrWhitelist struct {
CIDRWhitelistOpts
Tracer
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
CIDRWhitelistOpts struct {
@ -64,13 +63,11 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
break
}
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow)
}
}
if !allow {

View file

@ -45,7 +45,7 @@ var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
// setup implements MiddlewareWithSetup.
func (cri *cloudflareRealIP) setup() {
cri.realIP.RealIPOpts = RealIPOpts{
cri.realIP = realIP{
Header: "CF-Connecting-IP",
Recursive: true,
}
@ -60,14 +60,6 @@ func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool
return cri.realIP.before(w, r)
}
func (cri *cloudflareRealIP) enableTrace() {
cri.realIP.enableTrace()
}
func (cri *cloudflareRealIP) getTracer() *Tracer {
return cri.realIP.getTracer()
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return

View file

@ -8,7 +8,6 @@ import (
"sort"
"strings"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
@ -52,10 +51,6 @@ type (
MiddlewareFinalizerWithError interface {
finalize() error
}
MiddlewareWithTracer interface {
enableTrace()
getTracer() *Tracer
}
)
const DefaultPriority = 10
@ -84,26 +79,6 @@ func NewMiddleware[ImplType any]() *Middleware {
}
}
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.enableTrace()
log.Trace().Msgf("middleware %s enabled trace", m.name)
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.getTracer()
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.SetParent(parent.getTracer())
}
}
func (m *Middleware) setup() {
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
setup.setup()

View file

@ -3,7 +3,6 @@ package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
)
@ -24,14 +23,6 @@ func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
if mr, ok := comp.impl.(ResponseModifier); ok {
chainMid.modResps = append(chainMid.modResps, mr)
}
comp.setParent(m)
}
if common.IsTrace {
for _, child := range chain {
child.enableTrace()
}
m.enableTrace()
}
return m
}

View file

@ -9,7 +9,6 @@ import (
type (
modifyRequest struct {
ModifyRequestOpts
Tracer
}
// order: add_prefix -> set_headers -> add_headers -> hide_headers
ModifyRequestOpts struct {
@ -31,8 +30,6 @@ func (mr *ModifyRequestOpts) finalize() {
// before implements RequestModifier.
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
mr.AddTraceRequest("before modify request", r)
if len(mr.AddPrefix) != 0 {
mr.addPrefix(r, r.URL.Path)
}
@ -41,7 +38,6 @@ func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed
} else {
mr.modifyHeadersWithVarSubstitution(r, nil, r.Header)
}
mr.AddTraceRequest("after modify request", r)
return true
}

View file

@ -6,19 +6,16 @@ import (
type modifyResponse struct {
ModifyRequestOpts
Tracer
}
var ModifyResponse = NewMiddleware[modifyResponse]()
// modifyResponse implements ResponseModifier.
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
mr.AddTraceResponse("before modify response", resp)
if !mr.needVarSubstitution {
mr.modifyHeaders(resp.Request, resp.Header)
} else {
mr.modifyHeadersWithVarSubstitution(resp.Request, resp, resp.Header)
}
mr.AddTraceResponse("after modify response", resp)
return nil
}

View file

@ -13,7 +13,6 @@ type (
requestMap = map[string]*rate.Limiter
rateLimiter struct {
RateLimiterOpts
Tracer
requestMap requestMap
mu sync.Mutex
@ -51,7 +50,6 @@ func (rl *rateLimiter) newLimiter() *rate.Limiter {
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
http.Error(w, "Internal error", http.StatusInternalServerError)
return false
}

View file

@ -11,10 +11,7 @@ import (
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
type (
realIP struct {
RealIPOpts
Tracer
}
realIP RealIPOpts
RealIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string `validate:"required"`
@ -42,7 +39,7 @@ var (
// setup implements MiddlewareWithSetup.
func (ri *realIP) setup() {
ri.RealIPOpts = realIPOptsDefault
*ri = realIP(realIPOptsDefault)
}
// before implements RequestModifier.
@ -77,7 +74,6 @@ func (ri *realIP) setRealIP(req *http.Request) {
}
}
if !isTrusted {
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
return
}
@ -90,7 +86,6 @@ func (ri *realIP) setRealIP(req *http.Request) {
}
if len(realIPs) == 0 {
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
return
}
@ -105,12 +100,10 @@ func (ri *realIP) setRealIP(req *http.Request) {
}
if lastNonTrustedIP == "" {
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
return
}
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP)
}

View file

@ -71,7 +71,6 @@ func TestSetRealIP(t *testing.T) {
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err)
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
}

View file

@ -1,87 +0,0 @@
package middleware
import (
"net/http"
"sync"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders map[string]string `json:"req_headers,omitempty"`
RespHeaders map[string]string `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}
Traces []*Trace
)
var (
traces = make(Traces, 0)
tracesMu sync.Mutex
)
const MaxTraceNum = 100
func GetAllTrace() []*Trace {
return traces
}
func (tr *Trace) WithRequest(req *http.Request) *Trace {
if tr == nil {
return nil
}
tr.URL = req.RequestURI
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
return tr
}
func (tr *Trace) WithResponse(resp *http.Response) *Trace {
if tr == nil {
return nil
}
tr.URL = resp.Request.RequestURI
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
tr.RespStatus = resp.StatusCode
return tr
}
func (tr *Trace) With(what string, additional any) *Trace {
if tr == nil {
return nil
}
if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional[what] = additional
return tr
}
func (tr *Trace) WithError(err error) *Trace {
if tr == nil {
return nil
}
if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional["error"] = err.Error()
return tr
}
func addTrace(t *Trace) *Trace {
tracesMu.Lock()
defer tracesMu.Unlock()
if len(traces) > MaxTraceNum {
traces = traces[1:]
}
traces = append(traces, t)
return t
}

View file

@ -1,62 +0,0 @@
package middleware
import (
"fmt"
"net/http"
"time"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Tracer struct {
name string
enabled bool
}
func _() {
var _ MiddlewareWithTracer = &Tracer{}
}
func (t *Tracer) enableTrace() {
t.enabled = true
}
func (t *Tracer) getTracer() *Tracer {
return t
}
func (t *Tracer) SetParent(parent *Tracer) {
if parent == nil {
return
}
t.name = parent.name + "." + t.name
}
func (t *Tracer) addTrace(msg string) *Trace {
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: t.name,
Message: msg,
})
}
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(fmt.Sprintf(msg, args...))
}
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(msg).WithRequest(req)
}
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(msg).WithResponse(resp)
}