cleanup and simplify middleware implementations, refactor some other code

This commit is contained in:
yusing 2024-12-16 10:19:14 +08:00
parent 8a9cb2527e
commit 59f4eaf3ea
34 changed files with 641 additions and 720 deletions

View file

@ -7,12 +7,12 @@ cli:
plugins: plugins:
sources: sources:
- id: trunk - id: trunk
ref: v1.6.5 ref: v1.6.6
uri: https://github.com/trunk-io/plugins uri: https://github.com/trunk-io/plugins
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes) # Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
runtimes: runtimes:
enabled: enabled:
- node@18.12.1 - node@18.20.5
- python@3.10.8 - python@3.10.8
- go@1.23.2 - go@1.23.2
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration) # This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
@ -23,16 +23,16 @@ lint:
enabled: enabled:
- hadolint@2.12.1-beta - hadolint@2.12.1-beta
- actionlint@1.7.4 - actionlint@1.7.4
- checkov@3.2.324 - checkov@3.2.334
- git-diff-check - git-diff-check
- gofmt@1.20.4 - gofmt@1.20.4
- golangci-lint@1.62.2 - golangci-lint@1.62.2
- osv-scanner@1.9.1 - osv-scanner@1.9.1
- oxipng@9.1.3 - oxipng@9.1.3
- prettier@3.4.1 - prettier@3.4.2
- shellcheck@0.10.0 - shellcheck@0.10.0
- shfmt@3.6.0 - shfmt@3.6.0
- trufflehog@3.84.1 - trufflehog@3.86.1
actions: actions:
disabled: disabled:
- trunk-announce - trunk-announce

View file

@ -60,7 +60,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
} }
func rateLimited(f http.HandlerFunc) http.HandlerFunc { func rateLimited(f http.HandlerFunc) http.HandlerFunc {
m, err := middleware.RateLimiter.WithOptionsClone(middleware.OptionsRaw{ m, err := middleware.RateLimiter.New(middleware.OptionsRaw{
"average": 10, "average": 10,
"burst": 10, "burst": 10,
}) })

View file

@ -75,7 +75,7 @@ func Handler(w http.ResponseWriter, r *http.Request) {
// On nginx, when route for domain does not exist, it returns StatusBadGateway. // On nginx, when route for domain does not exist, it returns StatusBadGateway.
// Then scraper / scanners will know the subdomain is invalid. // Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if !middleware.ServeStaticErrorPageFile(w, r) { if served := middleware.ServeStaticErrorPageFile(w, r); !served {
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request") logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok { if ok {

View file

@ -16,6 +16,9 @@ const (
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme" HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host" HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port" HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
HeaderContentType = "Content-Type"
HeaderContentLength = "Content-Length"
) )
func RemoveHop(h http.Header) { func RemoveHop(h http.Header) {

View file

@ -24,7 +24,7 @@ func (lb *LoadBalancer) newIPHash() impl {
return impl return impl
} }
var err E.Error var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options) impl.realIP, err = middleware.RealIP.New(lb.Options)
if err != nil { if err != nil {
E.LogError("invalid real_ip options, ignoring", err, &impl.l) E.LogError("invalid real_ip options, ignoring", err, &impl.l)
} }

View file

@ -4,48 +4,45 @@ import (
"net" "net"
"net/http" "net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
) )
type cidrWhitelist struct { type (
cidrWhitelistOpts cidrWhitelist struct {
m *Middleware CIDRWhitelistOpts
*Tracer
cachedAddr F.Map[string, bool] // cache for trusted IPs cachedAddr F.Map[string, bool] // cache for trusted IPs
} }
CIDRWhitelistOpts struct {
type cidrWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"` Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"` StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"`
Message string Message string
} }
)
var ( var (
CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist} CIDRWhiteList = NewMiddleware[cidrWhitelist]()
cidrWhitelistDefaults = cidrWhitelistOpts{ cidrWhitelistDefaults = CIDRWhitelistOpts{
Allow: []*types.CIDR{}, Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden, StatusCode: http.StatusForbidden,
Message: "IP not allowed", Message: "IP not allowed",
} }
) )
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) { // setup implements MiddlewareWithSetup.
wl := new(cidrWhitelist) func (wl *cidrWhitelist) setup() {
wl.m = &Middleware{ wl.CIDRWhitelistOpts = cidrWhitelistDefaults
impl: wl,
before: wl.checkIP,
}
wl.cidrWhitelistOpts = cidrWhitelistDefaults
wl.cachedAddr = F.NewMapOf[string, bool]() wl.cachedAddr = F.NewMapOf[string, bool]()
err := Deserialize(opts, &wl.cidrWhitelistOpts)
if err != nil {
return nil, err
}
return wl.m, nil
} }
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) { // before implements RequestModifier.
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
return wl.checkIP(w, r)
}
// checkIP checks if the IP address is allowed.
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
var allow, ok bool var allow, ok bool
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok { if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr) ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
@ -53,24 +50,23 @@ func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Req
ipStr = r.RemoteAddr ipStr = r.RemoteAddr
} }
ip := net.ParseIP(ipStr) ip := net.ParseIP(ipStr)
for _, cidr := range wl.cidrWhitelistOpts.Allow { for _, cidr := range wl.CIDRWhitelistOpts.Allow {
if cidr.Contains(ip) { if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true) wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true allow = true
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr) wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
break break
} }
} }
if !allow { if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false) wl.cachedAddr.Store(r.RemoteAddr, false)
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow) wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
} }
} }
if !allow { if !allow {
w.WriteHeader(wl.StatusCode) http.Error(w, wl.Message, wl.StatusCode)
w.Write([]byte(wl.Message)) return false
return
} }
next(w, r) return true
} }

View file

@ -17,27 +17,27 @@ var deny, accept *Middleware
func TestCIDRWhitelistValidation(t *testing.T) { func TestCIDRWhitelistValidation(t *testing.T) {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{ _, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/32"}, "allow": []string{"1.2.3.4/32"},
"message": "test-message", "message": "test-message",
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
}) })
t.Run("missing allow", func(t *testing.T) { t.Run("missing allow", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{ _, err := CIDRWhiteList.New(OptionsRaw{
"message": "test-message", "message": "test-message",
}) })
ExpectError(t, utils.ErrValidationError, err) ExpectError(t, utils.ErrValidationError, err)
}) })
t.Run("invalid cidr", func(t *testing.T) { t.Run("invalid cidr", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{ _, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/123"}, "allow": []string{"1.2.3.4/123"},
"message": "test-message", "message": "test-message",
}) })
ExpectErrorT[*net.ParseError](t, err) ExpectErrorT[*net.ParseError](t, err)
}) })
t.Run("invalid status code", func(t *testing.T) { t.Run("invalid status code", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{ _, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"1.2.3.4/32"}, "allow": []string{"1.2.3.4/32"},
"status_code": 600, "status_code": 600,
"message": "test-message", "message": "test-message",

View file

@ -11,11 +11,14 @@ import (
"time" "time"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
type cloudflareRealIP struct {
realIP realIP
}
const ( const (
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4" cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6" cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
@ -29,26 +32,23 @@ var (
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger() cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
) )
var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP} var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) { // setup implements MiddlewareWithSetup.
cri := new(realIP) func (cri *cloudflareRealIP) setup() {
cri.m = &Middleware{ cri.realIP.RealIPOpts = RealIPOpts{
impl: cri,
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.From = cidrs
}
cri.setRealIP(r)
next(w, r)
},
}
cri.realIPOpts = realIPOpts{
Header: "CF-Connecting-IP", Header: "CF-Connecting-IP",
Recursive: true, Recursive: true,
} }
return cri.m, nil }
// before implements RequestModifier.
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.realIP.From = cidrs
}
return cri.realIP.before(w, r)
} }
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {

View file

@ -12,45 +12,38 @@ import (
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
) )
var CustomErrorPage *Middleware type customErrorPage struct{}
func init() { var CustomErrorPage = NewMiddleware[customErrorPage]()
CustomErrorPage = customErrorPage()
// before implements RequestModifier.
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
return !ServeStaticErrorPageFile(w, r)
} }
func customErrorPage() *Middleware { // modifyResponse implements ResponseModifier.
m := &Middleware{ func (customErrorPage) modifyResponse(resp *http.Response) error {
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
}
m.modifyResponse = func(resp *Response) error {
// only handles non-success status code and html/plain content type // only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header) contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok { if ok {
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode) logger.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
/* trunk-ignore(golangci-lint/errcheck) */ _, _ = io.Copy(io.Discard, resp.Body) // drain the original body
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close() resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage)) resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage)) resp.ContentLength = int64(len(errorPage))
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage))) resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8") resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
} else { } else {
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode) logger.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
} }
return nil return nil
} }
return nil return nil
}
return m
} }
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
path := r.URL.Path path := r.URL.Path
if path != "" && path[0] != '/' { if path != "" && path[0] != '/' {
path = "/" + path path = "/" + path
@ -65,11 +58,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
ext := filepath.Ext(filename) ext := filepath.Ext(filename)
switch ext { switch ext {
case ".html": case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
case ".js": case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8") w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
case ".css": case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8") w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
default: default:
logger.Error().Msgf("unexpected file type %q for %s", ext, filename) logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
} }

View file

@ -1,5 +0,0 @@
package middleware
import E "github.com/yusing/go-proxy/internal/error"
var ErrZeroValue = E.New("cannot be zero")

View file

@ -12,16 +12,17 @@ import (
"strings" "strings"
"time" "time"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
F "github.com/yusing/go-proxy/internal/utils/functional"
) )
type ( type (
forwardAuth struct { forwardAuth struct {
forwardAuthOpts ForwardAuthOpts
m *Middleware *Tracer
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
} }
forwardAuthOpts struct { ForwardAuthOpts struct {
Address string `validate:"url,required"` Address string `validate:"url,required"`
TrustForwardHeader bool TrustForwardHeader bool
AuthResponseHeaders []string AuthResponseHeaders []string
@ -29,36 +30,30 @@ type (
} }
) )
var ForwardAuth = &Middleware{withOptions: NewForwardAuth} var ForwardAuth = NewMiddleware[forwardAuth]()
var faHTTPClient = &http.Client{ var faHTTPClient = &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
CheckRedirect: func(r *Request, via []*Request) error { CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse return http.ErrUseLastResponse
}, },
} }
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) { // setup implements MiddlewareWithSetup.
fa := new(forwardAuth) func (fa *forwardAuth) setup() {
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil { fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
return nil, err
}
fa.m = &Middleware{
impl: fa,
before: fa.forward,
}
return fa.m, nil
} }
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { // before implements RequestModifier.
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
gphttp.RemoveHop(req.Header) gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect // Construct original URL for the redirect
// scheme := "http" scheme := "http"
// if req.TLS != nil { if req.TLS != nil {
// scheme = "https" scheme = "https"
// } }
// originalURL := scheme + "://" + req.Host + req.RequestURI originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address url := fa.Address
faReq, err := http.NewRequestWithContext( faReq, err := http.NewRequestWithContext(
@ -68,7 +63,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
nil, nil,
) )
if err != nil { if err != nil {
fa.m.AddTracef("new request err to %s", url).WithError(err) fa.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@ -79,12 +74,12 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders) faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq) fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik // Set headers needed by Authentik
// faReq.Header.Set("X-Original-URL", originalURL) faReq.Header.Set("X-Original-Url", originalURL)
fa.m.AddTraceRequest("forward auth request", faReq) fa.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq) faResp, err := faHTTPClient.Do(faReq)
if err != nil { if err != nil {
fa.m.AddTracef("failed to call %s", url).WithError(err) fa.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@ -92,30 +87,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
body, err := io.ReadAll(faResp.Body) body, err := io.ReadAll(faResp.Body)
if err != nil { if err != nil {
fa.m.AddTracef("failed to read response body from %s", url).WithError(err) fa.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices { if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.m.AddTraceResponse("forward auth response", faResp) fa.AddTraceResponse("forward auth response", faResp)
gphttp.CopyHeader(w.Header(), faResp.Header) gphttp.CopyHeader(w.Header(), faResp.Header)
gphttp.RemoveHop(w.Header()) gphttp.RemoveHop(w.Header())
redirectURL, err := faResp.Location() redirectURL, err := faResp.Location()
if err != nil { if err != nil {
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp) fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} else if redirectURL.String() != "" { } else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String()) w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("%s", "redirect to "+redirectURL.String()) fa.AddTracef("%s", "redirect to "+redirectURL.String())
} }
w.WriteHeader(faResp.StatusCode) w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil { if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp) fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
} }
return return
} }
@ -132,18 +127,22 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
authCookies := faResp.Cookies() authCookies := faResp.Cookies()
if len(authCookies) == 0 { if len(authCookies) > 0 {
next.ServeHTTP(w, req) fa.reqCookiesMap.Store(req, authCookies)
return
} }
return true
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
} }
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) { // modifyResponse implements ResponseModifier.
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
fa.setAuthCookies(resp, cookies)
fa.reqCookiesMap.Delete(resp.Request)
}
return nil
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 { if len(fa.AddAuthCookiesToResponse) == 0 {
return return
} }
@ -166,7 +165,7 @@ func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie
} }
} }
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) { func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader { if fa.TrustForwardHeader {
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok { if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {

View file

@ -3,11 +3,12 @@ package middleware
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"reflect"
"strings"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
type ( type (
@ -15,58 +16,97 @@ type (
ReverseProxy = gphttp.ReverseProxy ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest ProxyRequest = gphttp.ProxyRequest
Request = http.Request
Response = gphttp.ProxyResponse
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(*Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
ImplNewFunc = func() any
OptionsRaw = map[string]any OptionsRaw = map[string]any
Middleware struct { Middleware struct {
_ U.NoCopy
zerolog.Logger
name string name string
construct ImplNewFunc
before BeforeFunc // runs before ReverseProxy.ServeHTTP
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
withOptions CloneWithOptFunc
impl any impl any
parent *Middleware
children []*Middleware
trace bool
} }
RequestModifier interface {
before(w http.ResponseWriter, r *http.Request) (proceed bool)
}
ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() }
MiddlewareWithTracer *struct{ *Tracer }
) )
var Deserialize = U.Deserialize func NewMiddleware[ImplType any]() *Middleware {
// type check
func Rewrite(r RewriteFunc) BeforeFunc { switch any(new(ImplType)).(type) {
return func(next http.HandlerFunc, w ResponseWriter, req *Request) { case RequestModifier:
r(req) case ResponseModifier:
next(w, req) default:
panic("must implement RequestModifier or ResponseModifier")
} }
return &Middleware{
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
construct: func() any { return new(ImplType) },
}
}
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.Tracer = &Tracer{name: m.name}
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.Tracer
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.parent = parent.getTracer()
}
}
func (m *Middleware) setup() {
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
setup.setup()
}
}
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
if len(optsRaw) == 0 {
return nil
}
return utils.Deserialize(optsRaw, m.impl)
}
func (m *Middleware) finalize() {
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize()
}
}
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.construct == nil {
if optsRaw != nil {
panic("bug: middleware already constructed")
}
return m, nil
}
mid := &Middleware{name: m.name, impl: m.construct()}
mid.setup()
if err := mid.apply(optsRaw); err != nil {
return nil, err
}
mid.finalize()
return mid, nil
} }
func (m *Middleware) Name() string { func (m *Middleware) Name() string {
return m.name return m.name
} }
func (m *Middleware) Fullname() string {
if m.parent != nil {
return m.parent.Fullname() + "." + m.name
}
return m.name
}
func (m *Middleware) String() string { func (m *Middleware) String() string {
return m.name return m.name
} }
@ -78,57 +118,38 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}, "", " ") }, "", " ")
} }
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) { func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if m.withOptions != nil { if exec, ok := m.impl.(RequestModifier); ok {
m, err := m.withOptions(optsRaw) if proceed := exec.before(w, r); !proceed {
if err != nil { return
return nil, err
} }
m.Logger = logger.With().Str("name", m.name).Logger()
return m, nil
} }
next(w, r)
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
Logger: logger.With().Str("name", m.name).Logger(),
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
impl: m.impl,
parent: m.parent,
children: m.children,
}, nil
} }
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) { func (m *Middleware) ModifyResponse(resp *http.Response) error {
if m.before != nil { if exec, ok := m.impl.(ResponseModifier); ok {
m.before(next, w, r) return exec.modifyResponse(resp)
}
}
func (m *Middleware) ModifyResponse(resp *Response) error {
if m.modifyResponse != nil {
return m.modifyResponse(resp)
} }
return nil return nil
} }
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) { func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if m.modifyResponse != nil { if exec, ok := m.impl.(ResponseModifier); ok {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r}) return exec.modifyResponse(resp)
}) })
} }
if m.before != nil { if exec, ok := m.impl.(RequestModifier); ok {
m.before(next, w, r) if proceed := exec.before(w, r); !proceed {
} else { return
next(w, r)
} }
}
next(w, r)
} }
// TODO: check conflict or duplicates. // TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) { func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap)) middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := E.NewBuilder("middlewares compile error") errs := E.NewBuilder("middlewares compile error")
@ -141,7 +162,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
continue continue
} }
m, err = m.WithOptionsClone(opts) m, err = m.New(opts)
if err != nil { if err != nil {
invalidOpts.Add(err.Subject(name)) invalidOpts.Add(err.Subject(name))
continue continue
@ -157,7 +178,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) { func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
var middlewares []*Middleware var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap) middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil { if err != nil {
return return
} }
@ -166,34 +187,30 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
} }
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{ middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
name: "set_upstream_headers",
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { mid := NewMiddlewareChain(rp.TargetName, middlewares)
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname()) if before, ok := mid.impl.(RequestModifier); ok {
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port()) next := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
if proceed := before.before(w, r); proceed {
next(w, r) next(w, r)
}, }
}}, middlewares...))
if mid.before != nil {
ori := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
mid.before(ori, w, r)
} }
} }
if mid.modifyResponse != nil { if mr, ok := mid.impl.(ResponseModifier); ok {
if rp.ModifyResponse != nil { if rp.ModifyResponse != nil {
ori := rp.ModifyResponse ori := rp.ModifyResponse
rp.ModifyResponse = func(res *Response) error { rp.ModifyResponse = func(res *http.Response) error {
if err := mid.modifyResponse(res); err != nil { if err := mr.modifyResponse(res); err != nil {
return err return err
} }
return ori(res) return ori(res)
} }
} else { } else {
rp.ModifyResponse = mid.modifyResponse rp.ModifyResponse = mr.modifyResponse
} }
} }
} }

View file

@ -2,11 +2,9 @@ package middleware
import ( import (
"fmt" "fmt"
"net/http"
"os" "os"
"path" "path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -56,7 +54,7 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
continue continue
} }
delete(def, "use") delete(def, "use")
m, err := base.WithOptionsClone(def) m, err := base.New(def)
if err != nil { if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i)) chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue continue
@ -67,56 +65,5 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
if chainErr.HasError() { if chainErr.HasError() {
return nil, chainErr.Error() return nil, chainErr.Error()
} }
return BuildMiddlewareFromChain(name, chain), nil return NewMiddlewareChain(name, chain), nil
}
// TODO: check conflict or duplicates.
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
m := &Middleware{name: name, children: chain}
var befores []*Middleware
var modResps []*Middleware
for _, comp := range chain {
if comp.before != nil {
befores = append(befores, comp)
}
if comp.modifyResponse != nil {
modResps = append(modResps, comp)
}
comp.parent = m
}
if len(befores) > 0 {
m.before = buildBefores(befores)
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
errs := E.NewBuilder("modify response errors")
for _, mr := range modResps {
if err := mr.modifyResponse(res); err != nil {
errs.Add(E.From(err).Subject(mr.name))
}
}
return errs.Error()
}
}
if common.IsDebug {
m.EnableTrace()
m.AddTracef("middleware created")
}
return m
}
func buildBefores(befores []*Middleware) BeforeFunc {
if len(befores) == 1 {
return befores[0].before
}
nextBefores := buildBefores(befores[1:])
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
befores[0].before(func(w ResponseWriter, r *Request) {
nextBefores(next, w, r)
}, w, r)
}
} }

View file

@ -0,0 +1,58 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
type middlewareChain struct {
befores []RequestModifier
modResps []ResponseModifier
}
// TODO: check conflict or duplicates.
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
m := &Middleware{name: name, impl: chainMid}
for _, comp := range chain {
if before, ok := comp.impl.(RequestModifier); ok {
chainMid.befores = append(chainMid.befores, before)
}
if mr, ok := comp.impl.(ResponseModifier); ok {
chainMid.modResps = append(chainMid.modResps, mr)
}
comp.setParent(m)
}
if common.IsDebug {
m.enableTrace()
}
return m
}
// before implements RequestModifier.
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
for _, b := range m.befores {
if proceedNext = b.before(w, r); !proceedNext {
return false
}
}
return true
}
// modifyResponse implements ResponseModifier.
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
if len(m.modResps) == 0 {
return nil
}
errs := E.NewBuilder("modify response errors")
for i, mr := range m.modResps {
if err := mr.modifyResponse(resp); err != nil {
errs.Add(E.From(err).Subjectf("%d", i))
}
}
return errs.Error()
}

View file

@ -3,45 +3,39 @@ package middleware
import ( import (
"net/http" "net/http"
"strings" "strings"
E "github.com/yusing/go-proxy/internal/error"
) )
type ( type (
modifyRequest struct { modifyRequest struct {
modifyRequestOpts ModifyRequestOpts
m *Middleware *Tracer
needVarSubstitution bool
} }
// order: set_headers -> add_headers -> hide_headers // order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct { ModifyRequestOpts struct {
SetHeaders map[string]string SetHeaders map[string]string
AddHeaders map[string]string AddHeaders map[string]string
HideHeaders []string HideHeaders []string
needVarSubstitution bool
} }
) )
var ModifyRequest = &Middleware{withOptions: NewModifyRequest} var ModifyRequest = NewMiddleware[modifyRequest]()
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) { // finalize implements MiddlewareFinalizer.
mr := new(modifyRequest) func (mr *ModifyRequestOpts) finalize() {
mr.m = &Middleware{
impl: mr,
before: Rewrite(func(req *Request) {
mr.m.AddTraceRequest("before modify request", req)
mr.modifyHeaders(req, nil, req.Header)
mr.m.AddTraceRequest("after modify request", req)
}),
}
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}
mr.checkVarSubstitution() mr.checkVarSubstitution()
return mr.m, nil
} }
func (mr *modifyRequest) checkVarSubstitution() { // before implements RequestModifier.
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
mr.AddTraceRequest("before modify request", r)
mr.modifyHeaders(r, nil, r.Header)
mr.AddTraceRequest("after modify request", r)
return true
}
func (mr *ModifyRequestOpts) checkVarSubstitution() {
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} { for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
for _, v := range m { for _, v := range m {
if strings.ContainsRune(v, '$') { if strings.ContainsRune(v, '$') {
@ -52,10 +46,10 @@ func (mr *modifyRequest) checkVarSubstitution() {
} }
} }
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) { func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
if !mr.needVarSubstitution { if !mr.needVarSubstitution {
for k, v := range mr.SetHeaders { for k, v := range mr.SetHeaders {
if req != nil && strings.ToLower(k) == "host" { if req != nil && strings.EqualFold(k, "host") {
defer func() { defer func() {
req.Host = v req.Host = v
}() }()
@ -67,7 +61,7 @@ func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers htt
} }
} else { } else {
for k, v := range mr.SetHeaders { for k, v := range mr.SetHeaders {
if req != nil && strings.ToLower(k) == "host" { if req != nil && strings.EqualFold(k, "host") {
defer func() { defer func() {
req.Host = varReplace(req, resp, v) req.Host = varReplace(req, resp, v)
}() }()

View file

@ -43,7 +43,7 @@ func TestModifyRequest(t *testing.T) {
} }
t.Run("set_options", func(t *testing.T) { t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.WithOptionsClone(opts) mr, err := ModifyRequest.New(opts)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -2,32 +2,19 @@ package middleware
import ( import (
"net/http" "net/http"
E "github.com/yusing/go-proxy/internal/error"
) )
type modifyResponse = modifyRequest type modifyResponse struct {
ModifyRequestOpts
var ModifyResponse = &Middleware{withOptions: NewModifyResponse} *Tracer
}
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse) var ModifyResponse = NewMiddleware[modifyResponse]()
mr.m = &Middleware{
impl: mr, // modifyResponse implements ResponseModifier.
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
next(w, r) mr.AddTraceResponse("before modify response", resp)
}, mr.modifyHeaders(resp.Request, resp, resp.Header)
modifyResponse: func(resp *Response) error { mr.AddTraceResponse("after modify response", resp)
mr.m.AddTraceResponse("before modify response", resp.Response) return nil
mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header)
mr.m.AddTraceResponse("after modify response", resp.Response)
return nil
},
}
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}
mr.checkVarSubstitution()
return mr.m, nil
} }

View file

@ -46,7 +46,7 @@ func TestModifyResponse(t *testing.T) {
} }
t.Run("set_options", func(t *testing.T) { t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.WithOptionsClone(opts) mr, err := ModifyResponse.New(opts)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -1,117 +0,0 @@
package middleware
// import (
// "encoding/json"
// "fmt"
// "net/http"
// "net/url"
// E "github.com/yusing/go-proxy/internal/error"
// )
// type oAuth2 struct {
// oAuth2Opts
// m *Middleware
// }
// type oAuth2Opts struct {
// ClientID string `validate:"required"`
// ClientSecret string `validate:"required"`
// AuthURL string `validate:"required"` // Authorization Endpoint
// TokenURL string `validate:"required"` // Token Endpoint
// }
// var OAuth2 = &oAuth2{
// m: &Middleware{withOptions: NewAuthentikOAuth2},
// }
// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
// oauth := new(oAuth2)
// oauth.m = &Middleware{
// impl: oauth,
// before: oauth.handleOAuth2,
// }
// err := Deserialize(opts, &oauth.oAuth2Opts)
// if err != nil {
// return nil, err
// }
// return oauth.m, nil
// }
// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// // Check if the user is authenticated (you may use session, cookie, etc.)
// if !userIsAuthenticated(r) {
// // TODO: Redirect to OAuth2 auth URL
// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
// return
// }
// // If you have a token in the query string, process it
// if code := r.URL.Query().Get("code"); code != "" {
// // Exchange the authorization code for a token here
// // Use the TokenURL and authenticate the user
// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
// if err != nil {
// // handle error
// http.Error(rw, "failed to get token", http.StatusUnauthorized)
// return
// }
// // Save token and user info based on your requirements
// saveToken(rw, token)
// // Redirect to the originally requested URL
// http.Redirect(rw, r, "/", http.StatusFound)
// return
// }
// // If user is authenticated, go to the next handler
// next(rw, r)
// }
// func userIsAuthenticated(r *http.Request) bool {
// // Example: Check for a session or cookie
// session, err := r.Cookie("session_token")
// if err != nil || session.Value == "" {
// return false
// }
// // Validate the session_token if necessary
// return true
// }
// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// // Prepare the request body
// data := url.Values{
// "client_id": {opts.ClientID},
// "client_secret": {opts.ClientSecret},
// "code": {code},
// "grant_type": {"authorization_code"},
// "redirect_uri": {requestURI},
// }
// resp, err := http.PostForm(opts.TokenURL, data)
// if err != nil {
// return "", fmt.Errorf("failed to request token: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
// }
// // Decode the response
// var tokenResp struct {
// AccessToken string `json:"access_token"`
// }
// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
// return "", fmt.Errorf("failed to decode token response: %w", err)
// }
// return tokenResp.AccessToken, nil
// }
// func saveToken(rw ResponseWriter, token string) {
// // Example: Save token in cookie
// http.SetCookie(rw, &http.Cookie{
// Name: "auth_token",
// Value: token,
// // set other properties as necessary, such as Secure and HttpOnly
// })
// }

View file

@ -6,68 +6,56 @@ import (
"sync" "sync"
"time" "time"
E "github.com/yusing/go-proxy/internal/error"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
type ( type (
requestMap = map[string]*rate.Limiter requestMap = map[string]*rate.Limiter
rateLimiter struct { rateLimiter struct {
requestMap requestMap RateLimiterOpts
newLimiter func() *rate.Limiter *Tracer
m *Middleware
requestMap requestMap
mu sync.Mutex mu sync.Mutex
} }
rateLimiterOpts struct { RateLimiterOpts struct {
Average int `validate:"min=1,required"` Average int `validate:"min=1,required"`
Burst int `validate:"min=1,required"` Burst int `validate:"min=1,required"`
Period time.Duration Period time.Duration `validate:"min=1s"`
} }
) )
var ( var (
RateLimiter = &Middleware{withOptions: NewRateLimiter} RateLimiter = NewMiddleware[rateLimiter]()
rateLimiterOptsDefault = rateLimiterOpts{ rateLimiterOptsDefault = RateLimiterOpts{
Period: time.Second, Period: time.Second,
} }
) )
func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) { // setup implements MiddlewareWithSetup.
rl := new(rateLimiter) func (rl *rateLimiter) setup() {
opts := rateLimiterOptsDefault rl.RateLimiterOpts = rateLimiterOptsDefault
err := Deserialize(optsRaw, &opts)
if err != nil {
return nil, err
}
switch {
case opts.Average == 0:
return nil, ErrZeroValue.Subject("average")
case opts.Burst == 0:
return nil, ErrZeroValue.Subject("burst")
case opts.Period == 0:
return nil, ErrZeroValue.Subject("period")
}
rl.requestMap = make(requestMap, 0) rl.requestMap = make(requestMap, 0)
rl.newLimiter = func() *rate.Limiter {
return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst)
}
rl.m = &Middleware{
impl: rl,
before: rl.limit,
}
return rl.m, nil
} }
func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) { // before implements RequestModifier.
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
return rl.limit(w, r)
}
func (rl *rateLimiter) newLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
}
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
rl.mu.Lock() rl.mu.Lock()
host, _, err := net.SplitHostPort(r.RemoteAddr) host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
rl.m.Debug().Msgf("unable to parse remote address %s", r.RemoteAddr) rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
http.Error(w, "Internal error", http.StatusInternalServerError) http.Error(w, "Internal error", http.StatusInternalServerError)
return return false
} }
limiter, ok := rl.requestMap[host] limiter, ok := rl.requestMap[host]
@ -79,9 +67,9 @@ func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request
rl.mu.Unlock() rl.mu.Unlock()
if limiter.Allow() { if limiter.Allow() {
next(w, r) return true
return
} }
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return false
} }

View file

@ -14,7 +14,7 @@ func TestRateLimit(t *testing.T) {
"period": "1s", "period": "1s",
} }
rl, err := NewRateLimiter(opts) rl, err := RateLimiter.New(opts)
ExpectNoError(t, err) ExpectNoError(t, err)
for range 10 { for range 10 {
result, err := newMiddlewareTest(rl, nil) result, err := newMiddlewareTest(rl, nil)

View file

@ -2,24 +2,24 @@ package middleware
import ( import (
"net" "net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
) )
// https://nginx.org/en/docs/http/ngx_http_realip_module.html // https://nginx.org/en/docs/http/ngx_http_realip_module.html
type realIP struct { type (
realIPOpts realIP struct {
m *Middleware RealIPOpts
} *Tracer
}
type realIPOpts struct { RealIPOpts struct {
// Header is the name of the header to use for the real client IP // Header is the name of the header to use for the real client IP
Header string `validate:"required"` Header string `validate:"required"`
// From is a list of Address / CIDRs to trust // From is a list of Address / CIDRs to trust
From []*types.CIDR `validate:"min=1"` From []*types.CIDR `validate:"required,min=1"`
/* /*
If recursive search is disabled, If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by the original client address that matches one of the trusted addresses is replaced by
@ -29,31 +29,26 @@ type realIPOpts struct {
the last non-trusted address sent in the request header field. the last non-trusted address sent in the request header field.
*/ */
Recursive bool Recursive bool
} }
)
var ( var (
RealIP = &Middleware{withOptions: NewRealIP} RealIP = NewMiddleware[realIP]()
realIPOptsDefault = realIPOpts{ realIPOptsDefault = RealIPOpts{
Header: "X-Real-IP", Header: "X-Real-IP",
From: []*types.CIDR{}, From: []*types.CIDR{},
} }
) )
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) { // setup implements MiddlewareWithSetup.
riWithOpts := new(realIP) func (ri *realIP) setup() {
riWithOpts.m = &Middleware{ ri.RealIPOpts = realIPOptsDefault
impl: riWithOpts, }
before: Rewrite(riWithOpts.setRealIP),
} // before implements RequestModifier.
riWithOpts.realIPOpts = realIPOptsDefault func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
err := Deserialize(opts, &riWithOpts.realIPOpts) ri.setRealIP(r)
if err != nil { return true
return nil, err
}
if len(riWithOpts.From) == 0 {
return nil, E.New("no allowed CIDRs").Subject("from")
}
return riWithOpts.m, nil
} }
func (ri *realIP) isInCIDRList(ip net.IP) bool { func (ri *realIP) isInCIDRList(ip net.IP) bool {
@ -66,7 +61,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
return false return false
} }
func (ri *realIP) setRealIP(req *Request) { func (ri *realIP) setRealIP(req *http.Request) {
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr) clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil { if err != nil {
clientIPStr = req.RemoteAddr clientIPStr = req.RemoteAddr
@ -82,7 +77,7 @@ func (ri *realIP) setRealIP(req *Request) {
} }
} }
if !isTrusted { if !isTrusted {
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From) ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
return return
} }
@ -90,7 +85,7 @@ func (ri *realIP) setRealIP(req *Request) {
var lastNonTrustedIP string var lastNonTrustedIP string
if len(realIPs) == 0 { if len(realIPs) == 0 {
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req) ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
return return
} }
@ -105,12 +100,12 @@ func (ri *realIP) setRealIP(req *Request) {
} }
if lastNonTrustedIP == "" { if lastNonTrustedIP == "" {
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs) ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
return return
} }
req.RemoteAddr = lastNonTrustedIP req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP) req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP) req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
ri.m.AddTracef("set real ip %s", lastNonTrustedIP) ri.AddTracef("set real ip %s", lastNonTrustedIP)
} }

