support variables in modify request,response middlewares

This commit is contained in:
yusing 2024-12-03 10:20:18 +08:00
parent cd749ac6a4
commit ef1863f810
15 changed files with 291 additions and 114 deletions

View file

@ -12,6 +12,10 @@ const (
HeaderXForwardedPort = "X-Forwarded-Port"
HeaderXForwardedURI = "X-Forwarded-Uri"
HeaderXRealIP = "X-Real-IP"
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
)
func RemoveHop(h http.Header) {

View file

@ -164,7 +164,15 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
}
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
mid := BuildMiddlewareFromChain(rp.TargetName, middlewares)
mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{
name: "set_upstream_headers",
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname())
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port())
next(w, r)
},
}}, middlewares...))
if mid.before != nil {
ori := rp.HandlerFunc

View file

@ -19,12 +19,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]
eb.Add(err)
return nil
}
mids := BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
results := make(map[string]*Middleware, len(mids))
for k, v := range mids {
results[k+"@file"] = v
}
return results
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
}
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
@ -40,7 +35,7 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str
if err != nil {
eb.Add(err.Subject(source))
} else {
middlewares[name] = chain
middlewares[name+"@file"] = chain
}
}
return middlewares

View file

@ -38,17 +38,23 @@ func init() {
// snakes and cases will be stripped on `Get`
// so keys are lowercase without snake.
allMiddlewares = map[string]*Middleware{
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
"modifyresponse": ModifyResponse,
"modifyrequest": ModifyRequest,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"redirecthttp": RedirectHTTP,
"request": ModifyRequest,
"modifyrequest": ModifyRequest,
"response": ModifyResponse,
"modifyresponse": ModifyResponse,
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP,
"cloudflarerealip": CloudflareRealIP,
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
// !experimental
"forwardauth": ForwardAuth,

View file

@ -2,15 +2,16 @@ package middleware
import (
"net/http"
"strings"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
type (
modifyRequest struct {
modifyRequestOpts
m *Middleware
m *Middleware
needVarSubstitution bool
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
@ -24,38 +25,49 @@ var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyRequest)
mrFunc := mr.modifyRequest
if common.IsDebug {
mrFunc = mr.modifyRequestWithTrace
}
mr.m = &Middleware{
impl: mr,
before: Rewrite(mrFunc),
impl: mr,
before: Rewrite(func(req *Request) {
mr.m.AddTraceRequest("before modify request", req)
mr.modifyHeaders(req, nil, req.Header)
mr.m.AddTraceRequest("after modify request", req)
}),
}
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}
mr.checkVarSubstitution()
return mr.m, nil
}
func (mr *modifyRequest) modifyRequest(req *Request) {
for k, v := range mr.SetHeaders {
if http.CanonicalHeaderKey(k) == "Host" {
req.Host = v
func (mr *modifyRequest) checkVarSubstitution() {
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
for _, v := range m {
if strings.Contains(v, "$") {
mr.needVarSubstitution = true
return
}
}
req.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
req.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
req.Header.Del(k)
}
}
func (mr *modifyRequest) modifyRequestWithTrace(req *Request) {
mr.m.AddTraceRequest("before modify request", req)
mr.modifyRequest(req)
mr.m.AddTraceRequest("after modify request", req)
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
replaceVars := varReplacerDummy
if mr.needVarSubstitution {
replaceVars = varReplacer(req, resp)
}
for k, v := range mr.SetHeaders {
if strings.ToLower(k) == "host" {
req.Host = replaceVars(v)
}
headers.Set(k, replaceVars(v))
}
for k, v := range mr.AddHeaders {
headers.Add(k, replaceVars(v))
}
for _, k := range mr.HideHeaders {
headers.Del(k)
}
}

View file

@ -1,17 +1,39 @@
package middleware
import (
"bytes"
"net/http"
"slices"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyRequest(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{
"User-Agent": "go-proxy/v0.5.0",
"Host": "test.example.com",
"User-Agent": "go-proxy/v0.5.0",
"Host": "$upstream_addr",
"X-Test-Req-Method": "$req_method",
"X-Test-Req-Scheme": "$req_scheme",
"X-Test-Req-Host": "$req_host",
"X-Test-Req-Port": "$req_port",
"X-Test-Req-Addr": "$req_addr",
"X-Test-Req-Path": "$req_path",
"X-Test-Req-Query": "$req_query",
"X-Test-Req-Url": "$req_url",
"X-Test-Req-Uri": "$req_uri",
"X-Test-Req-Content-Type": "$req_content_type",
"X-Test-Req-Content-Length": "$req_content_length",
"X-Test-Remote-Addr": "$remote_addr",
"X-Test-Upstream-Scheme": "$upstream_scheme",
"X-Test-Upstream-Host": "$upstream_host",
"X-Test-Upstream-Port": "$upstream_port",
"X-Test-Upstream-Addr": "$upstream_addr",
"X-Test-Upstream-Url": "$upstream_url",
"X-Test-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
@ -26,13 +48,44 @@ func TestSetModifyRequest(t *testing.T) {
})
t.Run("request_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
body: bytes.Repeat([]byte("a"), 100),
headers: http.Header{
"Content-Type": []string{"application/json"},
},
})
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
})
}

View file

@ -3,52 +3,28 @@ package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
type (
modifyResponse struct {
modifyResponseOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts = modifyRequestOpts
)
type modifyResponse = modifyRequest
var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse)
mr.m = &Middleware{impl: mr}
if common.IsDebug {
mr.m.modifyResponse = mr.modifyResponseWithTrace
} else {
mr.m.modifyResponse = mr.modifyResponse
mr.m = &Middleware{
impl: mr,
modifyResponse: func(resp *http.Response) error {
mr.m.AddTraceResponse("before modify response", resp)
mr.modifyHeaders(resp.Request, resp, resp.Header)
mr.m.AddTraceResponse("after modify response", resp)
return nil
},
}
err := Deserialize(optsRaw, &mr.modifyResponseOpts)
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}
mr.checkVarSubstitution()
return mr.m, nil
}
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
for k, v := range mr.SetHeaders {
resp.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
resp.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
resp.Header.Del(k)
}
return nil
}
func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error {
mr.m.AddTraceResponse("before modify response", resp)
err := mr.modifyResponse(resp)
mr.m.AddTraceResponse("after modify response", resp)
return err
}

