cleanup and simplify middleware implementations, refactor some other code

This commit is contained in:
yusing 2024-12-16 10:19:14 +08:00
parent 8a9cb2527e
commit 59f4eaf3ea
34 changed files with 641 additions and 720 deletions

View file

@ -7,12 +7,12 @@ cli:
plugins:
sources:
- id: trunk
ref: v1.6.5
ref: v1.6.6
uri: https://github.com/trunk-io/plugins
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
runtimes:
enabled:
- node@18.12.1
- node@18.20.5
- python@3.10.8
- go@1.23.2
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
@ -23,16 +23,16 @@ lint:
enabled:
- hadolint@2.12.1-beta
- actionlint@1.7.4
- checkov@3.2.324
- checkov@3.2.334
- git-diff-check
- gofmt@1.20.4
- golangci-lint@1.62.2
- osv-scanner@1.9.1
- oxipng@9.1.3
- prettier@3.4.1
- prettier@3.4.2
- shellcheck@0.10.0
- shfmt@3.6.0
- trufflehog@3.84.1
- trufflehog@3.86.1
actions:
disabled:
- trunk-announce

View file

@ -60,7 +60,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
}
func rateLimited(f http.HandlerFunc) http.HandlerFunc {
m, err := middleware.RateLimiter.WithOptionsClone(middleware.OptionsRaw{
m, err := middleware.RateLimiter.New(middleware.OptionsRaw{
"average": 10,
"burst": 10,
})

View file

@ -75,7 +75,7 @@ func Handler(w http.ResponseWriter, r *http.Request) {
// On nginx, when route for domain does not exist, it returns StatusBadGateway.
// Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if !middleware.ServeStaticErrorPageFile(w, r) {
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok {

View file

@ -16,6 +16,9 @@ const (
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
HeaderContentType = "Content-Type"
HeaderContentLength = "Content-Length"
)
func RemoveHop(h http.Header) {

View file

@ -24,7 +24,7 @@ func (lb *LoadBalancer) newIPHash() impl {
return impl
}
var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options)
impl.realIP, err = middleware.RealIP.New(lb.Options)
if err != nil {
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
}

View file

@ -4,48 +4,45 @@ import (
"net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type cidrWhitelist struct {
cidrWhitelistOpts
m *Middleware
type (
cidrWhitelist struct {
CIDRWhitelistOpts
*Tracer
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
type cidrWhitelistOpts struct {
}
CIDRWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"`
Message string
}
}
)
var (
CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist}
cidrWhitelistDefaults = cidrWhitelistOpts{
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
cidrWhitelistDefaults = CIDRWhitelistOpts{
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
}
)
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
wl := new(cidrWhitelist)
wl.m = &Middleware{
impl: wl,
before: wl.checkIP,
}
wl.cidrWhitelistOpts = cidrWhitelistDefaults
// setup implements MiddlewareWithSetup.
func (wl *cidrWhitelist) setup() {
wl.CIDRWhitelistOpts = cidrWhitelistDefaults
wl.cachedAddr = F.NewMapOf[string, bool]()
err := Deserialize(opts, &wl.cidrWhitelistOpts)
if err != nil {
return nil, err
}
return wl.m, nil
}
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
// before implements RequestModifier.
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
return wl.checkIP(w, r)
}
// checkIP checks if the IP address is allowed.
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
var allow, ok bool
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
@ -53,24 +50,23 @@ func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Req
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.cidrWhitelistOpts.Allow {
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
break
}
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
}
}
if !allow {
w.WriteHeader(wl.StatusCode)
w.Write([]byte(wl.Message))
return
http.Error(w, wl.Message, wl.StatusCode)
return false
}
next(w, r)
return true
}

View file

@ -17,27 +17,27 @@ var deny, accept *Middleware
func TestCIDRWhitelistValidation(t *testing.T) {
t.Run("valid", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/32"},
"message": "test-message",
})
ExpectNoError(t, err)
})
t.Run("missing allow", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
_, err := CIDRWhiteList.New(OptionsRaw{
"message": "test-message",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/123"},
"message": "test-message",
})
ExpectErrorT[*net.ParseError](t, err)
})
t.Run("invalid status code", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/32"},
"status_code": 600,
"message": "test-message",

View file

@ -11,11 +11,14 @@ import (
"time"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type cloudflareRealIP struct {
realIP realIP
}
const (
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
@ -29,26 +32,23 @@ var (
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
)
var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP}
var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
cri := new(realIP)
cri.m = &Middleware{
impl: cri,
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.From = cidrs
}
cri.setRealIP(r)
next(w, r)
},
}
cri.realIPOpts = realIPOpts{
// setup implements MiddlewareWithSetup.
func (cri *cloudflareRealIP) setup() {
cri.realIP.RealIPOpts = RealIPOpts{
Header: "CF-Connecting-IP",
Recursive: true,
}
return cri.m, nil
}
// before implements RequestModifier.
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.realIP.From = cidrs
}
return cri.realIP.before(w, r)
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {

View file

@ -12,45 +12,38 @@ import (
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
)
var CustomErrorPage *Middleware
type customErrorPage struct{}
func init() {
CustomErrorPage = customErrorPage()
var CustomErrorPage = NewMiddleware[customErrorPage]()
// before implements RequestModifier.
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
return !ServeStaticErrorPageFile(w, r)
}
func customErrorPage() *Middleware {
m := &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
}
m.modifyResponse = func(resp *Response) error {
// modifyResponse implements ResponseModifier.
func (customErrorPage) modifyResponse(resp *http.Response) error {
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
/* trunk-ignore(golangci-lint/errcheck) */
io.Copy(io.Discard, resp.Body) // drain the original body
logger.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
} else {
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
logger.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
}
return m
}
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
path := r.URL.Path
if path != "" && path[0] != '/' {
path = "/" + path
@ -65,11 +58,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
default:
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
}

View file

@ -1,5 +0,0 @@
package middleware
import E "github.com/yusing/go-proxy/internal/error"
var ErrZeroValue = E.New("cannot be zero")

View file

@ -12,16 +12,17 @@ import (
"strings"
"time"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
forwardAuth struct {
forwardAuthOpts
m *Middleware
ForwardAuthOpts
*Tracer
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
}
forwardAuthOpts struct {
ForwardAuthOpts struct {
Address string `validate:"url,required"`
TrustForwardHeader bool
AuthResponseHeaders []string
@ -29,36 +30,30 @@ type (
}
)
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
var ForwardAuth = NewMiddleware[forwardAuth]()
var faHTTPClient = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(r *Request, via []*Request) error {
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
return nil, err
}
fa.m = &Middleware{
impl: fa,
before: fa.forward,
}
return fa.m, nil
// setup implements MiddlewareWithSetup.
func (fa *forwardAuth) setup() {
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
}
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
// before implements RequestModifier.
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect
// scheme := "http"
// if req.TLS != nil {
// scheme = "https"
// }
// originalURL := scheme + "://" + req.Host + req.RequestURI
scheme := "http"
if req.TLS != nil {
scheme = "https"
}
originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address
faReq, err := http.NewRequestWithContext(
@ -68,7 +63,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
nil,
)
if err != nil {
fa.m.AddTracef("new request err to %s", url).WithError(err)
fa.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -79,12 +74,12 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik
// faReq.Header.Set("X-Original-URL", originalURL)
fa.m.AddTraceRequest("forward auth request", faReq)
faReq.Header.Set("X-Original-Url", originalURL)
fa.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq)
if err != nil {
fa.m.AddTracef("failed to call %s", url).WithError(err)
fa.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -92,30 +87,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
fa.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.m.AddTraceResponse("forward auth response", faResp)
fa.AddTraceResponse("forward auth response", faResp)
gphttp.CopyHeader(w.Header(), faResp.Header)
gphttp.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
fa.AddTracef("%s", "redirect to "+redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
}
return
}
@ -132,18 +127,22 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
authCookies := faResp.Cookies()
if len(authCookies) == 0 {
next.ServeHTTP(w, req)
return
if len(authCookies) > 0 {
fa.reqCookiesMap.Store(req, authCookies)
}
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
return true
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
// modifyResponse implements ResponseModifier.
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
fa.setAuthCookies(resp, cookies)
fa.reqCookiesMap.Delete(resp.Request)
}
return nil
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
@ -166,7 +165,7 @@ func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {

View file

@ -3,11 +3,12 @@ package middleware
import (
"encoding/json"
"net/http"
"reflect"
"strings"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils"
)
type (
@ -15,58 +16,97 @@ type (
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
Request = http.Request
Response = gphttp.ProxyResponse
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(*Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
ImplNewFunc = func() any
OptionsRaw = map[string]any
Middleware struct {
_ U.NoCopy
zerolog.Logger
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
withOptions CloneWithOptFunc
construct ImplNewFunc
impl any
parent *Middleware
children []*Middleware
trace bool
}
RequestModifier interface {
before(w http.ResponseWriter, r *http.Request) (proceed bool)
}
ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() }
MiddlewareWithTracer *struct{ *Tracer }
)
var Deserialize = U.Deserialize
func Rewrite(r RewriteFunc) BeforeFunc {
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
r(req)
next(w, req)
func NewMiddleware[ImplType any]() *Middleware {
// type check
switch any(new(ImplType)).(type) {
case RequestModifier:
case ResponseModifier:
default:
panic("must implement RequestModifier or ResponseModifier")
}
return &Middleware{
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
construct: func() any { return new(ImplType) },
}
}
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.Tracer = &Tracer{name: m.name}
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.Tracer
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.parent = parent.getTracer()
}
}
func (m *Middleware) setup() {
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
setup.setup()
}
}
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
if len(optsRaw) == 0 {
return nil
}
return utils.Deserialize(optsRaw, m.impl)
}
func (m *Middleware) finalize() {
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize()
}
}
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.construct == nil {
if optsRaw != nil {
panic("bug: middleware already constructed")
}
return m, nil
}
mid := &Middleware{name: m.name, impl: m.construct()}
mid.setup()
if err := mid.apply(optsRaw); err != nil {
return nil, err
}
mid.finalize()
return mid, nil
}
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) Fullname() string {
if m.parent != nil {
return m.parent.Fullname() + "." + m.name
}
return m.name
}
func (m *Middleware) String() string {
return m.name
}
@ -78,57 +118,38 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}, "", " ")
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.withOptions != nil {
m, err := m.withOptions(optsRaw)
if err != nil {
return nil, err
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
m.Logger = logger.With().Str("name", m.name).Logger()
return m, nil
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
Logger: logger.With().Str("name", m.name).Logger(),
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
impl: m.impl,
parent: m.parent,
children: m.children,
}, nil
next(w, r)
}
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.before != nil {
m.before(next, w, r)
}
}
func (m *Middleware) ModifyResponse(resp *Response) error {
if m.modifyResponse != nil {
return m.modifyResponse(resp)
func (m *Middleware) ModifyResponse(resp *http.Response) error {
if exec, ok := m.impl.(ResponseModifier); ok {
return exec.modifyResponse(resp)
}
return nil
}
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.modifyResponse != nil {
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(ResponseModifier); ok {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r})
return exec.modifyResponse(resp)
})
}
if m.before != nil {
m.before(next, w, r)
} else {
next(w, r)
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := E.NewBuilder("middlewares compile error")
@ -141,7 +162,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
continue
}
m, err = m.WithOptionsClone(opts)
m, err = m.New(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
@ -157,7 +178,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap)
middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil {
return
}
@ -166,34 +187,30 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
}
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{
name: "set_upstream_headers",
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname())
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port())
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
mid := NewMiddlewareChain(rp.TargetName, middlewares)
if before, ok := mid.impl.(RequestModifier); ok {
next := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
if proceed := before.before(w, r); proceed {
next(w, r)
},
}}, middlewares...))
if mid.before != nil {
ori := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
mid.before(ori, w, r)
}
}
}
if mid.modifyResponse != nil {
if mr, ok := mid.impl.(ResponseModifier); ok {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *Response) error {
if err := mid.modifyResponse(res); err != nil {
rp.ModifyResponse = func(res *http.Response) error {
if err := mr.modifyResponse(res); err != nil {
return err
}
return ori(res)
}
} else {
rp.ModifyResponse = mid.modifyResponse
rp.ModifyResponse = mr.modifyResponse
}
}
}

View file

@ -2,11 +2,9 @@ package middleware
import (
"fmt"
"net/http"
"os"
"path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
@ -56,7 +54,7 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
continue
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
m, err := base.New(def)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
@ -67,56 +65,5 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
if chainErr.HasError() {
return nil, chainErr.Error()
}
return BuildMiddlewareFromChain(name, chain), nil
}
// TODO: check conflict or duplicates.
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
m := &Middleware{name: name, children: chain}
var befores []*Middleware
var modResps []*Middleware
for _, comp := range chain {
if comp.before != nil {
befores = append(befores, comp)
}
if comp.modifyResponse != nil {
modResps = append(modResps, comp)
}
comp.parent = m
}
if len(befores) > 0 {
m.before = buildBefores(befores)
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
errs := E.NewBuilder("modify response errors")
for _, mr := range modResps {
if err := mr.modifyResponse(res); err != nil {
errs.Add(E.From(err).Subject(mr.name))
}
}
return errs.Error()
}
}
if common.IsDebug {
m.EnableTrace()
m.AddTracef("middleware created")
}
return m
}
func buildBefores(befores []*Middleware) BeforeFunc {
if len(befores) == 1 {
return befores[0].before
}
nextBefores := buildBefores(befores[1:])
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
befores[0].before(func(w ResponseWriter, r *Request) {
nextBefores(next, w, r)
}, w, r)
}
return NewMiddlewareChain(name, chain), nil
}

View file

@ -0,0 +1,58 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
type middlewareChain struct {
befores []RequestModifier
modResps []ResponseModifier
}
// TODO: check conflict or duplicates.
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
m := &Middleware{name: name, impl: chainMid}
for _, comp := range chain {
if before, ok := comp.impl.(RequestModifier); ok {
chainMid.befores = append(chainMid.befores, before)
}
if mr, ok := comp.impl.(ResponseModifier); ok {
chainMid.modResps = append(chainMid.modResps, mr)
}
comp.setParent(m)
}
if common.IsDebug {
m.enableTrace()
}
return m
}
// before implements RequestModifier.
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
for _, b := range m.befores {
if proceedNext = b.before(w, r); !proceedNext {
return false
}
}
return true
}
// modifyResponse implements ResponseModifier.
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
if len(m.modResps) == 0 {
return nil
}
errs := E.NewBuilder("modify response errors")
for i, mr := range m.modResps {
if err := mr.modifyResponse(resp); err != nil {
errs.Add(E.From(err).Subjectf("%d", i))
}
}
return errs.Error()
}

View file

@ -3,45 +3,39 @@ package middleware
import (
"net/http"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type (
modifyRequest struct {
modifyRequestOpts
m *Middleware
needVarSubstitution bool
ModifyRequestOpts
*Tracer
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
ModifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
needVarSubstitution bool
}
)
var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
var ModifyRequest = NewMiddleware[modifyRequest]()
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyRequest)
mr.m = &Middleware{
impl: mr,
before: Rewrite(func(req *Request) {
mr.m.AddTraceRequest("before modify request", req)
mr.modifyHeaders(req, nil, req.Header)
mr.m.AddTraceRequest("after modify request", req)
}),
}
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}
// finalize implements MiddlewareFinalizer.
func (mr *ModifyRequestOpts) finalize() {
mr.checkVarSubstitution()
return mr.m, nil
}
func (mr *modifyRequest) checkVarSubstitution() {
// before implements RequestModifier.
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
mr.AddTraceRequest("before modify request", r)
mr.modifyHeaders(r, nil, r.Header)
mr.AddTraceRequest("after modify request", r)
return true
}
func (mr *ModifyRequestOpts) checkVarSubstitution() {
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
for _, v := range m {
if strings.ContainsRune(v, '$') {
@ -52,10 +46,10 @@ func (mr *modifyRequest) checkVarSubstitution() {
}
}
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
if !mr.needVarSubstitution {
for k, v := range mr.SetHeaders {
if req != nil && strings.ToLower(k) == "host" {
if req != nil && strings.EqualFold(k, "host") {
defer func() {
req.Host = v
}()
@ -67,7 +61,7 @@ func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers htt
}
} else {
for k, v := range mr.SetHeaders {
if req != nil && strings.ToLower(k) == "host" {
if req != nil && strings.EqualFold(k, "host") {
defer func() {
req.Host = varReplace(req, resp, v)
}()

View file

@ -43,7 +43,7 @@ func TestModifyRequest(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.WithOptionsClone(opts)
mr, err := ModifyRequest.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -2,32 +2,19 @@ package middleware
import (
"net/http"
E "github.com/yusing/go-proxy/internal/error"
)
type modifyResponse = modifyRequest
var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse)
mr.m = &Middleware{
impl: mr,
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
next(w, r)
},
modifyResponse: func(resp *Response) error {
mr.m.AddTraceResponse("before modify response", resp.Response)
mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header)
mr.m.AddTraceResponse("after modify response", resp.Response)
return nil
},
}
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}
mr.checkVarSubstitution()
return mr.m, nil
type modifyResponse struct {
ModifyRequestOpts
*Tracer
}
var ModifyResponse = NewMiddleware[modifyResponse]()
// modifyResponse implements ResponseModifier.
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
mr.AddTraceResponse("before modify response", resp)
mr.modifyHeaders(resp.Request, resp, resp.Header)
mr.AddTraceResponse("after modify response", resp)
return nil
}

View file

@ -46,7 +46,7 @@ func TestModifyResponse(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.WithOptionsClone(opts)
mr, err := ModifyResponse.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -1,117 +0,0 @@
package middleware
// import (
// "encoding/json"
// "fmt"
// "net/http"
// "net/url"
// E "github.com/yusing/go-proxy/internal/error"
// )
// type oAuth2 struct {
// oAuth2Opts
// m *Middleware
// }
// type oAuth2Opts struct {
// ClientID string `validate:"required"`
// ClientSecret string `validate:"required"`
// AuthURL string `validate:"required"` // Authorization Endpoint
// TokenURL string `validate:"required"` // Token Endpoint
// }
// var OAuth2 = &oAuth2{
// m: &Middleware{withOptions: NewAuthentikOAuth2},
// }
// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
// oauth := new(oAuth2)
// oauth.m = &Middleware{
// impl: oauth,
// before: oauth.handleOAuth2,
// }
// err := Deserialize(opts, &oauth.oAuth2Opts)
// if err != nil {
// return nil, err
// }
// return oauth.m, nil
// }
// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// // Check if the user is authenticated (you may use session, cookie, etc.)
// if !userIsAuthenticated(r) {
// // TODO: Redirect to OAuth2 auth URL
// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
// return
// }
// // If you have a token in the query string, process it
// if code := r.URL.Query().Get("code"); code != "" {
// // Exchange the authorization code for a token here
// // Use the TokenURL and authenticate the user
// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
// if err != nil {
// // handle error
// http.Error(rw, "failed to get token", http.StatusUnauthorized)
// return
// }
// // Save token and user info based on your requirements
// saveToken(rw, token)
// // Redirect to the originally requested URL
// http.Redirect(rw, r, "/", http.StatusFound)
// return
// }
// // If user is authenticated, go to the next handler
// next(rw, r)
// }
// func userIsAuthenticated(r *http.Request) bool {
// // Example: Check for a session or cookie
// session, err := r.Cookie("session_token")
// if err != nil || session.Value == "" {
// return false
// }
// // Validate the session_token if necessary
// return true
// }
// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// // Prepare the request body
// data := url.Values{
// "client_id": {opts.ClientID},
// "client_secret": {opts.ClientSecret},
// "code": {code},
// "grant_type": {"authorization_code"},
// "redirect_uri": {requestURI},
// }
// resp, err := http.PostForm(opts.TokenURL, data)
// if err != nil {
// return "", fmt.Errorf("failed to request token: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
// }
// // Decode the response
// var tokenResp struct {
// AccessToken string `json:"access_token"`
// }
// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
// return "", fmt.Errorf("failed to decode token response: %w", err)
// }
// return tokenResp.AccessToken, nil
// }
// func saveToken(rw ResponseWriter, token string) {
// // Example: Save token in cookie
// http.SetCookie(rw, &http.Cookie{
// Name: "auth_token",
// Value: token,
// // set other properties as necessary, such as Secure and HttpOnly
// })
// }

View file

@ -6,68 +6,56 @@ import (
"sync"
"time"
E "github.com/yusing/go-proxy/internal/error"
"golang.org/x/time/rate"
)
type (
requestMap = map[string]*rate.Limiter
rateLimiter struct {
requestMap requestMap
newLimiter func() *rate.Limiter
m *Middleware
RateLimiterOpts
*Tracer
requestMap requestMap
mu sync.Mutex
}
rateLimiterOpts struct {
RateLimiterOpts struct {
Average int `validate:"min=1,required"`
Burst int `validate:"min=1,required"`
Period time.Duration
Period time.Duration `validate:"min=1s"`
}
)
var (
RateLimiter = &Middleware{withOptions: NewRateLimiter}
rateLimiterOptsDefault = rateLimiterOpts{
RateLimiter = NewMiddleware[rateLimiter]()
rateLimiterOptsDefault = RateLimiterOpts{
Period: time.Second,
}
)
func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) {
rl := new(rateLimiter)
opts := rateLimiterOptsDefault
err := Deserialize(optsRaw, &opts)
if err != nil {
return nil, err
}
switch {
case opts.Average == 0:
return nil, ErrZeroValue.Subject("average")
case opts.Burst == 0:
return nil, ErrZeroValue.Subject("burst")
case opts.Period == 0:
return nil, ErrZeroValue.Subject("period")
}
// setup implements MiddlewareWithSetup.
func (rl *rateLimiter) setup() {
rl.RateLimiterOpts = rateLimiterOptsDefault
rl.requestMap = make(requestMap, 0)
rl.newLimiter = func() *rate.Limiter {
return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst)
}
rl.m = &Middleware{
impl: rl,
before: rl.limit,
}
return rl.m, nil
}
func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) {
// before implements RequestModifier.
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
return rl.limit(w, r)
}
func (rl *rateLimiter) newLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
}
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
rl.mu.Lock()
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
rl.m.Debug().Msgf("unable to parse remote address %s", r.RemoteAddr)
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
http.Error(w, "Internal error", http.StatusInternalServerError)
return
return false
}
limiter, ok := rl.requestMap[host]
@ -79,9 +67,9 @@ func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request
rl.mu.Unlock()
if limiter.Allow() {
next(w, r)
return
return true
}
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return false
}

View file

@ -14,7 +14,7 @@ func TestRateLimit(t *testing.T) {
"period": "1s",
}
rl, err := NewRateLimiter(opts)
rl, err := RateLimiter.New(opts)
ExpectNoError(t, err)
for range 10 {
result, err := newMiddlewareTest(rl, nil)

View file

@ -2,24 +2,24 @@ package middleware
import (
"net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
type realIP struct {
realIPOpts
m *Middleware
}
type realIPOpts struct {
type (
realIP struct {
RealIPOpts
*Tracer
}
RealIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string `validate:"required"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR `validate:"min=1"`
From []*types.CIDR `validate:"required,min=1"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
@ -29,31 +29,26 @@ type realIPOpts struct {
the last non-trusted address sent in the request header field.
*/
Recursive bool
}
}
)
var (
RealIP = &Middleware{withOptions: NewRealIP}
realIPOptsDefault = realIPOpts{
RealIP = NewMiddleware[realIP]()
realIPOptsDefault = RealIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{},
}
)
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
riWithOpts := new(realIP)
riWithOpts.m = &Middleware{
impl: riWithOpts,
before: Rewrite(riWithOpts.setRealIP),
}
riWithOpts.realIPOpts = realIPOptsDefault
err := Deserialize(opts, &riWithOpts.realIPOpts)
if err != nil {
return nil, err
}
if len(riWithOpts.From) == 0 {
return nil, E.New("no allowed CIDRs").Subject("from")
}
return riWithOpts.m, nil
// setup implements MiddlewareWithSetup.
func (ri *realIP) setup() {
ri.RealIPOpts = realIPOptsDefault
}
// before implements RequestModifier.
func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
ri.setRealIP(r)
return true
}
func (ri *realIP) isInCIDRList(ip net.IP) bool {
@ -66,7 +61,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
return false
}
func (ri *realIP) setRealIP(req *Request) {
func (ri *realIP) setRealIP(req *http.Request) {
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
clientIPStr = req.RemoteAddr
@ -82,7 +77,7 @@ func (ri *realIP) setRealIP(req *Request) {
}
}
if !isTrusted {
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
return
}
@ -90,7 +85,7 @@ func (ri *realIP) setRealIP(req *Request) {
var lastNonTrustedIP string
if len(realIPs) == 0 {
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
return
}
@ -105,12 +100,12 @@ func (ri *realIP) setRealIP(req *Request) {
}
if lastNonTrustedIP == "" {
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
return
}
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP)
}

View file

@ -21,7 +21,7 @@ func TestSetRealIPOpts(t *testing.T) {
},
"recursive": true,
}
optExpected := &realIPOpts{
optExpected := &RealIPOpts{
Header: gphttp.HeaderXRealIP,
From: []*types.CIDR{
{
@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
Recursive: true,
}
ri, err := NewRealIP(opts)
ri, err := RealIP.New(opts)
ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
@ -61,18 +61,17 @@ func TestSetRealIP(t *testing.T) {
optsMr := OptionsRaw{
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := NewRealIP(opts)
realip, err := RealIP.New(opts)
ExpectNoError(t, err)
mr, err := NewModifyRequest(optsMr)
mr, err := ModifyRequest.New(optsMr)
ExpectNoError(t, err)
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err)
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP)
}

View file

@ -7,19 +7,22 @@ import (
"github.com/yusing/go-proxy/internal/common"
)
var RedirectHTTP = &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if r.TLS == nil {
type redirectHTTP struct{}
var RedirectHTTP = NewMiddleware[redirectHTTP]()
// before implements RequestModifier.
func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if r.TLS != nil {
return true
}
r.URL.Scheme = "https"
host := r.Host
if i := strings.Index(host, ":"); i != -1 {
host = host[:i] // strip port number if present
}
r.URL.Host = host + ":" + common.ProxyHTTPSPort
logger.Info().Str("url", r.URL.String()).Msg("redirect to https")
logger.Debug().Str("url", r.URL.String()).Msg("redirect to https")
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
next(w, r)
},
return true
}

View file

@ -0,0 +1,34 @@
package middleware
import (
"net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
// internal use only.
type setUpstreamHeaders struct {
Scheme, Host, Port string
}
var suh = NewMiddleware[setUpstreamHeaders]()
func newSetUpstreamHeaders(rp *gphttp.ReverseProxy) *Middleware {
m, err := suh.New(OptionsRaw{
"scheme": rp.TargetURL.Scheme,
"host": rp.TargetURL.Hostname(),
"port": rp.TargetURL.Port(),
})
if err != nil {
panic(err)
}
return m
}
// before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
return true
}

View file

@ -141,7 +141,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
mid, setOptErr := middleware.New(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}

View file

@ -1,16 +1,14 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Trace struct {
type (
Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
@ -19,9 +17,9 @@ type Trace struct {
RespHeaders map[string]string `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}
type Traces []*Trace
}
Traces []*Trace
)
var (
traces = make(Traces, 0)
@ -34,7 +32,7 @@ func GetAllTrace() []*Trace {
return traces
}
func (tr *Trace) WithRequest(req *Request) *Trace {
func (tr *Trace) WithRequest(req *http.Request) *Trace {
if tr == nil {
return nil
}
@ -78,39 +76,6 @@ func (tr *Trace) WithError(err error) *Trace {
return tr
}
func (m *Middleware) EnableTrace() {
m.trace = true
for _, child := range m.children {
child.parent = m
child.EnableTrace()
}
}
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
if !m.trace {
return nil
}
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: m.Fullname(),
Message: fmt.Sprintf(msg, args...),
})
}
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithRequest(req)
}
func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithResponse(resp)
}
func addTrace(t *Trace) *Trace {
tracesMu.Lock()
defer tracesMu.Unlock()

View file

@ -0,0 +1,50 @@
package middleware
import (
"fmt"
"net/http"
"time"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Tracer struct {
name string
parent *Tracer
}
func (t *Tracer) Fullname() string {
if t.parent != nil {
return t.parent.Fullname() + "." + t.name
}
return t.name
}
func (t *Tracer) addTrace(msg string) *Trace {
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: t.Fullname(),
Message: msg,
})
}
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
if t == nil {
return nil
}
return t.addTrace(fmt.Sprintf(msg, args...))
}
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
if t == nil {
return nil
}
return t.addTrace(msg).WithRequest(req)
}
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
if t == nil {
return nil
}
return t.addTrace(msg).WithResponse(resp)
}

View file

@ -11,8 +11,8 @@ import (
)
type (
reqVarGetter func(*Request) string
respVarGetter func(*Response) string
reqVarGetter func(*http.Request) string
respVarGetter func(*http.Response) string
)
var (
@ -49,50 +49,50 @@ const (
)
var staticReqVarSubsMap = map[string]reqVarGetter{
VarRequestMethod: func(req *Request) string { return req.Method },
VarRequestScheme: func(req *Request) string {
VarRequestMethod: func(req *http.Request) string { return req.Method },
VarRequestScheme: func(req *http.Request) string {
if req.TLS != nil {
return "https"
}
return "http"
},
VarRequestHost: func(req *Request) string {
VarRequestHost: func(req *http.Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host
}
return reqHost
},
VarRequestPort: func(req *Request) string {
VarRequestPort: func(req *http.Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
VarRequestAddr: func(req *Request) string { return req.Host },
VarRequestPath: func(req *Request) string { return req.URL.Path },
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *Request) string { return req.URL.String() },
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *Request) string {
VarRequestAddr: func(req *http.Request) string { return req.Host },
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientIP
}
return ""
},
VarRemotePort: func(req *Request) string {
VarRemotePort: func(req *http.Request) string {
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientPort
}
return ""
},
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr },
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *Request) string {
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
if upPort != "" {
@ -100,7 +100,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
}
return upHost
},
VarUpstreamURL: func(req *Request) string {
VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
if upScheme == "" {
return ""
@ -116,12 +116,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
}
var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
}
func varReplace(req *Request, resp *Response, s string) string {
func varReplace(req *http.Request, resp *http.Response, s string) string {
if req != nil {
// Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string {

View file

@ -2,27 +2,44 @@ package middleware
import (
"net"
"net/http"
"strings"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
var SetXForwarded = &Middleware{
before: Rewrite(func(req *Request) {
req.Header.Del(gphttp.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
type (
setXForwarded struct{}
hideXForwarded struct{}
)
var (
SetXForwarded = NewMiddleware[setXForwarded]()
HideXForwarded = NewMiddleware[hideXForwarded]()
)
// before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(gphttp.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
req.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
}
}),
return true
}
var HideXForwarded = &Middleware{
before: Rewrite(func(req *Request) {
for k := range req.Header {
// before implements RequestModifier.
func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
toDelete := make([]string, 0, len(r.Header))
for k := range r.Header {
if strings.HasPrefix(k, "X-Forwarded-") {
req.Header.Del(k)
toDelete = append(toDelete, k)
}
}
}),
for _, k := range toDelete {
r.Header.Del(k)
}
return true
}

View file

@ -1,8 +0,0 @@
package http
import "net/http"
type ProxyResponse struct {
*http.Response
OriginalRequest *http.Request
}

View file

@ -87,7 +87,7 @@ type ReverseProxy struct {
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*ProxyResponse) error
ModifyResponse func(*http.Response) error
HandlerFunc http.HandlerFunc
@ -251,11 +251,14 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool {
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil {
res.Request = origReq
err := p.ModifyResponse(res)
res.Request = req
if err != nil {
res.Body.Close()
p.errorHandler(rw, req, err, true)
return false
@ -264,9 +267,6 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme)
// req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname())
// req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port())
p.HandlerFunc(rw, req)
}
@ -455,13 +455,13 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
res = &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: outreq.Proto,
ProtoMajor: outreq.ProtoMajor,
ProtoMinor: outreq.ProtoMinor,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
Request: outreq,
TLS: outreq.TLS,
Request: req,
TLS: req.TLS,
}
}

View file

@ -9,26 +9,31 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestTaskCreation(t *testing.T) {
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
const (
rootTaskName = "root-task"
subTaskName = "subtask"
)
ExpectEqual(t, "root-task", rootTask.Name())
ExpectEqual(t, "subtask", subTask.Name())
func TestTaskCreation(t *testing.T) {
rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask(subTaskName)
ExpectEqual(t, rootTaskName, rootTask.Name())
ExpectEqual(t, subTaskName, subTask.Name())
}
func TestTaskCancellation(t *testing.T) {
subTaskDone := make(chan struct{})
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask(subTaskName)
go func() {
subTask.Wait()
close(subTaskDone)
}()
go rootTask.Finish("done")
go rootTask.Finish(nil)
select {
case <-subTaskDone:
@ -42,14 +47,14 @@ func TestTaskCancellation(t *testing.T) {
}
func TestOnComplete(t *testing.T) {
rootTask := GlobalTask("root-task")
task := rootTask.Subtask("test")
rootTask := GlobalTask(rootTaskName)
task := rootTask.Subtask(subTaskName)
var value atomic.Int32
task.OnFinished("set value", func() {
value.Store(1234)
})
task.Finish("done")
task.Finish(nil)
ExpectEqual(t, value.Load(), 1234)
}
@ -57,36 +62,36 @@ func TestGlobalContextWait(t *testing.T) {
testResetGlobalTask()
defer CancelGlobalContext()
rootTask := GlobalTask("root-task")
rootTask := GlobalTask(rootTaskName)
finished1, finished2 := false, false
subTask1 := rootTask.Subtask("subtask1")
subTask2 := rootTask.Subtask("subtask2")
subTask1.OnFinished("set finished", func() {
subTask1 := rootTask.Subtask(subTaskName)
subTask2 := rootTask.Subtask(subTaskName)
subTask1.OnFinished("", func() {
finished1 = true
})
subTask2.OnFinished("set finished", func() {
subTask2.OnFinished("", func() {
finished2 = true
})
go func() {
time.Sleep(500 * time.Millisecond)
subTask1.Finish("done")
subTask1.Finish(nil)
}()
go func() {
time.Sleep(500 * time.Millisecond)
subTask2.Finish("done")
subTask2.Finish(nil)
}()
go func() {
subTask1.Wait()
subTask2.Wait()
rootTask.Finish("done")
rootTask.Finish(nil)
}()
GlobalContextWait(1 * time.Second)
_ = GlobalContextWait(1 * time.Second)
ExpectTrue(t, finished1)
ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err())
@ -97,8 +102,8 @@ func TestGlobalContextWait(t *testing.T) {
func TestTimeoutOnGlobalContextWait(t *testing.T) {
testResetGlobalTask()
rootTask := GlobalTask("root-task")
rootTask.Subtask("subtask")
rootTask := GlobalTask(rootTaskName)
rootTask.Subtask(subTaskName)
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
}
@ -107,7 +112,7 @@ func TestGlobalContextCancellation(t *testing.T) {
testResetGlobalTask()
taskDone := make(chan struct{})
rootTask := GlobalTask("root-task")
rootTask := GlobalTask(rootTaskName)
go func() {
rootTask.Wait()

View file

@ -19,7 +19,8 @@ func TestSerializeDeserialize(t *testing.T) {
MIS map[int]string
}
var testStruct = S{
var (
testStruct = S{
I: 1,
S: "hello",
IS: []int{1, 2, 3},
@ -27,8 +28,7 @@ func TestSerializeDeserialize(t *testing.T) {
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
}
var testStructSerialized = map[string]any{
testStructSerialized = map[string]any{
"I": 1,
"S": "hello",
"IS": []int{1, 2, 3},
@ -36,6 +36,7 @@ func TestSerializeDeserialize(t *testing.T) {
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
}
)
t.Run("serialize", func(t *testing.T) {
s, err := Serialize(testStruct)