diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index ae7e70d..d6726a6 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -11,7 +11,7 @@ import ( type ( cidrWhitelist struct { CIDRWhitelistOpts - *Tracer + Tracer cachedAddr F.Map[string, bool] // cache for trusted IPs } CIDRWhitelistOpts struct { diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index edee6e3..970d33e 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -11,6 +11,7 @@ import ( "time" "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -52,6 +53,14 @@ func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool return cri.realIP.before(w, r) } +func (cri *cloudflareRealIP) enableTrace() { + cri.realIP.enableTrace() +} + +func (cri *cloudflareRealIP) getTracer() *Tracer { + return cri.realIP.getTracer() +} + func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval { return @@ -74,14 +83,17 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { } else { cfCIDRs = make([]*types.CIDR, 0, 30) err := errors.Join( - fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs), - fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs), + fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs), + fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs), ) if err != nil { cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval) cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval)) return nil } + if len(cfCIDRs) == 0 { + logging.Warn().Msg("cloudflare CIDR range is empty") + } } cfCIDRsLastUpdate = time.Now() @@ -89,7 +101,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { return } -func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error { +func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error { resp, err := http.Get(endpoint) if err != nil { return err @@ -110,7 +122,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error { return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line) } - cfCIDRs = append(cfCIDRs, cidr) + *cfCIDRs = append(*cfCIDRs, cidr) } return nil diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index c193f94..a505ff4 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -19,7 +19,7 @@ import ( type ( forwardAuth struct { ForwardAuthOpts - *Tracer + Tracer reqCookiesMap F.Map[*http.Request, []*http.Cookie] } ForwardAuthOpts struct { diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index b4f702f..ff854e9 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -7,6 +7,7 @@ import ( "strings" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/utils" ) @@ -32,7 +33,10 @@ type ( ResponseModifier interface{ modifyResponse(r *http.Response) error } MiddlewareWithSetup interface{ setup() } MiddlewareFinalizer interface{ finalize() } - MiddlewareWithTracer *struct{ *Tracer } + MiddlewareWithTracer interface { + enableTrace() + getTracer() *Tracer + } ) func NewMiddleware[ImplType any]() *Middleware { @@ -51,20 +55,21 @@ func NewMiddleware[ImplType any]() *Middleware { func (m *Middleware) enableTrace() { if tracer, ok := m.impl.(MiddlewareWithTracer); ok { - tracer.Tracer = &Tracer{name: m.name} + tracer.enableTrace() + logging.Debug().Msgf("middleware %s enabled trace", m.name) } } func (m *Middleware) getTracer() *Tracer { if tracer, ok := m.impl.(MiddlewareWithTracer); ok { - return tracer.Tracer + return tracer.getTracer() } return nil } func (m *Middleware) setParent(parent *Middleware) { if tracer := m.getTracer(); tracer != nil { - tracer.parent = parent.getTracer() + tracer.SetParent(parent.getTracer()) } } diff --git a/internal/net/http/middleware/middleware_chain.go b/internal/net/http/middleware/middleware_chain.go index da14287..932d278 100644 --- a/internal/net/http/middleware/middleware_chain.go +++ b/internal/net/http/middleware/middleware_chain.go @@ -28,6 +28,9 @@ func NewMiddlewareChain(name string, chain []*Middleware) *Middleware { } if common.IsDebug { + for _, child := range chain { + child.enableTrace() + } m.enableTrace() } return m diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index c3195e8..f7bab53 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -1,10 +1,7 @@ package middleware import ( - "fmt" - "net/http" "path" - "strings" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" @@ -12,7 +9,31 @@ import ( "github.com/yusing/go-proxy/internal/utils/strutils" ) -var allMiddlewares map[string]*Middleware +// snakes and cases will be stripped on `Get` +// so keys are lowercase without snake. +var allMiddlewares = map[string]*Middleware{ + "redirecthttp": RedirectHTTP, + + "request": ModifyRequest, + "modifyrequest": ModifyRequest, + "response": ModifyResponse, + "modifyresponse": ModifyResponse, + "setxforwarded": SetXForwarded, + "hidexforwarded": HideXForwarded, + + "errorpage": CustomErrorPage, + "customerrorpage": CustomErrorPage, + + "realip": RealIP, + "cloudflarerealip": CloudflareRealIP, + + "cidrwhitelist": CIDRWhiteList, + "ratelimit": RateLimiter, + + // !experimental + "forwardauth": ForwardAuth, + // "oauth2": OAuth2.m, +} var ( ErrUnknownMiddleware = E.New("unknown middleware") @@ -33,46 +54,6 @@ func All() map[string]*Middleware { return allMiddlewares } -// initialize middleware names and label parsers. -func init() { - // snakes and cases will be stripped on `Get` - // so keys are lowercase without snake. - allMiddlewares = map[string]*Middleware{ - "redirecthttp": RedirectHTTP, - - "request": ModifyRequest, - "modifyrequest": ModifyRequest, - "response": ModifyResponse, - "modifyresponse": ModifyResponse, - "setxforwarded": SetXForwarded, - "hidexforwarded": HideXForwarded, - - "errorpage": CustomErrorPage, - "customerrorpage": CustomErrorPage, - - "realip": RealIP, - "cloudflarerealip": CloudflareRealIP, - - "cidrwhitelist": CIDRWhiteList, - "ratelimit": RateLimiter, - - // !experimental - "forwardauth": ForwardAuth, - // "oauth2": OAuth2.m, - } - names := make(map[*Middleware][]string) - for name, m := range allMiddlewares { - names[m] = append(names[m], http.CanonicalHeaderKey(name)) - } - for m, names := range names { - if len(names) > 1 { - m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", ")) - } else { - m.name = names[0] - } - } -} - func LoadComposeFiles() { errs := E.NewBuilder("middleware compile errors") middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0) diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 1e2d42d..a622c3c 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -8,7 +8,7 @@ import ( type ( modifyRequest struct { ModifyRequestOpts - *Tracer + Tracer } // order: set_headers -> add_headers -> hide_headers ModifyRequestOpts struct { diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index 67f7395..3285d9b 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -6,7 +6,7 @@ import ( type modifyResponse struct { ModifyRequestOpts - *Tracer + Tracer } var ModifyResponse = NewMiddleware[modifyResponse]() diff --git a/internal/net/http/middleware/rate_limit.go b/internal/net/http/middleware/rate_limit.go index 41deca7..07557bf 100644 --- a/internal/net/http/middleware/rate_limit.go +++ b/internal/net/http/middleware/rate_limit.go @@ -13,7 +13,7 @@ type ( requestMap = map[string]*rate.Limiter rateLimiter struct { RateLimiterOpts - *Tracer + Tracer requestMap requestMap mu sync.Mutex diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 8c723d7..0b5a53d 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -13,7 +13,7 @@ import ( type ( realIP struct { RealIPOpts - *Tracer + Tracer } RealIPOpts struct { // Header is the name of the header to use for the real client IP diff --git a/internal/net/http/middleware/tracer.go b/internal/net/http/middleware/tracer.go index 94b7419..c99f08c 100644 --- a/internal/net/http/middleware/tracer.go +++ b/internal/net/http/middleware/tracer.go @@ -9,41 +9,53 @@ import ( ) type Tracer struct { - name string - parent *Tracer + name string + enabled bool } -func (t *Tracer) Fullname() string { - if t.parent != nil { - return t.parent.Fullname() + "." + t.name +func _() { + var _ MiddlewareWithTracer = &Tracer{} +} + +func (t *Tracer) enableTrace() { + t.enabled = true +} + +func (t *Tracer) getTracer() *Tracer { + return t +} + +func (t *Tracer) SetParent(parent *Tracer) { + if parent == nil { + return } - return t.name + t.name = parent.name + "." + t.name } func (t *Tracer) addTrace(msg string) *Trace { return addTrace(&Trace{ Time: strutils.FormatTime(time.Now()), - Caller: t.Fullname(), + Caller: t.name, Message: msg, }) } func (t *Tracer) AddTracef(msg string, args ...any) *Trace { - if t == nil { + if !t.enabled { return nil } return t.addTrace(fmt.Sprintf(msg, args...)) } func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace { - if t == nil { + if !t.enabled { return nil } return t.addTrace(msg).WithRequest(req) } func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace { - if t == nil { + if !t.enabled { return nil } return t.addTrace(msg).WithResponse(resp) diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 08cfab8..a4f6202 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -13,7 +13,6 @@ import ( "github.com/go-playground/validator/v10" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/utils/strutils" "gopkg.in/yaml.v3" @@ -46,10 +45,8 @@ func RegisterDefaultValueFactory[T any](factory func() *T) { func New(t reflect.Type) reflect.Value { if dv, ok := defaultValues.Load(t); ok { - logging.Debug().Str("type", t.String()).Msg("using default value") return reflect.ValueOf(dv()) } - logging.Debug().Str("type", t.String()).Msg("using zero value") return reflect.New(t) }