mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
feat(middlewares): middleware bypass rules
This commit is contained in:
parent
ef95682116
commit
ddab2766b4
4 changed files with 221 additions and 19 deletions
53
internal/net/gphttp/middleware/bypass.go
Normal file
53
internal/net/gphttp/middleware/bypass.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/route/rules"
|
||||
)
|
||||
|
||||
type Bypass []rules.RuleOn
|
||||
|
||||
func (b Bypass) ShouldBypass(r *http.Request) bool {
|
||||
cached := rules.NewCache()
|
||||
defer cached.Release()
|
||||
for _, rule := range b {
|
||||
if rule.Check(cached, r) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type checkBypass struct {
|
||||
bypass Bypass
|
||||
modReq RequestModifier
|
||||
modRes ResponseModifier
|
||||
}
|
||||
|
||||
func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
|
||||
if c.modReq == nil || c.bypass.ShouldBypass(r) {
|
||||
return true
|
||||
}
|
||||
return c.modReq.before(w, r)
|
||||
}
|
||||
|
||||
func (c *checkBypass) modifyResponse(resp *http.Response) error {
|
||||
if c.modRes == nil || c.bypass.ShouldBypass(resp.Request) {
|
||||
return nil
|
||||
}
|
||||
return c.modRes.modifyResponse(resp)
|
||||
}
|
||||
|
||||
func (m *Middleware) withCheckBypass() any {
|
||||
if len(m.Bypass) > 0 {
|
||||
modReq, _ := m.impl.(RequestModifier)
|
||||
modRes, _ := m.impl.(ResponseModifier)
|
||||
return &checkBypass{
|
||||
bypass: m.Bypass,
|
||||
modReq: modReq,
|
||||
modRes: modRes,
|
||||
}
|
||||
}
|
||||
return m.impl
|
||||
}
|
131
internal/net/gphttp/middleware/bypass_test.go
Normal file
131
internal/net/gphttp/middleware/bypass_test.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
package middleware_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func noOpHandler(w http.ResponseWriter, r *http.Request) {}
|
||||
|
||||
func TestBypassCIDR(t *testing.T) {
|
||||
mr, err := ModifyRequest.New(map[string]any{
|
||||
"set_headers": map[string]string{
|
||||
"Test-Header": "test-value",
|
||||
},
|
||||
"bypass": []string{"remote 127.0.0.1/32"},
|
||||
})
|
||||
expect.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expectBypass bool
|
||||
}{
|
||||
{"bypass", "127.0.0.1:8080", true},
|
||||
{"no_bypass", "192.168.1.1:8080", false},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.RemoteAddr = test.remoteAddr
|
||||
recorder := httptest.NewRecorder()
|
||||
mr.ModifyRequest(noOpHandler, recorder, req)
|
||||
expect.NoError(t, err)
|
||||
if test.expectBypass {
|
||||
expect.Equal(t, req.Header.Get("Test-Header"), "")
|
||||
} else {
|
||||
expect.Equal(t, req.Header.Get("Test-Header"), "test-value")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBypassPath(t *testing.T) {
|
||||
mr, err := ModifyRequest.New(map[string]any{
|
||||
"bypass": []string{"path /test/*", "path /api"},
|
||||
"set_headers": map[string]string{
|
||||
"Test-Header": "test-value",
|
||||
},
|
||||
})
|
||||
expect.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expectBypass bool
|
||||
}{
|
||||
{"bypass", "/test/123", true},
|
||||
{"bypass2", "/test/123/456", true},
|
||||
{"bypass3", "/api", true},
|
||||
{"no_bypass", "/test1/123/456", false},
|
||||
{"no_bypass2", "/api/123", false},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
mr.ModifyRequest(noOpHandler, recorder, req)
|
||||
expect.NoError(t, err)
|
||||
if test.expectBypass {
|
||||
expect.Equal(t, req.Header.Get("Test-Header"), "")
|
||||
} else {
|
||||
expect.Equal(t, req.Header.Get("Test-Header"), "test-value")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct{}
|
||||
|
||||
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Request: req,
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestReverseProxyBypass(t *testing.T) {
|
||||
rp := reverseproxy.NewReverseProxy("test", types.MustParseURL("http://example.com"), fakeRoundTripper{})
|
||||
err := PatchReverseProxy(rp, map[string]OptionsRaw{
|
||||
"response": {
|
||||
"bypass": "path /test/* | path /api",
|
||||
"set_headers": map[string]string{
|
||||
"Test-Header": "test-value",
|
||||
},
|
||||
},
|
||||
})
|
||||
expect.NoError(t, err)
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expectBypass bool
|
||||
}{
|
||||
{"bypass", "/test/123", true},
|
||||
{"bypass2", "/test/123/456", true},
|
||||
{"bypass3", "/api", true},
|
||||
{"no_bypass", "/test1/123/456", false},
|
||||
{"no_bypass2", "/api/123", false},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
rp.ServeHTTP(recorder, req)
|
||||
if test.expectBypass {
|
||||
expect.Equal(t, recorder.Header().Get("Test-Header"), "")
|
||||
} else {
|
||||
expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ package middleware
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
|
@ -23,16 +24,22 @@ type (
|
|||
ImplNewFunc = func() any
|
||||
OptionsRaw = map[string]any
|
||||
|
||||
Middleware struct {
|
||||
name string
|
||||
construct ImplNewFunc
|
||||
impl any
|
||||
commonOptions = struct {
|
||||
// priority is only applied for ReverseProxy.
|
||||
//
|
||||
// Middleware compose follows the order of the slice
|
||||
//
|
||||
// Default is 10, 0 is the highest
|
||||
priority int
|
||||
Priority int `json:"priority"`
|
||||
Bypass Bypass `json:"bypass"`
|
||||
}
|
||||
|
||||
Middleware struct {
|
||||
name string
|
||||
construct ImplNewFunc
|
||||
impl any
|
||||
|
||||
commonOptions
|
||||
}
|
||||
ByPriority []*Middleware
|
||||
|
||||
|
@ -55,7 +62,7 @@ const DefaultPriority = 10
|
|||
|
||||
func (m ByPriority) Len() int { return len(m) }
|
||||
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
|
||||
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
|
||||
func (m ByPriority) Less(i, j int) bool { return m[i].Priority < m[j].Priority }
|
||||
|
||||
func NewMiddleware[ImplType any]() *Middleware {
|
||||
// type check
|
||||
|
@ -107,21 +114,22 @@ func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
|||
if len(optsRaw) == 0 {
|
||||
return nil
|
||||
}
|
||||
priority, ok := optsRaw["priority"].(int)
|
||||
if ok {
|
||||
m.priority = priority
|
||||
// remove priority for deserialization, restore later
|
||||
delete(optsRaw, "priority")
|
||||
defer func() {
|
||||
optsRaw["priority"] = priority
|
||||
}()
|
||||
} else {
|
||||
m.priority = DefaultPriority
|
||||
commonOpts := map[string]any{
|
||||
"priority": optsRaw["priority"],
|
||||
"bypass": optsRaw["bypass"],
|
||||
}
|
||||
if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
|
||||
return err
|
||||
}
|
||||
optsRaw = maps.Clone(optsRaw)
|
||||
for k := range commonOpts {
|
||||
delete(optsRaw, k)
|
||||
}
|
||||
return utils.MapUnmarshalValidate(optsRaw, m.impl)
|
||||
}
|
||||
|
||||
func (m *Middleware) finalize() error {
|
||||
m.impl = m.withCheckBypass()
|
||||
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
||||
finalizer.finalize()
|
||||
return nil
|
||||
|
@ -159,10 +167,16 @@ func (m *Middleware) String() string {
|
|||
}
|
||||
|
||||
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
||||
type allOptions struct {
|
||||
commonOptions
|
||||
any
|
||||
}
|
||||
return json.MarshalIndent(map[string]any{
|
||||
"name": m.name,
|
||||
"options": m.impl,
|
||||
"priority": m.priority,
|
||||
"name": m.name,
|
||||
"options": allOptions{
|
||||
commonOptions: m.commonOptions,
|
||||
any: m.impl,
|
||||
},
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,10 @@ type RuleOn struct {
|
|||
checker Checker
|
||||
}
|
||||
|
||||
func (on *RuleOn) Check(cached Cache, r *http.Request) bool {
|
||||
return on.checker.Check(cached, r)
|
||||
}
|
||||
|
||||
const (
|
||||
OnHeader = "header"
|
||||
OnQuery = "query"
|
||||
|
|
Loading…
Add table
Reference in a new issue