From 59f4eaf3eace9e188140717f7e538ac30130ac1e Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 16 Dec 2024 10:19:14 +0800 Subject: [PATCH] cleanup and simplify middleware implementations, refactor some other code --- .trunk/trunk.yaml | 10 +- internal/api/handler.go | 2 +- internal/entrypoint/entrypoint.go | 2 +- internal/net/http/header_utils.go | 3 + internal/net/http/loadbalancer/ip_hash.go | 2 +- .../net/http/middleware/cidr_whitelist.go | 64 +++--- .../http/middleware/cidr_whitelist_test.go | 8 +- .../net/http/middleware/cloudflare_real_ip.go | 34 +-- .../net/http/middleware/custom_error_page.go | 61 +++-- internal/net/http/middleware/errors.go | 5 - internal/net/http/middleware/forward_auth.go | 81 ++++--- internal/net/http/middleware/middleware.go | 213 ++++++++++-------- .../net/http/middleware/middleware_builder.go | 57 +---- .../net/http/middleware/middleware_chain.go | 58 +++++ .../net/http/middleware/modify_request.go | 46 ++-- .../http/middleware/modify_request_test.go | 2 +- .../net/http/middleware/modify_response.go | 39 ++-- .../http/middleware/modify_response_test.go | 2 +- internal/net/http/middleware/oauth2.go | 117 ---------- internal/net/http/middleware/rate_limit.go | 66 +++--- .../net/http/middleware/rate_limit_test.go | 2 +- internal/net/http/middleware/real_ip.go | 81 ++++--- internal/net/http/middleware/real_ip_test.go | 11 +- internal/net/http/middleware/redirect_http.go | 33 +-- .../http/middleware/set_upstream_headers.go | 34 +++ internal/net/http/middleware/test_utils.go | 2 +- internal/net/http/middleware/trace.go | 63 ++---- internal/net/http/middleware/tracer.go | 50 ++++ internal/net/http/middleware/vars.go | 50 ++-- internal/net/http/middleware/x_forwarded.go | 47 ++-- internal/net/http/proxy_response.go | 8 - internal/net/http/reverse_proxy_mod.go | 22 +- internal/task/task_test.go | 51 +++-- internal/utils/serialization_test.go | 35 +-- 34 files changed, 641 insertions(+), 720 deletions(-) delete mode 100644 internal/net/http/middleware/errors.go create mode 100644 internal/net/http/middleware/middleware_chain.go delete mode 100644 internal/net/http/middleware/oauth2.go create mode 100644 internal/net/http/middleware/set_upstream_headers.go create mode 100644 internal/net/http/middleware/tracer.go delete mode 100644 internal/net/http/proxy_response.go diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 7d5ac8d..8f50f86 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -7,12 +7,12 @@ cli: plugins: sources: - id: trunk - ref: v1.6.5 + ref: v1.6.6 uri: https://github.com/trunk-io/plugins # Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes) runtimes: enabled: - - node@18.12.1 + - node@18.20.5 - python@3.10.8 - go@1.23.2 # This is the section where you manage your linters. (https://docs.trunk.io/check/configuration) @@ -23,16 +23,16 @@ lint: enabled: - hadolint@2.12.1-beta - actionlint@1.7.4 - - checkov@3.2.324 + - checkov@3.2.334 - git-diff-check - gofmt@1.20.4 - golangci-lint@1.62.2 - osv-scanner@1.9.1 - oxipng@9.1.3 - - prettier@3.4.1 + - prettier@3.4.2 - shellcheck@0.10.0 - shfmt@3.6.0 - - trufflehog@3.84.1 + - trufflehog@3.86.1 actions: disabled: - trunk-announce diff --git a/internal/api/handler.go b/internal/api/handler.go index 45dd5f9..26e5762 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -60,7 +60,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc { } func rateLimited(f http.HandlerFunc) http.HandlerFunc { - m, err := middleware.RateLimiter.WithOptionsClone(middleware.OptionsRaw{ + m, err := middleware.RateLimiter.New(middleware.OptionsRaw{ "average": 10, "burst": 10, }) diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 5f66479..5e66c97 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -75,7 +75,7 @@ func Handler(w http.ResponseWriter, r *http.Request) { // On nginx, when route for domain does not exist, it returns StatusBadGateway. // Then scraper / scanners will know the subdomain is invalid. // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. - if !middleware.ServeStaticErrorPageFile(w, r) { + if served := middleware.ServeStaticErrorPageFile(w, r); !served { logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request") errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) if ok { diff --git a/internal/net/http/header_utils.go b/internal/net/http/header_utils.go index 4becb8e..d804d15 100644 --- a/internal/net/http/header_utils.go +++ b/internal/net/http/header_utils.go @@ -16,6 +16,9 @@ const ( HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme" HeaderUpstreamHost = "X-GoDoxy-Upstream-Host" HeaderUpstreamPort = "X-GoDoxy-Upstream-Port" + + HeaderContentType = "Content-Type" + HeaderContentLength = "Content-Length" ) func RemoveHop(h http.Header) { diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/http/loadbalancer/ip_hash.go index 3df48fd..cbb6ab0 100644 --- a/internal/net/http/loadbalancer/ip_hash.go +++ b/internal/net/http/loadbalancer/ip_hash.go @@ -24,7 +24,7 @@ func (lb *LoadBalancer) newIPHash() impl { return impl } var err E.Error - impl.realIP, err = middleware.NewRealIP(lb.Options) + impl.realIP, err = middleware.RealIP.New(lb.Options) if err != nil { E.LogError("invalid real_ip options, ignoring", err, &impl.l) } diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index 8945030..ae7e70d 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -4,48 +4,45 @@ import ( "net" "net/http" - E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/types" F "github.com/yusing/go-proxy/internal/utils/functional" ) -type cidrWhitelist struct { - cidrWhitelistOpts - m *Middleware - cachedAddr F.Map[string, bool] // cache for trusted IPs -} - -type cidrWhitelistOpts struct { - Allow []*types.CIDR `validate:"min=1"` - StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"` - Message string -} +type ( + cidrWhitelist struct { + CIDRWhitelistOpts + *Tracer + cachedAddr F.Map[string, bool] // cache for trusted IPs + } + CIDRWhitelistOpts struct { + Allow []*types.CIDR `validate:"min=1"` + StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"` + Message string + } +) var ( - CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist} - cidrWhitelistDefaults = cidrWhitelistOpts{ + CIDRWhiteList = NewMiddleware[cidrWhitelist]() + cidrWhitelistDefaults = CIDRWhitelistOpts{ Allow: []*types.CIDR{}, StatusCode: http.StatusForbidden, Message: "IP not allowed", } ) -func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) { - wl := new(cidrWhitelist) - wl.m = &Middleware{ - impl: wl, - before: wl.checkIP, - } - wl.cidrWhitelistOpts = cidrWhitelistDefaults +// setup implements MiddlewareWithSetup. +func (wl *cidrWhitelist) setup() { + wl.CIDRWhitelistOpts = cidrWhitelistDefaults wl.cachedAddr = F.NewMapOf[string, bool]() - err := Deserialize(opts, &wl.cidrWhitelistOpts) - if err != nil { - return nil, err - } - return wl.m, nil } -func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) { +// before implements RequestModifier. +func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool { + return wl.checkIP(w, r) +} + +// checkIP checks if the IP address is allowed. +func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool { var allow, ok bool if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok { ipStr, _, err := net.SplitHostPort(r.RemoteAddr) @@ -53,24 +50,23 @@ func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Req ipStr = r.RemoteAddr } ip := net.ParseIP(ipStr) - for _, cidr := range wl.cidrWhitelistOpts.Allow { + for _, cidr := range wl.CIDRWhitelistOpts.Allow { if cidr.Contains(ip) { wl.cachedAddr.Store(r.RemoteAddr, true) allow = true - wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr) + wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr) break } } if !allow { wl.cachedAddr.Store(r.RemoteAddr, false) - wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow) + wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow) } } if !allow { - w.WriteHeader(wl.StatusCode) - w.Write([]byte(wl.Message)) - return + http.Error(w, wl.Message, wl.StatusCode) + return false } - next(w, r) + return true } diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/http/middleware/cidr_whitelist_test.go index c1d5d6c..36f051e 100644 --- a/internal/net/http/middleware/cidr_whitelist_test.go +++ b/internal/net/http/middleware/cidr_whitelist_test.go @@ -17,27 +17,27 @@ var deny, accept *Middleware func TestCIDRWhitelistValidation(t *testing.T) { t.Run("valid", func(t *testing.T) { - _, err := NewCIDRWhitelist(OptionsRaw{ + _, err := CIDRWhiteList.New(OptionsRaw{ "allow": []string{"1.2.3.4/32"}, "message": "test-message", }) ExpectNoError(t, err) }) t.Run("missing allow", func(t *testing.T) { - _, err := NewCIDRWhitelist(OptionsRaw{ + _, err := CIDRWhiteList.New(OptionsRaw{ "message": "test-message", }) ExpectError(t, utils.ErrValidationError, err) }) t.Run("invalid cidr", func(t *testing.T) { - _, err := NewCIDRWhitelist(OptionsRaw{ + _, err := CIDRWhiteList.New(OptionsRaw{ "allow": []string{"1.2.3.4/123"}, "message": "test-message", }) ExpectErrorT[*net.ParseError](t, err) }) t.Run("invalid status code", func(t *testing.T) { - _, err := NewCIDRWhitelist(OptionsRaw{ + _, err := CIDRWhiteList.New(OptionsRaw{ "allow": []string{"1.2.3.4/32"}, "status_code": 600, "message": "test-message", diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index 05af4f4..97bd311 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -11,11 +11,14 @@ import ( "time" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) +type cloudflareRealIP struct { + realIP realIP +} + const ( cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4" cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6" @@ -29,26 +32,23 @@ var ( cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger() ) -var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP} +var CloudflareRealIP = NewMiddleware[cloudflareRealIP]() -func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) { - cri := new(realIP) - cri.m = &Middleware{ - impl: cri, - before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { - cidrs := tryFetchCFCIDR() - if cidrs != nil { - cri.From = cidrs - } - cri.setRealIP(r) - next(w, r) - }, - } - cri.realIPOpts = realIPOpts{ +// setup implements MiddlewareWithSetup. +func (cri *cloudflareRealIP) setup() { + cri.realIP.RealIPOpts = RealIPOpts{ Header: "CF-Connecting-IP", Recursive: true, } - return cri.m, nil +} + +// before implements RequestModifier. +func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool { + cidrs := tryFetchCFCIDR() + if cidrs != nil { + cri.realIP.From = cidrs + } + return cri.realIP.before(w, r) } func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { diff --git a/internal/net/http/middleware/custom_error_page.go b/internal/net/http/middleware/custom_error_page.go index 6b764e2..9748980 100644 --- a/internal/net/http/middleware/custom_error_page.go +++ b/internal/net/http/middleware/custom_error_page.go @@ -12,45 +12,38 @@ import ( "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" ) -var CustomErrorPage *Middleware +type customErrorPage struct{} -func init() { - CustomErrorPage = customErrorPage() +var CustomErrorPage = NewMiddleware[customErrorPage]() + +// before implements RequestModifier. +func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + return !ServeStaticErrorPageFile(w, r) } -func customErrorPage() *Middleware { - m := &Middleware{ - before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { - if !ServeStaticErrorPageFile(w, r) { - next(w, r) - } - }, - } - m.modifyResponse = func(resp *Response) error { - // only handles non-success status code and html/plain content type - contentType := gphttp.GetContentType(resp.Header) - if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { - errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) - if ok { - CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode) - /* trunk-ignore(golangci-lint/errcheck) */ - io.Copy(io.Discard, resp.Body) // drain the original body - resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(errorPage)) - resp.ContentLength = int64(len(errorPage)) - resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage))) - resp.Header.Set("Content-Type", "text/html; charset=utf-8") - } else { - CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode) - } - return nil +// modifyResponse implements ResponseModifier. +func (customErrorPage) modifyResponse(resp *http.Response) error { + // only handles non-success status code and html/plain content type + contentType := gphttp.GetContentType(resp.Header) + if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { + errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) + if ok { + logger.Debug().Msgf("error page for status %d loaded", resp.StatusCode) + _, _ = io.Copy(io.Discard, resp.Body) // drain the original body + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(errorPage)) + resp.ContentLength = int64(len(errorPage)) + resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage))) + resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8") + } else { + logger.Error().Msgf("unable to load error page for status %d", resp.StatusCode) } return nil } - return m + return nil } -func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { +func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) { path := r.URL.Path if path != "" && path[0] != '/' { path = "/" + path @@ -65,11 +58,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { ext := filepath.Ext(filename) switch ext { case ".html": - w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8") case ".js": - w.Header().Set("Content-Type", "application/javascript; charset=utf-8") + w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8") case ".css": - w.Header().Set("Content-Type", "text/css; charset=utf-8") + w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8") default: logger.Error().Msgf("unexpected file type %q for %s", ext, filename) } diff --git a/internal/net/http/middleware/errors.go b/internal/net/http/middleware/errors.go deleted file mode 100644 index faf4038..0000000 --- a/internal/net/http/middleware/errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package middleware - -import E "github.com/yusing/go-proxy/internal/error" - -var ErrZeroValue = E.New("cannot be zero") diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index 6a038d5..c193f94 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -12,16 +12,17 @@ import ( "strings" "time" - E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" + F "github.com/yusing/go-proxy/internal/utils/functional" ) type ( forwardAuth struct { - forwardAuthOpts - m *Middleware + ForwardAuthOpts + *Tracer + reqCookiesMap F.Map[*http.Request, []*http.Cookie] } - forwardAuthOpts struct { + ForwardAuthOpts struct { Address string `validate:"url,required"` TrustForwardHeader bool AuthResponseHeaders []string @@ -29,36 +30,30 @@ type ( } ) -var ForwardAuth = &Middleware{withOptions: NewForwardAuth} +var ForwardAuth = NewMiddleware[forwardAuth]() var faHTTPClient = &http.Client{ Timeout: 30 * time.Second, - CheckRedirect: func(r *Request, via []*Request) error { + CheckRedirect: func(r *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } -func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) { - fa := new(forwardAuth) - if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil { - return nil, err - } - fa.m = &Middleware{ - impl: fa, - before: fa.forward, - } - return fa.m, nil +// setup implements MiddlewareWithSetup. +func (fa *forwardAuth) setup() { + fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]() } -func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { +// before implements RequestModifier. +func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) { gphttp.RemoveHop(req.Header) // Construct original URL for the redirect - // scheme := "http" - // if req.TLS != nil { - // scheme = "https" - // } - // originalURL := scheme + "://" + req.Host + req.RequestURI + scheme := "http" + if req.TLS != nil { + scheme = "https" + } + originalURL := scheme + "://" + req.Host + req.RequestURI url := fa.Address faReq, err := http.NewRequestWithContext( @@ -68,7 +63,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req nil, ) if err != nil { - fa.m.AddTracef("new request err to %s", url).WithError(err) + fa.AddTracef("new request err to %s", url).WithError(err) w.WriteHeader(http.StatusInternalServerError) return } @@ -79,12 +74,12 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders) fa.setAuthHeaders(req, faReq) // Set headers needed by Authentik - // faReq.Header.Set("X-Original-URL", originalURL) - fa.m.AddTraceRequest("forward auth request", faReq) + faReq.Header.Set("X-Original-Url", originalURL) + fa.AddTraceRequest("forward auth request", faReq) faResp, err := faHTTPClient.Do(faReq) if err != nil { - fa.m.AddTracef("failed to call %s", url).WithError(err) + fa.AddTracef("failed to call %s", url).WithError(err) w.WriteHeader(http.StatusInternalServerError) return } @@ -92,30 +87,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req body, err := io.ReadAll(faResp.Body) if err != nil { - fa.m.AddTracef("failed to read response body from %s", url).WithError(err) + fa.AddTracef("failed to read response body from %s", url).WithError(err) w.WriteHeader(http.StatusInternalServerError) return } if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices { - fa.m.AddTraceResponse("forward auth response", faResp) + fa.AddTraceResponse("forward auth response", faResp) gphttp.CopyHeader(w.Header(), faResp.Header) gphttp.RemoveHop(w.Header()) redirectURL, err := faResp.Location() if err != nil { - fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp) + fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp) w.WriteHeader(http.StatusInternalServerError) return } else if redirectURL.String() != "" { w.Header().Set("Location", redirectURL.String()) - fa.m.AddTracef("%s", "redirect to "+redirectURL.String()) + fa.AddTracef("%s", "redirect to "+redirectURL.String()) } w.WriteHeader(faResp.StatusCode) if _, err = w.Write(body); err != nil { - fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp) + fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp) } return } @@ -132,18 +127,22 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req authCookies := faResp.Cookies() - if len(authCookies) == 0 { - next.ServeHTTP(w, req) - return + if len(authCookies) > 0 { + fa.reqCookiesMap.Store(req, authCookies) } - - next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error { - fa.setAuthCookies(resp, authCookies) - return nil - }), req) + return true } -func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) { +// modifyResponse implements ResponseModifier. +func (fa *forwardAuth) modifyResponse(resp *http.Response) error { + if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok { + fa.setAuthCookies(resp, cookies) + fa.reqCookiesMap.Delete(resp.Request) + } + return nil +} + +func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) { if len(fa.AddAuthCookiesToResponse) == 0 { return } @@ -166,7 +165,7 @@ func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie } } -func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) { +func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) { if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { if fa.TrustForwardHeader { if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok { diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 4a01dad..b4f702f 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -3,70 +3,110 @@ package middleware import ( "encoding/json" "net/http" + "reflect" + "strings" - "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils" ) type ( Error = E.Error - ReverseProxy = gphttp.ReverseProxy - ProxyRequest = gphttp.ProxyRequest - Request = http.Request - Response = gphttp.ProxyResponse - ResponseWriter = http.ResponseWriter - Header = http.Header - Cookie = http.Cookie + ReverseProxy = gphttp.ReverseProxy + ProxyRequest = gphttp.ProxyRequest - BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request) - RewriteFunc func(req *Request) - ModifyResponseFunc func(*Response) error - CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error) - - OptionsRaw = map[string]any + ImplNewFunc = func() any + OptionsRaw = map[string]any Middleware struct { - _ U.NoCopy - - zerolog.Logger - - name string - - before BeforeFunc // runs before ReverseProxy.ServeHTTP - modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse - - withOptions CloneWithOptFunc - impl any - - parent *Middleware - children []*Middleware - trace bool + name string + construct ImplNewFunc + impl any } + + RequestModifier interface { + before(w http.ResponseWriter, r *http.Request) (proceed bool) + } + ResponseModifier interface{ modifyResponse(r *http.Response) error } + MiddlewareWithSetup interface{ setup() } + MiddlewareFinalizer interface{ finalize() } + MiddlewareWithTracer *struct{ *Tracer } ) -var Deserialize = U.Deserialize - -func Rewrite(r RewriteFunc) BeforeFunc { - return func(next http.HandlerFunc, w ResponseWriter, req *Request) { - r(req) - next(w, req) +func NewMiddleware[ImplType any]() *Middleware { + // type check + switch any(new(ImplType)).(type) { + case RequestModifier: + case ResponseModifier: + default: + panic("must implement RequestModifier or ResponseModifier") } + return &Middleware{ + name: strings.ToLower(reflect.TypeFor[ImplType]().Name()), + construct: func() any { return new(ImplType) }, + } +} + +func (m *Middleware) enableTrace() { + if tracer, ok := m.impl.(MiddlewareWithTracer); ok { + tracer.Tracer = &Tracer{name: m.name} + } +} + +func (m *Middleware) getTracer() *Tracer { + if tracer, ok := m.impl.(MiddlewareWithTracer); ok { + return tracer.Tracer + } + return nil +} + +func (m *Middleware) setParent(parent *Middleware) { + if tracer := m.getTracer(); tracer != nil { + tracer.parent = parent.getTracer() + } +} + +func (m *Middleware) setup() { + if setup, ok := m.impl.(MiddlewareWithSetup); ok { + setup.setup() + } +} + +func (m *Middleware) apply(optsRaw OptionsRaw) E.Error { + if len(optsRaw) == 0 { + return nil + } + return utils.Deserialize(optsRaw, m.impl) +} + +func (m *Middleware) finalize() { + if finalizer, ok := m.impl.(MiddlewareFinalizer); ok { + finalizer.finalize() + } +} + +func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { + if m.construct == nil { + if optsRaw != nil { + panic("bug: middleware already constructed") + } + return m, nil + } + mid := &Middleware{name: m.name, impl: m.construct()} + mid.setup() + if err := mid.apply(optsRaw); err != nil { + return nil, err + } + mid.finalize() + return mid, nil } func (m *Middleware) Name() string { return m.name } -func (m *Middleware) Fullname() string { - if m.parent != nil { - return m.parent.Fullname() + "." + m.name - } - return m.name -} - func (m *Middleware) String() string { return m.name } @@ -78,57 +118,38 @@ func (m *Middleware) MarshalJSON() ([]byte, error) { }, "", " ") } -func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) { - if m.withOptions != nil { - m, err := m.withOptions(optsRaw) - if err != nil { - return nil, err +func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) { + if exec, ok := m.impl.(RequestModifier); ok { + if proceed := exec.before(w, r); !proceed { + return } - m.Logger = logger.With().Str("name", m.name).Logger() - return m, nil } - - // WithOptionsClone is called only once - // set withOptions and labelParser will not be used after that - return &Middleware{ - Logger: logger.With().Str("name", m.name).Logger(), - name: m.name, - before: m.before, - modifyResponse: m.modifyResponse, - impl: m.impl, - parent: m.parent, - children: m.children, - }, nil + next(w, r) } -func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) { - if m.before != nil { - m.before(next, w, r) - } -} - -func (m *Middleware) ModifyResponse(resp *Response) error { - if m.modifyResponse != nil { - return m.modifyResponse(resp) +func (m *Middleware) ModifyResponse(resp *http.Response) error { + if exec, ok := m.impl.(ResponseModifier); ok { + return exec.modifyResponse(resp) } return nil } -func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) { - if m.modifyResponse != nil { +func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) { + if exec, ok := m.impl.(ResponseModifier); ok { w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { - return m.modifyResponse(&Response{Response: resp, OriginalRequest: r}) + return exec.modifyResponse(resp) }) } - if m.before != nil { - m.before(next, w, r) - } else { - next(w, r) + if exec, ok := m.impl.(RequestModifier); ok { + if proceed := exec.before(w, r); !proceed { + return + } } + next(w, r) } // TODO: check conflict or duplicates. -func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) { +func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) { middlewares := make([]*Middleware, 0, len(middlewaresMap)) errs := E.NewBuilder("middlewares compile error") @@ -141,7 +162,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E continue } - m, err = m.WithOptionsClone(opts) + m, err = m.New(opts) if err != nil { invalidOpts.Add(err.Subject(name)) continue @@ -157,7 +178,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) { var middlewares []*Middleware - middlewares, err = createMiddlewares(middlewaresMap) + middlewares, err = compileMiddlewares(middlewaresMap) if err != nil { return } @@ -166,34 +187,30 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) ( } func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { - mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{ - name: "set_upstream_headers", - before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { - r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme) - r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname()) - r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port()) - next(w, r) - }, - }}, middlewares...)) + middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...) - if mid.before != nil { - ori := rp.HandlerFunc - rp.HandlerFunc = func(w http.ResponseWriter, r *Request) { - mid.before(ori, w, r) + mid := NewMiddlewareChain(rp.TargetName, middlewares) + + if before, ok := mid.impl.(RequestModifier); ok { + next := rp.HandlerFunc + rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + if proceed := before.before(w, r); proceed { + next(w, r) + } } } - if mid.modifyResponse != nil { + if mr, ok := mid.impl.(ResponseModifier); ok { if rp.ModifyResponse != nil { ori := rp.ModifyResponse - rp.ModifyResponse = func(res *Response) error { - if err := mid.modifyResponse(res); err != nil { + rp.ModifyResponse = func(res *http.Response) error { + if err := mr.modifyResponse(res); err != nil { return err } return ori(res) } } else { - rp.ModifyResponse = mid.modifyResponse + rp.ModifyResponse = mr.modifyResponse } } } diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index d85dc02..8f5aabf 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -2,11 +2,9 @@ package middleware import ( "fmt" - "net/http" "os" "path" - "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "gopkg.in/yaml.v3" ) @@ -56,7 +54,7 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar continue } delete(def, "use") - m, err := base.WithOptionsClone(def) + m, err := base.New(def) if err != nil { chainErr.Add(err.Subjectf("%s[%d]", name, i)) continue @@ -67,56 +65,5 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar if chainErr.HasError() { return nil, chainErr.Error() } - return BuildMiddlewareFromChain(name, chain), nil -} - -// TODO: check conflict or duplicates. -func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware { - m := &Middleware{name: name, children: chain} - - var befores []*Middleware - var modResps []*Middleware - - for _, comp := range chain { - if comp.before != nil { - befores = append(befores, comp) - } - if comp.modifyResponse != nil { - modResps = append(modResps, comp) - } - comp.parent = m - } - - if len(befores) > 0 { - m.before = buildBefores(befores) - } - if len(modResps) > 0 { - m.modifyResponse = func(res *Response) error { - errs := E.NewBuilder("modify response errors") - for _, mr := range modResps { - if err := mr.modifyResponse(res); err != nil { - errs.Add(E.From(err).Subject(mr.name)) - } - } - return errs.Error() - } - } - - if common.IsDebug { - m.EnableTrace() - m.AddTracef("middleware created") - } - return m -} - -func buildBefores(befores []*Middleware) BeforeFunc { - if len(befores) == 1 { - return befores[0].before - } - nextBefores := buildBefores(befores[1:]) - return func(next http.HandlerFunc, w ResponseWriter, r *Request) { - befores[0].before(func(w ResponseWriter, r *Request) { - nextBefores(next, w, r) - }, w, r) - } + return NewMiddlewareChain(name, chain), nil } diff --git a/internal/net/http/middleware/middleware_chain.go b/internal/net/http/middleware/middleware_chain.go new file mode 100644 index 0000000..da14287 --- /dev/null +++ b/internal/net/http/middleware/middleware_chain.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "net/http" + + "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" +) + +type middlewareChain struct { + befores []RequestModifier + modResps []ResponseModifier +} + +// TODO: check conflict or duplicates. +func NewMiddlewareChain(name string, chain []*Middleware) *Middleware { + chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}} + m := &Middleware{name: name, impl: chainMid} + + for _, comp := range chain { + if before, ok := comp.impl.(RequestModifier); ok { + chainMid.befores = append(chainMid.befores, before) + } + if mr, ok := comp.impl.(ResponseModifier); ok { + chainMid.modResps = append(chainMid.modResps, mr) + } + comp.setParent(m) + } + + if common.IsDebug { + m.enableTrace() + } + return m +} + +// before implements RequestModifier. +func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) { + for _, b := range m.befores { + if proceedNext = b.before(w, r); !proceedNext { + return false + } + } + return true +} + +// modifyResponse implements ResponseModifier. +func (m *middlewareChain) modifyResponse(resp *http.Response) error { + if len(m.modResps) == 0 { + return nil + } + errs := E.NewBuilder("modify response errors") + for i, mr := range m.modResps { + if err := mr.modifyResponse(resp); err != nil { + errs.Add(E.From(err).Subjectf("%d", i)) + } + } + return errs.Error() +} diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 334c18c..1e2d42d 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -3,45 +3,39 @@ package middleware import ( "net/http" "strings" - - E "github.com/yusing/go-proxy/internal/error" ) type ( modifyRequest struct { - modifyRequestOpts - m *Middleware - needVarSubstitution bool + ModifyRequestOpts + *Tracer } // order: set_headers -> add_headers -> hide_headers - modifyRequestOpts struct { + ModifyRequestOpts struct { SetHeaders map[string]string AddHeaders map[string]string HideHeaders []string + + needVarSubstitution bool } ) -var ModifyRequest = &Middleware{withOptions: NewModifyRequest} +var ModifyRequest = NewMiddleware[modifyRequest]() -func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) { - mr := new(modifyRequest) - mr.m = &Middleware{ - impl: mr, - before: Rewrite(func(req *Request) { - mr.m.AddTraceRequest("before modify request", req) - mr.modifyHeaders(req, nil, req.Header) - mr.m.AddTraceRequest("after modify request", req) - }), - } - err := Deserialize(optsRaw, &mr.modifyRequestOpts) - if err != nil { - return nil, err - } +// finalize implements MiddlewareFinalizer. +func (mr *ModifyRequestOpts) finalize() { mr.checkVarSubstitution() - return mr.m, nil } -func (mr *modifyRequest) checkVarSubstitution() { +// before implements RequestModifier. +func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + mr.AddTraceRequest("before modify request", r) + mr.modifyHeaders(r, nil, r.Header) + mr.AddTraceRequest("after modify request", r) + return true +} + +func (mr *ModifyRequestOpts) checkVarSubstitution() { for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} { for _, v := range m { if strings.ContainsRune(v, '$') { @@ -52,10 +46,10 @@ func (mr *modifyRequest) checkVarSubstitution() { } } -func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) { +func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) { if !mr.needVarSubstitution { for k, v := range mr.SetHeaders { - if req != nil && strings.ToLower(k) == "host" { + if req != nil && strings.EqualFold(k, "host") { defer func() { req.Host = v }() @@ -67,7 +61,7 @@ func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers htt } } else { for k, v := range mr.SetHeaders { - if req != nil && strings.ToLower(k) == "host" { + if req != nil && strings.EqualFold(k, "host") { defer func() { req.Host = varReplace(req, resp, v) }() diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/http/middleware/modify_request_test.go index 422c53d..0496a20 100644 --- a/internal/net/http/middleware/modify_request_test.go +++ b/internal/net/http/middleware/modify_request_test.go @@ -43,7 +43,7 @@ func TestModifyRequest(t *testing.T) { } t.Run("set_options", func(t *testing.T) { - mr, err := ModifyRequest.WithOptionsClone(opts) + mr, err := ModifyRequest.New(opts) ExpectNoError(t, err) ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index b5de559..67f7395 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -2,32 +2,19 @@ package middleware import ( "net/http" - - E "github.com/yusing/go-proxy/internal/error" ) -type modifyResponse = modifyRequest - -var ModifyResponse = &Middleware{withOptions: NewModifyResponse} - -func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) { - mr := new(modifyResponse) - mr.m = &Middleware{ - impl: mr, - before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { - next(w, r) - }, - modifyResponse: func(resp *Response) error { - mr.m.AddTraceResponse("before modify response", resp.Response) - mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header) - mr.m.AddTraceResponse("after modify response", resp.Response) - return nil - }, - } - err := Deserialize(optsRaw, &mr.modifyRequestOpts) - if err != nil { - return nil, err - } - mr.checkVarSubstitution() - return mr.m, nil +type modifyResponse struct { + ModifyRequestOpts + *Tracer +} + +var ModifyResponse = NewMiddleware[modifyResponse]() + +// modifyResponse implements ResponseModifier. +func (mr *modifyResponse) modifyResponse(resp *http.Response) error { + mr.AddTraceResponse("before modify response", resp) + mr.modifyHeaders(resp.Request, resp, resp.Header) + mr.AddTraceResponse("after modify response", resp) + return nil } diff --git a/internal/net/http/middleware/modify_response_test.go b/internal/net/http/middleware/modify_response_test.go index 0e7ef15..c45f343 100644 --- a/internal/net/http/middleware/modify_response_test.go +++ b/internal/net/http/middleware/modify_response_test.go @@ -46,7 +46,7 @@ func TestModifyResponse(t *testing.T) { } t.Run("set_options", func(t *testing.T) { - mr, err := ModifyResponse.WithOptionsClone(opts) + mr, err := ModifyResponse.New(opts) ExpectNoError(t, err) ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) diff --git a/internal/net/http/middleware/oauth2.go b/internal/net/http/middleware/oauth2.go deleted file mode 100644 index 10ea6d6..0000000 --- a/internal/net/http/middleware/oauth2.go +++ /dev/null @@ -1,117 +0,0 @@ -package middleware - -// import ( -// "encoding/json" -// "fmt" -// "net/http" -// "net/url" - -// E "github.com/yusing/go-proxy/internal/error" -// ) - -// type oAuth2 struct { -// oAuth2Opts -// m *Middleware -// } - -// type oAuth2Opts struct { -// ClientID string `validate:"required"` -// ClientSecret string `validate:"required"` -// AuthURL string `validate:"required"` // Authorization Endpoint -// TokenURL string `validate:"required"` // Token Endpoint -// } - -// var OAuth2 = &oAuth2{ -// m: &Middleware{withOptions: NewAuthentikOAuth2}, -// } - -// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) { -// oauth := new(oAuth2) -// oauth.m = &Middleware{ -// impl: oauth, -// before: oauth.handleOAuth2, -// } -// err := Deserialize(opts, &oauth.oAuth2Opts) -// if err != nil { -// return nil, err -// } -// return oauth.m, nil -// } - -// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) { -// // Check if the user is authenticated (you may use session, cookie, etc.) -// if !userIsAuthenticated(r) { -// // TODO: Redirect to OAuth2 auth URL -// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code", -// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound) -// return -// } - -// // If you have a token in the query string, process it -// if code := r.URL.Query().Get("code"); code != "" { -// // Exchange the authorization code for a token here -// // Use the TokenURL and authenticate the user -// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI) -// if err != nil { -// // handle error -// http.Error(rw, "failed to get token", http.StatusUnauthorized) -// return -// } - -// // Save token and user info based on your requirements -// saveToken(rw, token) - -// // Redirect to the originally requested URL -// http.Redirect(rw, r, "/", http.StatusFound) -// return -// } - -// // If user is authenticated, go to the next handler -// next(rw, r) -// } - -// func userIsAuthenticated(r *http.Request) bool { -// // Example: Check for a session or cookie -// session, err := r.Cookie("session_token") -// if err != nil || session.Value == "" { -// return false -// } -// // Validate the session_token if necessary -// return true -// } - -// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) { -// // Prepare the request body -// data := url.Values{ -// "client_id": {opts.ClientID}, -// "client_secret": {opts.ClientSecret}, -// "code": {code}, -// "grant_type": {"authorization_code"}, -// "redirect_uri": {requestURI}, -// } -// resp, err := http.PostForm(opts.TokenURL, data) -// if err != nil { -// return "", fmt.Errorf("failed to request token: %w", err) -// } -// defer resp.Body.Close() -// if resp.StatusCode != http.StatusOK { -// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status) -// } -// // Decode the response -// var tokenResp struct { -// AccessToken string `json:"access_token"` -// } -// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { -// return "", fmt.Errorf("failed to decode token response: %w", err) -// } -// return tokenResp.AccessToken, nil -// } - -// func saveToken(rw ResponseWriter, token string) { -// // Example: Save token in cookie -// http.SetCookie(rw, &http.Cookie{ -// Name: "auth_token", -// Value: token, -// // set other properties as necessary, such as Secure and HttpOnly -// }) -// } diff --git a/internal/net/http/middleware/rate_limit.go b/internal/net/http/middleware/rate_limit.go index eacc770..41deca7 100644 --- a/internal/net/http/middleware/rate_limit.go +++ b/internal/net/http/middleware/rate_limit.go @@ -6,68 +6,56 @@ import ( "sync" "time" - E "github.com/yusing/go-proxy/internal/error" "golang.org/x/time/rate" ) type ( requestMap = map[string]*rate.Limiter rateLimiter struct { - requestMap requestMap - newLimiter func() *rate.Limiter - m *Middleware + RateLimiterOpts + *Tracer - mu sync.Mutex + requestMap requestMap + mu sync.Mutex } - rateLimiterOpts struct { - Average int `validate:"min=1,required"` - Burst int `validate:"min=1,required"` - Period time.Duration + RateLimiterOpts struct { + Average int `validate:"min=1,required"` + Burst int `validate:"min=1,required"` + Period time.Duration `validate:"min=1s"` } ) var ( - RateLimiter = &Middleware{withOptions: NewRateLimiter} - rateLimiterOptsDefault = rateLimiterOpts{ + RateLimiter = NewMiddleware[rateLimiter]() + rateLimiterOptsDefault = RateLimiterOpts{ Period: time.Second, } ) -func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) { - rl := new(rateLimiter) - opts := rateLimiterOptsDefault - err := Deserialize(optsRaw, &opts) - if err != nil { - return nil, err - } - switch { - case opts.Average == 0: - return nil, ErrZeroValue.Subject("average") - case opts.Burst == 0: - return nil, ErrZeroValue.Subject("burst") - case opts.Period == 0: - return nil, ErrZeroValue.Subject("period") - } +// setup implements MiddlewareWithSetup. +func (rl *rateLimiter) setup() { + rl.RateLimiterOpts = rateLimiterOptsDefault rl.requestMap = make(requestMap, 0) - rl.newLimiter = func() *rate.Limiter { - return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst) - } - rl.m = &Middleware{ - impl: rl, - before: rl.limit, - } - return rl.m, nil } -func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) { +// before implements RequestModifier. +func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool { + return rl.limit(w, r) +} + +func (rl *rateLimiter) newLimiter() *rate.Limiter { + return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst) +} + +func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool { rl.mu.Lock() host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { - rl.m.Debug().Msgf("unable to parse remote address %s", r.RemoteAddr) + rl.AddTracef("unable to parse remote address %s", r.RemoteAddr) http.Error(w, "Internal error", http.StatusInternalServerError) - return + return false } limiter, ok := rl.requestMap[host] @@ -79,9 +67,9 @@ func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request rl.mu.Unlock() if limiter.Allow() { - next(w, r) - return + return true } http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + return false } diff --git a/internal/net/http/middleware/rate_limit_test.go b/internal/net/http/middleware/rate_limit_test.go index ec21781..1264997 100644 --- a/internal/net/http/middleware/rate_limit_test.go +++ b/internal/net/http/middleware/rate_limit_test.go @@ -14,7 +14,7 @@ func TestRateLimit(t *testing.T) { "period": "1s", } - rl, err := NewRateLimiter(opts) + rl, err := RateLimiter.New(opts) ExpectNoError(t, err) for range 10 { result, err := newMiddlewareTest(rl, nil) diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 0faa857..da3f3e0 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -2,58 +2,53 @@ package middleware import ( "net" + "net/http" - E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/types" ) // https://nginx.org/en/docs/http/ngx_http_realip_module.html -type realIP struct { - realIPOpts - m *Middleware -} - -type realIPOpts struct { - // Header is the name of the header to use for the real client IP - Header string `validate:"required"` - // From is a list of Address / CIDRs to trust - From []*types.CIDR `validate:"min=1"` - /* - If recursive search is disabled, - the original client address that matches one of the trusted addresses is replaced by - the last address sent in the request header field defined by the Header field. - If recursive search is enabled, - the original client address that matches one of the trusted addresses is replaced by - the last non-trusted address sent in the request header field. - */ - Recursive bool -} +type ( + realIP struct { + RealIPOpts + *Tracer + } + RealIPOpts struct { + // Header is the name of the header to use for the real client IP + Header string `validate:"required"` + // From is a list of Address / CIDRs to trust + From []*types.CIDR `validate:"required,min=1"` + /* + If recursive search is disabled, + the original client address that matches one of the trusted addresses is replaced by + the last address sent in the request header field defined by the Header field. + If recursive search is enabled, + the original client address that matches one of the trusted addresses is replaced by + the last non-trusted address sent in the request header field. + */ + Recursive bool + } +) var ( - RealIP = &Middleware{withOptions: NewRealIP} - realIPOptsDefault = realIPOpts{ + RealIP = NewMiddleware[realIP]() + realIPOptsDefault = RealIPOpts{ Header: "X-Real-IP", From: []*types.CIDR{}, } ) -func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) { - riWithOpts := new(realIP) - riWithOpts.m = &Middleware{ - impl: riWithOpts, - before: Rewrite(riWithOpts.setRealIP), - } - riWithOpts.realIPOpts = realIPOptsDefault - err := Deserialize(opts, &riWithOpts.realIPOpts) - if err != nil { - return nil, err - } - if len(riWithOpts.From) == 0 { - return nil, E.New("no allowed CIDRs").Subject("from") - } - return riWithOpts.m, nil +// setup implements MiddlewareWithSetup. +func (ri *realIP) setup() { + ri.RealIPOpts = realIPOptsDefault +} + +// before implements RequestModifier. +func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool { + ri.setRealIP(r) + return true } func (ri *realIP) isInCIDRList(ip net.IP) bool { @@ -66,7 +61,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool { return false } -func (ri *realIP) setRealIP(req *Request) { +func (ri *realIP) setRealIP(req *http.Request) { clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { clientIPStr = req.RemoteAddr @@ -82,7 +77,7 @@ func (ri *realIP) setRealIP(req *Request) { } } if !isTrusted { - ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From) + ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From) return } @@ -90,7 +85,7 @@ func (ri *realIP) setRealIP(req *Request) { var lastNonTrustedIP string if len(realIPs) == 0 { - ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req) + ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req) return } @@ -105,12 +100,12 @@ func (ri *realIP) setRealIP(req *Request) { } if lastNonTrustedIP == "" { - ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs) + ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs) return } req.RemoteAddr = lastNonTrustedIP req.Header.Set(ri.Header, lastNonTrustedIP) req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP) - ri.m.AddTracef("set real ip %s", lastNonTrustedIP) + ri.AddTracef("set real ip %s", lastNonTrustedIP) } diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go index 89a85e1..02f5bd5 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/http/middleware/real_ip_test.go @@ -21,7 +21,7 @@ func TestSetRealIPOpts(t *testing.T) { }, "recursive": true, } - optExpected := &realIPOpts{ + optExpected := &RealIPOpts{ Header: gphttp.HeaderXRealIP, From: []*types.CIDR{ { @@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) { Recursive: true, } - ri, err := NewRealIP(opts) + ri, err := RealIP.New(opts) ExpectNoError(t, err) ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) @@ -61,18 +61,17 @@ func TestSetRealIP(t *testing.T) { optsMr := OptionsRaw{ "set_headers": map[string]string{testHeader: testRealIP}, } - realip, err := NewRealIP(opts) + realip, err := RealIP.New(opts) ExpectNoError(t, err) - mr, err := NewModifyRequest(optsMr) + mr, err := ModifyRequest.New(optsMr) ExpectNoError(t, err) - mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip}) + mid := NewMiddlewareChain("test", []*Middleware{mr, realip}) result, err := newMiddlewareTest(mid, nil) ExpectNoError(t, err) t.Log(traces) ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) - ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP) } diff --git a/internal/net/http/middleware/redirect_http.go b/internal/net/http/middleware/redirect_http.go index e82b573..1ab7197 100644 --- a/internal/net/http/middleware/redirect_http.go +++ b/internal/net/http/middleware/redirect_http.go @@ -7,19 +7,22 @@ import ( "github.com/yusing/go-proxy/internal/common" ) -var RedirectHTTP = &Middleware{ - before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { - if r.TLS == nil { - r.URL.Scheme = "https" - host := r.Host - if i := strings.Index(host, ":"); i != -1 { - host = host[:i] // strip port number if present - } - r.URL.Host = host + ":" + common.ProxyHTTPSPort - logger.Info().Str("url", r.URL.String()).Msg("redirect to https") - http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) - return - } - next(w, r) - }, +type redirectHTTP struct{} + +var RedirectHTTP = NewMiddleware[redirectHTTP]() + +// before implements RequestModifier. +func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + if r.TLS != nil { + return true + } + r.URL.Scheme = "https" + host := r.Host + if i := strings.Index(host, ":"); i != -1 { + host = host[:i] // strip port number if present + } + r.URL.Host = host + ":" + common.ProxyHTTPSPort + logger.Debug().Str("url", r.URL.String()).Msg("redirect to https") + http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) + return true } diff --git a/internal/net/http/middleware/set_upstream_headers.go b/internal/net/http/middleware/set_upstream_headers.go new file mode 100644 index 0000000..fdde0f5 --- /dev/null +++ b/internal/net/http/middleware/set_upstream_headers.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "net/http" + + gphttp "github.com/yusing/go-proxy/internal/net/http" +) + +// internal use only. +type setUpstreamHeaders struct { + Scheme, Host, Port string +} + +var suh = NewMiddleware[setUpstreamHeaders]() + +func newSetUpstreamHeaders(rp *gphttp.ReverseProxy) *Middleware { + m, err := suh.New(OptionsRaw{ + "scheme": rp.TargetURL.Scheme, + "host": rp.TargetURL.Hostname(), + "port": rp.TargetURL.Port(), + }) + if err != nil { + panic(err) + } + return m +} + +// before implements RequestModifier. +func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme) + r.Header.Set(gphttp.HeaderUpstreamHost, s.Host) + r.Header.Set(gphttp.HeaderUpstreamPort, s.Port) + return true +} diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index d7f292b..eb6fdf3 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -141,7 +141,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr) - mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) + mid, setOptErr := middleware.New(args.middlewareOpt) if setOptErr != nil { return nil, setOptErr } diff --git a/internal/net/http/middleware/trace.go b/internal/net/http/middleware/trace.go index 5c46464..c3b0c73 100644 --- a/internal/net/http/middleware/trace.go +++ b/internal/net/http/middleware/trace.go @@ -1,27 +1,25 @@ package middleware import ( - "fmt" "net/http" "sync" - "time" gphttp "github.com/yusing/go-proxy/internal/net/http" - "github.com/yusing/go-proxy/internal/utils/strutils" ) -type Trace struct { - Time string `json:"time,omitempty"` - Caller string `json:"caller,omitempty"` - URL string `json:"url,omitempty"` - Message string `json:"msg"` - ReqHeaders map[string]string `json:"req_headers,omitempty"` - RespHeaders map[string]string `json:"resp_headers,omitempty"` - RespStatus int `json:"resp_status,omitempty"` - Additional map[string]any `json:"additional,omitempty"` -} - -type Traces []*Trace +type ( + Trace struct { + Time string `json:"time,omitempty"` + Caller string `json:"caller,omitempty"` + URL string `json:"url,omitempty"` + Message string `json:"msg"` + ReqHeaders map[string]string `json:"req_headers,omitempty"` + RespHeaders map[string]string `json:"resp_headers,omitempty"` + RespStatus int `json:"resp_status,omitempty"` + Additional map[string]any `json:"additional,omitempty"` + } + Traces []*Trace +) var ( traces = make(Traces, 0) @@ -34,7 +32,7 @@ func GetAllTrace() []*Trace { return traces } -func (tr *Trace) WithRequest(req *Request) *Trace { +func (tr *Trace) WithRequest(req *http.Request) *Trace { if tr == nil { return nil } @@ -78,39 +76,6 @@ func (tr *Trace) WithError(err error) *Trace { return tr } -func (m *Middleware) EnableTrace() { - m.trace = true - for _, child := range m.children { - child.parent = m - child.EnableTrace() - } -} - -func (m *Middleware) AddTracef(msg string, args ...any) *Trace { - if !m.trace { - return nil - } - return addTrace(&Trace{ - Time: strutils.FormatTime(time.Now()), - Caller: m.Fullname(), - Message: fmt.Sprintf(msg, args...), - }) -} - -func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace { - if !m.trace { - return nil - } - return m.AddTracef("%s", msg).WithRequest(req) -} - -func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace { - if !m.trace { - return nil - } - return m.AddTracef("%s", msg).WithResponse(resp) -} - func addTrace(t *Trace) *Trace { tracesMu.Lock() defer tracesMu.Unlock() diff --git a/internal/net/http/middleware/tracer.go b/internal/net/http/middleware/tracer.go new file mode 100644 index 0000000..94b7419 --- /dev/null +++ b/internal/net/http/middleware/tracer.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "fmt" + "net/http" + "time" + + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type Tracer struct { + name string + parent *Tracer +} + +func (t *Tracer) Fullname() string { + if t.parent != nil { + return t.parent.Fullname() + "." + t.name + } + return t.name +} + +func (t *Tracer) addTrace(msg string) *Trace { + return addTrace(&Trace{ + Time: strutils.FormatTime(time.Now()), + Caller: t.Fullname(), + Message: msg, + }) +} + +func (t *Tracer) AddTracef(msg string, args ...any) *Trace { + if t == nil { + return nil + } + return t.addTrace(fmt.Sprintf(msg, args...)) +} + +func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace { + if t == nil { + return nil + } + return t.addTrace(msg).WithRequest(req) +} + +func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace { + if t == nil { + return nil + } + return t.addTrace(msg).WithResponse(resp) +} diff --git a/internal/net/http/middleware/vars.go b/internal/net/http/middleware/vars.go index 106293f..7ef8c91 100644 --- a/internal/net/http/middleware/vars.go +++ b/internal/net/http/middleware/vars.go @@ -11,8 +11,8 @@ import ( ) type ( - reqVarGetter func(*Request) string - respVarGetter func(*Response) string + reqVarGetter func(*http.Request) string + respVarGetter func(*http.Response) string ) var ( @@ -49,50 +49,50 @@ const ( ) var staticReqVarSubsMap = map[string]reqVarGetter{ - VarRequestMethod: func(req *Request) string { return req.Method }, - VarRequestScheme: func(req *Request) string { + VarRequestMethod: func(req *http.Request) string { return req.Method }, + VarRequestScheme: func(req *http.Request) string { if req.TLS != nil { return "https" } return "http" }, - VarRequestHost: func(req *Request) string { + VarRequestHost: func(req *http.Request) string { reqHost, _, err := net.SplitHostPort(req.Host) if err != nil { return req.Host } return reqHost }, - VarRequestPort: func(req *Request) string { + VarRequestPort: func(req *http.Request) string { _, reqPort, _ := net.SplitHostPort(req.Host) return reqPort }, - VarRequestAddr: func(req *Request) string { return req.Host }, - VarRequestPath: func(req *Request) string { return req.URL.Path }, - VarRequestQuery: func(req *Request) string { return req.URL.RawQuery }, - VarRequestURL: func(req *Request) string { return req.URL.String() }, - VarRequestURI: func(req *Request) string { return req.URL.RequestURI() }, - VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") }, - VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) }, - VarRemoteHost: func(req *Request) string { + VarRequestAddr: func(req *http.Request) string { return req.Host }, + VarRequestPath: func(req *http.Request) string { return req.URL.Path }, + VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery }, + VarRequestURL: func(req *http.Request) string { return req.URL.String() }, + VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() }, + VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") }, + VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) }, + VarRemoteHost: func(req *http.Request) string { clientIP, _, err := net.SplitHostPort(req.RemoteAddr) if err == nil { return clientIP } return "" }, - VarRemotePort: func(req *Request) string { + VarRemotePort: func(req *http.Request) string { _, clientPort, err := net.SplitHostPort(req.RemoteAddr) if err == nil { return clientPort } return "" }, - VarRemoteAddr: func(req *Request) string { return req.RemoteAddr }, - VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, - VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, - VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, - VarUpstreamAddr: func(req *Request) string { + VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, + VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, + VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, + VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, + VarUpstreamAddr: func(req *http.Request) string { upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upPort := req.Header.Get(gphttp.HeaderUpstreamPort) if upPort != "" { @@ -100,7 +100,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ } return upHost }, - VarUpstreamURL: func(req *Request) string { + VarUpstreamURL: func(req *http.Request) string { upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) if upScheme == "" { return "" @@ -116,12 +116,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ } var staticRespVarSubsMap = map[string]respVarGetter{ - VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") }, - VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) }, - VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) }, + VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") }, + VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) }, + VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) }, } -func varReplace(req *Request, resp *Response, s string) string { +func varReplace(req *http.Request, resp *http.Response, s string) string { if req != nil { // Replace query parameters s = reArg.ReplaceAllStringFunc(s, func(match string) string { diff --git a/internal/net/http/middleware/x_forwarded.go b/internal/net/http/middleware/x_forwarded.go index 9728214..ff8a558 100644 --- a/internal/net/http/middleware/x_forwarded.go +++ b/internal/net/http/middleware/x_forwarded.go @@ -2,27 +2,44 @@ package middleware import ( "net" + "net/http" "strings" gphttp "github.com/yusing/go-proxy/internal/net/http" ) -var SetXForwarded = &Middleware{ - before: Rewrite(func(req *Request) { - req.Header.Del(gphttp.HeaderXForwardedFor) - clientIP, _, err := net.SplitHostPort(req.RemoteAddr) - if err == nil { - req.Header.Set(gphttp.HeaderXForwardedFor, clientIP) - } - }), +type ( + setXForwarded struct{} + hideXForwarded struct{} +) + +var ( + SetXForwarded = NewMiddleware[setXForwarded]() + HideXForwarded = NewMiddleware[hideXForwarded]() +) + +// before implements RequestModifier. +func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + r.Header.Del(gphttp.HeaderXForwardedFor) + clientIP, _, err := net.SplitHostPort(r.RemoteAddr) + if err == nil { + r.Header.Set(gphttp.HeaderXForwardedFor, clientIP) + } + return true } -var HideXForwarded = &Middleware{ - before: Rewrite(func(req *Request) { - for k := range req.Header { - if strings.HasPrefix(k, "X-Forwarded-") { - req.Header.Del(k) - } +// before implements RequestModifier. +func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + toDelete := make([]string, 0, len(r.Header)) + for k := range r.Header { + if strings.HasPrefix(k, "X-Forwarded-") { + toDelete = append(toDelete, k) } - }), + } + + for _, k := range toDelete { + r.Header.Del(k) + } + + return true } diff --git a/internal/net/http/proxy_response.go b/internal/net/http/proxy_response.go deleted file mode 100644 index 7a5c87c..0000000 --- a/internal/net/http/proxy_response.go +++ /dev/null @@ -1,8 +0,0 @@ -package http - -import "net/http" - -type ProxyResponse struct { - *http.Response - OriginalRequest *http.Request -} diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index 2d7e935..5cc6245 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -87,7 +87,7 @@ type ReverseProxy struct { // If ModifyResponse returns an error, ErrorHandler is called // with its error value. If ErrorHandler is nil, its default // implementation is used. - ModifyResponse func(*ProxyResponse) error + ModifyResponse func(*http.Response) error HandlerFunc http.HandlerFunc @@ -251,11 +251,14 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err // modifyResponse conditionally runs the optional ModifyResponse hook // and reports whether the request should proceed. -func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool { +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool { if p.ModifyResponse == nil { return true } - if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil { + res.Request = origReq + err := p.ModifyResponse(res) + res.Request = req + if err != nil { res.Body.Close() p.errorHandler(rw, req, err, true) return false @@ -264,9 +267,6 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response } func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - // req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme) - // req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname()) - // req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port()) p.HandlerFunc(rw, req) } @@ -455,13 +455,13 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { res = &http.Response{ Status: http.StatusText(http.StatusBadGateway), StatusCode: http.StatusBadGateway, - Proto: outreq.Proto, - ProtoMajor: outreq.ProtoMajor, - ProtoMinor: outreq.ProtoMinor, + Proto: req.Proto, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, Header: http.Header{}, Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))), - Request: outreq, - TLS: outreq.TLS, + Request: req, + TLS: req.TLS, } } diff --git a/internal/task/task_test.go b/internal/task/task_test.go index c2076f1..752343d 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -9,26 +9,31 @@ import ( . "github.com/yusing/go-proxy/internal/utils/testing" ) -func TestTaskCreation(t *testing.T) { - rootTask := GlobalTask("root-task") - subTask := rootTask.Subtask("subtask") +const ( + rootTaskName = "root-task" + subTaskName = "subtask" +) - ExpectEqual(t, "root-task", rootTask.Name()) - ExpectEqual(t, "subtask", subTask.Name()) +func TestTaskCreation(t *testing.T) { + rootTask := GlobalTask(rootTaskName) + subTask := rootTask.Subtask(subTaskName) + + ExpectEqual(t, rootTaskName, rootTask.Name()) + ExpectEqual(t, subTaskName, subTask.Name()) } func TestTaskCancellation(t *testing.T) { subTaskDone := make(chan struct{}) - rootTask := GlobalTask("root-task") - subTask := rootTask.Subtask("subtask") + rootTask := GlobalTask(rootTaskName) + subTask := rootTask.Subtask(subTaskName) go func() { subTask.Wait() close(subTaskDone) }() - go rootTask.Finish("done") + go rootTask.Finish(nil) select { case <-subTaskDone: @@ -42,14 +47,14 @@ func TestTaskCancellation(t *testing.T) { } func TestOnComplete(t *testing.T) { - rootTask := GlobalTask("root-task") - task := rootTask.Subtask("test") + rootTask := GlobalTask(rootTaskName) + task := rootTask.Subtask(subTaskName) var value atomic.Int32 task.OnFinished("set value", func() { value.Store(1234) }) - task.Finish("done") + task.Finish(nil) ExpectEqual(t, value.Load(), 1234) } @@ -57,36 +62,36 @@ func TestGlobalContextWait(t *testing.T) { testResetGlobalTask() defer CancelGlobalContext() - rootTask := GlobalTask("root-task") + rootTask := GlobalTask(rootTaskName) finished1, finished2 := false, false - subTask1 := rootTask.Subtask("subtask1") - subTask2 := rootTask.Subtask("subtask2") - subTask1.OnFinished("set finished", func() { + subTask1 := rootTask.Subtask(subTaskName) + subTask2 := rootTask.Subtask(subTaskName) + subTask1.OnFinished("", func() { finished1 = true }) - subTask2.OnFinished("set finished", func() { + subTask2.OnFinished("", func() { finished2 = true }) go func() { time.Sleep(500 * time.Millisecond) - subTask1.Finish("done") + subTask1.Finish(nil) }() go func() { time.Sleep(500 * time.Millisecond) - subTask2.Finish("done") + subTask2.Finish(nil) }() go func() { subTask1.Wait() subTask2.Wait() - rootTask.Finish("done") + rootTask.Finish(nil) }() - GlobalContextWait(1 * time.Second) + _ = GlobalContextWait(1 * time.Second) ExpectTrue(t, finished1) ExpectTrue(t, finished2) ExpectError(t, context.Canceled, rootTask.Context().Err()) @@ -97,8 +102,8 @@ func TestGlobalContextWait(t *testing.T) { func TestTimeoutOnGlobalContextWait(t *testing.T) { testResetGlobalTask() - rootTask := GlobalTask("root-task") - rootTask.Subtask("subtask") + rootTask := GlobalTask(rootTaskName) + rootTask.Subtask(subTaskName) ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond)) } @@ -107,7 +112,7 @@ func TestGlobalContextCancellation(t *testing.T) { testResetGlobalTask() taskDone := make(chan struct{}) - rootTask := GlobalTask("root-task") + rootTask := GlobalTask(rootTaskName) go func() { rootTask.Wait() diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index a868357..6ea2a70 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -19,23 +19,24 @@ func TestSerializeDeserialize(t *testing.T) { MIS map[int]string } - var testStruct = S{ - I: 1, - S: "hello", - IS: []int{1, 2, 3}, - SS: []string{"a", "b", "c"}, - MSI: map[string]int{"a": 1, "b": 2, "c": 3}, - MIS: map[int]string{1: "a", 2: "b", 3: "c"}, - } - - var testStructSerialized = map[string]any{ - "I": 1, - "S": "hello", - "IS": []int{1, 2, 3}, - "SS": []string{"a", "b", "c"}, - "MSI": map[string]int{"a": 1, "b": 2, "c": 3}, - "MIS": map[int]string{1: "a", 2: "b", 3: "c"}, - } + var ( + testStruct = S{ + I: 1, + S: "hello", + IS: []int{1, 2, 3}, + SS: []string{"a", "b", "c"}, + MSI: map[string]int{"a": 1, "b": 2, "c": 3}, + MIS: map[int]string{1: "a", 2: "b", 3: "c"}, + } + testStructSerialized = map[string]any{ + "I": 1, + "S": "hello", + "IS": []int{1, 2, 3}, + "SS": []string{"a", "b", "c"}, + "MSI": map[string]int{"a": 1, "b": 2, "c": 3}, + "MIS": map[int]string{1: "a", 2: "b", 3: "c"}, + } + ) t.Run("serialize", func(t *testing.T) { s, err := Serialize(testStruct)