fix middleware tracer and cloudflareRealIP

This commit is contained in:
yusing 2024-12-18 09:03:12 +08:00
parent 783b352e3b
commit 87279688e6
12 changed files with 81 additions and 71 deletions

View file

@ -11,7 +11,7 @@ import (
type ( type (
cidrWhitelist struct { cidrWhitelist struct {
CIDRWhitelistOpts CIDRWhitelistOpts
*Tracer Tracer
cachedAddr F.Map[string, bool] // cache for trusted IPs cachedAddr F.Map[string, bool] // cache for trusted IPs
} }
CIDRWhitelistOpts struct { CIDRWhitelistOpts struct {

View file

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -52,6 +53,14 @@ func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool
return cri.realIP.before(w, r) return cri.realIP.before(w, r)
} }
func (cri *cloudflareRealIP) enableTrace() {
cri.realIP.enableTrace()
}
func (cri *cloudflareRealIP) getTracer() *Tracer {
return cri.realIP.getTracer()
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval { if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return return
@ -74,14 +83,17 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
} else { } else {
cfCIDRs = make([]*types.CIDR, 0, 30) cfCIDRs = make([]*types.CIDR, 0, 30)
err := errors.Join( err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs), fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs), fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
) )
if err != nil { if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval) cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval)) cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil return nil
} }
if len(cfCIDRs) == 0 {
logging.Warn().Msg("cloudflare CIDR range is empty")
}
} }
cfCIDRsLastUpdate = time.Now() cfCIDRsLastUpdate = time.Now()
@ -89,7 +101,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
return return
} }
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error { func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
resp, err := http.Get(endpoint) resp, err := http.Get(endpoint)
if err != nil { if err != nil {
return err return err
@ -110,7 +122,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line) return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} }
cfCIDRs = append(cfCIDRs, cidr) *cfCIDRs = append(*cfCIDRs, cidr)
} }
return nil return nil

View file

@ -19,7 +19,7 @@ import (
type ( type (
forwardAuth struct { forwardAuth struct {
ForwardAuthOpts ForwardAuthOpts
*Tracer Tracer
reqCookiesMap F.Map[*http.Request, []*http.Cookie] reqCookiesMap F.Map[*http.Request, []*http.Cookie]
} }
ForwardAuthOpts struct { ForwardAuthOpts struct {

View file

@ -7,6 +7,7 @@ import (
"strings" "strings"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -32,7 +33,10 @@ type (
ResponseModifier interface{ modifyResponse(r *http.Response) error } ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() } MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() } MiddlewareFinalizer interface{ finalize() }
MiddlewareWithTracer *struct{ *Tracer } MiddlewareWithTracer interface {
enableTrace()
getTracer() *Tracer
}
) )
func NewMiddleware[ImplType any]() *Middleware { func NewMiddleware[ImplType any]() *Middleware {
@ -51,20 +55,21 @@ func NewMiddleware[ImplType any]() *Middleware {
func (m *Middleware) enableTrace() { func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok { if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.Tracer = &Tracer{name: m.name} tracer.enableTrace()
logging.Debug().Msgf("middleware %s enabled trace", m.name)
} }
} }
func (m *Middleware) getTracer() *Tracer { func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok { if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.Tracer return tracer.getTracer()
} }
return nil return nil
} }
func (m *Middleware) setParent(parent *Middleware) { func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil { if tracer := m.getTracer(); tracer != nil {
tracer.parent = parent.getTracer() tracer.SetParent(parent.getTracer())
} }
} }

View file

@ -28,6 +28,9 @@ func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
} }
if common.IsDebug { if common.IsDebug {
for _, child := range chain {
child.enableTrace()
}
m.enableTrace() m.enableTrace()
} }
return m return m

View file

