mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 12:42:34 +02:00
fix middleware tracer and cloudflareRealIP
This commit is contained in:
parent
783b352e3b
commit
87279688e6
12 changed files with 81 additions and 71 deletions
|
@ -11,7 +11,7 @@ import (
|
||||||
type (
|
type (
|
||||||
cidrWhitelist struct {
|
cidrWhitelist struct {
|
||||||
CIDRWhitelistOpts
|
CIDRWhitelistOpts
|
||||||
*Tracer
|
Tracer
|
||||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||||
}
|
}
|
||||||
CIDRWhitelistOpts struct {
|
CIDRWhitelistOpts struct {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
@ -52,6 +53,14 @@ func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool
|
||||||
return cri.realIP.before(w, r)
|
return cri.realIP.before(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cri *cloudflareRealIP) enableTrace() {
|
||||||
|
cri.realIP.enableTrace()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cri *cloudflareRealIP) getTracer() *Tracer {
|
||||||
|
return cri.realIP.getTracer()
|
||||||
|
}
|
||||||
|
|
||||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||||
return
|
return
|
||||||
|
@ -74,14 +83,17 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||||
} else {
|
} else {
|
||||||
cfCIDRs = make([]*types.CIDR, 0, 30)
|
cfCIDRs = make([]*types.CIDR, 0, 30)
|
||||||
err := errors.Join(
|
err := errors.Join(
|
||||||
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
|
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
|
||||||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
|
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
||||||
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if len(cfCIDRs) == 0 {
|
||||||
|
logging.Warn().Msg("cloudflare CIDR range is empty")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cfCIDRsLastUpdate = time.Now()
|
cfCIDRsLastUpdate = time.Now()
|
||||||
|
@ -89,7 +101,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
|
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
||||||
resp, err := http.Get(endpoint)
|
resp, err := http.Get(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -110,7 +122,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
|
||||||
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfCIDRs = append(cfCIDRs, cidr)
|
*cfCIDRs = append(*cfCIDRs, cidr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
type (
|
type (
|
||||||
forwardAuth struct {
|
forwardAuth struct {
|
||||||
ForwardAuthOpts
|
ForwardAuthOpts
|
||||||
*Tracer
|
Tracer
|
||||||
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
|
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
|
||||||
}
|
}
|
||||||
ForwardAuthOpts struct {
|
ForwardAuthOpts struct {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
@ -32,7 +33,10 @@ type (
|
||||||
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
||||||
MiddlewareWithSetup interface{ setup() }
|
MiddlewareWithSetup interface{ setup() }
|
||||||
MiddlewareFinalizer interface{ finalize() }
|
MiddlewareFinalizer interface{ finalize() }
|
||||||
MiddlewareWithTracer *struct{ *Tracer }
|
MiddlewareWithTracer interface {
|
||||||
|
enableTrace()
|
||||||
|
getTracer() *Tracer
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewMiddleware[ImplType any]() *Middleware {
|
func NewMiddleware[ImplType any]() *Middleware {
|
||||||
|
@ -51,20 +55,21 @@ func NewMiddleware[ImplType any]() *Middleware {
|
||||||
|
|
||||||
func (m *Middleware) enableTrace() {
|
func (m *Middleware) enableTrace() {
|
||||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||||
tracer.Tracer = &Tracer{name: m.name}
|
tracer.enableTrace()
|
||||||
|
logging.Debug().Msgf("middleware %s enabled trace", m.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) getTracer() *Tracer {
|
func (m *Middleware) getTracer() *Tracer {
|
||||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||||
return tracer.Tracer
|
return tracer.getTracer()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) setParent(parent *Middleware) {
|
func (m *Middleware) setParent(parent *Middleware) {
|
||||||
if tracer := m.getTracer(); tracer != nil {
|
if tracer := m.getTracer(); tracer != nil {
|
||||||
tracer.parent = parent.getTracer()
|
tracer.SetParent(parent.getTracer())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,9 @@ func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.IsDebug {
|
if common.IsDebug {
|
||||||
|
for _, child := range chain {
|
||||||
|
child.enableTrace()
|
||||||
|
}
|
||||||
m.enableTrace()
|
m.enableTrace()
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
@ -12,7 +9,31 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var allMiddlewares map[string]*Middleware
|
// snakes and cases will be stripped on `Get`
|
||||||
|
// so keys are lowercase without snake.
|
||||||
|
var allMiddlewares = map[string]*Middleware{
|
||||||
|
"redirecthttp": RedirectHTTP,
|
||||||
|
|
||||||
|
"request": ModifyRequest,
|
||||||
|
"modifyrequest": ModifyRequest,
|
||||||
|
"response": ModifyResponse,
|
||||||
|
"modifyresponse": ModifyResponse,
|
||||||
|
"setxforwarded": SetXForwarded,
|
||||||
|
"hidexforwarded": HideXForwarded,
|
||||||
|
|
||||||
|
"errorpage": CustomErrorPage,
|
||||||
|
"customerrorpage": CustomErrorPage,
|
||||||
|
|
||||||
|
"realip": RealIP,
|
||||||
|
"cloudflarerealip": CloudflareRealIP,
|
||||||
|
|
||||||
|
"cidrwhitelist": CIDRWhiteList,
|
||||||
|
"ratelimit": RateLimiter,
|
||||||
|
|
||||||
|
// !experimental
|
||||||
|
"forwardauth": ForwardAuth,
|
||||||
|
// "oauth2": OAuth2.m,
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUnknownMiddleware = E.New("unknown middleware")
|
ErrUnknownMiddleware = E.New("unknown middleware")
|
||||||
|
@ -33,46 +54,6 @@ func All() map[string]*Middleware {
|
||||||
return allMiddlewares
|
return allMiddlewares
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize middleware names and label parsers.
|
|
||||||
func init() {
|
|
||||||
// snakes and cases will be stripped on `Get`
|
|
||||||
// so keys are lowercase without snake.
|
|
||||||
allMiddlewares = map[string]*Middleware{
|
|
||||||
"redirecthttp": RedirectHTTP,
|
|
||||||
|
|
||||||
"request": ModifyRequest,
|
|
||||||
"modifyrequest": ModifyRequest,
|
|
||||||
"response": ModifyResponse,
|
|
||||||
"modifyresponse": ModifyResponse,
|
|
||||||
"setxforwarded": SetXForwarded,
|
|
||||||
"hidexforwarded": HideXForwarded,
|
|
||||||
|
|
||||||
"errorpage": CustomErrorPage,
|
|
||||||
"customerrorpage": CustomErrorPage,
|
|
||||||
|
|
||||||
"realip": RealIP,
|
|
||||||
"cloudflarerealip": CloudflareRealIP,
|
|
||||||
|
|
||||||
"cidrwhitelist": CIDRWhiteList,
|
|
||||||
"ratelimit": RateLimiter,
|
|
||||||
|
|
||||||
// !experimental
|
|
||||||
"forwardauth": ForwardAuth,
|
|
||||||
// "oauth2": OAuth2.m,
|
|
||||||
}
|
|
||||||
names := make(map[*Middleware][]string)
|
|
||||||
for name, m := range allMiddlewares {
|
|
||||||
names[m] = append(names[m], http.CanonicalHeaderKey(name))
|
|
||||||
}
|
|
||||||
for m, names := range names {
|
|
||||||
if len(names) > 1 {
|
|
||||||
m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", "))
|
|
||||||
} else {
|
|
||||||
m.name = names[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadComposeFiles() {
|
func LoadComposeFiles() {
|
||||||
errs := E.NewBuilder("middleware compile errors")
|
errs := E.NewBuilder("middleware compile errors")
|
||||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
type (
|
type (
|
||||||
modifyRequest struct {
|
modifyRequest struct {
|
||||||
ModifyRequestOpts
|
ModifyRequestOpts
|
||||||
*Tracer
|
Tracer
|
||||||
}
|
}
|
||||||
// order: set_headers -> add_headers -> hide_headers
|
// order: set_headers -> add_headers -> hide_headers
|
||||||
ModifyRequestOpts struct {
|
ModifyRequestOpts struct {
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
|
|
||||||
type modifyResponse struct {
|
type modifyResponse struct {
|
||||||
ModifyRequestOpts
|
ModifyRequestOpts
|
||||||
*Tracer
|
Tracer
|
||||||
}
|
}
|
||||||
|
|
||||||
var ModifyResponse = NewMiddleware[modifyResponse]()
|
var ModifyResponse = NewMiddleware[modifyResponse]()
|
||||||
|
|
|
@ -13,7 +13,7 @@ type (
|
||||||
requestMap = map[string]*rate.Limiter
|
requestMap = map[string]*rate.Limiter
|
||||||
rateLimiter struct {
|
rateLimiter struct {
|
||||||
RateLimiterOpts
|
RateLimiterOpts
|
||||||
*Tracer
|
Tracer
|
||||||
|
|
||||||
requestMap requestMap
|
requestMap requestMap
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
type (
|
type (
|
||||||
realIP struct {
|
realIP struct {
|
||||||
RealIPOpts
|
RealIPOpts
|
||||||
*Tracer
|
Tracer
|
||||||
}
|
}
|
||||||
RealIPOpts struct {
|
RealIPOpts struct {
|
||||||
// Header is the name of the header to use for the real client IP
|
// Header is the name of the header to use for the real client IP
|
||||||
|
|
|
@ -9,41 +9,53 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Tracer struct {
|
type Tracer struct {
|
||||||
name string
|
name string
|
||||||
parent *Tracer
|
enabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tracer) Fullname() string {
|
func _() {
|
||||||
if t.parent != nil {
|
var _ MiddlewareWithTracer = &Tracer{}
|
||||||
return t.parent.Fullname() + "." + t.name
|
}
|
||||||
|
|
||||||
|
func (t *Tracer) enableTrace() {
|
||||||
|
t.enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tracer) getTracer() *Tracer {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tracer) SetParent(parent *Tracer) {
|
||||||
|
if parent == nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return t.name
|
t.name = parent.name + "." + t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tracer) addTrace(msg string) *Trace {
|
func (t *Tracer) addTrace(msg string) *Trace {
|
||||||
return addTrace(&Trace{
|
return addTrace(&Trace{
|
||||||
Time: strutils.FormatTime(time.Now()),
|
Time: strutils.FormatTime(time.Now()),
|
||||||
Caller: t.Fullname(),
|
Caller: t.name,
|
||||||
Message: msg,
|
Message: msg,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
|
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
|
||||||
if t == nil {
|
if !t.enabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return t.addTrace(fmt.Sprintf(msg, args...))
|
return t.addTrace(fmt.Sprintf(msg, args...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
|
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
|
||||||
if t == nil {
|
if !t.enabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return t.addTrace(msg).WithRequest(req)
|
return t.addTrace(msg).WithRequest(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||||
if t == nil {
|
if !t.enabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return t.addTrace(msg).WithResponse(resp)
|
return t.addTrace(msg).WithResponse(resp)
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils/functional"
|
"github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
@ -46,10 +45,8 @@ func RegisterDefaultValueFactory[T any](factory func() *T) {
|
||||||
|
|
||||||
func New(t reflect.Type) reflect.Value {
|
func New(t reflect.Type) reflect.Value {
|
||||||
if dv, ok := defaultValues.Load(t); ok {
|
if dv, ok := defaultValues.Load(t); ok {
|
||||||
logging.Debug().Str("type", t.String()).Msg("using default value")
|
|
||||||
return reflect.ValueOf(dv())
|
return reflect.ValueOf(dv())
|
||||||
}
|
}
|
||||||
logging.Debug().Str("type", t.String()).Msg("using zero value")
|
|
||||||
return reflect.New(t)
|
return reflect.New(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue