cleanup and simplify middleware implementations, refactor some other code

This commit is contained in:
yusing 2024-12-16 10:19:14 +08:00
parent 8a9cb2527e
commit 59f4eaf3ea
34 changed files with 641 additions and 720 deletions

View file

@ -7,12 +7,12 @@ cli:
plugins:
sources:
- id: trunk
ref: v1.6.5
ref: v1.6.6
uri: https://github.com/trunk-io/plugins
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
runtimes:
enabled:
- node@18.12.1
- node@18.20.5
- python@3.10.8
- go@1.23.2
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
@ -23,16 +23,16 @@ lint:
enabled:
- hadolint@2.12.1-beta
- actionlint@1.7.4
- checkov@3.2.324
- checkov@3.2.334
- git-diff-check
- gofmt@1.20.4
- golangci-lint@1.62.2
- osv-scanner@1.9.1
- oxipng@9.1.3
- prettier@3.4.1
- prettier@3.4.2
- shellcheck@0.10.0
- shfmt@3.6.0
- trufflehog@3.84.1
- trufflehog@3.86.1
actions:
disabled:
- trunk-announce

View file

@ -60,7 +60,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
}
func rateLimited(f http.HandlerFunc) http.HandlerFunc {
m, err := middleware.RateLimiter.WithOptionsClone(middleware.OptionsRaw{
m, err := middleware.RateLimiter.New(middleware.OptionsRaw{
"average": 10,
"burst": 10,
})

View file

@ -75,7 +75,7 @@ func Handler(w http.ResponseWriter, r *http.Request) {
// On nginx, when route for domain does not exist, it returns StatusBadGateway.
// Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if !middleware.ServeStaticErrorPageFile(w, r) {
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok {

View file

@ -16,6 +16,9 @@ const (
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
HeaderContentType = "Content-Type"
HeaderContentLength = "Content-Length"
)
func RemoveHop(h http.Header) {

View file

@ -24,7 +24,7 @@ func (lb *LoadBalancer) newIPHash() impl {
return impl
}
var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options)
impl.realIP, err = middleware.RealIP.New(lb.Options)
if err != nil {
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
}

View file

@ -4,48 +4,45 @@ import (
"net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type cidrWhitelist struct {
cidrWhitelistOpts
m *Middleware
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
type cidrWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"`
Message string
}
type (
cidrWhitelist struct {
CIDRWhitelistOpts
*Tracer
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
CIDRWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"`
Message string
}
)
var (
CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist}
cidrWhitelistDefaults = cidrWhitelistOpts{
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
cidrWhitelistDefaults = CIDRWhitelistOpts{
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
}
)
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
wl := new(cidrWhitelist)
wl.m = &Middleware{
impl: wl,
before: wl.checkIP,
}
wl.cidrWhitelistOpts = cidrWhitelistDefaults
// setup implements MiddlewareWithSetup.
func (wl *cidrWhitelist) setup() {
wl.CIDRWhitelistOpts = cidrWhitelistDefaults
wl.cachedAddr = F.NewMapOf[string, bool]()
err := Deserialize(opts, &wl.cidrWhitelistOpts)
if err != nil {
return nil, err
}
return wl.m, nil
}
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
// before implements RequestModifier.
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
return wl.checkIP(w, r)
}
// checkIP checks if the IP address is allowed.
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
var allow, ok bool
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
@ -53,24 +50,23 @@ func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Req
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.cidrWhitelistOpts.Allow {
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
break
}
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
}
}
if !allow {
w.WriteHeader(wl.StatusCode)
w.Write([]byte(wl.Message))
return
http.Error(w, wl.Message, wl.StatusCode)
return false
}
next(w, r)
return true
}

View file

@ -17,27 +17,27 @@ var deny, accept *Middleware
func TestCIDRWhitelistValidation(t *testing.T) {
t.Run("valid", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/32"},
"message": "test-message",
})
ExpectNoError(t, err)
})
t.Run("missing allow", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
_, err := CIDRWhiteList.New(OptionsRaw{
"message": "test-message",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/123"},
"message": "test-message",
})
ExpectErrorT[*net.ParseError](t, err)
})
t.Run("invalid status code", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/32"},
"status_code": 600,
"message": "test-message",

View file

@ -11,11 +11,14 @@ import (
"time"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type cloudflareRealIP struct {
realIP realIP
}
const (
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
@ -29,26 +32,23 @@ var (
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
)
var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP}
var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
cri := new(realIP)
cri.m = &Middleware{
impl: cri,
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.From = cidrs
}
cri.setRealIP(r)
next(w, r)
},
}
cri.realIPOpts = realIPOpts{
// setup implements MiddlewareWithSetup.
func (cri *cloudflareRealIP) setup() {
cri.realIP.RealIPOpts = RealIPOpts{
Header: "CF-Connecting-IP",
Recursive: true,
}
return cri.m, nil
}
// before implements RequestModifier.
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.realIP.From = cidrs
}
return cri.realIP.before(w, r)
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {

View file

@ -12,45 +12,38 @@ import (
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
)
var CustomErrorPage *Middleware
type customErrorPage struct{}
func init() {
CustomErrorPage = customErrorPage()
var CustomErrorPage = NewMiddleware[customErrorPage]()
// before implements RequestModifier.
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
return !ServeStaticErrorPageFile(w, r)
}
func customErrorPage() *Middleware {
m := &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
}
m.modifyResponse = func(resp *Response) error {
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
/* trunk-ignore(golangci-lint/errcheck) */
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else {
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
// modifyResponse implements ResponseModifier.
func (customErrorPage) modifyResponse(resp *http.Response) error {
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
logger.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
} else {
logger.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return m
return nil
}
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
path := r.URL.Path
if path != "" && path[0] != '/' {
path = "/" + path
@ -65,11 +58,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
default:
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
}

View file

@ -1,5 +0,0 @@
package middleware
import E "github.com/yusing/go-proxy/internal/error"
var ErrZeroValue = E.New("cannot be zero")

View file

@ -12,16 +12,17 @@ import (
"strings"
"time"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
forwardAuth struct {
forwardAuthOpts
m *Middleware
ForwardAuthOpts
*Tracer
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
}
forwardAuthOpts struct {
ForwardAuthOpts struct {
Address string `validate:"url,required"`
TrustForwardHeader bool
AuthResponseHeaders []string
@ -29,36 +30,30 @@ type (
}
)
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
var ForwardAuth = NewMiddleware[forwardAuth]()
var faHTTPClient = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(r *Request, via []*Request) error {
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
return nil, err
}
fa.m = &Middleware{
impl: fa,
before: fa.forward,
}
return fa.m, nil
// setup implements MiddlewareWithSetup.
func (fa *forwardAuth) setup() {
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
}
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
// before implements RequestModifier.
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect
// scheme := "http"
// if req.TLS != nil {
// scheme = "https"
// }
// originalURL := scheme + "://" + req.Host + req.RequestURI
scheme := "http"
if req.TLS != nil {
scheme = "https"
}
originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address
faReq, err := http.NewRequestWithContext(
@ -68,7 +63,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
nil,
)
if err != nil {
fa.m.AddTracef("new request err to %s", url).WithError(err)
fa.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -79,12 +74,12 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik
// faReq.Header.Set("X-Original-URL", originalURL)
fa.m.AddTraceRequest("forward auth request", faReq)
faReq.Header.Set("X-Original-Url", originalURL)
fa.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq)
if err != nil {
fa.m.AddTracef("failed to call %s", url).WithError(err)
fa.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -92,30 +87,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
fa.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.m.AddTraceResponse("forward auth response", faResp)
fa.AddTraceResponse("forward auth response", faResp)
gphttp.CopyHeader(w.Header(), faResp.Header)
gphttp.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
fa.AddTracef("%s", "redirect to "+redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
}
return
}
@ -132,18 +127,22 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
authCookies := faResp.Cookies()
if len(authCookies) == 0 {
next.ServeHTTP(w, req)
return
if len(authCookies) > 0 {
fa.reqCookiesMap.Store(req, authCookies)
}
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
return true
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
// modifyResponse implements ResponseModifier.
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
fa.setAuthCookies(resp, cookies)
fa.reqCookiesMap.Delete(resp.Request)
}
return nil
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
@ -166,7 +165,7 @@ func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {

View file

@ -3,70 +3,110 @@ package middleware
import (
"encoding/json"
"net/http"
"reflect"
"strings"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Error = E.Error
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
Request = http.Request
Response = gphttp.ProxyResponse
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(*Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
OptionsRaw = map[string]any
ImplNewFunc = func() any
OptionsRaw = map[string]any
Middleware struct {
_ U.NoCopy
zerolog.Logger
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
withOptions CloneWithOptFunc
impl any
parent *Middleware
children []*Middleware
trace bool
name string
construct ImplNewFunc
impl any
}
RequestModifier interface {
before(w http.ResponseWriter, r *http.Request) (proceed bool)
}
ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() }
MiddlewareWithTracer *struct{ *Tracer }
)
var Deserialize = U.Deserialize
func Rewrite(r RewriteFunc) BeforeFunc {
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
r(req)
next(w, req)
func NewMiddleware[ImplType any]() *Middleware {
// type check
switch any(new(ImplType)).(type) {
case RequestModifier:
case ResponseModifier:
default:
panic("must implement RequestModifier or ResponseModifier")
}
return &Middleware{
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
construct: func() any { return new(ImplType) },
}
}
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.Tracer = &Tracer{name: m.name}
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.Tracer
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.parent = parent.getTracer()
}
}
func (m *Middleware) setup() {
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
setup.setup()
}
}
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
if len(optsRaw) == 0 {
return nil
}
return utils.Deserialize(optsRaw, m.impl)
}
func (m *Middleware) finalize() {
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize()
}
}
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.construct == nil {
if optsRaw != nil {
panic("bug: middleware already constructed")
}
return m, nil
}
mid := &Middleware{name: m.name, impl: m.construct()}
mid.setup()
if err := mid.apply(optsRaw); err != nil {
return nil, err
}
mid.finalize()
return mid, nil
}
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) Fullname() string {
if m.parent != nil {
return m.parent.Fullname() + "." + m.name
}
return m.name
}
func (m *Middleware) String() string {
return m.name
}
@ -78,57 +118,38 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}, "", " ")
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.withOptions != nil {
m, err := m.withOptions(optsRaw)
if err != nil {
return nil, err
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
m.Logger = logger.With().Str("name", m.name).Logger()
return m, nil
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
Logger: logger.With().Str("name", m.name).Logger(),
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
impl: m.impl,
parent: m.parent,
children: m.children,
}, nil
next(w, r)
}
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.before != nil {
m.before(next, w, r)
}
}
func (m *Middleware) ModifyResponse(resp *Response) error {
if m.modifyResponse != nil {
return m.modifyResponse(resp)
func (m *Middleware) ModifyResponse(resp *http.Response) error {
if exec, ok := m.impl.(ResponseModifier); ok {
return exec.modifyResponse(resp)
}
return nil
}
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.modifyResponse != nil {
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(ResponseModifier); ok {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r})
return exec.modifyResponse(resp)
})
}
if m.before != nil {
m.before(next, w, r)
} else {
next(w, r)
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := E.NewBuilder("middlewares compile error")
@ -141,7 +162,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
continue
}
m, err = m.WithOptionsClone(opts)
m, err = m.New(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
@ -157,7 +178,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap)
middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil {
return
}
@ -166,34 +187,30 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
}
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{
name: "set_upstream_headers",
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname())
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port())
next(w, r)
},
}}, middlewares...))
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
if mid.before != nil {
ori := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
mid.before(ori, w, r)
mid := NewMiddlewareChain(rp.TargetName, middlewares)
if before, ok := mid.impl.(RequestModifier); ok {
next := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
if proceed := before.before(w, r); proceed {
next(w, r)
}
}
}
if mid.modifyResponse != nil {
if mr, ok := mid.impl.(ResponseModifier); ok {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *Response) error {
if err := mid.modifyResponse(res); err != nil {
rp.ModifyResponse = func(res *http.Response) error {
if err := mr.modifyResponse(res); err != nil {
return err
}
return ori(res)
}
} else {
rp.ModifyResponse = mid.modifyResponse
rp.ModifyResponse = mr.modifyResponse
}
}
}

View file

@ -2,11 +2,9 @@ package middleware
import (
"fmt"
"net/http"
"os"
"path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
@ -56,7 +54,7 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
continue
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
m, err := base.New(def)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
@ -67,56 +65,5 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
if chainErr.HasError() {
return nil, chainErr.Error()
}
return BuildMiddlewareFromChain(name, chain), nil
}
// TODO: check conflict or duplicates.
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
m := &Middleware{name: name, children: chain}
var befores []*Middleware
var modResps []*Middleware
for _, comp := range chain {
if comp.before != nil {
befores = append(befores, comp)
}
if comp.modifyResponse != nil {
modResps = append(modResps, comp)
}
comp.parent = m
}
if len(befores) > 0 {
m.before = buildBefores(befores)
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
errs := E.NewBuilder("modify response errors")
for _, mr := range modResps {
if err := mr.modifyResponse(res); err != nil {
errs.Add(E.From(err).Subject(mr.name))
}
}
return errs.Error()
}
}
if common.IsDebug {
m.EnableTrace()
m.AddTracef("middleware created")
}
return m
}
func buildBefores(befores []*Middleware) BeforeFunc {
if len(befores) == 1 {
return befores[0].before
}
nextBefores := buildBefores(befores[1:])
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
befores[0].before(func(w ResponseWriter, r *Request) {
nextBefores(next, w, r)
}, w, r)
}
return NewMiddlewareChain(name, chain), nil
}

View file

@ -0,0 +1,58 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
type middlewareChain struct {
befores []RequestModifier
modResps []ResponseModifier
}
// TODO: check conflict or duplicates.
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
m := &Middleware{name: name, impl: chainMid}
for _, comp := range chain {
if before, ok := comp.impl.(RequestModifier); ok {
chainMid.befores = append(chainMid.befores, before)
}
if mr, ok := comp.impl.(ResponseModifier); ok {
chainMid.modResps = append(chainMid.modResps, mr)
}
comp.setParent(m)
}
if common.IsDebug {
m.enableTrace()
}
return m
}
// before implements RequestModifier.
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
for _, b := range m.befores {
if proceedNext = b.before(w, r); !proceedNext {
return false
}
}
return true
}
// modifyResponse implements ResponseModifier.
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
if len(m.modResps) == 0 {
return nil
}
errs := E.NewBuilder("modify response errors")
for i, mr := range m.modResps {
if err := mr.modifyResponse(resp); err != nil {
errs.Add(E.From(err).Subjectf("%d", i))
}
}
return errs.Error()
}

View file

@ -3,45 +3,39 @@ package middleware
import (
"net/http"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type (
modifyRequest struct {
modifyRequestOpts
m *Middleware
needVarSubstitution bool
ModifyRequestOpts
*Tracer
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
ModifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
needVarSubstitution bool
}
)
var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
var ModifyRequest = NewMiddleware[modifyRequest]()
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyRequest)
mr.m = &Middleware{
impl: mr,
before: Rewrite(func(req *Request) {
mr.m.AddTraceRequest("before modify request", req)
mr.modifyHeaders(req, nil, req.Header)
mr.m.AddTraceRequest("after modify request", req)
}),
}
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}
// finalize implements MiddlewareFinalizer.
func (mr *ModifyRequestOpts) finalize() {
mr.checkVarSubstitution()
return mr.m, nil
}
func (mr *modifyRequest) checkVarSubstitution() {
// before implements RequestModifier.
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
mr.AddTraceRequest("before modify request", r)
mr.modifyHeaders(r, nil, r.Header)
mr.AddTraceRequest("after modify request", r)
return true
}
func (mr *ModifyRequestOpts) checkVarSubstitution() {
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
for _, v := range m {
if strings.ContainsRune(v, '$') {
@ -52,10 +46,10 @@ func (mr *modifyRequest) checkVarSubstitution() {
}
}
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
if !mr.needVarSubstitution {
for k, v := range mr.SetHeaders {
if req != nil && strings.ToLower(k) == "host" {
if req != nil && strings.EqualFold(k, "host") {
defer func() {
req.Host = v
}()
@ -67,7 +61,7 @@ func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers htt
}
} else {
for k, v := range mr.SetHeaders {
if req != nil && strings.ToLower(k) == "host" {
if req != nil && strings.EqualFold(k, "host") {
defer func() {
req.Host = varReplace(req, resp, v)
}()

View file

@ -43,7 +43,7 @@ func TestModifyRequest(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.WithOptionsClone(opts)
mr, err := ModifyRequest.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -2,32 +2,19 @@ package middleware
import (
"net/http"
E "github.com/yusing/go-proxy/internal/error"
)
type modifyResponse = modifyRequest
var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse)
mr.m = &Middleware{
impl: mr,
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
next(w, r)
},
modifyResponse: func(resp *Response) error {
mr.m.AddTraceResponse("before modify response", resp.Response)
mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header)
mr.m.AddTraceResponse("after modify response", resp.Response)
return nil
},
}
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}
mr.checkVarSubstitution()
return mr.m, nil
type modifyResponse struct {
ModifyRequestOpts
*Tracer
}
var ModifyResponse = NewMiddleware[modifyResponse]()
// modifyResponse implements ResponseModifier.
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
mr.AddTraceResponse("before modify response", resp)
mr.modifyHeaders(resp.Request, resp, resp.Header)
mr.AddTraceResponse("after modify response", resp)
return nil
}

View file

@ -46,7 +46,7 @@ func TestModifyResponse(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.WithOptionsClone(opts)
mr, err := ModifyResponse.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -1,117 +0,0 @@
package middleware
// import (
// "encoding/json"
// "fmt"
// "net/http"
// "net/url"
// E "github.com/yusing/go-proxy/internal/error"
// )
// type oAuth2 struct {
// oAuth2Opts
// m *Middleware
// }
// type oAuth2Opts struct {
// ClientID string `validate:"required"`
// ClientSecret string `validate:"required"`
// AuthURL string `validate:"required"` // Authorization Endpoint
// TokenURL string `validate:"required"` // Token Endpoint
// }
// var OAuth2 = &oAuth2{
// m: &Middleware{withOptions: NewAuthentikOAuth2},
// }
// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
// oauth := new(oAuth2)
// oauth.m = &Middleware{
// impl: oauth,
// before: oauth.handleOAuth2,
// }
// err := Deserialize(opts, &oauth.oAuth2Opts)
// if err != nil {
// return nil, err
// }
// return oauth.m, nil
// }
// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// // Check if the user is authenticated (you may use session, cookie, etc.)
// if !userIsAuthenticated(r) {
// // TODO: Redirect to OAuth2 auth URL
// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
// return
// }
// // If you have a token in the query string, process it
// if code := r.URL.Query().Get("code"); code != "" {
// // Exchange the authorization code for a token here
// // Use the TokenURL and authenticate the user
// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
// if err != nil {
// // handle error
// http.Error(rw, "failed to get token", http.StatusUnauthorized)
// return
// }
// // Save token and user info based on your requirements
// saveToken(rw, token)
// // Redirect to the originally requested URL
// http.Redirect(rw, r, "/", http.StatusFound)
// return
// }
// // If user is authenticated, go to the next handler
// next(rw, r)
// }
// func userIsAuthenticated(r *http.Request) bool {
// // Example: Check for a session or cookie
// session, err := r.Cookie("session_token")
// if err != nil || session.Value == "" {
// return false
// }
// // Validate the session_token if necessary
// return true
// }
// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// // Prepare the request body
// data := url.Values{
// "client_id": {opts.ClientID},
// "client_secret": {opts.ClientSecret},
// "code": {code},
// "grant_type": {"authorization_code"},
// "redirect_uri": {requestURI},
// }
// resp, err := http.PostForm(opts.TokenURL, data)
// if err != nil {
// return "", fmt.Errorf("failed to request token: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
// }
// // Decode the response
// var tokenResp struct {
// AccessToken string `json:"access_token"`
// }
// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
// return "", fmt.Errorf("failed to decode token response: %w", err)
// }
// return tokenResp.AccessToken, nil
// }
// func saveToken(rw ResponseWriter, token string) {
// // Example: Save token in cookie
// http.SetCookie(rw, &http.Cookie{
// Name: "auth_token",
// Value: token,
// // set other properties as necessary, such as Secure and HttpOnly
// })
// }

View file

@ -6,68 +6,56 @@ import (
"sync"
"time"
E "github.com/yusing/go-proxy/internal/error"
"golang.org/x/time/rate"
)
type (
requestMap = map[string]*rate.Limiter
rateLimiter struct {
requestMap requestMap
newLimiter func() *rate.Limiter
m *Middleware
RateLimiterOpts
*Tracer
mu sync.Mutex
requestMap requestMap
mu sync.Mutex
}
rateLimiterOpts struct {
Average int `validate:"min=1,required"`
Burst int `validate:"min=1,required"`
Period time.Duration
RateLimiterOpts struct {
Average int `validate:"min=1,required"`
Burst int `validate:"min=1,required"`
Period time.Duration `validate:"min=1s"`
}
)
var (
RateLimiter = &Middleware{withOptions: NewRateLimiter}
rateLimiterOptsDefault = rateLimiterOpts{
RateLimiter = NewMiddleware[rateLimiter]()
rateLimiterOptsDefault = RateLimiterOpts{
Period: time.Second,
}
)
func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) {
rl := new(rateLimiter)
opts := rateLimiterOptsDefault
err := Deserialize(optsRaw, &opts)
if err != nil {
return nil, err
}
switch {
case opts.Average == 0:
return nil, ErrZeroValue.Subject("average")
case opts.Burst == 0:
return nil, ErrZeroValue.Subject("burst")
case opts.Period == 0:
return nil, ErrZeroValue.Subject("period")
}
// setup implements MiddlewareWithSetup.
func (rl *rateLimiter) setup() {
rl.RateLimiterOpts = rateLimiterOptsDefault
rl.requestMap = make(requestMap, 0)
rl.newLimiter = func() *rate.Limiter {
return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst)
}
rl.m = &Middleware{
impl: rl,
before: rl.limit,
}
return rl.m, nil
}
func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) {
// before implements RequestModifier.
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
return rl.limit(w, r)
}
func (rl *rateLimiter) newLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
}
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
rl.mu.Lock()
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
rl.m.Debug().Msgf("unable to parse remote address %s", r.RemoteAddr)
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
http.Error(w, "Internal error", http.StatusInternalServerError)
return
return false
}
limiter, ok := rl.requestMap[host]
@ -79,9 +67,9 @@ func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request
rl.mu.Unlock()
if limiter.Allow() {
next(w, r)
return
return true
}
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return false
}

View file

@ -14,7 +14,7 @@ func TestRateLimit(t *testing.T) {
"period": "1s",
}
rl, err := NewRateLimiter(opts)
rl, err := RateLimiter.New(opts)
ExpectNoError(t, err)
for range 10 {
result, err := newMiddlewareTest(rl, nil)

View file

@ -2,58 +2,53 @@ package middleware
import (
"net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
type realIP struct {
realIPOpts
m *Middleware
}
type realIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string `validate:"required"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR `validate:"min=1"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
the last address sent in the request header field defined by the Header field.
If recursive search is enabled,
the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field.
*/
Recursive bool
}
type (
realIP struct {
RealIPOpts
*Tracer
}
RealIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string `validate:"required"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR `validate:"required,min=1"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
the last address sent in the request header field defined by the Header field.
If recursive search is enabled,
the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field.
*/
Recursive bool
}
)
var (
RealIP = &Middleware{withOptions: NewRealIP}
realIPOptsDefault = realIPOpts{
RealIP = NewMiddleware[realIP]()
realIPOptsDefault = RealIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{},
}
)
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
riWithOpts := new(realIP)
riWithOpts.m = &Middleware{
impl: riWithOpts,
before: Rewrite(riWithOpts.setRealIP),
}
riWithOpts.realIPOpts = realIPOptsDefault
err := Deserialize(opts, &riWithOpts.realIPOpts)
if err != nil {
return nil, err
}
if len(riWithOpts.From) == 0 {
return nil, E.New("no allowed CIDRs").Subject("from")
}
return riWithOpts.m, nil
// setup implements MiddlewareWithSetup.
func (ri *realIP) setup() {
ri.RealIPOpts = realIPOptsDefault
}
// before implements RequestModifier.
func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
ri.setRealIP(r)
return true
}
func (ri *realIP) isInCIDRList(ip net.IP) bool {
@ -66,7 +61,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
return false
}
func (ri *realIP) setRealIP(req *Request) {
func (ri *realIP) setRealIP(req *http.Request) {
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
clientIPStr = req.RemoteAddr
@ -82,7 +77,7 @@ func (ri *realIP) setRealIP(req *Request) {
}
}
if !isTrusted {
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
return
}
@ -90,7 +85,7 @@ func (ri *realIP) setRealIP(req *Request) {
var lastNonTrustedIP string
if len(realIPs) == 0 {
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
return
}
@ -105,12 +100,12 @@ func (ri *realIP) setRealIP(req *Request) {
}
if lastNonTrustedIP == "" {
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
return
}
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP)
}

View file

@ -21,7 +21,7 @@ func TestSetRealIPOpts(t *testing.T) {
},
"recursive": true,
}
optExpected := &realIPOpts{
optExpected := &RealIPOpts{
Header: gphttp.HeaderXRealIP,
From: []*types.CIDR{
{
@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
Recursive: true,
}
ri, err := NewRealIP(opts)
ri, err := RealIP.New(opts)
ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
@ -61,18 +61,17 @@ func TestSetRealIP(t *testing.T) {
optsMr := OptionsRaw{
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := NewRealIP(opts)
realip, err := RealIP.New(opts)
ExpectNoError(t, err)
mr, err := NewModifyRequest(optsMr)
mr, err := ModifyRequest.New(optsMr)
ExpectNoError(t, err)
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err)
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP)
}

View file

@ -7,19 +7,22 @@ import (
"github.com/yusing/go-proxy/internal/common"
)
var RedirectHTTP = &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if r.TLS == nil {
r.URL.Scheme = "https"
host := r.Host
if i := strings.Index(host, ":"); i != -1 {
host = host[:i] // strip port number if present
}
r.URL.Host = host + ":" + common.ProxyHTTPSPort
logger.Info().Str("url", r.URL.String()).Msg("redirect to https")
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
next(w, r)
},
type redirectHTTP struct{}
var RedirectHTTP = NewMiddleware[redirectHTTP]()
// before implements RequestModifier.
func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if r.TLS != nil {
return true
}
r.URL.Scheme = "https"
host := r.Host
if i := strings.Index(host, ":"); i != -1 {
host = host[:i] // strip port number if present
}
r.URL.Host = host + ":" + common.ProxyHTTPSPort
logger.Debug().Str("url", r.URL.String()).Msg("redirect to https")
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return true
}

View file

@ -0,0 +1,34 @@
package middleware
import (
"net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
// internal use only.
type setUpstreamHeaders struct {
Scheme, Host, Port string
}
var suh = NewMiddleware[setUpstreamHeaders]()
func newSetUpstreamHeaders(rp *gphttp.ReverseProxy) *Middleware {
m, err := suh.New(OptionsRaw{
"scheme": rp.TargetURL.Scheme,
"host": rp.TargetURL.Hostname(),
"port": rp.TargetURL.Port(),
})
if err != nil {
panic(err)
}
return m
}
// before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
return true
}

View file

@ -141,7 +141,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
mid, setOptErr := middleware.New(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}

View file

@ -1,27 +1,25 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders map[string]string `json:"req_headers,omitempty"`
RespHeaders map[string]string `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}
type Traces []*Trace
type (
Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders map[string]string `json:"req_headers,omitempty"`
RespHeaders map[string]string `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}
Traces []*Trace
)
var (
traces = make(Traces, 0)
@ -34,7 +32,7 @@ func GetAllTrace() []*Trace {
return traces
}
func (tr *Trace) WithRequest(req *Request) *Trace {
func (tr *Trace) WithRequest(req *http.Request) *Trace {
if tr == nil {
return nil
}
@ -78,39 +76,6 @@ func (tr *Trace) WithError(err error) *Trace {
return tr
}
func (m *Middleware) EnableTrace() {
m.trace = true
for _, child := range m.children {
child.parent = m
child.EnableTrace()
}
}
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
if !m.trace {
return nil
}
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: m.Fullname(),
Message: fmt.Sprintf(msg, args...),
})
}
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithRequest(req)
}
func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithResponse(resp)
}
func addTrace(t *Trace) *Trace {
tracesMu.Lock()
defer tracesMu.Unlock()

View file

@ -0,0 +1,50 @@
package middleware
import (
"fmt"
"net/http"
"time"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Tracer struct {
name string
parent *Tracer
}
func (t *Tracer) Fullname() string {
if t.parent != nil {
return t.parent.Fullname() + "." + t.name
}
return t.name
}
func (t *Tracer) addTrace(msg string) *Trace {
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: t.Fullname(),
Message: msg,
})
}
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
if t == nil {
return nil
}
return t.addTrace(fmt.Sprintf(msg, args...))
}
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
if t == nil {
return nil
}
return t.addTrace(msg).WithRequest(req)
}
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
if t == nil {
return nil
}
return t.addTrace(msg).WithResponse(resp)
}

View file

@ -11,8 +11,8 @@ import (
)
type (
reqVarGetter func(*Request) string
respVarGetter func(*Response) string
reqVarGetter func(*http.Request) string
respVarGetter func(*http.Response) string
)
var (
@ -49,50 +49,50 @@ const (
)
var staticReqVarSubsMap = map[string]reqVarGetter{
VarRequestMethod: func(req *Request) string { return req.Method },
VarRequestScheme: func(req *Request) string {
VarRequestMethod: func(req *http.Request) string { return req.Method },
VarRequestScheme: func(req *http.Request) string {
if req.TLS != nil {
return "https"
}
return "http"
},
VarRequestHost: func(req *Request) string {
VarRequestHost: func(req *http.Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host
}
return reqHost
},
VarRequestPort: func(req *Request) string {
VarRequestPort: func(req *http.Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
VarRequestAddr: func(req *Request) string { return req.Host },
VarRequestPath: func(req *Request) string { return req.URL.Path },
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *Request) string { return req.URL.String() },
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *Request) string {
VarRequestAddr: func(req *http.Request) string { return req.Host },
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientIP
}
return ""
},
VarRemotePort: func(req *Request) string {
VarRemotePort: func(req *http.Request) string {
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientPort
}
return ""
},
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr },
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *Request) string {
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
if upPort != "" {
@ -100,7 +100,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
}
return upHost
},
VarUpstreamURL: func(req *Request) string {
VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
if upScheme == "" {
return ""
@ -116,12 +116,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
}
var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
}
func varReplace(req *Request, resp *Response, s string) string {
func varReplace(req *http.Request, resp *http.Response, s string) string {
if req != nil {
// Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string {

View file

@ -2,27 +2,44 @@ package middleware
import (
"net"
"net/http"
"strings"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
var SetXForwarded = &Middleware{
before: Rewrite(func(req *Request) {
req.Header.Del(gphttp.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
req.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
}
}),
type (
setXForwarded struct{}
hideXForwarded struct{}
)
var (
SetXForwarded = NewMiddleware[setXForwarded]()
HideXForwarded = NewMiddleware[hideXForwarded]()
)
// before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(gphttp.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
}
return true
}
var HideXForwarded = &Middleware{
before: Rewrite(func(req *Request) {
for k := range req.Header {
if strings.HasPrefix(k, "X-Forwarded-") {
req.Header.Del(k)
}
// before implements RequestModifier.
func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
toDelete := make([]string, 0, len(r.Header))
for k := range r.Header {
if strings.HasPrefix(k, "X-Forwarded-") {
toDelete = append(toDelete, k)
}
}),
}
for _, k := range toDelete {
r.Header.Del(k)
}
return true
}

View file

@ -1,8 +0,0 @@
package http
import "net/http"
type ProxyResponse struct {
*http.Response
OriginalRequest *http.Request
}

View file

@ -87,7 +87,7 @@ type ReverseProxy struct {
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*ProxyResponse) error
ModifyResponse func(*http.Response) error
HandlerFunc http.HandlerFunc
@ -251,11 +251,14 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool {
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil {
res.Request = origReq
err := p.ModifyResponse(res)
res.Request = req
if err != nil {
res.Body.Close()
p.errorHandler(rw, req, err, true)
return false
@ -264,9 +267,6 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme)
// req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname())
// req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port())
p.HandlerFunc(rw, req)
}
@ -455,13 +455,13 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
res = &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: outreq.Proto,
ProtoMajor: outreq.ProtoMajor,
ProtoMinor: outreq.ProtoMinor,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
Request: outreq,
TLS: outreq.TLS,
Request: req,
TLS: req.TLS,
}
}

View file

@ -9,26 +9,31 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestTaskCreation(t *testing.T) {
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
const (
rootTaskName = "root-task"
subTaskName = "subtask"
)
ExpectEqual(t, "root-task", rootTask.Name())
ExpectEqual(t, "subtask", subTask.Name())
func TestTaskCreation(t *testing.T) {
rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask(subTaskName)
ExpectEqual(t, rootTaskName, rootTask.Name())
ExpectEqual(t, subTaskName, subTask.Name())
}
func TestTaskCancellation(t *testing.T) {
subTaskDone := make(chan struct{})
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask(subTaskName)
go func() {
subTask.Wait()
close(subTaskDone)
}()
go rootTask.Finish("done")
go rootTask.Finish(nil)
select {
case <-subTaskDone:
@ -42,14 +47,14 @@ func TestTaskCancellation(t *testing.T) {
}
func TestOnComplete(t *testing.T) {
rootTask := GlobalTask("root-task")
task := rootTask.Subtask("test")
rootTask := GlobalTask(rootTaskName)
task := rootTask.Subtask(subTaskName)
var value atomic.Int32
task.OnFinished("set value", func() {
value.Store(1234)
})
task.Finish("done")
task.Finish(nil)
ExpectEqual(t, value.Load(), 1234)
}
@ -57,36 +62,36 @@ func TestGlobalContextWait(t *testing.T) {
testResetGlobalTask()
defer CancelGlobalContext()
rootTask := GlobalTask("root-task")
rootTask := GlobalTask(rootTaskName)
finished1, finished2 := false, false
subTask1 := rootTask.Subtask("subtask1")
subTask2 := rootTask.Subtask("subtask2")
subTask1.OnFinished("set finished", func() {
subTask1 := rootTask.Subtask(subTaskName)
subTask2 := rootTask.Subtask(subTaskName)
subTask1.OnFinished("", func() {
finished1 = true
})
subTask2.OnFinished("set finished", func() {
subTask2.OnFinished("", func() {
finished2 = true
})
go func() {
time.Sleep(500 * time.Millisecond)
subTask1.Finish("done")
subTask1.Finish(nil)
}()
go func() {
time.Sleep(500 * time.Millisecond)
subTask2.Finish("done")
subTask2.Finish(nil)
}()
go func() {
subTask1.Wait()
subTask2.Wait()
rootTask.Finish("done")
rootTask.Finish(nil)
}()
GlobalContextWait(1 * time.Second)
_ = GlobalContextWait(1 * time.Second)
ExpectTrue(t, finished1)
ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err())
@ -97,8 +102,8 @@ func TestGlobalContextWait(t *testing.T) {
func TestTimeoutOnGlobalContextWait(t *testing.T) {
testResetGlobalTask()
rootTask := GlobalTask("root-task")
rootTask.Subtask("subtask")
rootTask := GlobalTask(rootTaskName)
rootTask.Subtask(subTaskName)
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
}
@ -107,7 +112,7 @@ func TestGlobalContextCancellation(t *testing.T) {
testResetGlobalTask()
taskDone := make(chan struct{})
rootTask := GlobalTask("root-task")
rootTask := GlobalTask(rootTaskName)
go func() {
rootTask.Wait()

View file

@ -19,23 +19,24 @@ func TestSerializeDeserialize(t *testing.T) {
MIS map[int]string
}
var testStruct = S{
I: 1,
S: "hello",
IS: []int{1, 2, 3},
SS: []string{"a", "b", "c"},
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
}
var testStructSerialized = map[string]any{
"I": 1,
"S": "hello",
"IS": []int{1, 2, 3},
"SS": []string{"a", "b", "c"},
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
}
var (
testStruct = S{
I: 1,
S: "hello",
IS: []int{1, 2, 3},
SS: []string{"a", "b", "c"},
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
}
testStructSerialized = map[string]any{
"I": 1,
"S": "hello",
"IS": []int{1, 2, 3},
"SS": []string{"a", "b", "c"},
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
}
)
t.Run("serialize", func(t *testing.T) {
s, err := Serialize(testStruct)