fix middleware tracer and cloudflareRealIP

This commit is contained in:
yusing 2024-12-18 09:03:12 +08:00
parent 783b352e3b
commit 87279688e6
12 changed files with 81 additions and 71 deletions

View file

@ -11,7 +11,7 @@ import (
type (
cidrWhitelist struct {
CIDRWhitelistOpts
*Tracer
Tracer
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
CIDRWhitelistOpts struct {

View file

@ -11,6 +11,7 @@ import (
"time"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -52,6 +53,14 @@ func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool
return cri.realIP.before(w, r)
}
func (cri *cloudflareRealIP) enableTrace() {
cri.realIP.enableTrace()
}
func (cri *cloudflareRealIP) getTracer() *Tracer {
return cri.realIP.getTracer()
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return
@ -74,14 +83,17 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
} else {
cfCIDRs = make([]*types.CIDR, 0, 30)
err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
)
if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
if len(cfCIDRs) == 0 {
logging.Warn().Msg("cloudflare CIDR range is empty")
}
}
cfCIDRsLastUpdate = time.Now()
@ -89,7 +101,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
@ -110,7 +122,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
}
cfCIDRs = append(cfCIDRs, cidr)
*cfCIDRs = append(*cfCIDRs, cidr)
}
return nil

View file

@ -19,7 +19,7 @@ import (
type (
forwardAuth struct {
ForwardAuthOpts
*Tracer
Tracer
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
}
ForwardAuthOpts struct {

View file

@ -7,6 +7,7 @@ import (
"strings"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/utils"
)
@ -32,7 +33,10 @@ type (
ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() }
MiddlewareWithTracer *struct{ *Tracer }
MiddlewareWithTracer interface {
enableTrace()
getTracer() *Tracer
}
)
func NewMiddleware[ImplType any]() *Middleware {
@ -51,20 +55,21 @@ func NewMiddleware[ImplType any]() *Middleware {
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.Tracer = &Tracer{name: m.name}
tracer.enableTrace()
logging.Debug().Msgf("middleware %s enabled trace", m.name)
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.Tracer
return tracer.getTracer()
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.parent = parent.getTracer()
tracer.SetParent(parent.getTracer())
}
}

View file

@ -28,6 +28,9 @@ func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
}
if common.IsDebug {
for _, child := range chain {
child.enableTrace()
}
m.enableTrace()
}
return m

View file

@ -1,10 +1,7 @@
package middleware
import (
"fmt"
"net/http"
"path"
"strings"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
@ -12,32 +9,9 @@ import (
"github.com/yusing/go-proxy/internal/utils/strutils"
)
var allMiddlewares map[string]*Middleware
var (
ErrUnknownMiddleware = E.New("unknown middleware")
ErrDuplicatedMiddleware = E.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
if !ok {
return nil, ErrUnknownMiddleware.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
}
return middleware, nil
}
func All() map[string]*Middleware {
return allMiddlewares
}
// initialize middleware names and label parsers.
func init() {
// snakes and cases will be stripped on `Get`
// so keys are lowercase without snake.
allMiddlewares = map[string]*Middleware{
var allMiddlewares = map[string]*Middleware{
"redirecthttp": RedirectHTTP,
"request": ModifyRequest,
@ -60,17 +34,24 @@ func init() {
"forwardauth": ForwardAuth,
// "oauth2": OAuth2.m,
}
names := make(map[*Middleware][]string)
for name, m := range allMiddlewares {
names[m] = append(names[m], http.CanonicalHeaderKey(name))
}
for m, names := range names {
if len(names) > 1 {
m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", "))
} else {
m.name = names[0]
var (
ErrUnknownMiddleware = E.New("unknown middleware")
ErrDuplicatedMiddleware = E.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
if !ok {
return nil, ErrUnknownMiddleware.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
}
return middleware, nil
}
func All() map[string]*Middleware {
return allMiddlewares
}
func LoadComposeFiles() {

View file

@ -8,7 +8,7 @@ import (
type (
modifyRequest struct {
ModifyRequestOpts
*Tracer
Tracer
}
// order: set_headers -> add_headers -> hide_headers
ModifyRequestOpts struct {

View file

@ -6,7 +6,7 @@ import (
type modifyResponse struct {
ModifyRequestOpts
*Tracer
Tracer
}
var ModifyResponse = NewMiddleware[modifyResponse]()

View file

@ -13,7 +13,7 @@ type (
requestMap = map[string]*rate.Limiter
rateLimiter struct {
RateLimiterOpts
*Tracer
Tracer
requestMap requestMap
mu sync.Mutex

View file

@ -13,7 +13,7 @@ import (
type (
realIP struct {
RealIPOpts
*Tracer
Tracer
}
RealIPOpts struct {
// Header is the name of the header to use for the real client IP

View file

@ -10,40 +10,52 @@ import (
type Tracer struct {
name string
parent *Tracer
enabled bool
}
func (t *Tracer) Fullname() string {
if t.parent != nil {
return t.parent.Fullname() + "." + t.name
func _() {
var _ MiddlewareWithTracer = &Tracer{}
}
return t.name
func (t *Tracer) enableTrace() {
t.enabled = true
}
func (t *Tracer) getTracer() *Tracer {
return t
}
func (t *Tracer) SetParent(parent *Tracer) {
if parent == nil {
return
}
t.name = parent.name + "." + t.name
}
func (t *Tracer) addTrace(msg string) *Trace {
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: t.Fullname(),
Caller: t.name,
Message: msg,
})
}
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
if t == nil {
if !t.enabled {
return nil
}
return t.addTrace(fmt.Sprintf(msg, args...))
}
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
if t == nil {
if !t.enabled {
return nil
}
return t.addTrace(msg).WithRequest(req)
}
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
if t == nil {
if !t.enabled {
return nil
}
return t.addTrace(msg).WithResponse(resp)

View file

@ -13,7 +13,6 @@ import (
"github.com/go-playground/validator/v10"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
"gopkg.in/yaml.v3"
@ -46,10 +45,8 @@ func RegisterDefaultValueFactory[T any](factory func() *T) {
func New(t reflect.Type) reflect.Value {
if dv, ok := defaultValues.Load(t); ok {
logging.Debug().Str("type", t.String()).Msg("using default value")
return reflect.ValueOf(dv())
}
logging.Debug().Str("type", t.String()).Msg("using zero value")
return reflect.New(t)
}