@ -1,10 +1,7 @@
package middleware package middleware
import ( import (
"fmt"
"net/http"
"path" "path"
"strings"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -12,7 +9,31 @@ import (
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
var allMiddlewares map[string]*Middleware // snakes and cases will be stripped on `Get`
// so keys are lowercase without snake.
var allMiddlewares = map[string]*Middleware{
"redirecthttp": RedirectHTTP,
"request": ModifyRequest,
"modifyrequest": ModifyRequest,
"response": ModifyResponse,
"modifyresponse": ModifyResponse,
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP,
"cloudflarerealip": CloudflareRealIP,
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
// !experimental
"forwardauth": ForwardAuth,
// "oauth2": OAuth2.m,
}
var ( var (
ErrUnknownMiddleware = E.New("unknown middleware") ErrUnknownMiddleware = E.New("unknown middleware")
@ -33,46 +54,6 @@ func All() map[string]*Middleware {
return allMiddlewares return allMiddlewares
} }
// initialize middleware names and label parsers.
func init() {
// snakes and cases will be stripped on `Get`
// so keys are lowercase without snake.
allMiddlewares = map[string]*Middleware{
"redirecthttp": RedirectHTTP,
"request": ModifyRequest,
"modifyrequest": ModifyRequest,
"response": ModifyResponse,
"modifyresponse": ModifyResponse,
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP,
"cloudflarerealip": CloudflareRealIP,
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
// !experimental
"forwardauth": ForwardAuth,
// "oauth2": OAuth2.m,
}
names := make(map[*Middleware][]string)
for name, m := range allMiddlewares {
names[m] = append(names[m], http.CanonicalHeaderKey(name))
}
for m, names := range names {
if len(names) > 1 {
m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", "))
} else {
m.name = names[0]
}
}
}
func LoadComposeFiles() { func LoadComposeFiles() {
errs := E.NewBuilder("middleware compile errors") errs := E.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0) middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)

View file

@ -8,7 +8,7 @@ import (
type ( type (
modifyRequest struct { modifyRequest struct {
ModifyRequestOpts ModifyRequestOpts
*Tracer Tracer
} }
// order: set_headers -> add_headers -> hide_headers // order: set_headers -> add_headers -> hide_headers
ModifyRequestOpts struct { ModifyRequestOpts struct {

View file

@ -6,7 +6,7 @@ import (
type modifyResponse struct { type modifyResponse struct {
ModifyRequestOpts ModifyRequestOpts
*Tracer Tracer
} }
var ModifyResponse = NewMiddleware[modifyResponse]() var ModifyResponse = NewMiddleware[modifyResponse]()

View file

@ -13,7 +13,7 @@ type (
requestMap = map[string]*rate.Limiter requestMap = map[string]*rate.Limiter
rateLimiter struct { rateLimiter struct {
RateLimiterOpts RateLimiterOpts
*Tracer Tracer
requestMap requestMap requestMap requestMap
mu sync.Mutex mu sync.Mutex

View file

@ -13,7 +13,7 @@ import (
type ( type (
realIP struct { realIP struct {
RealIPOpts RealIPOpts
*Tracer Tracer
} }
RealIPOpts struct { RealIPOpts struct {
// Header is the name of the header to use for the real client IP // Header is the name of the header to use for the real client IP

View file

@ -9,41 +9,53 @@ import (
) )
type Tracer struct { type Tracer struct {
name string name string
parent *Tracer enabled bool
} }
func (t *Tracer) Fullname() string { func _() {
if t.parent != nil { var _ MiddlewareWithTracer = &Tracer{}
return t.parent.Fullname() + "." + t.name }
func (t *Tracer) enableTrace() {
t.enabled = true
}
func (t *Tracer) getTracer() *Tracer {
return t
}
func (t *Tracer) SetParent(parent *Tracer) {
if parent == nil {
return
} }
return t.name t.name = parent.name + "." + t.name
} }
func (t *Tracer) addTrace(msg string) *Trace { func (t *Tracer) addTrace(msg string) *Trace {
return addTrace(&Trace{ return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()), Time: strutils.FormatTime(time.Now()),
Caller: t.Fullname(), Caller: t.name,
Message: msg, Message: msg,
}) })
} }
func (t *Tracer) AddTracef(msg string, args ...any) *Trace { func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
if t == nil { if !t.enabled {
return nil return nil
} }
return t.addTrace(fmt.Sprintf(msg, args...)) return t.addTrace(fmt.Sprintf(msg, args...))
} }
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace { func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
if t == nil { if !t.enabled {
return nil return nil
} }
return t.addTrace(msg).WithRequest(req) return t.addTrace(msg).WithRequest(req)
} }
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace { func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
if t == nil { if !t.enabled {
return nil return nil
} }
return t.addTrace(msg).WithResponse(resp) return t.addTrace(msg).WithResponse(resp)

View file

@ -13,7 +13,6 @@ import (
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@ -46,10 +45,8 @@ func RegisterDefaultValueFactory[T any](factory func() *T) {
func New(t reflect.Type) reflect.Value { func New(t reflect.Type) reflect.Value {
if dv, ok := defaultValues.Load(t); ok { if dv, ok := defaultValues.Load(t); ok {
logging.Debug().Str("type", t.String()).Msg("using default value")
return reflect.ValueOf(dv()) return reflect.ValueOf(dv())
} }
logging.Debug().Str("type", t.String()).Msg("using zero value")
return reflect.New(t) return reflect.New(t)
} }