View file

@ -21,7 +21,7 @@ func TestSetRealIPOpts(t *testing.T) {
}, },
"recursive": true, "recursive": true,
} }
optExpected := &realIPOpts{ optExpected := &RealIPOpts{
Header: gphttp.HeaderXRealIP, Header: gphttp.HeaderXRealIP,
From: []*types.CIDR{ From: []*types.CIDR{
{ {
@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
Recursive: true, Recursive: true,
} }
ri, err := NewRealIP(opts) ri, err := RealIP.New(opts)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
@ -61,18 +61,17 @@ func TestSetRealIP(t *testing.T) {
optsMr := OptionsRaw{ optsMr := OptionsRaw{
"set_headers": map[string]string{testHeader: testRealIP}, "set_headers": map[string]string{testHeader: testRealIP},
} }
realip, err := NewRealIP(opts) realip, err := RealIP.New(opts)
ExpectNoError(t, err) ExpectNoError(t, err)
mr, err := NewModifyRequest(optsMr) mr, err := ModifyRequest.New(optsMr)
ExpectNoError(t, err) ExpectNoError(t, err)
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip}) mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil) result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err) ExpectNoError(t, err)
t.Log(traces) t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP)
} }

View file

@ -7,19 +7,22 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
) )
var RedirectHTTP = &Middleware{ type redirectHTTP struct{}
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if r.TLS == nil { var RedirectHTTP = NewMiddleware[redirectHTTP]()
// before implements RequestModifier.
func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if r.TLS != nil {
return true
}
r.URL.Scheme = "https" r.URL.Scheme = "https"
host := r.Host host := r.Host
if i := strings.Index(host, ":"); i != -1 { if i := strings.Index(host, ":"); i != -1 {
host = host[:i] // strip port number if present host = host[:i] // strip port number if present
} }
r.URL.Host = host + ":" + common.ProxyHTTPSPort r.URL.Host = host + ":" + common.ProxyHTTPSPort
logger.Info().Str("url", r.URL.String()).Msg("redirect to https") logger.Debug().Str("url", r.URL.String()).Msg("redirect to https")
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return return true
}
next(w, r)
},
} }

View file

@ -0,0 +1,34 @@
package middleware
import (
"net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
// internal use only.
type setUpstreamHeaders struct {
Scheme, Host, Port string
}
var suh = NewMiddleware[setUpstreamHeaders]()
func newSetUpstreamHeaders(rp *gphttp.ReverseProxy) *Middleware {
m, err := suh.New(OptionsRaw{
"scheme": rp.TargetURL.Scheme,
"host": rp.TargetURL.Hostname(),
"port": rp.TargetURL.Port(),
})
if err != nil {
panic(err)
}
return m
}
// before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
return true
}

View file

@ -141,7 +141,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr) rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) mid, setOptErr := middleware.New(args.middlewareOpt)
if setOptErr != nil { if setOptErr != nil {
return nil, setOptErr return nil, setOptErr
} }

View file

