mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 12:42:34 +02:00
fix middleware tracer and cloudflareRealIP
This commit is contained in:
parent
783b352e3b
commit
87279688e6
12 changed files with 81 additions and 71 deletions
|
@ -11,7 +11,7 @@ import (
|
|||
type (
|
||||
cidrWhitelist struct {
|
||||
CIDRWhitelistOpts
|
||||
*Tracer
|
||||
Tracer
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
CIDRWhitelistOpts struct {
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
@ -52,6 +53,14 @@ func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool
|
|||
return cri.realIP.before(w, r)
|
||||
}
|
||||
|
||||
func (cri *cloudflareRealIP) enableTrace() {
|
||||
cri.realIP.enableTrace()
|
||||
}
|
||||
|
||||
func (cri *cloudflareRealIP) getTracer() *Tracer {
|
||||
return cri.realIP.getTracer()
|
||||
}
|
||||
|
||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
|
@ -74,14 +83,17 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
|||
} else {
|
||||
cfCIDRs = make([]*types.CIDR, 0, 30)
|
||||
err := errors.Join(
|
||||
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
|
||||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
|
||||
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
|
||||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
|
||||
)
|
||||
if err != nil {
|
||||
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
||||
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||
return nil
|
||||
}
|
||||
if len(cfCIDRs) == 0 {
|
||||
logging.Warn().Msg("cloudflare CIDR range is empty")
|
||||
}
|
||||
}
|
||||
|
||||
cfCIDRsLastUpdate = time.Now()
|
||||
|
@ -89,7 +101,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
|||
return
|
||||
}
|
||||
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
||||
resp, err := http.Get(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -110,7 +122,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
|
|||
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
||||
}
|
||||
|
||||
cfCIDRs = append(cfCIDRs, cidr)
|
||||
*cfCIDRs = append(*cfCIDRs, cidr)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
type (
|
||||
forwardAuth struct {
|
||||
ForwardAuthOpts
|
||||
*Tracer
|
||||
Tracer
|
||||
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
|
||||
}
|
||||
ForwardAuthOpts struct {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
@ -32,7 +33,10 @@ type (
|
|||
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
||||
MiddlewareWithSetup interface{ setup() }
|
||||
MiddlewareFinalizer interface{ finalize() }
|
||||
MiddlewareWithTracer *struct{ *Tracer }
|
||||
MiddlewareWithTracer interface {
|
||||
enableTrace()
|
||||
getTracer() *Tracer
|
||||
}
|
||||
)
|
||||
|
||||
func NewMiddleware[ImplType any]() *Middleware {
|
||||
|
@ -51,20 +55,21 @@ func NewMiddleware[ImplType any]() *Middleware {
|
|||
|
||||
func (m *Middleware) enableTrace() {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
tracer.Tracer = &Tracer{name: m.name}
|
||||
tracer.enableTrace()
|
||||
logging.Debug().Msgf("middleware %s enabled trace", m.name)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) getTracer() *Tracer {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
return tracer.Tracer
|
||||
return tracer.getTracer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) setParent(parent *Middleware) {
|
||||
if tracer := m.getTracer(); tracer != nil {
|
||||
tracer.parent = parent.getTracer()
|
||||
tracer.SetParent(parent.getTracer())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,9 @@ func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
|
|||
}
|
||||
|
||||
if common.IsDebug {
|
||||
for _, child := range chain {
|
||||
child.enableTrace()
|
||||
}
|
||||
m.enableTrace()
|
||||
}
|
||||
return m
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
@ -12,32 +9,9 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var allMiddlewares map[string]*Middleware
|
||||
|
||||
var (
|
||||
ErrUnknownMiddleware = E.New("unknown middleware")
|
||||
ErrDuplicatedMiddleware = E.New("duplicated middleware")
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
|
||||
if !ok {
|
||||
return nil, ErrUnknownMiddleware.
|
||||
Subject(name).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
|
||||
}
|
||||
return middleware, nil
|
||||
}
|
||||
|
||||
func All() map[string]*Middleware {
|
||||
return allMiddlewares
|
||||
}
|
||||
|
||||
// initialize middleware names and label parsers.
|
||||
func init() {
|
||||
// snakes and cases will be stripped on `Get`
|
||||
// so keys are lowercase without snake.
|
||||
allMiddlewares = map[string]*Middleware{
|
||||
var allMiddlewares = map[string]*Middleware{
|
||||
"redirecthttp": RedirectHTTP,
|
||||
|
||||
"request": ModifyRequest,
|
||||
|
@ -60,17 +34,24 @@ func init() {
|
|||
"forwardauth": ForwardAuth,
|
||||
// "oauth2": OAuth2.m,
|
||||
}
|
||||
names := make(map[*Middleware][]string)
|
||||
for name, m := range allMiddlewares {
|
||||
names[m] = append(names[m], http.CanonicalHeaderKey(name))
|
||||
}
|
||||
for m, names := range names {
|
||||
if len(names) > 1 {
|
||||
m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", "))
|
||||
} else {
|
||||
m.name = names[0]
|
||||
|
||||
var (
|
||||
ErrUnknownMiddleware = E.New("unknown middleware")
|
||||
ErrDuplicatedMiddleware = E.New("duplicated middleware")
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
|
||||
if !ok {
|
||||
return nil, ErrUnknownMiddleware.
|
||||
Subject(name).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
|
||||
}
|
||||
return middleware, nil
|
||||
}
|
||||
|
||||
func All() map[string]*Middleware {
|
||||
return allMiddlewares
|
||||
}
|
||||
|
||||
func LoadComposeFiles() {
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
type (
|
||||
modifyRequest struct {
|
||||
ModifyRequestOpts
|
||||
*Tracer
|
||||
Tracer
|
||||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
ModifyRequestOpts struct {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
type modifyResponse struct {
|
||||
ModifyRequestOpts
|
||||
*Tracer
|
||||
Tracer
|
||||
}
|
||||
|
||||
var ModifyResponse = NewMiddleware[modifyResponse]()
|
||||
|
|
|
@ -13,7 +13,7 @@ type (
|
|||
requestMap = map[string]*rate.Limiter
|
||||
rateLimiter struct {
|
||||
RateLimiterOpts
|
||||
*Tracer
|
||||
Tracer
|
||||
|
||||
requestMap requestMap
|
||||
mu sync.Mutex
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
type (
|
||||
realIP struct {
|
||||
RealIPOpts
|
||||
*Tracer
|
||||
Tracer
|
||||
}
|
||||
RealIPOpts struct {
|
||||
// Header is the name of the header to use for the real client IP
|
||||
|
|
|
@ -10,40 +10,52 @@ import (
|
|||
|
||||
type Tracer struct {
|
||||
name string
|
||||
parent *Tracer
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (t *Tracer) Fullname() string {
|
||||
if t.parent != nil {
|
||||
return t.parent.Fullname() + "." + t.name
|
||||
func _() {
|
||||
var _ MiddlewareWithTracer = &Tracer{}
|
||||
}
|
||||
return t.name
|
||||
|
||||
func (t *Tracer) enableTrace() {
|
||||
t.enabled = true
|
||||
}
|
||||
|
||||
func (t *Tracer) getTracer() *Tracer {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tracer) SetParent(parent *Tracer) {
|
||||
if parent == nil {
|
||||
return
|
||||
}
|
||||
t.name = parent.name + "." + t.name
|
||||
}
|
||||
|
||||
func (t *Tracer) addTrace(msg string) *Trace {
|
||||
return addTrace(&Trace{
|
||||
Time: strutils.FormatTime(time.Now()),
|
||||
Caller: t.Fullname(),
|
||||
Caller: t.name,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
|
||||
if t == nil {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
|
||||
if t == nil {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||
if t == nil {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(msg).WithResponse(resp)
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
|
||||
"github.com/go-playground/validator/v10"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
@ -46,10 +45,8 @@ func RegisterDefaultValueFactory[T any](factory func() *T) {
|
|||
|
||||
func New(t reflect.Type) reflect.Value {
|
||||
if dv, ok := defaultValues.Load(t); ok {
|
||||
logging.Debug().Str("type", t.String()).Msg("using default value")
|
||||
return reflect.ValueOf(dv())
|
||||
}
|
||||
logging.Debug().Str("type", t.String()).Msg("using zero value")
|
||||
return reflect.New(t)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue