From 5c9083a5dfa70be9bf0f49ee7eb11a5ac5f80e8d Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 13 Feb 2025 07:19:40 +0800 Subject: [PATCH] remove forwardAuth middleware --- internal/net/http/middleware/forward_auth.go | 221 ------------------- internal/net/http/middleware/middlewares.go | 4 - 2 files changed, 225 deletions(-) delete mode 100644 internal/net/http/middleware/forward_auth.go diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go deleted file mode 100644 index a505ff4..0000000 --- a/internal/net/http/middleware/forward_auth.go +++ /dev/null @@ -1,221 +0,0 @@ -// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go) -// Copyright (c) 2020-2024 Traefik Labs -// Copyright (c) 2024 yusing - -package middleware - -import ( - "io" - "net" - "net/http" - "slices" - "strings" - "time" - - gphttp "github.com/yusing/go-proxy/internal/net/http" - F "github.com/yusing/go-proxy/internal/utils/functional" -) - -type ( - forwardAuth struct { - ForwardAuthOpts - Tracer - reqCookiesMap F.Map[*http.Request, []*http.Cookie] - } - ForwardAuthOpts struct { - Address string `validate:"url,required"` - TrustForwardHeader bool - AuthResponseHeaders []string - AddAuthCookiesToResponse []string - } -) - -var ForwardAuth = NewMiddleware[forwardAuth]() - -var faHTTPClient = &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(r *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, -} - -// setup implements MiddlewareWithSetup. -func (fa *forwardAuth) setup() { - fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]() -} - -// before implements RequestModifier. -func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) { - gphttp.RemoveHop(req.Header) - - // Construct original URL for the redirect - scheme := "http" - if req.TLS != nil { - scheme = "https" - } - originalURL := scheme + "://" + req.Host + req.RequestURI - - url := fa.Address - faReq, err := http.NewRequestWithContext( - req.Context(), - http.MethodGet, - url, - nil, - ) - if err != nil { - fa.AddTracef("new request err to %s", url).WithError(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - gphttp.CopyHeader(faReq.Header, req.Header) - gphttp.RemoveHop(faReq.Header) - - faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders) - fa.setAuthHeaders(req, faReq) - // Set headers needed by Authentik - faReq.Header.Set("X-Original-Url", originalURL) - fa.AddTraceRequest("forward auth request", faReq) - - faResp, err := faHTTPClient.Do(faReq) - if err != nil { - fa.AddTracef("failed to call %s", url).WithError(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer faResp.Body.Close() - - body, err := io.ReadAll(faResp.Body) - if err != nil { - fa.AddTracef("failed to read response body from %s", url).WithError(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices { - fa.AddTraceResponse("forward auth response", faResp) - gphttp.CopyHeader(w.Header(), faResp.Header) - gphttp.RemoveHop(w.Header()) - - redirectURL, err := faResp.Location() - if err != nil { - fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp) - w.WriteHeader(http.StatusInternalServerError) - return - } else if redirectURL.String() != "" { - w.Header().Set("Location", redirectURL.String()) - fa.AddTracef("%s", "redirect to "+redirectURL.String()) - } - - w.WriteHeader(faResp.StatusCode) - - if _, err = w.Write(body); err != nil { - fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp) - } - return - } - - for _, key := range fa.AuthResponseHeaders { - key := http.CanonicalHeaderKey(key) - req.Header.Del(key) - if len(faResp.Header[key]) > 0 { - req.Header[key] = append([]string(nil), faResp.Header[key]...) - } - } - - req.RequestURI = req.URL.RequestURI() - - authCookies := faResp.Cookies() - - if len(authCookies) > 0 { - fa.reqCookiesMap.Store(req, authCookies) - } - return true -} - -// modifyResponse implements ResponseModifier. -func (fa *forwardAuth) modifyResponse(resp *http.Response) error { - if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok { - fa.setAuthCookies(resp, cookies) - fa.reqCookiesMap.Delete(resp.Request) - } - return nil -} - -func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) { - if len(fa.AddAuthCookiesToResponse) == 0 { - return - } - - cookies := resp.Cookies() - resp.Header.Del("Set-Cookie") - - for _, cookie := range cookies { - if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) { - // this cookie is not an auth cookie, so add it back - resp.Header.Add("Set-Cookie", cookie.String()) - } - } - - for _, cookie := range authCookies { - if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) { - // this cookie is an auth cookie, so add to resp - resp.Header.Add("Set-Cookie", cookie.String()) - } - } -} - -func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) { - if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - if fa.TrustForwardHeader { - if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok { - clientIP = strings.Join(prior, ", ") + ", " + clientIP - } - } - faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP) - } - - xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod) - switch { - case xMethod != "" && fa.TrustForwardHeader: - faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod) - case req.Method != "": - faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method) - default: - faReq.Header.Del(gphttp.HeaderXForwardedMethod) - } - - xfp := req.Header.Get(gphttp.HeaderXForwardedProto) - switch { - case xfp != "" && fa.TrustForwardHeader: - faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp) - case req.TLS != nil: - faReq.Header.Set(gphttp.HeaderXForwardedProto, "https") - default: - faReq.Header.Set(gphttp.HeaderXForwardedProto, "http") - } - - if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader { - faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp) - } - - xfh := req.Header.Get(gphttp.HeaderXForwardedHost) - switch { - case xfh != "" && fa.TrustForwardHeader: - faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh) - case req.Host != "": - faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host) - default: - faReq.Header.Del(gphttp.HeaderXForwardedHost) - } - - xfURI := req.Header.Get(gphttp.HeaderXForwardedURI) - switch { - case xfURI != "" && fa.TrustForwardHeader: - faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI) - case req.URL.RequestURI() != "": - faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI()) - default: - faReq.Header.Del(gphttp.HeaderXForwardedURI) - } -} diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 261954b..a5050be 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -32,10 +32,6 @@ var allMiddlewares = map[string]*Middleware{ "cidrwhitelist": CIDRWhiteList, "ratelimit": RateLimiter, - - // !experimental - "forwardauth": ForwardAuth, - // "oauth2": OAuth2.m, } var (