View file

@ -5,21 +5,22 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http",
reqURL: types.MustParseURL("http://example.com"),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com:"+common.ProxyHTTPSPort)
}
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https",
reqURL: types.MustParseURL("https://example.com"),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)

View file

@ -7,7 +7,6 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
@ -19,8 +18,6 @@ import (
var testHeadersRaw []byte
var testHeaders http.Header
const testHost = "example.com"
func init() {
if !common.IsTest {
return
@ -67,16 +64,16 @@ type TestResult struct {
type testArgs struct {
middlewareOpt OptionsRaw
proxyURL string
reqURL types.URL
upstreamURL types.URL
body []byte
scheme string
realRoundTrip bool
headers http.Header
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
var body io.Reader
var rr requestRecorder
var proxyURL *url.URL
var requestTarget string
var err error
if args == nil {
@ -87,34 +84,24 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
body = bytes.NewReader(args.body)
}
switch args.scheme {
case "":
fallthrough
case "http":
requestTarget = "http://" + testHost
case "https":
requestTarget = "https://" + testHost
default:
panic("typo?")
if args.reqURL.Nil() {
args.reqURL = E.Must(types.ParseURL("https://example.com"))
}
req := httptest.NewRequest(http.MethodGet, requestTarget, body)
req := httptest.NewRequest(http.MethodGet, args.reqURL.String(), body)
for k, v := range args.headers {
req.Header[k] = v
}
w := httptest.NewRecorder()
if args.scheme == "https" && req.TLS == nil {
panic("bug occurred")
if args.upstreamURL.Nil() {
args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
}
if args.proxyURL != "" {
proxyURL, err = url.Parse(args.proxyURL)
if err != nil {
return nil, E.From(err)
}
if args.realRoundTrip {
rr.parent = http.DefaultTransport
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gphttp.NewReverseProxy(middleware.name, types.NewURL(proxyURL), &rr)
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, &rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr

View file

@ -97,10 +97,16 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
}
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithRequest(req)
}
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithResponse(resp)
}

View file

@ -0,0 +1,106 @@
package middleware
import (
"net"
"net/http"
"regexp"
"strconv"
"strings"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
type varReplaceFunc func(string) string
var (
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
reHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
reStatic = regexp.MustCompile(`\$[\w_]+`)
)
func varSubsMap(req *Request, resp *Response) map[string]func() string {
reqHost, reqPort, err := net.SplitHostPort(req.Host)
if err != nil {
reqHost = req.Host
}
reqAddr := reqHost
if reqPort != "" {
reqAddr += ":" + reqPort
}
pairs := map[string]func() string{
"$req_method": func() string { return req.Method },
"$req_scheme": func() string { return req.URL.Scheme },
"$req_host": func() string { return reqHost },
"$req_port": func() string { return reqPort },
"$req_addr": func() string { return reqAddr },
"$req_path": func() string { return req.URL.Path },
"$req_query": func() string { return req.URL.RawQuery },
"$req_url": func() string { return req.URL.String() },
"$req_uri": req.URL.RequestURI,
"$req_content_type": func() string { return req.Header.Get("Content-Type") },
"$req_content_length": func() string { return strconv.FormatInt(req.ContentLength, 10) },
"$remote_addr": func() string { return req.RemoteAddr },
}
if resp != nil {
pairs["$resp_content_type"] = func() string { return resp.Header.Get("Content-Type") }
pairs["$resp_content_length"] = func() string { return resp.Header.Get("Content-Length") }
pairs["$status_code"] = func() string { return strconv.Itoa(resp.StatusCode) }
}
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
if upScheme == "" {
return pairs
}
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
upAddr := upHost
if upPort != "" {
upAddr += ":" + upPort
}
upURL := upScheme + "://" + upAddr
pairs["$upstream_scheme"] = func() string { return upScheme }
pairs["$upstream_host"] = func() string { return upHost }
pairs["$upstream_port"] = func() string { return upPort }
pairs["$upstream_addr"] = func() string { return upAddr }
pairs["$upstream_url"] = func() string { return upURL }
return pairs
}
func varReplacer(req *Request, resp *Response) varReplaceFunc {
pairs := varSubsMap(req, resp)
return func(s string) string {
// Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
name := match[5 : len(match)-1]
for k, v := range req.URL.Query() {
if strings.EqualFold(k, name) {
return v[0]
}
}
return ""
})
// Replace headers
s = reHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
return req.Header.Get(header)
})
// Replace static variables
return reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := pairs[match]; ok {
return fn()
}
return match
})
}
}
func varReplacerDummy(s string) string {
return s
}

View file

@ -264,6 +264,9 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme)
// req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname())
// req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port())
p.HandlerFunc(rw, req)
}

View file

@ -8,6 +8,14 @@ type URL struct {
*urlPkg.URL
}
func MustParseURL(url string) URL {
u, err := ParseURL(url)
if err != nil {
panic(err)
}
return u
}
func ParseURL(url string) (URL, error) {
u, err := urlPkg.Parse(url)
if err != nil {
@ -20,6 +28,10 @@ func NewURL(url *urlPkg.URL) URL {
return URL{url}
}
func (u URL) Nil() bool {
return u.URL == nil
}
func (u URL) String() string {
if u.URL == nil {
return "nil"

View file

@ -101,12 +101,21 @@ func TestWebhookBody(t *testing.T) {
})
ExpectNoError(t, err)
var body map[string][]map[string]any
var body struct {
Embeds []struct {
Title string `json:"title"`
Fields []struct {
Name string `json:"name"`
Value string `json:"value"`
} `json:"fields"`
} `json:"embeds"`
}
err = json.NewDecoder(bodyReader).Decode(&body)
ExpectNoError(t, err)
ExpectEqual(t, body["embeds"][0]["title"], "abc")
fields := ExpectType[[]map[string]any](t, body["embeds"][0]["fields"])
ExpectEqual(t, fields[0]["name"], "foo")
ExpectEqual(t, fields[0]["value"], "bar")
ExpectEqual(t, body.Embeds[0].Title, "abc")
fields := body.Embeds[0].Fields
ExpectEqual(t, fields[0].Name, "foo")
ExpectEqual(t, fields[0].Value, "bar")
}

View file

@ -109,8 +109,7 @@ func ExpectType[T any](t *testing.T, got any) (_ T) {
t.Helper()
_, ok := got.(T)
if !ok {
t.Fatalf("expected type %s, got %s", reflect.TypeFor[T](), reflect.TypeOf(got).Elem())
return
t.Fatalf("expected type %s, got %T", reflect.TypeFor[T](), got)
}
return got.(T)
}