mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
cleanup and simplify middleware implementations, refactor some other code
This commit is contained in:
parent
8a9cb2527e
commit
59f4eaf3ea
34 changed files with 641 additions and 720 deletions
|
@ -7,12 +7,12 @@ cli:
|
|||
plugins:
|
||||
sources:
|
||||
- id: trunk
|
||||
ref: v1.6.5
|
||||
ref: v1.6.6
|
||||
uri: https://github.com/trunk-io/plugins
|
||||
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
|
||||
runtimes:
|
||||
enabled:
|
||||
- node@18.12.1
|
||||
- node@18.20.5
|
||||
- python@3.10.8
|
||||
- go@1.23.2
|
||||
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
|
||||
|
@ -23,16 +23,16 @@ lint:
|
|||
enabled:
|
||||
- hadolint@2.12.1-beta
|
||||
- actionlint@1.7.4
|
||||
- checkov@3.2.324
|
||||
- checkov@3.2.334
|
||||
- git-diff-check
|
||||
- gofmt@1.20.4
|
||||
- golangci-lint@1.62.2
|
||||
- osv-scanner@1.9.1
|
||||
- oxipng@9.1.3
|
||||
- prettier@3.4.1
|
||||
- prettier@3.4.2
|
||||
- shellcheck@0.10.0
|
||||
- shfmt@3.6.0
|
||||
- trufflehog@3.84.1
|
||||
- trufflehog@3.86.1
|
||||
actions:
|
||||
disabled:
|
||||
- trunk-announce
|
||||
|
|
|
@ -60,7 +60,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
|
|||
}
|
||||
|
||||
func rateLimited(f http.HandlerFunc) http.HandlerFunc {
|
||||
m, err := middleware.RateLimiter.WithOptionsClone(middleware.OptionsRaw{
|
||||
m, err := middleware.RateLimiter.New(middleware.OptionsRaw{
|
||||
"average": 10,
|
||||
"burst": 10,
|
||||
})
|
||||
|
|
|
@ -75,7 +75,7 @@ func Handler(w http.ResponseWriter, r *http.Request) {
|
|||
// On nginx, when route for domain does not exist, it returns StatusBadGateway.
|
||||
// Then scraper / scanners will know the subdomain is invalid.
|
||||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||
if !middleware.ServeStaticErrorPageFile(w, r) {
|
||||
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
||||
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
||||
if ok {
|
||||
|
|
|
@ -16,6 +16,9 @@ const (
|
|||
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
|
||||
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
|
||||
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
|
||||
|
||||
HeaderContentType = "Content-Type"
|
||||
HeaderContentLength = "Content-Length"
|
||||
)
|
||||
|
||||
func RemoveHop(h http.Header) {
|
||||
|
|
|
@ -24,7 +24,7 @@ func (lb *LoadBalancer) newIPHash() impl {
|
|||
return impl
|
||||
}
|
||||
var err E.Error
|
||||
impl.realIP, err = middleware.NewRealIP(lb.Options)
|
||||
impl.realIP, err = middleware.RealIP.New(lb.Options)
|
||||
if err != nil {
|
||||
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
||||
}
|
||||
|
|
|
@ -4,48 +4,45 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type cidrWhitelist struct {
|
||||
cidrWhitelistOpts
|
||||
m *Middleware
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
|
||||
type cidrWhitelistOpts struct {
|
||||
Allow []*types.CIDR `validate:"min=1"`
|
||||
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"`
|
||||
Message string
|
||||
}
|
||||
type (
|
||||
cidrWhitelist struct {
|
||||
CIDRWhitelistOpts
|
||||
*Tracer
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
CIDRWhitelistOpts struct {
|
||||
Allow []*types.CIDR `validate:"min=1"`
|
||||
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"`
|
||||
Message string
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist}
|
||||
cidrWhitelistDefaults = cidrWhitelistOpts{
|
||||
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
|
||||
cidrWhitelistDefaults = CIDRWhitelistOpts{
|
||||
Allow: []*types.CIDR{},
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "IP not allowed",
|
||||
}
|
||||
)
|
||||
|
||||
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
wl := new(cidrWhitelist)
|
||||
wl.m = &Middleware{
|
||||
impl: wl,
|
||||
before: wl.checkIP,
|
||||
}
|
||||
wl.cidrWhitelistOpts = cidrWhitelistDefaults
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (wl *cidrWhitelist) setup() {
|
||||
wl.CIDRWhitelistOpts = cidrWhitelistDefaults
|
||||
wl.cachedAddr = F.NewMapOf[string, bool]()
|
||||
err := Deserialize(opts, &wl.cidrWhitelistOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wl.m, nil
|
||||
}
|
||||
|
||||
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
// before implements RequestModifier.
|
||||
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
return wl.checkIP(w, r)
|
||||
}
|
||||
|
||||
// checkIP checks if the IP address is allowed.
|
||||
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
||||
var allow, ok bool
|
||||
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
|
||||
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
|
@ -53,24 +50,23 @@ func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Req
|
|||
ipStr = r.RemoteAddr
|
||||
}
|
||||
ip := net.ParseIP(ipStr)
|
||||
for _, cidr := range wl.cidrWhitelistOpts.Allow {
|
||||
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
|
||||
if cidr.Contains(ip) {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||
allow = true
|
||||
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
||||
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
|
||||
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
w.WriteHeader(wl.StatusCode)
|
||||
w.Write([]byte(wl.Message))
|
||||
return
|
||||
http.Error(w, wl.Message, wl.StatusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
next(w, r)
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -17,27 +17,27 @@ var deny, accept *Middleware
|
|||
|
||||
func TestCIDRWhitelistValidation(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"1.2.3.4/32"},
|
||||
"message": "test-message",
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
})
|
||||
t.Run("missing allow", func(t *testing.T) {
|
||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"message": "test-message",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid cidr", func(t *testing.T) {
|
||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"1.2.3.4/123"},
|
||||
"message": "test-message",
|
||||
})
|
||||
ExpectErrorT[*net.ParseError](t, err)
|
||||
})
|
||||
t.Run("invalid status code", func(t *testing.T) {
|
||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"1.2.3.4/32"},
|
||||
"status_code": 600,
|
||||
"message": "test-message",
|
||||
|
|
|
@ -11,11 +11,14 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type cloudflareRealIP struct {
|
||||
realIP realIP
|
||||
}
|
||||
|
||||
const (
|
||||
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
|
||||
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
|
||||
|
@ -29,26 +32,23 @@ var (
|
|||
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
|
||||
)
|
||||
|
||||
var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP}
|
||||
var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
|
||||
|
||||
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
|
||||
cri := new(realIP)
|
||||
cri.m = &Middleware{
|
||||
impl: cri,
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
cidrs := tryFetchCFCIDR()
|
||||
if cidrs != nil {
|
||||
cri.From = cidrs
|
||||
}
|
||||
cri.setRealIP(r)
|
||||
next(w, r)
|
||||
},
|
||||
}
|
||||
cri.realIPOpts = realIPOpts{
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (cri *cloudflareRealIP) setup() {
|
||||
cri.realIP.RealIPOpts = RealIPOpts{
|
||||
Header: "CF-Connecting-IP",
|
||||
Recursive: true,
|
||||
}
|
||||
return cri.m, nil
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
cidrs := tryFetchCFCIDR()
|
||||
if cidrs != nil {
|
||||
cri.realIP.From = cidrs
|
||||
}
|
||||
return cri.realIP.before(w, r)
|
||||
}
|
||||
|
||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
|
|
|
@ -12,45 +12,38 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
|
||||
)
|
||||
|
||||
var CustomErrorPage *Middleware
|
||||
type customErrorPage struct{}
|
||||
|
||||
func init() {
|
||||
CustomErrorPage = customErrorPage()
|
||||
var CustomErrorPage = NewMiddleware[customErrorPage]()
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
return !ServeStaticErrorPageFile(w, r)
|
||||
}
|
||||
|
||||
func customErrorPage() *Middleware {
|
||||
m := &Middleware{
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if !ServeStaticErrorPageFile(w, r) {
|
||||
next(w, r)
|
||||
}
|
||||
},
|
||||
}
|
||||
m.modifyResponse = func(resp *Response) error {
|
||||
// only handles non-success status code and html/plain content type
|
||||
contentType := gphttp.GetContentType(resp.Header)
|
||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||
if ok {
|
||||
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
io.Copy(io.Discard, resp.Body) // drain the original body
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||
resp.ContentLength = int64(len(errorPage))
|
||||
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
|
||||
} else {
|
||||
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (customErrorPage) modifyResponse(resp *http.Response) error {
|
||||
// only handles non-success status code and html/plain content type
|
||||
contentType := gphttp.GetContentType(resp.Header)
|
||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||
if ok {
|
||||
logger.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||
resp.ContentLength = int64(len(errorPage))
|
||||
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
|
||||
} else {
|
||||
logger.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
return nil
|
||||
}
|
||||
|
||||
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
||||
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
|
||||
path := r.URL.Path
|
||||
if path != "" && path[0] != '/' {
|
||||
path = "/" + path
|
||||
|
@ -65,11 +58,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
|||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".html":
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
|
||||
case ".js":
|
||||
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
|
||||
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
|
||||
case ".css":
|
||||
w.Header().Set("Content-Type", "text/css; charset=utf-8")
|
||||
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
|
||||
default:
|
||||
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import E "github.com/yusing/go-proxy/internal/error"
|
||||
|
||||
var ErrZeroValue = E.New("cannot be zero")
|
|
@ -12,16 +12,17 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
forwardAuth struct {
|
||||
forwardAuthOpts
|
||||
m *Middleware
|
||||
ForwardAuthOpts
|
||||
*Tracer
|
||||
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
|
||||
}
|
||||
forwardAuthOpts struct {
|
||||
ForwardAuthOpts struct {
|
||||
Address string `validate:"url,required"`
|
||||
TrustForwardHeader bool
|
||||
AuthResponseHeaders []string
|
||||
|
@ -29,36 +30,30 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
|
||||
var ForwardAuth = NewMiddleware[forwardAuth]()
|
||||
|
||||
var faHTTPClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: func(r *Request, via []*Request) error {
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
fa := new(forwardAuth)
|
||||
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fa.m = &Middleware{
|
||||
impl: fa,
|
||||
before: fa.forward,
|
||||
}
|
||||
return fa.m, nil
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (fa *forwardAuth) setup() {
|
||||
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
// before implements RequestModifier.
|
||||
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
|
||||
gphttp.RemoveHop(req.Header)
|
||||
|
||||
// Construct original URL for the redirect
|
||||
// scheme := "http"
|
||||
// if req.TLS != nil {
|
||||
// scheme = "https"
|
||||
// }
|
||||
// originalURL := scheme + "://" + req.Host + req.RequestURI
|
||||
scheme := "http"
|
||||
if req.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
originalURL := scheme + "://" + req.Host + req.RequestURI
|
||||
|
||||
url := fa.Address
|
||||
faReq, err := http.NewRequestWithContext(
|
||||
|
@ -68,7 +63,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("new request err to %s", url).WithError(err)
|
||||
fa.AddTracef("new request err to %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -79,12 +74,12 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
|
||||
fa.setAuthHeaders(req, faReq)
|
||||
// Set headers needed by Authentik
|
||||
// faReq.Header.Set("X-Original-URL", originalURL)
|
||||
fa.m.AddTraceRequest("forward auth request", faReq)
|
||||
faReq.Header.Set("X-Original-Url", originalURL)
|
||||
fa.AddTraceRequest("forward auth request", faReq)
|
||||
|
||||
faResp, err := faHTTPClient.Do(faReq)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to call %s", url).WithError(err)
|
||||
fa.AddTracef("failed to call %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -92,30 +87,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
|
||||
body, err := io.ReadAll(faResp.Body)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
|
||||
fa.AddTracef("failed to read response body from %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
|
||||
fa.m.AddTraceResponse("forward auth response", faResp)
|
||||
fa.AddTraceResponse("forward auth response", faResp)
|
||||
gphttp.CopyHeader(w.Header(), faResp.Header)
|
||||
gphttp.RemoveHop(w.Header())
|
||||
|
||||
redirectURL, err := faResp.Location()
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
||||
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
} else if redirectURL.String() != "" {
|
||||
w.Header().Set("Location", redirectURL.String())
|
||||
fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
|
||||
fa.AddTracef("%s", "redirect to "+redirectURL.String())
|
||||
}
|
||||
|
||||
w.WriteHeader(faResp.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
||||
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -132,18 +127,22 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
|
||||
authCookies := faResp.Cookies()
|
||||
|
||||
if len(authCookies) == 0 {
|
||||
next.ServeHTTP(w, req)
|
||||
return
|
||||
if len(authCookies) > 0 {
|
||||
fa.reqCookiesMap.Store(req, authCookies)
|
||||
}
|
||||
|
||||
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
|
||||
fa.setAuthCookies(resp, authCookies)
|
||||
return nil
|
||||
}), req)
|
||||
return true
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
|
||||
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
|
||||
fa.setAuthCookies(resp, cookies)
|
||||
fa.reqCookiesMap.Delete(resp.Request)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
|
||||
if len(fa.AddAuthCookiesToResponse) == 0 {
|
||||
return
|
||||
}
|
||||
|
@ -166,7 +165,7 @@ func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie
|
|||
}
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
|
||||
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
if fa.TrustForwardHeader {
|
||||
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
|
||||
|
|
|
@ -3,70 +3,110 @@ package middleware
|
|||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
Error = E.Error
|
||||
|
||||
ReverseProxy = gphttp.ReverseProxy
|
||||
ProxyRequest = gphttp.ProxyRequest
|
||||
Request = http.Request
|
||||
Response = gphttp.ProxyResponse
|
||||
ResponseWriter = http.ResponseWriter
|
||||
Header = http.Header
|
||||
Cookie = http.Cookie
|
||||
ReverseProxy = gphttp.ReverseProxy
|
||||
ProxyRequest = gphttp.ProxyRequest
|
||||
|
||||
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
||||
RewriteFunc func(req *Request)
|
||||
ModifyResponseFunc func(*Response) error
|
||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
||||
|
||||
OptionsRaw = map[string]any
|
||||
ImplNewFunc = func() any
|
||||
OptionsRaw = map[string]any
|
||||
|
||||
Middleware struct {
|
||||
_ U.NoCopy
|
||||
|
||||
zerolog.Logger
|
||||
|
||||
name string
|
||||
|
||||
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
||||
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
|
||||
|
||||
withOptions CloneWithOptFunc
|
||||
impl any
|
||||
|
||||
parent *Middleware
|
||||
children []*Middleware
|
||||
trace bool
|
||||
name string
|
||||
construct ImplNewFunc
|
||||
impl any
|
||||
}
|
||||
|
||||
RequestModifier interface {
|
||||
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
}
|
||||
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
||||
MiddlewareWithSetup interface{ setup() }
|
||||
MiddlewareFinalizer interface{ finalize() }
|
||||
MiddlewareWithTracer *struct{ *Tracer }
|
||||
)
|
||||
|
||||
var Deserialize = U.Deserialize
|
||||
|
||||
func Rewrite(r RewriteFunc) BeforeFunc {
|
||||
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
r(req)
|
||||
next(w, req)
|
||||
func NewMiddleware[ImplType any]() *Middleware {
|
||||
// type check
|
||||
switch any(new(ImplType)).(type) {
|
||||
case RequestModifier:
|
||||
case ResponseModifier:
|
||||
default:
|
||||
panic("must implement RequestModifier or ResponseModifier")
|
||||
}
|
||||
return &Middleware{
|
||||
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
||||
construct: func() any { return new(ImplType) },
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) enableTrace() {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
tracer.Tracer = &Tracer{name: m.name}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) getTracer() *Tracer {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
return tracer.Tracer
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) setParent(parent *Middleware) {
|
||||
if tracer := m.getTracer(); tracer != nil {
|
||||
tracer.parent = parent.getTracer()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) setup() {
|
||||
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
|
||||
setup.setup()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
|
||||
if len(optsRaw) == 0 {
|
||||
return nil
|
||||
}
|
||||
return utils.Deserialize(optsRaw, m.impl)
|
||||
}
|
||||
|
||||
func (m *Middleware) finalize() {
|
||||
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
||||
finalizer.finalize()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
if m.construct == nil {
|
||||
if optsRaw != nil {
|
||||
panic("bug: middleware already constructed")
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
mid := &Middleware{name: m.name, impl: m.construct()}
|
||||
mid.setup()
|
||||
if err := mid.apply(optsRaw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mid.finalize()
|
||||
return mid, nil
|
||||
}
|
||||
|
||||
func (m *Middleware) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) Fullname() string {
|
||||
if m.parent != nil {
|
||||
return m.parent.Fullname() + "." + m.name
|
||||
}
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) String() string {
|
||||
return m.name
|
||||
}
|
||||
|
@ -78,57 +118,38 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
|
|||
}, "", " ")
|
||||
}
|
||||
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
if m.withOptions != nil {
|
||||
m, err := m.withOptions(optsRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if exec, ok := m.impl.(RequestModifier); ok {
|
||||
if proceed := exec.before(w, r); !proceed {
|
||||
return
|
||||
}
|
||||
m.Logger = logger.With().Str("name", m.name).Logger()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// WithOptionsClone is called only once
|
||||
// set withOptions and labelParser will not be used after that
|
||||
return &Middleware{
|
||||
Logger: logger.With().Str("name", m.name).Logger(),
|
||||
name: m.name,
|
||||
before: m.before,
|
||||
modifyResponse: m.modifyResponse,
|
||||
impl: m.impl,
|
||||
parent: m.parent,
|
||||
children: m.children,
|
||||
}, nil
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if m.before != nil {
|
||||
m.before(next, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) ModifyResponse(resp *Response) error {
|
||||
if m.modifyResponse != nil {
|
||||
return m.modifyResponse(resp)
|
||||
func (m *Middleware) ModifyResponse(resp *http.Response) error {
|
||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||
return exec.modifyResponse(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if m.modifyResponse != nil {
|
||||
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
||||
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r})
|
||||
return exec.modifyResponse(resp)
|
||||
})
|
||||
}
|
||||
if m.before != nil {
|
||||
m.before(next, w, r)
|
||||
} else {
|
||||
next(w, r)
|
||||
if exec, ok := m.impl.(RequestModifier); ok {
|
||||
if proceed := exec.before(w, r); !proceed {
|
||||
return
|
||||
}
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
||||
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
errs := E.NewBuilder("middlewares compile error")
|
||||
|
@ -141,7 +162,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
|
|||
continue
|
||||
}
|
||||
|
||||
m, err = m.WithOptionsClone(opts)
|
||||
m, err = m.New(opts)
|
||||
if err != nil {
|
||||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
|
@ -157,7 +178,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
|
|||
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
||||
var middlewares []*Middleware
|
||||
middlewares, err = createMiddlewares(middlewaresMap)
|
||||
middlewares, err = compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -166,34 +187,30 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
|
|||
}
|
||||
|
||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||
mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{
|
||||
name: "set_upstream_headers",
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
|
||||
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname())
|
||||
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port())
|
||||
next(w, r)
|
||||
},
|
||||
}}, middlewares...))
|
||||
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
||||
|
||||
if mid.before != nil {
|
||||
ori := rp.HandlerFunc
|
||||
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
|
||||
mid.before(ori, w, r)
|
||||
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
||||
|
||||
if before, ok := mid.impl.(RequestModifier); ok {
|
||||
next := rp.HandlerFunc
|
||||
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
||||
if proceed := before.before(w, r); proceed {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mid.modifyResponse != nil {
|
||||
if mr, ok := mid.impl.(ResponseModifier); ok {
|
||||
if rp.ModifyResponse != nil {
|
||||
ori := rp.ModifyResponse
|
||||
rp.ModifyResponse = func(res *Response) error {
|
||||
if err := mid.modifyResponse(res); err != nil {
|
||||
rp.ModifyResponse = func(res *http.Response) error {
|
||||
if err := mr.modifyResponse(res); err != nil {
|
||||
return err
|
||||
}
|
||||
return ori(res)
|
||||
}
|
||||
} else {
|
||||
rp.ModifyResponse = mid.modifyResponse
|
||||
rp.ModifyResponse = mr.modifyResponse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,11 +2,9 @@ package middleware
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
@ -56,7 +54,7 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
|
|||
continue
|
||||
}
|
||||
delete(def, "use")
|
||||
m, err := base.WithOptionsClone(def)
|
||||
m, err := base.New(def)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
|
@ -67,56 +65,5 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
|
|||
if chainErr.HasError() {
|
||||
return nil, chainErr.Error()
|
||||
}
|
||||
return BuildMiddlewareFromChain(name, chain), nil
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
||||
m := &Middleware{name: name, children: chain}
|
||||
|
||||
var befores []*Middleware
|
||||
var modResps []*Middleware
|
||||
|
||||
for _, comp := range chain {
|
||||
if comp.before != nil {
|
||||
befores = append(befores, comp)
|
||||
}
|
||||
if comp.modifyResponse != nil {
|
||||
modResps = append(modResps, comp)
|
||||
}
|
||||
comp.parent = m
|
||||
}
|
||||
|
||||
if len(befores) > 0 {
|
||||
m.before = buildBefores(befores)
|
||||
}
|
||||
if len(modResps) > 0 {
|
||||
m.modifyResponse = func(res *Response) error {
|
||||
errs := E.NewBuilder("modify response errors")
|
||||
for _, mr := range modResps {
|
||||
if err := mr.modifyResponse(res); err != nil {
|
||||
errs.Add(E.From(err).Subject(mr.name))
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
m.EnableTrace()
|
||||
m.AddTracef("middleware created")
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func buildBefores(befores []*Middleware) BeforeFunc {
|
||||
if len(befores) == 1 {
|
||||
return befores[0].before
|
||||
}
|
||||
nextBefores := buildBefores(befores[1:])
|
||||
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
befores[0].before(func(w ResponseWriter, r *Request) {
|
||||
nextBefores(next, w, r)
|
||||
}, w, r)
|
||||
}
|
||||
return NewMiddlewareChain(name, chain), nil
|
||||
}
|
||||
|
|
58
internal/net/http/middleware/middleware_chain.go
Normal file
58
internal/net/http/middleware/middleware_chain.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type middlewareChain struct {
|
||||
befores []RequestModifier
|
||||
modResps []ResponseModifier
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
|
||||
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
|
||||
m := &Middleware{name: name, impl: chainMid}
|
||||
|
||||
for _, comp := range chain {
|
||||
if before, ok := comp.impl.(RequestModifier); ok {
|
||||
chainMid.befores = append(chainMid.befores, before)
|
||||
}
|
||||
if mr, ok := comp.impl.(ResponseModifier); ok {
|
||||
chainMid.modResps = append(chainMid.modResps, mr)
|
||||
}
|
||||
comp.setParent(m)
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
m.enableTrace()
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
|
||||
for _, b := range m.befores {
|
||||
if proceedNext = b.before(w, r); !proceedNext {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
||||
if len(m.modResps) == 0 {
|
||||
return nil
|
||||
}
|
||||
errs := E.NewBuilder("modify response errors")
|
||||
for i, mr := range m.modResps {
|
||||
if err := mr.modifyResponse(resp); err != nil {
|
||||
errs.Add(E.From(err).Subjectf("%d", i))
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
|
@ -3,45 +3,39 @@ package middleware
|
|||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type (
|
||||
modifyRequest struct {
|
||||
modifyRequestOpts
|
||||
m *Middleware
|
||||
needVarSubstitution bool
|
||||
ModifyRequestOpts
|
||||
*Tracer
|
||||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyRequestOpts struct {
|
||||
ModifyRequestOpts struct {
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
|
||||
needVarSubstitution bool
|
||||
}
|
||||
)
|
||||
|
||||
var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
|
||||
var ModifyRequest = NewMiddleware[modifyRequest]()
|
||||
|
||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
mr := new(modifyRequest)
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
before: Rewrite(func(req *Request) {
|
||||
mr.m.AddTraceRequest("before modify request", req)
|
||||
mr.modifyHeaders(req, nil, req.Header)
|
||||
mr.m.AddTraceRequest("after modify request", req)
|
||||
}),
|
||||
}
|
||||
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// finalize implements MiddlewareFinalizer.
|
||||
func (mr *ModifyRequestOpts) finalize() {
|
||||
mr.checkVarSubstitution()
|
||||
return mr.m, nil
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) checkVarSubstitution() {
|
||||
// before implements RequestModifier.
|
||||
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
mr.AddTraceRequest("before modify request", r)
|
||||
mr.modifyHeaders(r, nil, r.Header)
|
||||
mr.AddTraceRequest("after modify request", r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mr *ModifyRequestOpts) checkVarSubstitution() {
|
||||
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
|
||||
for _, v := range m {
|
||||
if strings.ContainsRune(v, '$') {
|
||||
|
@ -52,10 +46,10 @@ func (mr *modifyRequest) checkVarSubstitution() {
|
|||
}
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
|
||||
func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
|
||||
if !mr.needVarSubstitution {
|
||||
for k, v := range mr.SetHeaders {
|
||||
if req != nil && strings.ToLower(k) == "host" {
|
||||
if req != nil && strings.EqualFold(k, "host") {
|
||||
defer func() {
|
||||
req.Host = v
|
||||
}()
|
||||
|
@ -67,7 +61,7 @@ func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers htt
|
|||
}
|
||||
} else {
|
||||
for k, v := range mr.SetHeaders {
|
||||
if req != nil && strings.ToLower(k) == "host" {
|
||||
if req != nil && strings.EqualFold(k, "host") {
|
||||
defer func() {
|
||||
req.Host = varReplace(req, resp, v)
|
||||
}()
|
||||
|
|
|
@ -43,7 +43,7 @@ func TestModifyRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyRequest.WithOptionsClone(opts)
|
||||
mr, err := ModifyRequest.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
|
|
|
@ -2,32 +2,19 @@ package middleware
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type modifyResponse = modifyRequest
|
||||
|
||||
var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
|
||||
|
||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
mr := new(modifyResponse)
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
next(w, r)
|
||||
},
|
||||
modifyResponse: func(resp *Response) error {
|
||||
mr.m.AddTraceResponse("before modify response", resp.Response)
|
||||
mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header)
|
||||
mr.m.AddTraceResponse("after modify response", resp.Response)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mr.checkVarSubstitution()
|
||||
return mr.m, nil
|
||||
type modifyResponse struct {
|
||||
ModifyRequestOpts
|
||||
*Tracer
|
||||
}
|
||||
|
||||
var ModifyResponse = NewMiddleware[modifyResponse]()
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||
mr.AddTraceResponse("before modify response", resp)
|
||||
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
||||
mr.AddTraceResponse("after modify response", resp)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func TestModifyResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyResponse.WithOptionsClone(opts)
|
||||
mr, err := ModifyResponse.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
|
|
|
@ -1,117 +0,0 @@
|
|||
package middleware
|
||||
|
||||
// import (
|
||||
// "encoding/json"
|
||||
// "fmt"
|
||||
// "net/http"
|
||||
// "net/url"
|
||||
|
||||
// E "github.com/yusing/go-proxy/internal/error"
|
||||
// )
|
||||
|
||||
// type oAuth2 struct {
|
||||
// oAuth2Opts
|
||||
// m *Middleware
|
||||
// }
|
||||
|
||||
// type oAuth2Opts struct {
|
||||
// ClientID string `validate:"required"`
|
||||
// ClientSecret string `validate:"required"`
|
||||
// AuthURL string `validate:"required"` // Authorization Endpoint
|
||||
// TokenURL string `validate:"required"` // Token Endpoint
|
||||
// }
|
||||
|
||||
// var OAuth2 = &oAuth2{
|
||||
// m: &Middleware{withOptions: NewAuthentikOAuth2},
|
||||
// }
|
||||
|
||||
// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
// oauth := new(oAuth2)
|
||||
// oauth.m = &Middleware{
|
||||
// impl: oauth,
|
||||
// before: oauth.handleOAuth2,
|
||||
// }
|
||||
// err := Deserialize(opts, &oauth.oAuth2Opts)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return oauth.m, nil
|
||||
// }
|
||||
|
||||
// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
|
||||
// // Check if the user is authenticated (you may use session, cookie, etc.)
|
||||
// if !userIsAuthenticated(r) {
|
||||
// // TODO: Redirect to OAuth2 auth URL
|
||||
// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
|
||||
// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
|
||||
// return
|
||||
// }
|
||||
|
||||
// // If you have a token in the query string, process it
|
||||
// if code := r.URL.Query().Get("code"); code != "" {
|
||||
// // Exchange the authorization code for a token here
|
||||
// // Use the TokenURL and authenticate the user
|
||||
// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// http.Error(rw, "failed to get token", http.StatusUnauthorized)
|
||||
// return
|
||||
// }
|
||||
|
||||
// // Save token and user info based on your requirements
|
||||
// saveToken(rw, token)
|
||||
|
||||
// // Redirect to the originally requested URL
|
||||
// http.Redirect(rw, r, "/", http.StatusFound)
|
||||
// return
|
||||
// }
|
||||
|
||||
// // If user is authenticated, go to the next handler
|
||||
// next(rw, r)
|
||||
// }
|
||||
|
||||
// func userIsAuthenticated(r *http.Request) bool {
|
||||
// // Example: Check for a session or cookie
|
||||
// session, err := r.Cookie("session_token")
|
||||
// if err != nil || session.Value == "" {
|
||||
// return false
|
||||
// }
|
||||
// // Validate the session_token if necessary
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
|
||||
// // Prepare the request body
|
||||
// data := url.Values{
|
||||
// "client_id": {opts.ClientID},
|
||||
// "client_secret": {opts.ClientSecret},
|
||||
// "code": {code},
|
||||
// "grant_type": {"authorization_code"},
|
||||
// "redirect_uri": {requestURI},
|
||||
// }
|
||||
// resp, err := http.PostForm(opts.TokenURL, data)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("failed to request token: %w", err)
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
// if resp.StatusCode != http.StatusOK {
|
||||
// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
|
||||
// }
|
||||
// // Decode the response
|
||||
// var tokenResp struct {
|
||||
// AccessToken string `json:"access_token"`
|
||||
// }
|
||||
// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
// return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
// }
|
||||
// return tokenResp.AccessToken, nil
|
||||
// }
|
||||
|
||||
// func saveToken(rw ResponseWriter, token string) {
|
||||
// // Example: Save token in cookie
|
||||
// http.SetCookie(rw, &http.Cookie{
|
||||
// Name: "auth_token",
|
||||
// Value: token,
|
||||
// // set other properties as necessary, such as Secure and HttpOnly
|
||||
// })
|
||||
// }
|
|
@ -6,68 +6,56 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type (
|
||||
requestMap = map[string]*rate.Limiter
|
||||
rateLimiter struct {
|
||||
requestMap requestMap
|
||||
newLimiter func() *rate.Limiter
|
||||
m *Middleware
|
||||
RateLimiterOpts
|
||||
*Tracer
|
||||
|
||||
mu sync.Mutex
|
||||
requestMap requestMap
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
rateLimiterOpts struct {
|
||||
Average int `validate:"min=1,required"`
|
||||
Burst int `validate:"min=1,required"`
|
||||
Period time.Duration
|
||||
RateLimiterOpts struct {
|
||||
Average int `validate:"min=1,required"`
|
||||
Burst int `validate:"min=1,required"`
|
||||
Period time.Duration `validate:"min=1s"`
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
RateLimiter = &Middleware{withOptions: NewRateLimiter}
|
||||
rateLimiterOptsDefault = rateLimiterOpts{
|
||||
RateLimiter = NewMiddleware[rateLimiter]()
|
||||
rateLimiterOptsDefault = RateLimiterOpts{
|
||||
Period: time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
rl := new(rateLimiter)
|
||||
opts := rateLimiterOptsDefault
|
||||
err := Deserialize(optsRaw, &opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case opts.Average == 0:
|
||||
return nil, ErrZeroValue.Subject("average")
|
||||
case opts.Burst == 0:
|
||||
return nil, ErrZeroValue.Subject("burst")
|
||||
case opts.Period == 0:
|
||||
return nil, ErrZeroValue.Subject("period")
|
||||
}
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (rl *rateLimiter) setup() {
|
||||
rl.RateLimiterOpts = rateLimiterOptsDefault
|
||||
rl.requestMap = make(requestMap, 0)
|
||||
rl.newLimiter = func() *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst)
|
||||
}
|
||||
rl.m = &Middleware{
|
||||
impl: rl,
|
||||
before: rl.limit,
|
||||
}
|
||||
return rl.m, nil
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
// before implements RequestModifier.
|
||||
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
return rl.limit(w, r)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) newLimiter() *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
|
||||
rl.mu.Lock()
|
||||
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
rl.m.Debug().Msgf("unable to parse remote address %s", r.RemoteAddr)
|
||||
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
|
||||
http.Error(w, "Internal error", http.StatusInternalServerError)
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
limiter, ok := rl.requestMap[host]
|
||||
|
@ -79,9 +67,9 @@ func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request
|
|||
rl.mu.Unlock()
|
||||
|
||||
if limiter.Allow() {
|
||||
next(w, r)
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ func TestRateLimit(t *testing.T) {
|
|||
"period": "1s",
|
||||
}
|
||||
|
||||
rl, err := NewRateLimiter(opts)
|
||||
rl, err := RateLimiter.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(rl, nil)
|
||||
|
|
|
@ -2,58 +2,53 @@ package middleware
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||
|
||||
type realIP struct {
|
||||
realIPOpts
|
||||
m *Middleware
|
||||
}
|
||||
|
||||
type realIPOpts struct {
|
||||
// Header is the name of the header to use for the real client IP
|
||||
Header string `validate:"required"`
|
||||
// From is a list of Address / CIDRs to trust
|
||||
From []*types.CIDR `validate:"min=1"`
|
||||
/*
|
||||
If recursive search is disabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last address sent in the request header field defined by the Header field.
|
||||
If recursive search is enabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last non-trusted address sent in the request header field.
|
||||
*/
|
||||
Recursive bool
|
||||
}
|
||||
type (
|
||||
realIP struct {
|
||||
RealIPOpts
|
||||
*Tracer
|
||||
}
|
||||
RealIPOpts struct {
|
||||
// Header is the name of the header to use for the real client IP
|
||||
Header string `validate:"required"`
|
||||
// From is a list of Address / CIDRs to trust
|
||||
From []*types.CIDR `validate:"required,min=1"`
|
||||
/*
|
||||
If recursive search is disabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last address sent in the request header field defined by the Header field.
|
||||
If recursive search is enabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last non-trusted address sent in the request header field.
|
||||
*/
|
||||
Recursive bool
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
RealIP = &Middleware{withOptions: NewRealIP}
|
||||
realIPOptsDefault = realIPOpts{
|
||||
RealIP = NewMiddleware[realIP]()
|
||||
realIPOptsDefault = RealIPOpts{
|
||||
Header: "X-Real-IP",
|
||||
From: []*types.CIDR{},
|
||||
}
|
||||
)
|
||||
|
||||
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
riWithOpts := new(realIP)
|
||||
riWithOpts.m = &Middleware{
|
||||
impl: riWithOpts,
|
||||
before: Rewrite(riWithOpts.setRealIP),
|
||||
}
|
||||
riWithOpts.realIPOpts = realIPOptsDefault
|
||||
err := Deserialize(opts, &riWithOpts.realIPOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(riWithOpts.From) == 0 {
|
||||
return nil, E.New("no allowed CIDRs").Subject("from")
|
||||
}
|
||||
return riWithOpts.m, nil
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (ri *realIP) setup() {
|
||||
ri.RealIPOpts = realIPOptsDefault
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
ri.setRealIP(r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
||||
|
@ -66,7 +61,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (ri *realIP) setRealIP(req *Request) {
|
||||
func (ri *realIP) setRealIP(req *http.Request) {
|
||||
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
clientIPStr = req.RemoteAddr
|
||||
|
@ -82,7 +77,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
|||
}
|
||||
}
|
||||
if !isTrusted {
|
||||
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
||||
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -90,7 +85,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
|||
var lastNonTrustedIP string
|
||||
|
||||
if len(realIPs) == 0 {
|
||||
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
||||
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -105,12 +100,12 @@ func (ri *realIP) setRealIP(req *Request) {
|
|||
}
|
||||
|
||||
if lastNonTrustedIP == "" {
|
||||
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
||||
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
||||
return
|
||||
}
|
||||
|
||||
req.RemoteAddr = lastNonTrustedIP
|
||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
|
||||
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||
ri.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
|||
},
|
||||
"recursive": true,
|
||||
}
|
||||
optExpected := &realIPOpts{
|
||||
optExpected := &RealIPOpts{
|
||||
Header: gphttp.HeaderXRealIP,
|
||||
From: []*types.CIDR{
|
||||
{
|
||||
|
@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
|||
Recursive: true,
|
||||
}
|
||||
|
||||
ri, err := NewRealIP(opts)
|
||||
ri, err := RealIP.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
|
@ -61,18 +61,17 @@ func TestSetRealIP(t *testing.T) {
|
|||
optsMr := OptionsRaw{
|
||||
"set_headers": map[string]string{testHeader: testRealIP},
|
||||
}
|
||||
realip, err := NewRealIP(opts)
|
||||
realip, err := RealIP.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
mr, err := NewModifyRequest(optsMr)
|
||||
mr, err := ModifyRequest.New(optsMr)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
|
||||
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
|
||||
|
||||
result, err := newMiddlewareTest(mid, nil)
|
||||
ExpectNoError(t, err)
|
||||
t.Log(traces)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
||||
ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP)
|
||||
}
|
||||
|
|
|
@ -7,19 +7,22 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
var RedirectHTTP = &Middleware{
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if r.TLS == nil {
|
||||
r.URL.Scheme = "https"
|
||||
host := r.Host
|
||||
if i := strings.Index(host, ":"); i != -1 {
|
||||
host = host[:i] // strip port number if present
|
||||
}
|
||||
r.URL.Host = host + ":" + common.ProxyHTTPSPort
|
||||
logger.Info().Str("url", r.URL.String()).Msg("redirect to https")
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
},
|
||||
type redirectHTTP struct{}
|
||||
|
||||
var RedirectHTTP = NewMiddleware[redirectHTTP]()
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if r.TLS != nil {
|
||||
return true
|
||||
}
|
||||
r.URL.Scheme = "https"
|
||||
host := r.Host
|
||||
if i := strings.Index(host, ":"); i != -1 {
|
||||
host = host[:i] // strip port number if present
|
||||
}
|
||||
r.URL.Host = host + ":" + common.ProxyHTTPSPort
|
||||
logger.Debug().Str("url", r.URL.String()).Msg("redirect to https")
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return true
|
||||
}
|
||||
|
|
34
internal/net/http/middleware/set_upstream_headers.go
Normal file
34
internal/net/http/middleware/set_upstream_headers.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
)
|
||||
|
||||
// internal use only.
|
||||
type setUpstreamHeaders struct {
|
||||
Scheme, Host, Port string
|
||||
}
|
||||
|
||||
var suh = NewMiddleware[setUpstreamHeaders]()
|
||||
|
||||
func newSetUpstreamHeaders(rp *gphttp.ReverseProxy) *Middleware {
|
||||
m, err := suh.New(OptionsRaw{
|
||||
"scheme": rp.TargetURL.Scheme,
|
||||
"host": rp.TargetURL.Hostname(),
|
||||
"port": rp.TargetURL.Port(),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
|
||||
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
|
||||
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
|
||||
return true
|
||||
}
|
|
@ -141,7 +141,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
|||
|
||||
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
|
||||
|
||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||
mid, setOptErr := middleware.New(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
}
|
||||
|
|
|
@ -1,27 +1,25 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Trace struct {
|
||||
Time string `json:"time,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Message string `json:"msg"`
|
||||
ReqHeaders map[string]string `json:"req_headers,omitempty"`
|
||||
RespHeaders map[string]string `json:"resp_headers,omitempty"`
|
||||
RespStatus int `json:"resp_status,omitempty"`
|
||||
Additional map[string]any `json:"additional,omitempty"`
|
||||
}
|
||||
|
||||
type Traces []*Trace
|
||||
type (
|
||||
Trace struct {
|
||||
Time string `json:"time,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Message string `json:"msg"`
|
||||
ReqHeaders map[string]string `json:"req_headers,omitempty"`
|
||||
RespHeaders map[string]string `json:"resp_headers,omitempty"`
|
||||
RespStatus int `json:"resp_status,omitempty"`
|
||||
Additional map[string]any `json:"additional,omitempty"`
|
||||
}
|
||||
Traces []*Trace
|
||||
)
|
||||
|
||||
var (
|
||||
traces = make(Traces, 0)
|
||||
|
@ -34,7 +32,7 @@ func GetAllTrace() []*Trace {
|
|||
return traces
|
||||
}
|
||||
|
||||
func (tr *Trace) WithRequest(req *Request) *Trace {
|
||||
func (tr *Trace) WithRequest(req *http.Request) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -78,39 +76,6 @@ func (tr *Trace) WithError(err error) *Trace {
|
|||
return tr
|
||||
}
|
||||
|
||||
func (m *Middleware) EnableTrace() {
|
||||
m.trace = true
|
||||
for _, child := range m.children {
|
||||
child.parent = m
|
||||
child.EnableTrace()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
|
||||
if !m.trace {
|
||||
return nil
|
||||
}
|
||||
return addTrace(&Trace{
|
||||
Time: strutils.FormatTime(time.Now()),
|
||||
Caller: m.Fullname(),
|
||||
Message: fmt.Sprintf(msg, args...),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
|
||||
if !m.trace {
|
||||
return nil
|
||||
}
|
||||
return m.AddTracef("%s", msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||
if !m.trace {
|
||||
return nil
|
||||
}
|
||||
return m.AddTracef("%s", msg).WithResponse(resp)
|
||||
}
|
||||
|
||||
func addTrace(t *Trace) *Trace {
|
||||
tracesMu.Lock()
|
||||
defer tracesMu.Unlock()
|
||||
|
|
50
internal/net/http/middleware/tracer.go
Normal file
50
internal/net/http/middleware/tracer.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Tracer struct {
|
||||
name string
|
||||
parent *Tracer
|
||||
}
|
||||
|
||||
func (t *Tracer) Fullname() string {
|
||||
if t.parent != nil {
|
||||
return t.parent.Fullname() + "." + t.name
|
||||
}
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *Tracer) addTrace(msg string) *Trace {
|
||||
return addTrace(&Trace{
|
||||
Time: strutils.FormatTime(time.Now()),
|
||||
Caller: t.Fullname(),
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(msg).WithResponse(resp)
|
||||
}
|
|
@ -11,8 +11,8 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
reqVarGetter func(*Request) string
|
||||
respVarGetter func(*Response) string
|
||||
reqVarGetter func(*http.Request) string
|
||||
respVarGetter func(*http.Response) string
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -49,50 +49,50 @@ const (
|
|||
)
|
||||
|
||||
var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||
VarRequestMethod: func(req *Request) string { return req.Method },
|
||||
VarRequestScheme: func(req *Request) string {
|
||||
VarRequestMethod: func(req *http.Request) string { return req.Method },
|
||||
VarRequestScheme: func(req *http.Request) string {
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
},
|
||||
VarRequestHost: func(req *Request) string {
|
||||
VarRequestHost: func(req *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
return req.Host
|
||||
}
|
||||
return reqHost
|
||||
},
|
||||
VarRequestPort: func(req *Request) string {
|
||||
VarRequestPort: func(req *http.Request) string {
|
||||
_, reqPort, _ := net.SplitHostPort(req.Host)
|
||||
return reqPort
|
||||
},
|
||||
VarRequestAddr: func(req *Request) string { return req.Host },
|
||||
VarRequestPath: func(req *Request) string { return req.URL.Path },
|
||||
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery },
|
||||
VarRequestURL: func(req *Request) string { return req.URL.String() },
|
||||
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() },
|
||||
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") },
|
||||
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||
VarRemoteHost: func(req *Request) string {
|
||||
VarRequestAddr: func(req *http.Request) string { return req.Host },
|
||||
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
|
||||
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
|
||||
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
|
||||
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
|
||||
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
|
||||
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||
VarRemoteHost: func(req *http.Request) string {
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientIP
|
||||
}
|
||||
return ""
|
||||
},
|
||||
VarRemotePort: func(req *Request) string {
|
||||
VarRemotePort: func(req *http.Request) string {
|
||||
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientPort
|
||||
}
|
||||
return ""
|
||||
},
|
||||
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr },
|
||||
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
||||
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
||||
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
||||
VarUpstreamAddr: func(req *Request) string {
|
||||
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
||||
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
||||
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
||||
VarUpstreamAddr: func(req *http.Request) string {
|
||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||
if upPort != "" {
|
||||
|
@ -100,7 +100,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
|||
}
|
||||
return upHost
|
||||
},
|
||||
VarUpstreamURL: func(req *Request) string {
|
||||
VarUpstreamURL: func(req *http.Request) string {
|
||||
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
||||
if upScheme == "" {
|
||||
return ""
|
||||
|
@ -116,12 +116,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
|||
}
|
||||
|
||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") },
|
||||
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
|
||||
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
|
||||
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
|
||||
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||
}
|
||||
|
||||
func varReplace(req *Request, resp *Response, s string) string {
|
||||
func varReplace(req *http.Request, resp *http.Response, s string) string {
|
||||
if req != nil {
|
||||
// Replace query parameters
|
||||
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
|
||||
|
|
|
@ -2,27 +2,44 @@ package middleware
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
)
|
||||
|
||||
var SetXForwarded = &Middleware{
|
||||
before: Rewrite(func(req *Request) {
|
||||
req.Header.Del(gphttp.HeaderXForwardedFor)
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
req.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
|
||||
}
|
||||
}),
|
||||
type (
|
||||
setXForwarded struct{}
|
||||
hideXForwarded struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
SetXForwarded = NewMiddleware[setXForwarded]()
|
||||
HideXForwarded = NewMiddleware[hideXForwarded]()
|
||||
)
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Del(gphttp.HeaderXForwardedFor)
|
||||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil {
|
||||
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var HideXForwarded = &Middleware{
|
||||
before: Rewrite(func(req *Request) {
|
||||
for k := range req.Header {
|
||||
if strings.HasPrefix(k, "X-Forwarded-") {
|
||||
req.Header.Del(k)
|
||||
}
|
||||
// before implements RequestModifier.
|
||||
func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
toDelete := make([]string, 0, len(r.Header))
|
||||
for k := range r.Header {
|
||||
if strings.HasPrefix(k, "X-Forwarded-") {
|
||||
toDelete = append(toDelete, k)
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
for _, k := range toDelete {
|
||||
r.Header.Del(k)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
type ProxyResponse struct {
|
||||
*http.Response
|
||||
OriginalRequest *http.Request
|
||||
}
|
|
@ -87,7 +87,7 @@ type ReverseProxy struct {
|
|||
// If ModifyResponse returns an error, ErrorHandler is called
|
||||
// with its error value. If ErrorHandler is nil, its default
|
||||
// implementation is used.
|
||||
ModifyResponse func(*ProxyResponse) error
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
||||
HandlerFunc http.HandlerFunc
|
||||
|
||||
|
@ -251,11 +251,14 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
|
|||
|
||||
// modifyResponse conditionally runs the optional ModifyResponse hook
|
||||
// and reports whether the request should proceed.
|
||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool {
|
||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool {
|
||||
if p.ModifyResponse == nil {
|
||||
return true
|
||||
}
|
||||
if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil {
|
||||
res.Request = origReq
|
||||
err := p.ModifyResponse(res)
|
||||
res.Request = req
|
||||
if err != nil {
|
||||
res.Body.Close()
|
||||
p.errorHandler(rw, req, err, true)
|
||||
return false
|
||||
|
@ -264,9 +267,6 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
|
|||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme)
|
||||
// req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname())
|
||||
// req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port())
|
||||
p.HandlerFunc(rw, req)
|
||||
}
|
||||
|
||||
|
@ -455,13 +455,13 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
|||
res = &http.Response{
|
||||
Status: http.StatusText(http.StatusBadGateway),
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Proto: outreq.Proto,
|
||||
ProtoMajor: outreq.ProtoMajor,
|
||||
ProtoMinor: outreq.ProtoMinor,
|
||||
Proto: req.Proto,
|
||||
ProtoMajor: req.ProtoMajor,
|
||||
ProtoMinor: req.ProtoMinor,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
||||
Request: outreq,
|
||||
TLS: outreq.TLS,
|
||||
Request: req,
|
||||
TLS: req.TLS,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,26 +9,31 @@ import (
|
|||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestTaskCreation(t *testing.T) {
|
||||
rootTask := GlobalTask("root-task")
|
||||
subTask := rootTask.Subtask("subtask")
|
||||
const (
|
||||
rootTaskName = "root-task"
|
||||
subTaskName = "subtask"
|
||||
)
|
||||
|
||||
ExpectEqual(t, "root-task", rootTask.Name())
|
||||
ExpectEqual(t, "subtask", subTask.Name())
|
||||
func TestTaskCreation(t *testing.T) {
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
subTask := rootTask.Subtask(subTaskName)
|
||||
|
||||
ExpectEqual(t, rootTaskName, rootTask.Name())
|
||||
ExpectEqual(t, subTaskName, subTask.Name())
|
||||
}
|
||||
|
||||
func TestTaskCancellation(t *testing.T) {
|
||||
subTaskDone := make(chan struct{})
|
||||
|
||||
rootTask := GlobalTask("root-task")
|
||||
subTask := rootTask.Subtask("subtask")
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
subTask := rootTask.Subtask(subTaskName)
|
||||
|
||||
go func() {
|
||||
subTask.Wait()
|
||||
close(subTaskDone)
|
||||
}()
|
||||
|
||||
go rootTask.Finish("done")
|
||||
go rootTask.Finish(nil)
|
||||
|
||||
select {
|
||||
case <-subTaskDone:
|
||||
|
@ -42,14 +47,14 @@ func TestTaskCancellation(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestOnComplete(t *testing.T) {
|
||||
rootTask := GlobalTask("root-task")
|
||||
task := rootTask.Subtask("test")
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
task := rootTask.Subtask(subTaskName)
|
||||
|
||||
var value atomic.Int32
|
||||
task.OnFinished("set value", func() {
|
||||
value.Store(1234)
|
||||
})
|
||||
task.Finish("done")
|
||||
task.Finish(nil)
|
||||
ExpectEqual(t, value.Load(), 1234)
|
||||
}
|
||||
|
||||
|
@ -57,36 +62,36 @@ func TestGlobalContextWait(t *testing.T) {
|
|||
testResetGlobalTask()
|
||||
defer CancelGlobalContext()
|
||||
|
||||
rootTask := GlobalTask("root-task")
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
|
||||
finished1, finished2 := false, false
|
||||
|
||||
subTask1 := rootTask.Subtask("subtask1")
|
||||
subTask2 := rootTask.Subtask("subtask2")
|
||||
subTask1.OnFinished("set finished", func() {
|
||||
subTask1 := rootTask.Subtask(subTaskName)
|
||||
subTask2 := rootTask.Subtask(subTaskName)
|
||||
subTask1.OnFinished("", func() {
|
||||
finished1 = true
|
||||
})
|
||||
subTask2.OnFinished("set finished", func() {
|
||||
subTask2.OnFinished("", func() {
|
||||
finished2 = true
|
||||
})
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
subTask1.Finish("done")
|
||||
subTask1.Finish(nil)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
subTask2.Finish("done")
|
||||
subTask2.Finish(nil)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
subTask1.Wait()
|
||||
subTask2.Wait()
|
||||
rootTask.Finish("done")
|
||||
rootTask.Finish(nil)
|
||||
}()
|
||||
|
||||
GlobalContextWait(1 * time.Second)
|
||||
_ = GlobalContextWait(1 * time.Second)
|
||||
ExpectTrue(t, finished1)
|
||||
ExpectTrue(t, finished2)
|
||||
ExpectError(t, context.Canceled, rootTask.Context().Err())
|
||||
|
@ -97,8 +102,8 @@ func TestGlobalContextWait(t *testing.T) {
|
|||
func TestTimeoutOnGlobalContextWait(t *testing.T) {
|
||||
testResetGlobalTask()
|
||||
|
||||
rootTask := GlobalTask("root-task")
|
||||
rootTask.Subtask("subtask")
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
rootTask.Subtask(subTaskName)
|
||||
|
||||
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
|
||||
}
|
||||
|
@ -107,7 +112,7 @@ func TestGlobalContextCancellation(t *testing.T) {
|
|||
testResetGlobalTask()
|
||||
|
||||
taskDone := make(chan struct{})
|
||||
rootTask := GlobalTask("root-task")
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
|
||||
go func() {
|
||||
rootTask.Wait()
|
||||
|
|
|
@ -19,23 +19,24 @@ func TestSerializeDeserialize(t *testing.T) {
|
|||
MIS map[int]string
|
||||
}
|
||||
|
||||
var testStruct = S{
|
||||
I: 1,
|
||||
S: "hello",
|
||||
IS: []int{1, 2, 3},
|
||||
SS: []string{"a", "b", "c"},
|
||||
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
|
||||
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
|
||||
}
|
||||
|
||||
var testStructSerialized = map[string]any{
|
||||
"I": 1,
|
||||
"S": "hello",
|
||||
"IS": []int{1, 2, 3},
|
||||
"SS": []string{"a", "b", "c"},
|
||||
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
|
||||
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
|
||||
}
|
||||
var (
|
||||
testStruct = S{
|
||||
I: 1,
|
||||
S: "hello",
|
||||
IS: []int{1, 2, 3},
|
||||
SS: []string{"a", "b", "c"},
|
||||
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
|
||||
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
|
||||
}
|
||||
testStructSerialized = map[string]any{
|
||||
"I": 1,
|
||||
"S": "hello",
|
||||
"IS": []int{1, 2, 3},
|
||||
"SS": []string{"a", "b", "c"},
|
||||
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
|
||||
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
|
||||
}
|
||||
)
|
||||
|
||||
t.Run("serialize", func(t *testing.T) {
|
||||
s, err := Serialize(testStruct)
|
||||
|
|
Loading…
Add table
Reference in a new issue