mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-06 22:44:03 +02:00
cleanup and simplify middleware implementations, refactor some other code
This commit is contained in:
parent
8a9cb2527e
commit
59f4eaf3ea
34 changed files with 641 additions and 720 deletions
|
@ -7,12 +7,12 @@ cli:
|
||||||
plugins:
|
plugins:
|
||||||
sources:
|
sources:
|
||||||
- id: trunk
|
- id: trunk
|
||||||
ref: v1.6.5
|
ref: v1.6.6
|
||||||
uri: https://github.com/trunk-io/plugins
|
uri: https://github.com/trunk-io/plugins
|
||||||
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
|
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
|
||||||
runtimes:
|
runtimes:
|
||||||
enabled:
|
enabled:
|
||||||
- node@18.12.1
|
- node@18.20.5
|
||||||
- python@3.10.8
|
- python@3.10.8
|
||||||
- go@1.23.2
|
- go@1.23.2
|
||||||
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
|
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
|
||||||
|
@ -23,16 +23,16 @@ lint:
|
||||||
enabled:
|
enabled:
|
||||||
- hadolint@2.12.1-beta
|
- hadolint@2.12.1-beta
|
||||||
- actionlint@1.7.4
|
- actionlint@1.7.4
|
||||||
- checkov@3.2.324
|
- checkov@3.2.334
|
||||||
- git-diff-check
|
- git-diff-check
|
||||||
- gofmt@1.20.4
|
- gofmt@1.20.4
|
||||||
- golangci-lint@1.62.2
|
- golangci-lint@1.62.2
|
||||||
- osv-scanner@1.9.1
|
- osv-scanner@1.9.1
|
||||||
- oxipng@9.1.3
|
- oxipng@9.1.3
|
||||||
- prettier@3.4.1
|
- prettier@3.4.2
|
||||||
- shellcheck@0.10.0
|
- shellcheck@0.10.0
|
||||||
- shfmt@3.6.0
|
- shfmt@3.6.0
|
||||||
- trufflehog@3.84.1
|
- trufflehog@3.86.1
|
||||||
actions:
|
actions:
|
||||||
disabled:
|
disabled:
|
||||||
- trunk-announce
|
- trunk-announce
|
||||||
|
|
|
@ -60,7 +60,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
func rateLimited(f http.HandlerFunc) http.HandlerFunc {
|
func rateLimited(f http.HandlerFunc) http.HandlerFunc {
|
||||||
m, err := middleware.RateLimiter.WithOptionsClone(middleware.OptionsRaw{
|
m, err := middleware.RateLimiter.New(middleware.OptionsRaw{
|
||||||
"average": 10,
|
"average": 10,
|
||||||
"burst": 10,
|
"burst": 10,
|
||||||
})
|
})
|
||||||
|
|
|
@ -75,7 +75,7 @@ func Handler(w http.ResponseWriter, r *http.Request) {
|
||||||
// On nginx, when route for domain does not exist, it returns StatusBadGateway.
|
// On nginx, when route for domain does not exist, it returns StatusBadGateway.
|
||||||
// Then scraper / scanners will know the subdomain is invalid.
|
// Then scraper / scanners will know the subdomain is invalid.
|
||||||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||||
if !middleware.ServeStaticErrorPageFile(w, r) {
|
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
||||||
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
|
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
|
||||||
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
||||||
if ok {
|
if ok {
|
||||||
|
|
|
@ -16,6 +16,9 @@ const (
|
||||||
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
|
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
|
||||||
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
|
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
|
||||||
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
|
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
|
||||||
|
|
||||||
|
HeaderContentType = "Content-Type"
|
||||||
|
HeaderContentLength = "Content-Length"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RemoveHop(h http.Header) {
|
func RemoveHop(h http.Header) {
|
||||||
|
|
|
@ -24,7 +24,7 @@ func (lb *LoadBalancer) newIPHash() impl {
|
||||||
return impl
|
return impl
|
||||||
}
|
}
|
||||||
var err E.Error
|
var err E.Error
|
||||||
impl.realIP, err = middleware.NewRealIP(lb.Options)
|
impl.realIP, err = middleware.RealIP.New(lb.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,48 +4,45 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
type cidrWhitelist struct {
|
type (
|
||||||
cidrWhitelistOpts
|
cidrWhitelist struct {
|
||||||
m *Middleware
|
CIDRWhitelistOpts
|
||||||
|
*Tracer
|
||||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||||
}
|
}
|
||||||
|
CIDRWhitelistOpts struct {
|
||||||
type cidrWhitelistOpts struct {
|
|
||||||
Allow []*types.CIDR `validate:"min=1"`
|
Allow []*types.CIDR `validate:"min=1"`
|
||||||
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"`
|
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"`
|
||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist}
|
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
|
||||||
cidrWhitelistDefaults = cidrWhitelistOpts{
|
cidrWhitelistDefaults = CIDRWhitelistOpts{
|
||||||
Allow: []*types.CIDR{},
|
Allow: []*types.CIDR{},
|
||||||
StatusCode: http.StatusForbidden,
|
StatusCode: http.StatusForbidden,
|
||||||
Message: "IP not allowed",
|
Message: "IP not allowed",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
// setup implements MiddlewareWithSetup.
|
||||||
wl := new(cidrWhitelist)
|
func (wl *cidrWhitelist) setup() {
|
||||||
wl.m = &Middleware{
|
wl.CIDRWhitelistOpts = cidrWhitelistDefaults
|
||||||
impl: wl,
|
|
||||||
before: wl.checkIP,
|
|
||||||
}
|
|
||||||
wl.cidrWhitelistOpts = cidrWhitelistDefaults
|
|
||||||
wl.cachedAddr = F.NewMapOf[string, bool]()
|
wl.cachedAddr = F.NewMapOf[string, bool]()
|
||||||
err := Deserialize(opts, &wl.cidrWhitelistOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return wl.m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
// before implements RequestModifier.
|
||||||
|
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
return wl.checkIP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkIP checks if the IP address is allowed.
|
||||||
|
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
||||||
var allow, ok bool
|
var allow, ok bool
|
||||||
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
|
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
|
||||||
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
@ -53,24 +50,23 @@ func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Req
|
||||||
ipStr = r.RemoteAddr
|
ipStr = r.RemoteAddr
|
||||||
}
|
}
|
||||||
ip := net.ParseIP(ipStr)
|
ip := net.ParseIP(ipStr)
|
||||||
for _, cidr := range wl.cidrWhitelistOpts.Allow {
|
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
|
||||||
if cidr.Contains(ip) {
|
if cidr.Contains(ip) {
|
||||||
wl.cachedAddr.Store(r.RemoteAddr, true)
|
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||||
allow = true
|
allow = true
|
||||||
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
wl.cachedAddr.Store(r.RemoteAddr, false)
|
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||||
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
|
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
w.WriteHeader(wl.StatusCode)
|
http.Error(w, wl.Message, wl.StatusCode)
|
||||||
w.Write([]byte(wl.Message))
|
return false
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
next(w, r)
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,27 +17,27 @@ var deny, accept *Middleware
|
||||||
|
|
||||||
func TestCIDRWhitelistValidation(t *testing.T) {
|
func TestCIDRWhitelistValidation(t *testing.T) {
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||||
"allow": []string{"1.2.3.4/32"},
|
"allow": []string{"1.2.3.4/32"},
|
||||||
"message": "test-message",
|
"message": "test-message",
|
||||||
})
|
})
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
})
|
})
|
||||||
t.Run("missing allow", func(t *testing.T) {
|
t.Run("missing allow", func(t *testing.T) {
|
||||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||||
"message": "test-message",
|
"message": "test-message",
|
||||||
})
|
})
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
ExpectError(t, utils.ErrValidationError, err)
|
||||||
})
|
})
|
||||||
t.Run("invalid cidr", func(t *testing.T) {
|
t.Run("invalid cidr", func(t *testing.T) {
|
||||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||||
"allow": []string{"1.2.3.4/123"},
|
"allow": []string{"1.2.3.4/123"},
|
||||||
"message": "test-message",
|
"message": "test-message",
|
||||||
})
|
})
|
||||||
ExpectErrorT[*net.ParseError](t, err)
|
ExpectErrorT[*net.ParseError](t, err)
|
||||||
})
|
})
|
||||||
t.Run("invalid status code", func(t *testing.T) {
|
t.Run("invalid status code", func(t *testing.T) {
|
||||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||||
"allow": []string{"1.2.3.4/32"},
|
"allow": []string{"1.2.3.4/32"},
|
||||||
"status_code": 600,
|
"status_code": 600,
|
||||||
"message": "test-message",
|
"message": "test-message",
|
||||||
|
|
|
@ -11,11 +11,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type cloudflareRealIP struct {
|
||||||
|
realIP realIP
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
|
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
|
||||||
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
|
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
|
||||||
|
@ -29,26 +32,23 @@ var (
|
||||||
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
|
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
|
||||||
)
|
)
|
||||||
|
|
||||||
var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP}
|
var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
|
||||||
|
|
||||||
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
|
// setup implements MiddlewareWithSetup.
|
||||||
cri := new(realIP)
|
func (cri *cloudflareRealIP) setup() {
|
||||||
cri.m = &Middleware{
|
cri.realIP.RealIPOpts = RealIPOpts{
|
||||||
impl: cri,
|
|
||||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
|
||||||
cidrs := tryFetchCFCIDR()
|
|
||||||
if cidrs != nil {
|
|
||||||
cri.From = cidrs
|
|
||||||
}
|
|
||||||
cri.setRealIP(r)
|
|
||||||
next(w, r)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cri.realIPOpts = realIPOpts{
|
|
||||||
Header: "CF-Connecting-IP",
|
Header: "CF-Connecting-IP",
|
||||||
Recursive: true,
|
Recursive: true,
|
||||||
}
|
}
|
||||||
return cri.m, nil
|
}
|
||||||
|
|
||||||
|
// before implements RequestModifier.
|
||||||
|
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
cidrs := tryFetchCFCIDR()
|
||||||
|
if cidrs != nil {
|
||||||
|
cri.realIP.From = cidrs
|
||||||
|
}
|
||||||
|
return cri.realIP.before(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||||
|
|
|
@ -12,45 +12,38 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
|
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
|
||||||
)
|
)
|
||||||
|
|
||||||
var CustomErrorPage *Middleware
|
type customErrorPage struct{}
|
||||||
|
|
||||||
func init() {
|
var CustomErrorPage = NewMiddleware[customErrorPage]()
|
||||||
CustomErrorPage = customErrorPage()
|
|
||||||
|
// before implements RequestModifier.
|
||||||
|
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
return !ServeStaticErrorPageFile(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func customErrorPage() *Middleware {
|
// modifyResponse implements ResponseModifier.
|
||||||
m := &Middleware{
|
func (customErrorPage) modifyResponse(resp *http.Response) error {
|
||||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
|
||||||
if !ServeStaticErrorPageFile(w, r) {
|
|
||||||
next(w, r)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
m.modifyResponse = func(resp *Response) error {
|
|
||||||
// only handles non-success status code and html/plain content type
|
// only handles non-success status code and html/plain content type
|
||||||
contentType := gphttp.GetContentType(resp.Header)
|
contentType := gphttp.GetContentType(resp.Header)
|
||||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||||
if ok {
|
if ok {
|
||||||
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
logger.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||||
/* trunk-ignore(golangci-lint/errcheck) */
|
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
|
||||||
io.Copy(io.Discard, resp.Body) // drain the original body
|
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||||
resp.ContentLength = int64(len(errorPage))
|
resp.ContentLength = int64(len(errorPage))
|
||||||
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
|
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||||
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
|
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
|
||||||
} else {
|
} else {
|
||||||
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
logger.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
|
||||||
path := r.URL.Path
|
path := r.URL.Path
|
||||||
if path != "" && path[0] != '/' {
|
if path != "" && path[0] != '/' {
|
||||||
path = "/" + path
|
path = "/" + path
|
||||||
|
@ -65,11 +58,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
||||||
ext := filepath.Ext(filename)
|
ext := filepath.Ext(filename)
|
||||||
switch ext {
|
switch ext {
|
||||||
case ".html":
|
case ".html":
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
|
||||||
case ".js":
|
case ".js":
|
||||||
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
|
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
|
||||||
case ".css":
|
case ".css":
|
||||||
w.Header().Set("Content-Type", "text/css; charset=utf-8")
|
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
|
||||||
default:
|
default:
|
||||||
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
|
|
||||||
var ErrZeroValue = E.New("cannot be zero")
|
|
|
@ -12,16 +12,17 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
forwardAuth struct {
|
forwardAuth struct {
|
||||||
forwardAuthOpts
|
ForwardAuthOpts
|
||||||
m *Middleware
|
*Tracer
|
||||||
|
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
|
||||||
}
|
}
|
||||||
forwardAuthOpts struct {
|
ForwardAuthOpts struct {
|
||||||
Address string `validate:"url,required"`
|
Address string `validate:"url,required"`
|
||||||
TrustForwardHeader bool
|
TrustForwardHeader bool
|
||||||
AuthResponseHeaders []string
|
AuthResponseHeaders []string
|
||||||
|
@ -29,36 +30,30 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
|
var ForwardAuth = NewMiddleware[forwardAuth]()
|
||||||
|
|
||||||
var faHTTPClient = &http.Client{
|
var faHTTPClient = &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
CheckRedirect: func(r *Request, via []*Request) error {
|
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||||
return http.ErrUseLastResponse
|
return http.ErrUseLastResponse
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
// setup implements MiddlewareWithSetup.
|
||||||
fa := new(forwardAuth)
|
func (fa *forwardAuth) setup() {
|
||||||
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
|
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
fa.m = &Middleware{
|
|
||||||
impl: fa,
|
|
||||||
before: fa.forward,
|
|
||||||
}
|
|
||||||
return fa.m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
// before implements RequestModifier.
|
||||||
|
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
|
||||||
gphttp.RemoveHop(req.Header)
|
gphttp.RemoveHop(req.Header)
|
||||||
|
|
||||||
// Construct original URL for the redirect
|
// Construct original URL for the redirect
|
||||||
// scheme := "http"
|
scheme := "http"
|
||||||
// if req.TLS != nil {
|
if req.TLS != nil {
|
||||||
// scheme = "https"
|
scheme = "https"
|
||||||
// }
|
}
|
||||||
// originalURL := scheme + "://" + req.Host + req.RequestURI
|
originalURL := scheme + "://" + req.Host + req.RequestURI
|
||||||
|
|
||||||
url := fa.Address
|
url := fa.Address
|
||||||
faReq, err := http.NewRequestWithContext(
|
faReq, err := http.NewRequestWithContext(
|
||||||
|
@ -68,7 +63,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fa.m.AddTracef("new request err to %s", url).WithError(err)
|
fa.AddTracef("new request err to %s", url).WithError(err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -79,12 +74,12 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
|
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
|
||||||
fa.setAuthHeaders(req, faReq)
|
fa.setAuthHeaders(req, faReq)
|
||||||
// Set headers needed by Authentik
|
// Set headers needed by Authentik
|
||||||
// faReq.Header.Set("X-Original-URL", originalURL)
|
faReq.Header.Set("X-Original-Url", originalURL)
|
||||||
fa.m.AddTraceRequest("forward auth request", faReq)
|
fa.AddTraceRequest("forward auth request", faReq)
|
||||||
|
|
||||||
faResp, err := faHTTPClient.Do(faReq)
|
faResp, err := faHTTPClient.Do(faReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fa.m.AddTracef("failed to call %s", url).WithError(err)
|
fa.AddTracef("failed to call %s", url).WithError(err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -92,30 +87,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
|
|
||||||
body, err := io.ReadAll(faResp.Body)
|
body, err := io.ReadAll(faResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
|
fa.AddTracef("failed to read response body from %s", url).WithError(err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
|
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
fa.m.AddTraceResponse("forward auth response", faResp)
|
fa.AddTraceResponse("forward auth response", faResp)
|
||||||
gphttp.CopyHeader(w.Header(), faResp.Header)
|
gphttp.CopyHeader(w.Header(), faResp.Header)
|
||||||
gphttp.RemoveHop(w.Header())
|
gphttp.RemoveHop(w.Header())
|
||||||
|
|
||||||
redirectURL, err := faResp.Location()
|
redirectURL, err := faResp.Location()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
} else if redirectURL.String() != "" {
|
} else if redirectURL.String() != "" {
|
||||||
w.Header().Set("Location", redirectURL.String())
|
w.Header().Set("Location", redirectURL.String())
|
||||||
fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
|
fa.AddTracef("%s", "redirect to "+redirectURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(faResp.StatusCode)
|
w.WriteHeader(faResp.StatusCode)
|
||||||
|
|
||||||
if _, err = w.Write(body); err != nil {
|
if _, err = w.Write(body); err != nil {
|
||||||
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -132,18 +127,22 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
|
|
||||||
authCookies := faResp.Cookies()
|
authCookies := faResp.Cookies()
|
||||||
|
|
||||||
if len(authCookies) == 0 {
|
if len(authCookies) > 0 {
|
||||||
next.ServeHTTP(w, req)
|
fa.reqCookiesMap.Store(req, authCookies)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
|
|
||||||
fa.setAuthCookies(resp, authCookies)
|
|
||||||
return nil
|
|
||||||
}), req)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
|
// modifyResponse implements ResponseModifier.
|
||||||
|
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
|
||||||
|
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
|
||||||
|
fa.setAuthCookies(resp, cookies)
|
||||||
|
fa.reqCookiesMap.Delete(resp.Request)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
|
||||||
if len(fa.AddAuthCookiesToResponse) == 0 {
|
if len(fa.AddAuthCookiesToResponse) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -166,7 +165,7 @@ func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
|
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
|
||||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||||
if fa.TrustForwardHeader {
|
if fa.TrustForwardHeader {
|
||||||
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
|
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
|
||||||
|
|
|
@ -3,11 +3,12 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -15,58 +16,97 @@ type (
|
||||||
|
|
||||||
ReverseProxy = gphttp.ReverseProxy
|
ReverseProxy = gphttp.ReverseProxy
|
||||||
ProxyRequest = gphttp.ProxyRequest
|
ProxyRequest = gphttp.ProxyRequest
|
||||||
Request = http.Request
|
|
||||||
Response = gphttp.ProxyResponse
|
|
||||||
ResponseWriter = http.ResponseWriter
|
|
||||||
Header = http.Header
|
|
||||||
Cookie = http.Cookie
|
|
||||||
|
|
||||||
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
|
||||||
RewriteFunc func(req *Request)
|
|
||||||
ModifyResponseFunc func(*Response) error
|
|
||||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
|
||||||
|
|
||||||
|
ImplNewFunc = func() any
|
||||||
OptionsRaw = map[string]any
|
OptionsRaw = map[string]any
|
||||||
|
|
||||||
Middleware struct {
|
Middleware struct {
|
||||||
_ U.NoCopy
|
|
||||||
|
|
||||||
zerolog.Logger
|
|
||||||
|
|
||||||
name string
|
name string
|
||||||
|
construct ImplNewFunc
|
||||||
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
|
||||||
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
|
|
||||||
|
|
||||||
withOptions CloneWithOptFunc
|
|
||||||
impl any
|
impl any
|
||||||
|
|
||||||
parent *Middleware
|
|
||||||
children []*Middleware
|
|
||||||
trace bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RequestModifier interface {
|
||||||
|
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||||
|
}
|
||||||
|
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
||||||
|
MiddlewareWithSetup interface{ setup() }
|
||||||
|
MiddlewareFinalizer interface{ finalize() }
|
||||||
|
MiddlewareWithTracer *struct{ *Tracer }
|
||||||
)
|
)
|
||||||
|
|
||||||
var Deserialize = U.Deserialize
|
func NewMiddleware[ImplType any]() *Middleware {
|
||||||
|
// type check
|
||||||
func Rewrite(r RewriteFunc) BeforeFunc {
|
switch any(new(ImplType)).(type) {
|
||||||
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
case RequestModifier:
|
||||||
r(req)
|
case ResponseModifier:
|
||||||
next(w, req)
|
default:
|
||||||
|
panic("must implement RequestModifier or ResponseModifier")
|
||||||
}
|
}
|
||||||
|
return &Middleware{
|
||||||
|
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
||||||
|
construct: func() any { return new(ImplType) },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) enableTrace() {
|
||||||
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||||
|
tracer.Tracer = &Tracer{name: m.name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) getTracer() *Tracer {
|
||||||
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||||
|
return tracer.Tracer
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) setParent(parent *Middleware) {
|
||||||
|
if tracer := m.getTracer(); tracer != nil {
|
||||||
|
tracer.parent = parent.getTracer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) setup() {
|
||||||
|
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
|
||||||
|
setup.setup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
|
||||||
|
if len(optsRaw) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return utils.Deserialize(optsRaw, m.impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) finalize() {
|
||||||
|
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
||||||
|
finalizer.finalize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
|
if m.construct == nil {
|
||||||
|
if optsRaw != nil {
|
||||||
|
panic("bug: middleware already constructed")
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
mid := &Middleware{name: m.name, impl: m.construct()}
|
||||||
|
mid.setup()
|
||||||
|
if err := mid.apply(optsRaw); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mid.finalize()
|
||||||
|
return mid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) Name() string {
|
func (m *Middleware) Name() string {
|
||||||
return m.name
|
return m.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) Fullname() string {
|
|
||||||
if m.parent != nil {
|
|
||||||
return m.parent.Fullname() + "." + m.name
|
|
||||||
}
|
|
||||||
return m.name
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Middleware) String() string {
|
func (m *Middleware) String() string {
|
||||||
return m.name
|
return m.name
|
||||||
}
|
}
|
||||||
|
@ -78,57 +118,38 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
|
||||||
}, "", " ")
|
}, "", " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||||
if m.withOptions != nil {
|
if exec, ok := m.impl.(RequestModifier); ok {
|
||||||
m, err := m.withOptions(optsRaw)
|
if proceed := exec.before(w, r); !proceed {
|
||||||
if err != nil {
|
return
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
m.Logger = logger.With().Str("name", m.name).Logger()
|
|
||||||
return m, nil
|
|
||||||
}
|
}
|
||||||
|
next(w, r)
|
||||||
// WithOptionsClone is called only once
|
|
||||||
// set withOptions and labelParser will not be used after that
|
|
||||||
return &Middleware{
|
|
||||||
Logger: logger.With().Str("name", m.name).Logger(),
|
|
||||||
name: m.name,
|
|
||||||
before: m.before,
|
|
||||||
modifyResponse: m.modifyResponse,
|
|
||||||
impl: m.impl,
|
|
||||||
parent: m.parent,
|
|
||||||
children: m.children,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
func (m *Middleware) ModifyResponse(resp *http.Response) error {
|
||||||
if m.before != nil {
|
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||||
m.before(next, w, r)
|
return exec.modifyResponse(resp)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Middleware) ModifyResponse(resp *Response) error {
|
|
||||||
if m.modifyResponse != nil {
|
|
||||||
return m.modifyResponse(resp)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||||
if m.modifyResponse != nil {
|
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||||
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
||||||
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r})
|
return exec.modifyResponse(resp)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if m.before != nil {
|
if exec, ok := m.impl.(RequestModifier); ok {
|
||||||
m.before(next, w, r)
|
if proceed := exec.before(w, r); !proceed {
|
||||||
} else {
|
return
|
||||||
next(w, r)
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
next(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check conflict or duplicates.
|
// TODO: check conflict or duplicates.
|
||||||
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
||||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||||
|
|
||||||
errs := E.NewBuilder("middlewares compile error")
|
errs := E.NewBuilder("middlewares compile error")
|
||||||
|
@ -141,7 +162,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err = m.WithOptionsClone(opts)
|
m, err = m.New(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
invalidOpts.Add(err.Subject(name))
|
invalidOpts.Add(err.Subject(name))
|
||||||
continue
|
continue
|
||||||
|
@ -157,7 +178,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
|
||||||
|
|
||||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
||||||
var middlewares []*Middleware
|
var middlewares []*Middleware
|
||||||
middlewares, err = createMiddlewares(middlewaresMap)
|
middlewares, err = compileMiddlewares(middlewaresMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -166,34 +187,30 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
|
||||||
}
|
}
|
||||||
|
|
||||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||||
mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{
|
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
||||||
name: "set_upstream_headers",
|
|
||||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
||||||
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
|
|
||||||
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname())
|
if before, ok := mid.impl.(RequestModifier); ok {
|
||||||
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port())
|
next := rp.HandlerFunc
|
||||||
|
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if proceed := before.before(w, r); proceed {
|
||||||
next(w, r)
|
next(w, r)
|
||||||
},
|
}
|
||||||
}}, middlewares...))
|
|
||||||
|
|
||||||
if mid.before != nil {
|
|
||||||
ori := rp.HandlerFunc
|
|
||||||
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
|
|
||||||
mid.before(ori, w, r)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if mid.modifyResponse != nil {
|
if mr, ok := mid.impl.(ResponseModifier); ok {
|
||||||
if rp.ModifyResponse != nil {
|
if rp.ModifyResponse != nil {
|
||||||
ori := rp.ModifyResponse
|
ori := rp.ModifyResponse
|
||||||
rp.ModifyResponse = func(res *Response) error {
|
rp.ModifyResponse = func(res *http.Response) error {
|
||||||
if err := mid.modifyResponse(res); err != nil {
|
if err := mr.modifyResponse(res); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return ori(res)
|
return ori(res)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
rp.ModifyResponse = mid.modifyResponse
|
rp.ModifyResponse = mr.modifyResponse
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,11 +2,9 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
@ -56,7 +54,7 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
delete(def, "use")
|
delete(def, "use")
|
||||||
m, err := base.WithOptionsClone(def)
|
m, err := base.New(def)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||||
continue
|
continue
|
||||||
|
@ -67,56 +65,5 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
|
||||||
if chainErr.HasError() {
|
if chainErr.HasError() {
|
||||||
return nil, chainErr.Error()
|
return nil, chainErr.Error()
|
||||||
}
|
}
|
||||||
return BuildMiddlewareFromChain(name, chain), nil
|
return NewMiddlewareChain(name, chain), nil
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: check conflict or duplicates.
|
|
||||||
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
|
||||||
m := &Middleware{name: name, children: chain}
|
|
||||||
|
|
||||||
var befores []*Middleware
|
|
||||||
var modResps []*Middleware
|
|
||||||
|
|
||||||
for _, comp := range chain {
|
|
||||||
if comp.before != nil {
|
|
||||||
befores = append(befores, comp)
|
|
||||||
}
|
|
||||||
if comp.modifyResponse != nil {
|
|
||||||
modResps = append(modResps, comp)
|
|
||||||
}
|
|
||||||
comp.parent = m
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(befores) > 0 {
|
|
||||||
m.before = buildBefores(befores)
|
|
||||||
}
|
|
||||||
if len(modResps) > 0 {
|
|
||||||
m.modifyResponse = func(res *Response) error {
|
|
||||||
errs := E.NewBuilder("modify response errors")
|
|
||||||
for _, mr := range modResps {
|
|
||||||
if err := mr.modifyResponse(res); err != nil {
|
|
||||||
errs.Add(E.From(err).Subject(mr.name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return errs.Error()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if common.IsDebug {
|
|
||||||
m.EnableTrace()
|
|
||||||
m.AddTracef("middleware created")
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildBefores(befores []*Middleware) BeforeFunc {
|
|
||||||
if len(befores) == 1 {
|
|
||||||
return befores[0].before
|
|
||||||
}
|
|
||||||
nextBefores := buildBefores(befores[1:])
|
|
||||||
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
|
||||||
befores[0].before(func(w ResponseWriter, r *Request) {
|
|
||||||
nextBefores(next, w, r)
|
|
||||||
}, w, r)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
58
internal/net/http/middleware/middleware_chain.go
Normal file
58
internal/net/http/middleware/middleware_chain.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
type middlewareChain struct {
|
||||||
|
befores []RequestModifier
|
||||||
|
modResps []ResponseModifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: check conflict or duplicates.
|
||||||
|
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
|
||||||
|
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
|
||||||
|
m := &Middleware{name: name, impl: chainMid}
|
||||||
|
|
||||||
|
for _, comp := range chain {
|
||||||
|
if before, ok := comp.impl.(RequestModifier); ok {
|
||||||
|
chainMid.befores = append(chainMid.befores, before)
|
||||||
|
}
|
||||||
|
if mr, ok := comp.impl.(ResponseModifier); ok {
|
||||||
|
chainMid.modResps = append(chainMid.modResps, mr)
|
||||||
|
}
|
||||||
|
comp.setParent(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IsDebug {
|
||||||
|
m.enableTrace()
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// before implements RequestModifier.
|
||||||
|
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
|
||||||
|
for _, b := range m.befores {
|
||||||
|
if proceedNext = b.before(w, r); !proceedNext {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// modifyResponse implements ResponseModifier.
|
||||||
|
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
||||||
|
if len(m.modResps) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
errs := E.NewBuilder("modify response errors")
|
||||||
|
for i, mr := range m.modResps {
|
||||||
|
if err := mr.modifyResponse(resp); err != nil {
|
||||||
|
errs.Add(E.From(err).Subjectf("%d", i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errs.Error()
|
||||||
|
}
|
|
@ -3,45 +3,39 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
modifyRequest struct {
|
modifyRequest struct {
|
||||||
modifyRequestOpts
|
ModifyRequestOpts
|
||||||
m *Middleware
|
*Tracer
|
||||||
needVarSubstitution bool
|
|
||||||
}
|
}
|
||||||
// order: set_headers -> add_headers -> hide_headers
|
// order: set_headers -> add_headers -> hide_headers
|
||||||
modifyRequestOpts struct {
|
ModifyRequestOpts struct {
|
||||||
SetHeaders map[string]string
|
SetHeaders map[string]string
|
||||||
AddHeaders map[string]string
|
AddHeaders map[string]string
|
||||||
HideHeaders []string
|
HideHeaders []string
|
||||||
|
|
||||||
|
needVarSubstitution bool
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
|
var ModifyRequest = NewMiddleware[modifyRequest]()
|
||||||
|
|
||||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
// finalize implements MiddlewareFinalizer.
|
||||||
mr := new(modifyRequest)
|
func (mr *ModifyRequestOpts) finalize() {
|
||||||
mr.m = &Middleware{
|
|
||||||
impl: mr,
|
|
||||||
before: Rewrite(func(req *Request) {
|
|
||||||
mr.m.AddTraceRequest("before modify request", req)
|
|
||||||
mr.modifyHeaders(req, nil, req.Header)
|
|
||||||
mr.m.AddTraceRequest("after modify request", req)
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mr.checkVarSubstitution()
|
mr.checkVarSubstitution()
|
||||||
return mr.m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mr *modifyRequest) checkVarSubstitution() {
|
// before implements RequestModifier.
|
||||||
|
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
mr.AddTraceRequest("before modify request", r)
|
||||||
|
mr.modifyHeaders(r, nil, r.Header)
|
||||||
|
mr.AddTraceRequest("after modify request", r)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *ModifyRequestOpts) checkVarSubstitution() {
|
||||||
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
|
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
|
||||||
for _, v := range m {
|
for _, v := range m {
|
||||||
if strings.ContainsRune(v, '$') {
|
if strings.ContainsRune(v, '$') {
|
||||||
|
@ -52,10 +46,10 @@ func (mr *modifyRequest) checkVarSubstitution() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
|
func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
|
||||||
if !mr.needVarSubstitution {
|
if !mr.needVarSubstitution {
|
||||||
for k, v := range mr.SetHeaders {
|
for k, v := range mr.SetHeaders {
|
||||||
if req != nil && strings.ToLower(k) == "host" {
|
if req != nil && strings.EqualFold(k, "host") {
|
||||||
defer func() {
|
defer func() {
|
||||||
req.Host = v
|
req.Host = v
|
||||||
}()
|
}()
|
||||||
|
@ -67,7 +61,7 @@ func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers htt
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for k, v := range mr.SetHeaders {
|
for k, v := range mr.SetHeaders {
|
||||||
if req != nil && strings.ToLower(k) == "host" {
|
if req != nil && strings.EqualFold(k, "host") {
|
||||||
defer func() {
|
defer func() {
|
||||||
req.Host = varReplace(req, resp, v)
|
req.Host = varReplace(req, resp, v)
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -43,7 +43,7 @@ func TestModifyRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("set_options", func(t *testing.T) {
|
t.Run("set_options", func(t *testing.T) {
|
||||||
mr, err := ModifyRequest.WithOptionsClone(opts)
|
mr, err := ModifyRequest.New(opts)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||||
|
|
|
@ -2,32 +2,19 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type modifyResponse = modifyRequest
|
type modifyResponse struct {
|
||||||
|
ModifyRequestOpts
|
||||||
var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
|
*Tracer
|
||||||
|
}
|
||||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
|
||||||
mr := new(modifyResponse)
|
var ModifyResponse = NewMiddleware[modifyResponse]()
|
||||||
mr.m = &Middleware{
|
|
||||||
impl: mr,
|
// modifyResponse implements ResponseModifier.
|
||||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||||
next(w, r)
|
mr.AddTraceResponse("before modify response", resp)
|
||||||
},
|
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
||||||
modifyResponse: func(resp *Response) error {
|
mr.AddTraceResponse("after modify response", resp)
|
||||||
mr.m.AddTraceResponse("before modify response", resp.Response)
|
return nil
|
||||||
mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header)
|
|
||||||
mr.m.AddTraceResponse("after modify response", resp.Response)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mr.checkVarSubstitution()
|
|
||||||
return mr.m, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ func TestModifyResponse(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("set_options", func(t *testing.T) {
|
t.Run("set_options", func(t *testing.T) {
|
||||||
mr, err := ModifyResponse.WithOptionsClone(opts)
|
mr, err := ModifyResponse.New(opts)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||||
|
|
|
@ -1,117 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
// import (
|
|
||||||
// "encoding/json"
|
|
||||||
// "fmt"
|
|
||||||
// "net/http"
|
|
||||||
// "net/url"
|
|
||||||
|
|
||||||
// E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
// )
|
|
||||||
|
|
||||||
// type oAuth2 struct {
|
|
||||||
// oAuth2Opts
|
|
||||||
// m *Middleware
|
|
||||||
// }
|
|
||||||
|
|
||||||
// type oAuth2Opts struct {
|
|
||||||
// ClientID string `validate:"required"`
|
|
||||||
// ClientSecret string `validate:"required"`
|
|
||||||
// AuthURL string `validate:"required"` // Authorization Endpoint
|
|
||||||
// TokenURL string `validate:"required"` // Token Endpoint
|
|
||||||
// }
|
|
||||||
|
|
||||||
// var OAuth2 = &oAuth2{
|
|
||||||
// m: &Middleware{withOptions: NewAuthentikOAuth2},
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
|
|
||||||
// oauth := new(oAuth2)
|
|
||||||
// oauth.m = &Middleware{
|
|
||||||
// impl: oauth,
|
|
||||||
// before: oauth.handleOAuth2,
|
|
||||||
// }
|
|
||||||
// err := Deserialize(opts, &oauth.oAuth2Opts)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
// return oauth.m, nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
|
|
||||||
// // Check if the user is authenticated (you may use session, cookie, etc.)
|
|
||||||
// if !userIsAuthenticated(r) {
|
|
||||||
// // TODO: Redirect to OAuth2 auth URL
|
|
||||||
// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
|
|
||||||
// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // If you have a token in the query string, process it
|
|
||||||
// if code := r.URL.Query().Get("code"); code != "" {
|
|
||||||
// // Exchange the authorization code for a token here
|
|
||||||
// // Use the TokenURL and authenticate the user
|
|
||||||
// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
|
|
||||||
// if err != nil {
|
|
||||||
// // handle error
|
|
||||||
// http.Error(rw, "failed to get token", http.StatusUnauthorized)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Save token and user info based on your requirements
|
|
||||||
// saveToken(rw, token)
|
|
||||||
|
|
||||||
// // Redirect to the originally requested URL
|
|
||||||
// http.Redirect(rw, r, "/", http.StatusFound)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // If user is authenticated, go to the next handler
|
|
||||||
// next(rw, r)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func userIsAuthenticated(r *http.Request) bool {
|
|
||||||
// // Example: Check for a session or cookie
|
|
||||||
// session, err := r.Cookie("session_token")
|
|
||||||
// if err != nil || session.Value == "" {
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
// // Validate the session_token if necessary
|
|
||||||
// return true
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
|
|
||||||
// // Prepare the request body
|
|
||||||
// data := url.Values{
|
|
||||||
// "client_id": {opts.ClientID},
|
|
||||||
// "client_secret": {opts.ClientSecret},
|
|
||||||
// "code": {code},
|
|
||||||
// "grant_type": {"authorization_code"},
|
|
||||||
// "redirect_uri": {requestURI},
|
|
||||||
// }
|
|
||||||
// resp, err := http.PostForm(opts.TokenURL, data)
|
|
||||||
// if err != nil {
|
|
||||||
// return "", fmt.Errorf("failed to request token: %w", err)
|
|
||||||
// }
|
|
||||||
// defer resp.Body.Close()
|
|
||||||
// if resp.StatusCode != http.StatusOK {
|
|
||||||
// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
|
|
||||||
// }
|
|
||||||
// // Decode the response
|
|
||||||
// var tokenResp struct {
|
|
||||||
// AccessToken string `json:"access_token"`
|
|
||||||
// }
|
|
||||||
// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
||||||
// return "", fmt.Errorf("failed to decode token response: %w", err)
|
|
||||||
// }
|
|
||||||
// return tokenResp.AccessToken, nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func saveToken(rw ResponseWriter, token string) {
|
|
||||||
// // Example: Save token in cookie
|
|
||||||
// http.SetCookie(rw, &http.Cookie{
|
|
||||||
// Name: "auth_token",
|
|
||||||
// Value: token,
|
|
||||||
// // set other properties as necessary, such as Secure and HttpOnly
|
|
||||||
// })
|
|
||||||
// }
|
|
|
@ -6,68 +6,56 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
requestMap = map[string]*rate.Limiter
|
requestMap = map[string]*rate.Limiter
|
||||||
rateLimiter struct {
|
rateLimiter struct {
|
||||||
requestMap requestMap
|
RateLimiterOpts
|
||||||
newLimiter func() *rate.Limiter
|
*Tracer
|
||||||
m *Middleware
|
|
||||||
|
|
||||||
|
requestMap requestMap
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
rateLimiterOpts struct {
|
RateLimiterOpts struct {
|
||||||
Average int `validate:"min=1,required"`
|
Average int `validate:"min=1,required"`
|
||||||
Burst int `validate:"min=1,required"`
|
Burst int `validate:"min=1,required"`
|
||||||
Period time.Duration
|
Period time.Duration `validate:"min=1s"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
RateLimiter = &Middleware{withOptions: NewRateLimiter}
|
RateLimiter = NewMiddleware[rateLimiter]()
|
||||||
rateLimiterOptsDefault = rateLimiterOpts{
|
rateLimiterOptsDefault = RateLimiterOpts{
|
||||||
Period: time.Second,
|
Period: time.Second,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
// setup implements MiddlewareWithSetup.
|
||||||
rl := new(rateLimiter)
|
func (rl *rateLimiter) setup() {
|
||||||
opts := rateLimiterOptsDefault
|
rl.RateLimiterOpts = rateLimiterOptsDefault
|
||||||
err := Deserialize(optsRaw, &opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
switch {
|
|
||||||
case opts.Average == 0:
|
|
||||||
return nil, ErrZeroValue.Subject("average")
|
|
||||||
case opts.Burst == 0:
|
|
||||||
return nil, ErrZeroValue.Subject("burst")
|
|
||||||
case opts.Period == 0:
|
|
||||||
return nil, ErrZeroValue.Subject("period")
|
|
||||||
}
|
|
||||||
rl.requestMap = make(requestMap, 0)
|
rl.requestMap = make(requestMap, 0)
|
||||||
rl.newLimiter = func() *rate.Limiter {
|
|
||||||
return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst)
|
|
||||||
}
|
|
||||||
rl.m = &Middleware{
|
|
||||||
impl: rl,
|
|
||||||
before: rl.limit,
|
|
||||||
}
|
|
||||||
return rl.m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
// before implements RequestModifier.
|
||||||
|
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
return rl.limit(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *rateLimiter) newLimiter() *rate.Limiter {
|
||||||
|
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
|
||||||
rl.mu.Lock()
|
rl.mu.Lock()
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rl.m.Debug().Msgf("unable to parse remote address %s", r.RemoteAddr)
|
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
|
||||||
http.Error(w, "Internal error", http.StatusInternalServerError)
|
http.Error(w, "Internal error", http.StatusInternalServerError)
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
limiter, ok := rl.requestMap[host]
|
limiter, ok := rl.requestMap[host]
|
||||||
|
@ -79,9 +67,9 @@ func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request
|
||||||
rl.mu.Unlock()
|
rl.mu.Unlock()
|
||||||
|
|
||||||
if limiter.Allow() {
|
if limiter.Allow() {
|
||||||
next(w, r)
|
return true
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ func TestRateLimit(t *testing.T) {
|
||||||
"period": "1s",
|
"period": "1s",
|
||||||
}
|
}
|
||||||
|
|
||||||
rl, err := NewRateLimiter(opts)
|
rl, err := RateLimiter.New(opts)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
for range 10 {
|
for range 10 {
|
||||||
result, err := newMiddlewareTest(rl, nil)
|
result, err := newMiddlewareTest(rl, nil)
|
||||||
|
|
|
@ -2,24 +2,24 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||||
|
|
||||||
type realIP struct {
|
type (
|
||||||
realIPOpts
|
realIP struct {
|
||||||
m *Middleware
|
RealIPOpts
|
||||||
}
|
*Tracer
|
||||||
|
}
|
||||||
type realIPOpts struct {
|
RealIPOpts struct {
|
||||||
// Header is the name of the header to use for the real client IP
|
// Header is the name of the header to use for the real client IP
|
||||||
Header string `validate:"required"`
|
Header string `validate:"required"`
|
||||||
// From is a list of Address / CIDRs to trust
|
// From is a list of Address / CIDRs to trust
|
||||||
From []*types.CIDR `validate:"min=1"`
|
From []*types.CIDR `validate:"required,min=1"`
|
||||||
/*
|
/*
|
||||||
If recursive search is disabled,
|
If recursive search is disabled,
|
||||||
the original client address that matches one of the trusted addresses is replaced by
|
the original client address that matches one of the trusted addresses is replaced by
|
||||||
|
@ -29,31 +29,26 @@ type realIPOpts struct {
|
||||||
the last non-trusted address sent in the request header field.
|
the last non-trusted address sent in the request header field.
|
||||||
*/
|
*/
|
||||||
Recursive bool
|
Recursive bool
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
RealIP = &Middleware{withOptions: NewRealIP}
|
RealIP = NewMiddleware[realIP]()
|
||||||
realIPOptsDefault = realIPOpts{
|
realIPOptsDefault = RealIPOpts{
|
||||||
Header: "X-Real-IP",
|
Header: "X-Real-IP",
|
||||||
From: []*types.CIDR{},
|
From: []*types.CIDR{},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
|
// setup implements MiddlewareWithSetup.
|
||||||
riWithOpts := new(realIP)
|
func (ri *realIP) setup() {
|
||||||
riWithOpts.m = &Middleware{
|
ri.RealIPOpts = realIPOptsDefault
|
||||||
impl: riWithOpts,
|
}
|
||||||
before: Rewrite(riWithOpts.setRealIP),
|
|
||||||
}
|
// before implements RequestModifier.
|
||||||
riWithOpts.realIPOpts = realIPOptsDefault
|
func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
|
||||||
err := Deserialize(opts, &riWithOpts.realIPOpts)
|
ri.setRealIP(r)
|
||||||
if err != nil {
|
return true
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(riWithOpts.From) == 0 {
|
|
||||||
return nil, E.New("no allowed CIDRs").Subject("from")
|
|
||||||
}
|
|
||||||
return riWithOpts.m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
||||||
|
@ -66,7 +61,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ri *realIP) setRealIP(req *Request) {
|
func (ri *realIP) setRealIP(req *http.Request) {
|
||||||
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
clientIPStr = req.RemoteAddr
|
clientIPStr = req.RemoteAddr
|
||||||
|
@ -82,7 +77,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !isTrusted {
|
if !isTrusted {
|
||||||
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,7 +85,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||||
var lastNonTrustedIP string
|
var lastNonTrustedIP string
|
||||||
|
|
||||||
if len(realIPs) == 0 {
|
if len(realIPs) == 0 {
|
||||||
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,12 +100,12 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastNonTrustedIP == "" {
|
if lastNonTrustedIP == "" {
|
||||||
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.RemoteAddr = lastNonTrustedIP
|
req.RemoteAddr = lastNonTrustedIP
|
||||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||||
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
|
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
|
||||||
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
|
ri.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
||||||
},
|
},
|
||||||
"recursive": true,
|
"recursive": true,
|
||||||
}
|
}
|
||||||
optExpected := &realIPOpts{
|
optExpected := &RealIPOpts{
|
||||||
Header: gphttp.HeaderXRealIP,
|
Header: gphttp.HeaderXRealIP,
|
||||||
From: []*types.CIDR{
|
From: []*types.CIDR{
|
||||||
{
|
{
|
||||||
|
@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
||||||
Recursive: true,
|
Recursive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
ri, err := NewRealIP(opts)
|
ri, err := RealIP.New(opts)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||||
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||||
|
@ -61,18 +61,17 @@ func TestSetRealIP(t *testing.T) {
|
||||||
optsMr := OptionsRaw{
|
optsMr := OptionsRaw{
|
||||||
"set_headers": map[string]string{testHeader: testRealIP},
|
"set_headers": map[string]string{testHeader: testRealIP},
|
||||||
}
|
}
|
||||||
realip, err := NewRealIP(opts)
|
realip, err := RealIP.New(opts)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
|
|
||||||
mr, err := NewModifyRequest(optsMr)
|
mr, err := ModifyRequest.New(optsMr)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
|
|
||||||
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
|
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
|
||||||
|
|
||||||
result, err := newMiddlewareTest(mid, nil)
|
result, err := newMiddlewareTest(mid, nil)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
t.Log(traces)
|
t.Log(traces)
|
||||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||||
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
||||||
ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,19 +7,22 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var RedirectHTTP = &Middleware{
|
type redirectHTTP struct{}
|
||||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
|
||||||
if r.TLS == nil {
|
var RedirectHTTP = NewMiddleware[redirectHTTP]()
|
||||||
|
|
||||||
|
// before implements RequestModifier.
|
||||||
|
func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
if r.TLS != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
r.URL.Scheme = "https"
|
r.URL.Scheme = "https"
|
||||||
host := r.Host
|
host := r.Host
|
||||||
if i := strings.Index(host, ":"); i != -1 {
|
if i := strings.Index(host, ":"); i != -1 {
|
||||||
host = host[:i] // strip port number if present
|
host = host[:i] // strip port number if present
|
||||||
}
|
}
|
||||||
r.URL.Host = host + ":" + common.ProxyHTTPSPort
|
r.URL.Host = host + ":" + common.ProxyHTTPSPort
|
||||||
logger.Info().Str("url", r.URL.String()).Msg("redirect to https")
|
logger.Debug().Str("url", r.URL.String()).Msg("redirect to https")
|
||||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||||
return
|
return true
|
||||||
}
|
|
||||||
next(w, r)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
34
internal/net/http/middleware/set_upstream_headers.go
Normal file
34
internal/net/http/middleware/set_upstream_headers.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// internal use only.
|
||||||
|
type setUpstreamHeaders struct {
|
||||||
|
Scheme, Host, Port string
|
||||||
|
}
|
||||||
|
|
||||||
|
var suh = NewMiddleware[setUpstreamHeaders]()
|
||||||
|
|
||||||
|
func newSetUpstreamHeaders(rp *gphttp.ReverseProxy) *Middleware {
|
||||||
|
m, err := suh.New(OptionsRaw{
|
||||||
|
"scheme": rp.TargetURL.Scheme,
|
||||||
|
"host": rp.TargetURL.Hostname(),
|
||||||
|
"port": rp.TargetURL.Port(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// before implements RequestModifier.
|
||||||
|
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
|
||||||
|
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
|
||||||
|
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
|
||||||
|
return true
|
||||||
|
}
|
|
@ -141,7 +141,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
||||||
|
|
||||||
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
|
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
|
||||||
|
|
||||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
mid, setOptErr := middleware.New(args.middlewareOpt)
|
||||||
if setOptErr != nil {
|
if setOptErr != nil {
|
||||||
return nil, setOptErr
|
return nil, setOptErr
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Trace struct {
|
type (
|
||||||
|
Trace struct {
|
||||||
Time string `json:"time,omitempty"`
|
Time string `json:"time,omitempty"`
|
||||||
Caller string `json:"caller,omitempty"`
|
Caller string `json:"caller,omitempty"`
|
||||||
URL string `json:"url,omitempty"`
|
URL string `json:"url,omitempty"`
|
||||||
|
@ -19,9 +17,9 @@ type Trace struct {
|
||||||
RespHeaders map[string]string `json:"resp_headers,omitempty"`
|
RespHeaders map[string]string `json:"resp_headers,omitempty"`
|
||||||
RespStatus int `json:"resp_status,omitempty"`
|
RespStatus int `json:"resp_status,omitempty"`
|
||||||
Additional map[string]any `json:"additional,omitempty"`
|
Additional map[string]any `json:"additional,omitempty"`
|
||||||
}
|
}
|
||||||
|
Traces []*Trace
|
||||||
type Traces []*Trace
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
traces = make(Traces, 0)
|
traces = make(Traces, 0)
|
||||||
|
@ -34,7 +32,7 @@ func GetAllTrace() []*Trace {
|
||||||
return traces
|
return traces
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr *Trace) WithRequest(req *Request) *Trace {
|
func (tr *Trace) WithRequest(req *http.Request) *Trace {
|
||||||
if tr == nil {
|
if tr == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -78,39 +76,6 @@ func (tr *Trace) WithError(err error) *Trace {
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) EnableTrace() {
|
|
||||||
m.trace = true
|
|
||||||
for _, child := range m.children {
|
|
||||||
child.parent = m
|
|
||||||
child.EnableTrace()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
|
|
||||||
if !m.trace {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return addTrace(&Trace{
|
|
||||||
Time: strutils.FormatTime(time.Now()),
|
|
||||||
Caller: m.Fullname(),
|
|
||||||
Message: fmt.Sprintf(msg, args...),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
|
|
||||||
if !m.trace {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m.AddTracef("%s", msg).WithRequest(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
|
||||||
if !m.trace {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m.AddTracef("%s", msg).WithResponse(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addTrace(t *Trace) *Trace {
|
func addTrace(t *Trace) *Trace {
|
||||||
tracesMu.Lock()
|
tracesMu.Lock()
|
||||||
defer tracesMu.Unlock()
|
defer tracesMu.Unlock()
|
||||||
|
|
50
internal/net/http/middleware/tracer.go
Normal file
50
internal/net/http/middleware/tracer.go
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tracer struct {
|
||||||
|
name string
|
||||||
|
parent *Tracer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tracer) Fullname() string {
|
||||||
|
if t.parent != nil {
|
||||||
|
return t.parent.Fullname() + "." + t.name
|
||||||
|
}
|
||||||
|
return t.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tracer) addTrace(msg string) *Trace {
|
||||||
|
return addTrace(&Trace{
|
||||||
|
Time: strutils.FormatTime(time.Now()),
|
||||||
|
Caller: t.Fullname(),
|
||||||
|
Message: msg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
|
||||||
|
if t == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return t.addTrace(fmt.Sprintf(msg, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
|
||||||
|
if t == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return t.addTrace(msg).WithRequest(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||||
|
if t == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return t.addTrace(msg).WithResponse(resp)
|
||||||
|
}
|
|
@ -11,8 +11,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
reqVarGetter func(*Request) string
|
reqVarGetter func(*http.Request) string
|
||||||
respVarGetter func(*Response) string
|
respVarGetter func(*http.Response) string
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -49,50 +49,50 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var staticReqVarSubsMap = map[string]reqVarGetter{
|
var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||||
VarRequestMethod: func(req *Request) string { return req.Method },
|
VarRequestMethod: func(req *http.Request) string { return req.Method },
|
||||||
VarRequestScheme: func(req *Request) string {
|
VarRequestScheme: func(req *http.Request) string {
|
||||||
if req.TLS != nil {
|
if req.TLS != nil {
|
||||||
return "https"
|
return "https"
|
||||||
}
|
}
|
||||||
return "http"
|
return "http"
|
||||||
},
|
},
|
||||||
VarRequestHost: func(req *Request) string {
|
VarRequestHost: func(req *http.Request) string {
|
||||||
reqHost, _, err := net.SplitHostPort(req.Host)
|
reqHost, _, err := net.SplitHostPort(req.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return req.Host
|
return req.Host
|
||||||
}
|
}
|
||||||
return reqHost
|
return reqHost
|
||||||
},
|
},
|
||||||
VarRequestPort: func(req *Request) string {
|
VarRequestPort: func(req *http.Request) string {
|
||||||
_, reqPort, _ := net.SplitHostPort(req.Host)
|
_, reqPort, _ := net.SplitHostPort(req.Host)
|
||||||
return reqPort
|
return reqPort
|
||||||
},
|
},
|
||||||
VarRequestAddr: func(req *Request) string { return req.Host },
|
VarRequestAddr: func(req *http.Request) string { return req.Host },
|
||||||
VarRequestPath: func(req *Request) string { return req.URL.Path },
|
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
|
||||||
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery },
|
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
|
||||||
VarRequestURL: func(req *Request) string { return req.URL.String() },
|
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
|
||||||
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() },
|
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
|
||||||
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") },
|
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
|
||||||
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||||
VarRemoteHost: func(req *Request) string {
|
VarRemoteHost: func(req *http.Request) string {
|
||||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return clientIP
|
return clientIP
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
},
|
},
|
||||||
VarRemotePort: func(req *Request) string {
|
VarRemotePort: func(req *http.Request) string {
|
||||||
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
|
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return clientPort
|
return clientPort
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
},
|
},
|
||||||
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr },
|
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||||
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
||||||
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
||||||
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
||||||
VarUpstreamAddr: func(req *Request) string {
|
VarUpstreamAddr: func(req *http.Request) string {
|
||||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||||
if upPort != "" {
|
if upPort != "" {
|
||||||
|
@ -100,7 +100,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||||
}
|
}
|
||||||
return upHost
|
return upHost
|
||||||
},
|
},
|
||||||
VarUpstreamURL: func(req *Request) string {
|
VarUpstreamURL: func(req *http.Request) string {
|
||||||
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
||||||
if upScheme == "" {
|
if upScheme == "" {
|
||||||
return ""
|
return ""
|
||||||
|
@ -116,12 +116,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||||
}
|
}
|
||||||
|
|
||||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||||
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") },
|
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
|
||||||
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
|
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
|
||||||
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
|
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||||
}
|
}
|
||||||
|
|
||||||
func varReplace(req *Request, resp *Response, s string) string {
|
func varReplace(req *http.Request, resp *http.Response, s string) string {
|
||||||
if req != nil {
|
if req != nil {
|
||||||
// Replace query parameters
|
// Replace query parameters
|
||||||
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
|
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
|
||||||
|
|
|
@ -2,27 +2,44 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var SetXForwarded = &Middleware{
|
type (
|
||||||
before: Rewrite(func(req *Request) {
|
setXForwarded struct{}
|
||||||
req.Header.Del(gphttp.HeaderXForwardedFor)
|
hideXForwarded struct{}
|
||||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
SetXForwarded = NewMiddleware[setXForwarded]()
|
||||||
|
HideXForwarded = NewMiddleware[hideXForwarded]()
|
||||||
|
)
|
||||||
|
|
||||||
|
// before implements RequestModifier.
|
||||||
|
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
r.Header.Del(gphttp.HeaderXForwardedFor)
|
||||||
|
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
req.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
|
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
|
||||||
}
|
}
|
||||||
}),
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
var HideXForwarded = &Middleware{
|
// before implements RequestModifier.
|
||||||
before: Rewrite(func(req *Request) {
|
func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
for k := range req.Header {
|
toDelete := make([]string, 0, len(r.Header))
|
||||||
|
for k := range r.Header {
|
||||||
if strings.HasPrefix(k, "X-Forwarded-") {
|
if strings.HasPrefix(k, "X-Forwarded-") {
|
||||||
req.Header.Del(k)
|
toDelete = append(toDelete, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}),
|
|
||||||
|
for _, k := range toDelete {
|
||||||
|
r.Header.Del(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
package http
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
type ProxyResponse struct {
|
|
||||||
*http.Response
|
|
||||||
OriginalRequest *http.Request
|
|
||||||
}
|
|
|
@ -87,7 +87,7 @@ type ReverseProxy struct {
|
||||||
// If ModifyResponse returns an error, ErrorHandler is called
|
// If ModifyResponse returns an error, ErrorHandler is called
|
||||||
// with its error value. If ErrorHandler is nil, its default
|
// with its error value. If ErrorHandler is nil, its default
|
||||||
// implementation is used.
|
// implementation is used.
|
||||||
ModifyResponse func(*ProxyResponse) error
|
ModifyResponse func(*http.Response) error
|
||||||
|
|
||||||
HandlerFunc http.HandlerFunc
|
HandlerFunc http.HandlerFunc
|
||||||
|
|
||||||
|
@ -251,11 +251,14 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
|
||||||
|
|
||||||
// modifyResponse conditionally runs the optional ModifyResponse hook
|
// modifyResponse conditionally runs the optional ModifyResponse hook
|
||||||
// and reports whether the request should proceed.
|
// and reports whether the request should proceed.
|
||||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool {
|
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool {
|
||||||
if p.ModifyResponse == nil {
|
if p.ModifyResponse == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil {
|
res.Request = origReq
|
||||||
|
err := p.ModifyResponse(res)
|
||||||
|
res.Request = req
|
||||||
|
if err != nil {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
p.errorHandler(rw, req, err, true)
|
p.errorHandler(rw, req, err, true)
|
||||||
return false
|
return false
|
||||||
|
@ -264,9 +267,6 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
// req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme)
|
|
||||||
// req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname())
|
|
||||||
// req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port())
|
|
||||||
p.HandlerFunc(rw, req)
|
p.HandlerFunc(rw, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -455,13 +455,13 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
||||||
res = &http.Response{
|
res = &http.Response{
|
||||||
Status: http.StatusText(http.StatusBadGateway),
|
Status: http.StatusText(http.StatusBadGateway),
|
||||||
StatusCode: http.StatusBadGateway,
|
StatusCode: http.StatusBadGateway,
|
||||||
Proto: outreq.Proto,
|
Proto: req.Proto,
|
||||||
ProtoMajor: outreq.ProtoMajor,
|
ProtoMajor: req.ProtoMajor,
|
||||||
ProtoMinor: outreq.ProtoMinor,
|
ProtoMinor: req.ProtoMinor,
|
||||||
Header: http.Header{},
|
Header: http.Header{},
|
||||||
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
||||||
Request: outreq,
|
Request: req,
|
||||||
TLS: outreq.TLS,
|
TLS: req.TLS,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,26 +9,31 @@ import (
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTaskCreation(t *testing.T) {
|
const (
|
||||||
rootTask := GlobalTask("root-task")
|
rootTaskName = "root-task"
|
||||||
subTask := rootTask.Subtask("subtask")
|
subTaskName = "subtask"
|
||||||
|
)
|
||||||
|
|
||||||
ExpectEqual(t, "root-task", rootTask.Name())
|
func TestTaskCreation(t *testing.T) {
|
||||||
ExpectEqual(t, "subtask", subTask.Name())
|
rootTask := GlobalTask(rootTaskName)
|
||||||
|
subTask := rootTask.Subtask(subTaskName)
|
||||||
|
|
||||||
|
ExpectEqual(t, rootTaskName, rootTask.Name())
|
||||||
|
ExpectEqual(t, subTaskName, subTask.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTaskCancellation(t *testing.T) {
|
func TestTaskCancellation(t *testing.T) {
|
||||||
subTaskDone := make(chan struct{})
|
subTaskDone := make(chan struct{})
|
||||||
|
|
||||||
rootTask := GlobalTask("root-task")
|
rootTask := GlobalTask(rootTaskName)
|
||||||
subTask := rootTask.Subtask("subtask")
|
subTask := rootTask.Subtask(subTaskName)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
subTask.Wait()
|
subTask.Wait()
|
||||||
close(subTaskDone)
|
close(subTaskDone)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go rootTask.Finish("done")
|
go rootTask.Finish(nil)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-subTaskDone:
|
case <-subTaskDone:
|
||||||
|
@ -42,14 +47,14 @@ func TestTaskCancellation(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOnComplete(t *testing.T) {
|
func TestOnComplete(t *testing.T) {
|
||||||
rootTask := GlobalTask("root-task")
|
rootTask := GlobalTask(rootTaskName)
|
||||||
task := rootTask.Subtask("test")
|
task := rootTask.Subtask(subTaskName)
|
||||||
|
|
||||||
var value atomic.Int32
|
var value atomic.Int32
|
||||||
task.OnFinished("set value", func() {
|
task.OnFinished("set value", func() {
|
||||||
value.Store(1234)
|
value.Store(1234)
|
||||||
})
|
})
|
||||||
task.Finish("done")
|
task.Finish(nil)
|
||||||
ExpectEqual(t, value.Load(), 1234)
|
ExpectEqual(t, value.Load(), 1234)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,36 +62,36 @@ func TestGlobalContextWait(t *testing.T) {
|
||||||
testResetGlobalTask()
|
testResetGlobalTask()
|
||||||
defer CancelGlobalContext()
|
defer CancelGlobalContext()
|
||||||
|
|
||||||
rootTask := GlobalTask("root-task")
|
rootTask := GlobalTask(rootTaskName)
|
||||||
|
|
||||||
finished1, finished2 := false, false
|
finished1, finished2 := false, false
|
||||||
|
|
||||||
subTask1 := rootTask.Subtask("subtask1")
|
subTask1 := rootTask.Subtask(subTaskName)
|
||||||
subTask2 := rootTask.Subtask("subtask2")
|
subTask2 := rootTask.Subtask(subTaskName)
|
||||||
subTask1.OnFinished("set finished", func() {
|
subTask1.OnFinished("", func() {
|
||||||
finished1 = true
|
finished1 = true
|
||||||
})
|
})
|
||||||
subTask2.OnFinished("set finished", func() {
|
subTask2.OnFinished("", func() {
|
||||||
finished2 = true
|
finished2 = true
|
||||||
})
|
})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
subTask1.Finish("done")
|
subTask1.Finish(nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
subTask2.Finish("done")
|
subTask2.Finish(nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
subTask1.Wait()
|
subTask1.Wait()
|
||||||
subTask2.Wait()
|
subTask2.Wait()
|
||||||
rootTask.Finish("done")
|
rootTask.Finish(nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
GlobalContextWait(1 * time.Second)
|
_ = GlobalContextWait(1 * time.Second)
|
||||||
ExpectTrue(t, finished1)
|
ExpectTrue(t, finished1)
|
||||||
ExpectTrue(t, finished2)
|
ExpectTrue(t, finished2)
|
||||||
ExpectError(t, context.Canceled, rootTask.Context().Err())
|
ExpectError(t, context.Canceled, rootTask.Context().Err())
|
||||||
|
@ -97,8 +102,8 @@ func TestGlobalContextWait(t *testing.T) {
|
||||||
func TestTimeoutOnGlobalContextWait(t *testing.T) {
|
func TestTimeoutOnGlobalContextWait(t *testing.T) {
|
||||||
testResetGlobalTask()
|
testResetGlobalTask()
|
||||||
|
|
||||||
rootTask := GlobalTask("root-task")
|
rootTask := GlobalTask(rootTaskName)
|
||||||
rootTask.Subtask("subtask")
|
rootTask.Subtask(subTaskName)
|
||||||
|
|
||||||
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
|
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
|
||||||
}
|
}
|
||||||
|
@ -107,7 +112,7 @@ func TestGlobalContextCancellation(t *testing.T) {
|
||||||
testResetGlobalTask()
|
testResetGlobalTask()
|
||||||
|
|
||||||
taskDone := make(chan struct{})
|
taskDone := make(chan struct{})
|
||||||
rootTask := GlobalTask("root-task")
|
rootTask := GlobalTask(rootTaskName)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
rootTask.Wait()
|
rootTask.Wait()
|
||||||
|
|
|
@ -19,7 +19,8 @@ func TestSerializeDeserialize(t *testing.T) {
|
||||||
MIS map[int]string
|
MIS map[int]string
|
||||||
}
|
}
|
||||||
|
|
||||||
var testStruct = S{
|
var (
|
||||||
|
testStruct = S{
|
||||||
I: 1,
|
I: 1,
|
||||||
S: "hello",
|
S: "hello",
|
||||||
IS: []int{1, 2, 3},
|
IS: []int{1, 2, 3},
|
||||||
|
@ -27,8 +28,7 @@ func TestSerializeDeserialize(t *testing.T) {
|
||||||
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
|
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
|
||||||
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
|
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
|
||||||
}
|
}
|
||||||
|
testStructSerialized = map[string]any{
|
||||||
var testStructSerialized = map[string]any{
|
|
||||||
"I": 1,
|
"I": 1,
|
||||||
"S": "hello",
|
"S": "hello",
|
||||||
"IS": []int{1, 2, 3},
|
"IS": []int{1, 2, 3},
|
||||||
|
@ -36,6 +36,7 @@ func TestSerializeDeserialize(t *testing.T) {
|
||||||
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
|
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
|
||||||
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
|
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
t.Run("serialize", func(t *testing.T) {
|
t.Run("serialize", func(t *testing.T) {
|
||||||
s, err := Serialize(testStruct)
|
s, err := Serialize(testStruct)
|
||||||
|
|
Loading…
Add table
Reference in a new issue