@ -1,16 +1,14 @@
package middleware package middleware
import ( import (
"fmt"
"net/http" "net/http"
"sync" "sync"
"time"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type Trace struct { type (
Trace struct {
Time string `json:"time,omitempty"` Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"` Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"` URL string `json:"url,omitempty"`
@ -19,9 +17,9 @@ type Trace struct {
RespHeaders map[string]string `json:"resp_headers,omitempty"` RespHeaders map[string]string `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"` RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"` Additional map[string]any `json:"additional,omitempty"`
} }
Traces []*Trace
type Traces []*Trace )
var ( var (
traces = make(Traces, 0) traces = make(Traces, 0)
@ -34,7 +32,7 @@ func GetAllTrace() []*Trace {
return traces return traces
} }
func (tr *Trace) WithRequest(req *Request) *Trace { func (tr *Trace) WithRequest(req *http.Request) *Trace {
if tr == nil { if tr == nil {
return nil return nil
} }
@ -78,39 +76,6 @@ func (tr *Trace) WithError(err error) *Trace {
return tr return tr
} }
func (m *Middleware) EnableTrace() {
m.trace = true
for _, child := range m.children {
child.parent = m
child.EnableTrace()
}
}
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
if !m.trace {
return nil
}
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: m.Fullname(),
Message: fmt.Sprintf(msg, args...),
})
}
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithRequest(req)
}
func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithResponse(resp)
}
func addTrace(t *Trace) *Trace { func addTrace(t *Trace) *Trace {
tracesMu.Lock() tracesMu.Lock()
defer tracesMu.Unlock() defer tracesMu.Unlock()

View file

@ -0,0 +1,50 @@
package middleware
import (
"fmt"
"net/http"
"time"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Tracer struct {
name string
parent *Tracer
}
func (t *Tracer) Fullname() string {
if t.parent != nil {
return t.parent.Fullname() + "." + t.name
}
return t.name
}
func (t *Tracer) addTrace(msg string) *Trace {
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: t.Fullname(),
Message: msg,
})
}
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
if t == nil {
return nil
}
return t.addTrace(fmt.Sprintf(msg, args...))
}
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
if t == nil {
return nil
}
return t.addTrace(msg).WithRequest(req)
}
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
if t == nil {
return nil
}
return t.addTrace(msg).WithResponse(resp)
}

View file

@ -11,8 +11,8 @@ import (
) )
type ( type (
reqVarGetter func(*Request) string reqVarGetter func(*http.Request) string
respVarGetter func(*Response) string respVarGetter func(*http.Response) string
) )
var ( var (
@ -49,50 +49,50 @@ const (
) )
var staticReqVarSubsMap = map[string]reqVarGetter{ var staticReqVarSubsMap = map[string]reqVarGetter{
VarRequestMethod: func(req *Request) string { return req.Method }, VarRequestMethod: func(req *http.Request) string { return req.Method },
VarRequestScheme: func(req *Request) string { VarRequestScheme: func(req *http.Request) string {
if req.TLS != nil { if req.TLS != nil {
return "https" return "https"
} }
return "http" return "http"
}, },
VarRequestHost: func(req *Request) string { VarRequestHost: func(req *http.Request) string {
reqHost, _, err := net.SplitHostPort(req.Host) reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil { if err != nil {
return req.Host return req.Host
} }
return reqHost return reqHost
}, },
VarRequestPort: func(req *Request) string { VarRequestPort: func(req *http.Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host) _, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort return reqPort
}, },
VarRequestAddr: func(req *Request) string { return req.Host }, VarRequestAddr: func(req *http.Request) string { return req.Host },
VarRequestPath: func(req *Request) string { return req.URL.Path }, VarRequestPath: func(req *http.Request) string { return req.URL.Path },
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery }, VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *Request) string { return req.URL.String() }, VarRequestURL: func(req *http.Request) string { return req.URL.String() },
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() }, VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") }, VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) }, VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *Request) string { VarRemoteHost: func(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr) clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil { if err == nil {
return clientIP return clientIP
} }
return "" return ""
}, },
VarRemotePort: func(req *Request) string { VarRemotePort: func(req *http.Request) string {
_, clientPort, err := net.SplitHostPort(req.RemoteAddr) _, clientPort, err := net.SplitHostPort(req.RemoteAddr)
if err == nil { if err == nil {
return clientPort return clientPort
} }
return "" return ""
}, },
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr }, VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *Request) string { VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort) upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
if upPort != "" { if upPort != "" {
@ -100,7 +100,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
} }
return upHost return upHost
}, },
VarUpstreamURL: func(req *Request) string { VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
if upScheme == "" { if upScheme == "" {
return "" return ""
@ -116,12 +116,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
} }
var staticRespVarSubsMap = map[string]respVarGetter{ var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") }, VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) }, VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) }, VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
} }
func varReplace(req *Request, resp *Response, s string) string { func varReplace(req *http.Request, resp *http.Response, s string) string {
if req != nil { if req != nil {
// Replace query parameters // Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string { s = reArg.ReplaceAllStringFunc(s, func(match string) string {

View file

@ -2,27 +2,44 @@ package middleware
import ( import (
"net" "net"
"net/http"
"strings" "strings"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
) )
var SetXForwarded = &Middleware{ type (
before: Rewrite(func(req *Request) { setXForwarded struct{}
req.Header.Del(gphttp.HeaderXForwardedFor) hideXForwarded struct{}
clientIP, _, err := net.SplitHostPort(req.RemoteAddr) )
var (
SetXForwarded = NewMiddleware[setXForwarded]()
HideXForwarded = NewMiddleware[hideXForwarded]()
)
// before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(gphttp.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil { if err == nil {
req.Header.Set(gphttp.HeaderXForwardedFor, clientIP) r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
} }
}), return true
} }
var HideXForwarded = &Middleware{ // before implements RequestModifier.
before: Rewrite(func(req *Request) { func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
for k := range req.Header { toDelete := make([]string, 0, len(r.Header))
for k := range r.Header {
if strings.HasPrefix(k, "X-Forwarded-") { if strings.HasPrefix(k, "X-Forwarded-") {
req.Header.Del(k) toDelete = append(toDelete, k)
} }
} }
}),
for _, k := range toDelete {
r.Header.Del(k)
}
return true
} }

View file

@ -1,8 +0,0 @@
package http
import "net/http"
type ProxyResponse struct {
*http.Response
OriginalRequest *http.Request
}

View file

@ -87,7 +87,7 @@ type ReverseProxy struct {
// If ModifyResponse returns an error, ErrorHandler is called // If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default // with its error value. If ErrorHandler is nil, its default
// implementation is used. // implementation is used.
ModifyResponse func(*ProxyResponse) error ModifyResponse func(*http.Response) error
HandlerFunc http.HandlerFunc HandlerFunc http.HandlerFunc
@ -251,11 +251,14 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
// modifyResponse conditionally runs the optional ModifyResponse hook // modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed. // and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool { func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool {
if p.ModifyResponse == nil { if p.ModifyResponse == nil {
return true return true
} }
if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil { res.Request = origReq
err := p.ModifyResponse(res)
res.Request = req
if err != nil {
res.Body.Close() res.Body.Close()
p.errorHandler(rw, req, err, true) p.errorHandler(rw, req, err, true)
return false return false
@ -264,9 +267,6 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
} }
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme)
// req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname())
// req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port())
p.HandlerFunc(rw, req) p.HandlerFunc(rw, req)
} }
@ -455,13 +455,13 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
res = &http.Response{ res = &http.Response{
Status: http.StatusText(http.StatusBadGateway), Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway, StatusCode: http.StatusBadGateway,
Proto: outreq.Proto, Proto: req.Proto,
ProtoMajor: outreq.ProtoMajor, ProtoMajor: req.ProtoMajor,
ProtoMinor: outreq.ProtoMinor, ProtoMinor: req.ProtoMinor,
Header: http.Header{}, Header: http.Header{},
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))), Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
Request: outreq, Request: req,
TLS: outreq.TLS, TLS: req.TLS,
} }
} }

View file

@ -9,26 +9,31 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestTaskCreation(t *testing.T) { const (
rootTask := GlobalTask("root-task") rootTaskName = "root-task"
subTask := rootTask.Subtask("subtask") subTaskName = "subtask"
)
ExpectEqual(t, "root-task", rootTask.Name()) func TestTaskCreation(t *testing.T) {
ExpectEqual(t, "subtask", subTask.Name()) rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask(subTaskName)
ExpectEqual(t, rootTaskName, rootTask.Name())
ExpectEqual(t, subTaskName, subTask.Name())
} }
func TestTaskCancellation(t *testing.T) { func TestTaskCancellation(t *testing.T) {
subTaskDone := make(chan struct{}) subTaskDone := make(chan struct{})
rootTask := GlobalTask("root-task") rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask("subtask") subTask := rootTask.Subtask(subTaskName)
go func() { go func() {
subTask.Wait() subTask.Wait()
close(subTaskDone) close(subTaskDone)
}() }()
go rootTask.Finish("done") go rootTask.Finish(nil)
select { select {
case <-subTaskDone: case <-subTaskDone:
@ -42,14 +47,14 @@ func TestTaskCancellation(t *testing.T) {
} }
func TestOnComplete(t *testing.T) { func TestOnComplete(t *testing.T) {
rootTask := GlobalTask("root-task") rootTask := GlobalTask(rootTaskName)
task := rootTask.Subtask("test") task := rootTask.Subtask(subTaskName)
var value atomic.Int32 var value atomic.Int32
task.OnFinished("set value", func() { task.OnFinished("set value", func() {
value.Store(1234) value.Store(1234)
}) })
task.Finish("done") task.Finish(nil)
ExpectEqual(t, value.Load(), 1234) ExpectEqual(t, value.Load(), 1234)
} }
@ -57,36 +62,36 @@ func TestGlobalContextWait(t *testing.T) {
testResetGlobalTask() testResetGlobalTask()
defer CancelGlobalContext() defer CancelGlobalContext()
rootTask := GlobalTask("root-task") rootTask := GlobalTask(rootTaskName)
finished1, finished2 := false, false finished1, finished2 := false, false
subTask1 := rootTask.Subtask("subtask1") subTask1 := rootTask.Subtask(subTaskName)
subTask2 := rootTask.Subtask("subtask2") subTask2 := rootTask.Subtask(subTaskName)
subTask1.OnFinished("set finished", func() { subTask1.OnFinished("", func() {
finished1 = true finished1 = true
}) })
subTask2.OnFinished("set finished", func() { subTask2.OnFinished("", func() {
finished2 = true finished2 = true
}) })
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
subTask1.Finish("done") subTask1.Finish(nil)
}() }()
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
subTask2.Finish("done") subTask2.Finish(nil)
}() }()
go func() { go func() {
subTask1.Wait() subTask1.Wait()
subTask2.Wait() subTask2.Wait()
rootTask.Finish("done") rootTask.Finish(nil)
}() }()
GlobalContextWait(1 * time.Second) _ = GlobalContextWait(1 * time.Second)
ExpectTrue(t, finished1) ExpectTrue(t, finished1)
ExpectTrue(t, finished2) ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err()) ExpectError(t, context.Canceled, rootTask.Context().Err())
@ -97,8 +102,8 @@ func TestGlobalContextWait(t *testing.T) {
func TestTimeoutOnGlobalContextWait(t *testing.T) { func TestTimeoutOnGlobalContextWait(t *testing.T) {
testResetGlobalTask() testResetGlobalTask()
rootTask := GlobalTask("root-task") rootTask := GlobalTask(rootTaskName)
rootTask.Subtask("subtask") rootTask.Subtask(subTaskName)
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond)) ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
} }
@ -107,7 +112,7 @@ func TestGlobalContextCancellation(t *testing.T) {
testResetGlobalTask() testResetGlobalTask()
taskDone := make(chan struct{}) taskDone := make(chan struct{})
rootTask := GlobalTask("root-task") rootTask := GlobalTask(rootTaskName)
go func() { go func() {
rootTask.Wait() rootTask.Wait()

View file

@ -19,7 +19,8 @@ func TestSerializeDeserialize(t *testing.T) {
MIS map[int]string MIS map[int]string
} }
var testStruct = S{ var (
testStruct = S{
I: 1, I: 1,
S: "hello", S: "hello",
IS: []int{1, 2, 3}, IS: []int{1, 2, 3},
@ -27,8 +28,7 @@ func TestSerializeDeserialize(t *testing.T) {
MSI: map[string]int{"a": 1, "b": 2, "c": 3}, MSI: map[string]int{"a": 1, "b": 2, "c": 3},
MIS: map[int]string{1: "a", 2: "b", 3: "c"}, MIS: map[int]string{1: "a", 2: "b", 3: "c"},
} }
testStructSerialized = map[string]any{
var testStructSerialized = map[string]any{
"I": 1, "I": 1,
"S": "hello", "S": "hello",
"IS": []int{1, 2, 3}, "IS": []int{1, 2, 3},
@ -36,6 +36,7 @@ func TestSerializeDeserialize(t *testing.T) {
"MSI": map[string]int{"a": 1, "b": 2, "c": 3}, "MSI": map[string]int{"a": 1, "b": 2, "c": 3},
"MIS": map[int]string{1: "a", 2: "b", 3: "c"}, "MIS": map[int]string{1: "a", 2: "b", 3: "c"},
} }
)
t.Run("serialize", func(t *testing.T) { t.Run("serialize", func(t *testing.T) {
s, err := Serialize(testStruct) s, err := Serialize(testStruct)