mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fixed middleware implementation, added middleware tracing for easier debug
This commit is contained in:
parent
d172552fb0
commit
ba13b81b0e
31 changed files with 561 additions and 196 deletions
15
README.md
15
README.md
|
@ -83,13 +83,14 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
|||
|
||||
### Commands line arguments
|
||||
|
||||
| Argument | Description | Example |
|
||||
| ----------- | -------------------------------- | -------------------------- |
|
||||
| empty | start proxy server | |
|
||||
| `validate` | validate config and exit | |
|
||||
| `reload` | trigger a force reload of config | |
|
||||
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
|
||||
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
|
||||
| Argument | Description | Example |
|
||||
| ----------------- | ---------------------------------------------------- | -------------------------------- |
|
||||
| empty | start proxy server | |
|
||||
| `validate` | validate config and exit | |
|
||||
| `reload` | trigger a force reload of config | |
|
||||
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
|
||||
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
|
||||
| `debug-ls-mtrace` | list middleware trace **(works only in debug mode)** | `go-proxy debug-ls-mtrace \| jq` |
|
||||
|
||||
**run with `docker exec go-proxy /app/go-proxy <command>`**
|
||||
|
||||
|
|
12
cmd/main.go
12
cmd/main.go
|
@ -18,7 +18,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal"
|
||||
"github.com/yusing/go-proxy/internal/api"
|
||||
apiUtils "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/query"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
|
@ -57,7 +57,7 @@ func main() {
|
|||
}
|
||||
|
||||
if args.Command == common.CommandReload {
|
||||
if err := apiUtils.ReloadServer(); err.HasError() {
|
||||
if err := query.ReloadServer(); err.HasError() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Print("ok")
|
||||
|
@ -93,7 +93,7 @@ func main() {
|
|||
printJSON(cfg.Value())
|
||||
return
|
||||
case common.CommandListRoutes:
|
||||
routes, err := apiUtils.ListRoutes()
|
||||
routes, err := query.ListRoutes()
|
||||
if err.HasError() {
|
||||
log.Printf("failed to connect to api server: %s", err)
|
||||
log.Printf("falling back to config file")
|
||||
|
@ -108,6 +108,12 @@ func main() {
|
|||
case common.CommandDebugListProviders:
|
||||
printJSON(cfg.DumpProviders())
|
||||
return
|
||||
case common.CommandDebugListMTrace:
|
||||
trace, err := query.ListMiddlewareTraces()
|
||||
if err.HasError() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(trace)
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
|
|
|
@ -25,10 +25,10 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
|||
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
|
||||
return
|
||||
case route.Type() == R.RouteTypeReverseProxy:
|
||||
ok = U.IsSiteHealthy(route.URL().String())
|
||||
ok = IsSiteHealthy(route.URL().String())
|
||||
case route.Type() == R.RouteTypeStream:
|
||||
entry := route.Entry()
|
||||
ok = U.IsStreamHealthy(
|
||||
ok = IsStreamHealthy(
|
||||
strings.Split(entry.Scheme, ":")[1], // target scheme
|
||||
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
|
||||
)
|
||||
|
|
|
@ -1,21 +1,22 @@
|
|||
package utils
|
||||
package v1
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
func IsSiteHealthy(url string) bool {
|
||||
// try HEAD first
|
||||
// if HEAD is not allowed, try GET
|
||||
resp, err := httpClient.Head(url)
|
||||
resp, err := U.Head(url)
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
|
||||
_, err = httpClient.Get(url)
|
||||
_, err = U.Get(url)
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
|
@ -8,19 +8,28 @@ import (
|
|||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
)
|
||||
|
||||
const (
|
||||
ListRoutes = "routes"
|
||||
ListConfigFiles = "config_files"
|
||||
ListMiddlewareTrace = "middleware_trace"
|
||||
)
|
||||
|
||||
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
what := r.PathValue("what")
|
||||
if what == "" {
|
||||
what = "routes"
|
||||
what = ListRoutes
|
||||
}
|
||||
|
||||
switch what {
|
||||
case "routes":
|
||||
case ListRoutes:
|
||||
listRoutes(cfg, w, r)
|
||||
case "config_files":
|
||||
case ListConfigFiles:
|
||||
listConfigFiles(w, r)
|
||||
case ListMiddlewareTrace:
|
||||
listMiddlewareTrace(w, r)
|
||||
default:
|
||||
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
|
||||
}
|
||||
|
@ -59,3 +68,12 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
|
||||
resp, err := json.Marshal(middleware.GetAllTrace())
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
return
|
||||
}
|
||||
w.Write(resp)
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package utils
|
||||
package query
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
@ -6,12 +6,15 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
|
||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
)
|
||||
|
||||
func ReloadServer() E.NestedError {
|
||||
resp, err := httpClient.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
|
||||
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
@ -32,7 +35,7 @@ func ReloadServer() E.NestedError {
|
|||
}
|
||||
|
||||
func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
||||
resp, err := httpClient.Get(fmt.Sprintf("%s/v1/list/routes", common.APIHTTPURL))
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes))
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
|
@ -47,3 +50,20 @@ func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
|||
}
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace))
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode)
|
||||
}
|
||||
var traces middleware.Traces
|
||||
err = json.NewDecoder(resp.Body).Decode(&traces)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
return traces, nil
|
||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
var httpClient = &http.Client{
|
||||
var HTTPClient = &http.Client{
|
||||
Timeout: common.ConnectionTimeout,
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
|
@ -21,3 +21,7 @@ var httpClient = &http.Client{
|
|||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
|
||||
var Get = HTTPClient.Get
|
||||
var Post = HTTPClient.Post
|
||||
var Head = HTTPClient.Head
|
||||
|
|
|
@ -2,9 +2,9 @@ package common
|
|||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type Args struct {
|
||||
|
@ -20,6 +20,7 @@ const (
|
|||
CommandReload = "reload"
|
||||
CommandDebugListEntries = "debug-ls-entries"
|
||||
CommandDebugListProviders = "debug-ls-providers"
|
||||
CommandDebugListMTrace = "debug-ls-mtrace"
|
||||
)
|
||||
|
||||
var ValidCommands = []string{
|
||||
|
@ -31,23 +32,24 @@ var ValidCommands = []string{
|
|||
CommandReload,
|
||||
CommandDebugListEntries,
|
||||
CommandDebugListProviders,
|
||||
CommandDebugListMTrace,
|
||||
}
|
||||
|
||||
func GetArgs() Args {
|
||||
var args Args
|
||||
flag.Parse()
|
||||
args.Command = flag.Arg(0)
|
||||
if err := validateArg(args.Command); err.HasError() {
|
||||
if err := validateArg(args.Command); err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func validateArg(arg string) E.NestedError {
|
||||
func validateArg(arg string) error {
|
||||
for _, v := range ValidCommands {
|
||||
if arg == v {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return E.Invalid("argument", arg)
|
||||
return fmt.Errorf("invalid command: %s", arg)
|
||||
}
|
||||
|
|
|
@ -4,14 +4,14 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", false)
|
||||
IsTest = GetEnvBool("GOPROXY_TEST", false)
|
||||
IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test")
|
||||
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
|
||||
|
||||
ProxyHTTPAddr,
|
||||
|
@ -35,7 +35,14 @@ func GetEnvBool(key string, defaultValue bool) bool {
|
|||
if !ok || value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return U.ParseBool(value)
|
||||
switch strings.ToLower(value) {
|
||||
case "true", "yes", "1":
|
||||
return true
|
||||
case "false", "no", "0":
|
||||
return false
|
||||
default:
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
func GetEnv(key, defaultValue string) string {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type cidrWhitelist struct {
|
||||
|
@ -19,7 +20,7 @@ type cidrWhitelistOpts struct {
|
|||
StatusCode int
|
||||
Message string
|
||||
|
||||
trustedAddr map[string]struct{} // cache for trusted IPs
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
|
||||
var CIDRWhiteList = &cidrWhitelist{
|
||||
|
@ -28,15 +29,16 @@ var CIDRWhiteList = &cidrWhitelist{
|
|||
"allow": D.YamlStringListParser,
|
||||
"statusCode": D.IntParser,
|
||||
},
|
||||
withOptions: NewCIDRWhitelist,
|
||||
},
|
||||
}
|
||||
|
||||
var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
|
||||
return &cidrWhitelistOpts{
|
||||
Allow: []*types.CIDR{},
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "IP not allowed",
|
||||
trustedAddr: make(map[string]struct{}),
|
||||
Allow: []*types.CIDR{},
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "IP not allowed",
|
||||
cachedAddr: F.NewMapOf[string, bool](),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,23 +59,32 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
|
|||
return wl.m, nil
|
||||
}
|
||||
|
||||
func (wl *cidrWhitelist) checkIP(next http.Handler, w ResponseWriter, r *Request) {
|
||||
var ok bool
|
||||
if _, ok = wl.trustedAddr[r.RemoteAddr]; !ok {
|
||||
ip := net.IP(r.RemoteAddr)
|
||||
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
var allow, ok bool
|
||||
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
|
||||
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ipStr = r.RemoteAddr
|
||||
}
|
||||
ip := net.ParseIP(ipStr)
|
||||
for _, cidr := range wl.cidrWhitelistOpts.Allow {
|
||||
if cidr.Contains(ip) {
|
||||
wl.trustedAddr[r.RemoteAddr] = struct{}{}
|
||||
ok = true
|
||||
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||
allow = true
|
||||
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
if !allow {
|
||||
w.WriteHeader(wl.StatusCode)
|
||||
w.Write([]byte(wl.Message))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
next(w, r)
|
||||
}
|
||||
|
|
42
internal/net/http/middleware/cidr_whitelist_test.go
Normal file
42
internal/net/http/middleware/cidr_whitelist_test.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
//go:embed test_data/cidr_whitelist_test.yml
|
||||
var testCIDRWhitelistCompose []byte
|
||||
var deny, accept *Middleware
|
||||
|
||||
func TestCIDRWhitelist(t *testing.T) {
|
||||
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
deny = mids["deny@file"]
|
||||
accept = mids["accept@file"]
|
||||
if deny == nil || accept == nil {
|
||||
panic("bug occurred")
|
||||
}
|
||||
|
||||
t.Run("deny", func(t *testing.T) {
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(deny, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
|
||||
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("accept", func(t *testing.T) {
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(accept, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -39,12 +39,13 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
|
|||
cri := new(realIP)
|
||||
cri.m = &Middleware{
|
||||
impl: cri,
|
||||
rewrite: func(r *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
cidrs := tryFetchCFCIDR()
|
||||
if cidrs != nil {
|
||||
cri.From = cidrs
|
||||
}
|
||||
cri.setRealIP(r)
|
||||
next(w, r)
|
||||
},
|
||||
}
|
||||
cri.realIPOpts = &realIPOpts{
|
||||
|
|
|
@ -15,9 +15,9 @@ import (
|
|||
)
|
||||
|
||||
var CustomErrorPage = &Middleware{
|
||||
before: func(next http.Handler, w ResponseWriter, r *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if !ServeStaticErrorPageFile(w, r) {
|
||||
next.ServeHTTP(w, r)
|
||||
next(w, r)
|
||||
}
|
||||
},
|
||||
modifyResponse: func(resp *Response) error {
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
|
||||
|
@ -45,7 +44,6 @@ var ForwardAuth = func() *forwardAuth {
|
|||
fa.m.withOptions = NewForwardAuthfunc
|
||||
return fa
|
||||
}()
|
||||
var faLogger = logrus.WithField("middleware", "ForwardAuth")
|
||||
|
||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
faWithOpts := new(forwardAuth)
|
||||
|
@ -80,7 +78,7 @@ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
|||
return faWithOpts.m, nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
|
||||
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
gpHTTP.RemoveHop(req.Header)
|
||||
|
||||
faReq, err := http.NewRequestWithContext(
|
||||
|
@ -90,7 +88,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
|||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("new request err to %s", fa.Address).With("error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -103,7 +101,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
|||
|
||||
faResp, err := fa.client.Do(faReq)
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("failed to call %s", fa.Address).With("error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -111,7 +109,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
|||
|
||||
body, err := io.ReadAll(faResp.Body)
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("failed to read response body from %s", fa.Address).With("error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -122,7 +120,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
|||
|
||||
redirectURL, err := faResp.Location()
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("failed to get location from %s", fa.Address).With("error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
} else if redirectURL.String() != "" {
|
||||
|
@ -132,7 +130,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
|||
w.WriteHeader(faResp.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("failed to write response body from %s", fa.Address).With("error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package middleware
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
|
@ -21,7 +22,7 @@ type (
|
|||
Header = http.Header
|
||||
Cookie = http.Cookie
|
||||
|
||||
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
|
||||
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
||||
RewriteFunc func(req *Request)
|
||||
ModifyResponseFunc func(resp *Response) error
|
||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
|
||||
|
@ -33,23 +34,38 @@ type (
|
|||
name string
|
||||
|
||||
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
||||
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
|
||||
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
|
||||
|
||||
transport http.RoundTripper
|
||||
|
||||
withOptions CloneWithOptFunc
|
||||
labelParserMap D.ValueParserMap
|
||||
impl any
|
||||
|
||||
parent *Middleware
|
||||
children []*Middleware
|
||||
trace bool
|
||||
}
|
||||
)
|
||||
|
||||
var Deserialize = U.Deserialize
|
||||
|
||||
func Rewrite(r RewriteFunc) BeforeFunc {
|
||||
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
r(req)
|
||||
next(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) Fullname() string {
|
||||
if m.parent != nil {
|
||||
return m.parent.Fullname() + "." + m.name
|
||||
}
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) String() string {
|
||||
return m.name
|
||||
}
|
||||
|
@ -72,14 +88,21 @@ func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Nested
|
|||
|
||||
// WithOptionsClone is called only once
|
||||
// set withOptions and labelParser will not be used after that
|
||||
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
|
||||
return &Middleware{
|
||||
m.name,
|
||||
m.before,
|
||||
m.modifyResponse,
|
||||
nil, nil,
|
||||
m.impl,
|
||||
m.parent,
|
||||
m.children,
|
||||
false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
|
||||
befores := make([]BeforeFunc, 0, len(middlewares))
|
||||
rewrites := make([]RewriteFunc, 0, len(middlewares))
|
||||
modResps := make([]ModifyResponseFunc, 0, len(middlewares))
|
||||
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
invalidM := E.NewBuilder("invalid middlewares")
|
||||
invalidOpts := E.NewBuilder("invalid options")
|
||||
|
@ -88,7 +111,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
|
|||
invalidM.To(&res)
|
||||
}()
|
||||
|
||||
for name, opts := range middlewares {
|
||||
for name, opts := range middlewaresMap {
|
||||
m, ok := Get(name)
|
||||
if !ok {
|
||||
invalidM.Add(E.NotExist("middleware", name))
|
||||
|
@ -100,56 +123,35 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
|
|||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
}
|
||||
if m.before != nil {
|
||||
befores = append(befores, m.before)
|
||||
}
|
||||
if m.rewrite != nil {
|
||||
rewrites = append(rewrites, m.rewrite)
|
||||
}
|
||||
if m.modifyResponse != nil {
|
||||
modResps = append(modResps, m.modifyResponse)
|
||||
}
|
||||
middlewares = append(middlewares, m)
|
||||
}
|
||||
|
||||
if invalidM.HasError() {
|
||||
return
|
||||
}
|
||||
|
||||
origServeHTTP := rp.ServeHTTP
|
||||
for i, before := range befores {
|
||||
if i < len(befores)-1 {
|
||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||
before(rp.ServeHTTP, w, r)
|
||||
}
|
||||
} else {
|
||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||
before(origServeHTTP, w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(rewrites) > 0 {
|
||||
origServeHTTP = rp.ServeHTTP
|
||||
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, rewrite := range rewrites {
|
||||
rewrite(r)
|
||||
}
|
||||
origServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modResps) > 0 {
|
||||
if rp.ModifyResponse != nil {
|
||||
modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...)
|
||||
}
|
||||
rp.ModifyResponse = func(res *Response) error {
|
||||
b := E.NewBuilder("errors in middleware ModifyResponse")
|
||||
for _, mr := range modResps {
|
||||
b.AddE(mr(res))
|
||||
}
|
||||
return b.Build().Error()
|
||||
}
|
||||
}
|
||||
|
||||
patchReverseProxy(rpName, rp, middlewares)
|
||||
return
|
||||
}
|
||||
|
||||
func patchReverseProxy(rpName string, rp *ReverseProxy, middlewares []*Middleware) {
|
||||
mid := BuildMiddlewareFromChain(rpName, middlewares)
|
||||
|
||||
if mid.before != nil {
|
||||
ori := rp.ServeHTTP
|
||||
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
|
||||
mid.before(ori, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
if mid.modifyResponse != nil {
|
||||
if rp.ModifyResponse != nil {
|
||||
ori := rp.ModifyResponse
|
||||
rp.ModifyResponse = func(res *http.Response) error {
|
||||
return errors.Join(mid.modifyResponse(res), ori(res))
|
||||
}
|
||||
} else {
|
||||
rp.ModifyResponse = mid.modifyResponse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
@ -23,7 +25,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
|||
var rawMap map[string][]map[string]any
|
||||
err := yaml.Unmarshal(data, &rawMap)
|
||||
if err != nil {
|
||||
b.Add(E.FailWith("toml unmarshal", err))
|
||||
b.Add(E.FailWith("yaml unmarshal", err))
|
||||
return
|
||||
}
|
||||
middlewares = make(map[string]*Middleware)
|
||||
|
@ -31,18 +33,22 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
|||
chainErr := E.NewBuilder(name)
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"].(string) == "" {
|
||||
chainErr.Add(E.Missing("use").Subjectf("%s.%d", name, i))
|
||||
if def["use"] == nil || def["use"] == "" {
|
||||
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
|
||||
continue
|
||||
}
|
||||
baseName := def["use"].(string)
|
||||
base, ok := Get(baseName)
|
||||
if !ok {
|
||||
chainErr.Add(E.NotExist("middleware", baseName).Subjectf("%s.%d", name, i))
|
||||
continue
|
||||
base, ok = middlewares[baseName]
|
||||
if !ok {
|
||||
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(def, "use")
|
||||
m, err := base.WithOptionsClone(def)
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("item%d", i))
|
||||
continue
|
||||
|
@ -52,8 +58,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
|||
if chainErr.HasError() {
|
||||
b.Add(chainErr.Build())
|
||||
} else {
|
||||
name = name + "@file"
|
||||
middlewares[name] = BuildMiddlewareFromChain(name, chain)
|
||||
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -61,47 +66,49 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
|||
|
||||
// TODO: check conflict or duplicates
|
||||
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
||||
var (
|
||||
befores []BeforeFunc
|
||||
rewrites []RewriteFunc
|
||||
modResps []ModifyResponseFunc
|
||||
)
|
||||
for _, m := range chain {
|
||||
if m.before != nil {
|
||||
befores = append(befores, m.before)
|
||||
m := &Middleware{name: name, children: chain}
|
||||
|
||||
var befores []*Middleware
|
||||
var modResps []*Middleware
|
||||
|
||||
for _, comp := range chain {
|
||||
if comp.before != nil {
|
||||
befores = append(befores, comp)
|
||||
}
|
||||
if m.rewrite != nil {
|
||||
rewrites = append(rewrites, m.rewrite)
|
||||
}
|
||||
if m.modifyResponse != nil {
|
||||
modResps = append(modResps, m.modifyResponse)
|
||||
if comp.modifyResponse != nil {
|
||||
modResps = append(modResps, comp)
|
||||
}
|
||||
comp.parent = m
|
||||
}
|
||||
|
||||
m := &Middleware{name: name}
|
||||
if len(befores) > 0 {
|
||||
m.before = func(next http.Handler, w ResponseWriter, r *Request) {
|
||||
for _, before := range befores {
|
||||
before(next, w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(rewrites) > 0 {
|
||||
m.rewrite = func(r *Request) {
|
||||
for _, rewrite := range rewrites {
|
||||
rewrite(r)
|
||||
}
|
||||
}
|
||||
m.before = buildBefores(befores)
|
||||
}
|
||||
if len(modResps) > 0 {
|
||||
m.modifyResponse = func(res *Response) error {
|
||||
b := E.NewBuilder("errors in middleware %s", name)
|
||||
b := E.NewBuilder("errors in middleware")
|
||||
for _, mr := range modResps {
|
||||
b.AddE(mr(res))
|
||||
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
|
||||
}
|
||||
return b.Build().Error()
|
||||
}
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
m.EnableTrace()
|
||||
m.AddTracef("middleware created")
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func buildBefores(befores []*Middleware) BeforeFunc {
|
||||
if len(befores) == 1 {
|
||||
return befores[0].before
|
||||
}
|
||||
nextBefores := buildBefores(befores[1:])
|
||||
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
befores[0].before(func(w ResponseWriter, r *Request) {
|
||||
nextBefores(next, w, r)
|
||||
}, w, r)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,10 +67,10 @@ func LoadComposeFiles() {
|
|||
b.Add(E.Duplicated("middleware", name))
|
||||
continue
|
||||
}
|
||||
middlewares[name] = m
|
||||
middlewares[U.ToLowerNoSnake(name)] = m
|
||||
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
|
||||
}
|
||||
b.Add(err.Subject(defFile))
|
||||
b.Add(err.Subject(path.Base(defFile)))
|
||||
}
|
||||
if b.HasError() {
|
||||
logger.Error(b.Build())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
@ -32,9 +33,15 @@ var ModifyRequest = func() *modifyRequest {
|
|||
|
||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
mr := new(modifyRequest)
|
||||
var mrFunc RewriteFunc
|
||||
if common.IsDebug {
|
||||
mrFunc = mr.modifyRequestWithTrace
|
||||
} else {
|
||||
mrFunc = mr.modifyRequest
|
||||
}
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
rewrite: mr.modifyRequest,
|
||||
impl: mr,
|
||||
before: Rewrite(mrFunc),
|
||||
}
|
||||
mr.modifyRequestOpts = new(modifyRequestOpts)
|
||||
err := Deserialize(optsRaw, mr.modifyRequestOpts)
|
||||
|
@ -55,3 +62,9 @@ func (mr *modifyRequest) modifyRequest(req *Request) {
|
|||
req.Header.Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) modifyRequestWithTrace(req *Request) {
|
||||
mr.m.AddTraceRequest("before modify request", req)
|
||||
mr.modifyRequest(req)
|
||||
mr.m.AddTraceRequest("after modify request", req)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package middleware
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
@ -34,9 +35,11 @@ var ModifyResponse = func() (mr *modifyResponse) {
|
|||
|
||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
mr := new(modifyResponse)
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
modifyResponse: mr.modifyResponse,
|
||||
mr.m = &Middleware{impl: mr}
|
||||
if common.IsDebug {
|
||||
mr.m.modifyResponse = mr.modifyResponseWithTrace
|
||||
} else {
|
||||
mr.m.modifyResponse = mr.modifyResponse
|
||||
}
|
||||
mr.modifyResponseOpts = new(modifyResponseOpts)
|
||||
err := Deserialize(optsRaw, mr.modifyResponseOpts)
|
||||
|
@ -58,3 +61,10 @@ func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error {
|
||||
mr.m.AddTraceResponse("before modify response", resp)
|
||||
err := mr.modifyResponse(resp)
|
||||
mr.m.AddTraceResponse("after modify response", resp)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,8 +2,8 @@ package middleware
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
|
@ -49,13 +49,14 @@ var realIPOptsDefault = func() *realIPOpts {
|
|||
}
|
||||
}
|
||||
|
||||
var realIPLogger = logrus.WithField("middleware", "RealIP")
|
||||
|
||||
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
riWithOpts := new(realIP)
|
||||
riWithOpts.m = &Middleware{
|
||||
impl: riWithOpts,
|
||||
rewrite: riWithOpts.setRealIP,
|
||||
impl: riWithOpts,
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
riWithOpts.setRealIP(r)
|
||||
next(w, r)
|
||||
},
|
||||
}
|
||||
riWithOpts.realIPOpts = realIPOptsDefault()
|
||||
err := Deserialize(opts, riWithOpts.realIPOpts)
|
||||
|
@ -78,7 +79,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
|||
func (ri *realIP) setRealIP(req *Request) {
|
||||
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
realIPLogger.Debugf("failed to split host port %s", err)
|
||||
clientIPStr = req.RemoteAddr
|
||||
}
|
||||
clientIP := net.ParseIP(clientIPStr)
|
||||
|
||||
|
@ -90,7 +91,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
|||
}
|
||||
}
|
||||
if !isTrusted {
|
||||
realIPLogger.Debugf("client ip %s is not trusted", clientIP)
|
||||
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -98,7 +99,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
|||
var lastNonTrustedIP string
|
||||
|
||||
if len(realIPs) == 0 {
|
||||
realIPLogger.Debugf("no real ip found in header %q", ri.Header)
|
||||
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -110,14 +111,16 @@ func (ri *realIP) setRealIP(req *Request) {
|
|||
lastNonTrustedIP = r
|
||||
}
|
||||
}
|
||||
if lastNonTrustedIP == "" {
|
||||
realIPLogger.Debugf("no non-trusted ip found in header %q", ri.Header)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if lastNonTrustedIP == "" {
|
||||
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
||||
return
|
||||
}
|
||||
|
||||
req.RemoteAddr = lastNonTrustedIP
|
||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||
req.Header.Set("X-Real-IP", lastNonTrustedIP)
|
||||
req.Header.Set("X-Forwarded-For", lastNonTrustedIP)
|
||||
req.Header.Set(xForwardedFor, lastNonTrustedIP)
|
||||
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||
}
|
||||
|
|
|
@ -2,13 +2,15 @@ package middleware
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetRealIP(t *testing.T) {
|
||||
func TestSetRealIPOpts(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"header": "X-Real-IP",
|
||||
"from": []string{
|
||||
|
@ -37,13 +39,39 @@ func TestSetRealIP(t *testing.T) {
|
|||
Recursive: true,
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
ri, err := RealIP.m.WithOptionsClone(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
|
||||
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
|
||||
})
|
||||
// TODO test
|
||||
ri, err := NewRealIP(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
for i, CIDR := range ri.impl.(*realIP).From {
|
||||
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRealIP(t *testing.T) {
|
||||
const (
|
||||
testHeader = "X-Real-IP"
|
||||
testRealIP = "192.168.1.1"
|
||||
)
|
||||
opts := OptionsRaw{
|
||||
"header": testHeader,
|
||||
"from": []string{"0.0.0.0/0"},
|
||||
}
|
||||
optsMr := OptionsRaw{
|
||||
"set_headers": map[string]string{testHeader: testRealIP},
|
||||
}
|
||||
realip, err := NewRealIP(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
|
||||
mr, err := NewModifyRequest(optsMr)
|
||||
ExpectNoError(t, err.Error())
|
||||
|
||||
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
|
||||
|
||||
result, err := newMiddlewareTest(mid, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
t.Log(traces)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
||||
ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP)
|
||||
}
|
||||
|
|
|
@ -7,13 +7,13 @@ import (
|
|||
)
|
||||
|
||||
var RedirectHTTP = &Middleware{
|
||||
before: func(next http.Handler, w ResponseWriter, r *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if r.TLS == nil {
|
||||
r.URL.Scheme = "https"
|
||||
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
next(w, r)
|
||||
},
|
||||
}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
deny:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 192.168.1.1:1234
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- 0.0.0.0/0
|
||||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
||||
accept:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 192.168.0.1:1234
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- 0.0.0.0/0
|
||||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
|
||||
)
|
||||
|
@ -20,6 +21,9 @@ var testHeaders http.Header
|
|||
const testHost = "example.com"
|
||||
|
||||
func init() {
|
||||
if !common.IsTest {
|
||||
return
|
||||
}
|
||||
tmp := map[string]string{}
|
||||
err := json.Unmarshal(testHeadersRaw, &tmp)
|
||||
if err != nil {
|
||||
|
@ -31,13 +35,15 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
type requestHeaderRecorder struct {
|
||||
type requestRecorder struct {
|
||||
parent http.RoundTripper
|
||||
reqHeaders http.Header
|
||||
headers http.Header
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rt.reqHeaders = req.Header
|
||||
func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rt.headers = req.Header
|
||||
rt.remoteAddr = req.RemoteAddr
|
||||
if rt.parent != nil {
|
||||
return rt.parent.RoundTrip(req)
|
||||
}
|
||||
|
@ -46,6 +52,7 @@ func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, e
|
|||
Header: testHeaders,
|
||||
Body: io.NopCloser(bytes.NewBufferString("OK")),
|
||||
Request: req,
|
||||
TLS: req.TLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -53,6 +60,7 @@ type TestResult struct {
|
|||
RequestHeaders http.Header
|
||||
ResponseHeaders http.Header
|
||||
ResponseStatus int
|
||||
RemoteAddr string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
|
@ -65,7 +73,7 @@ type testArgs struct {
|
|||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
|
||||
var body io.Reader
|
||||
var rt = new(requestHeaderRecorder)
|
||||
var rr = new(requestRecorder)
|
||||
var proxyURL *url.URL
|
||||
var requestTarget string
|
||||
var err error
|
||||
|
@ -98,17 +106,16 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
|||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
rt.parent = http.DefaultTransport
|
||||
rr.parent = http.DefaultTransport
|
||||
} else {
|
||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||
}
|
||||
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
|
||||
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
|
||||
middleware.name: args.middlewareOpt,
|
||||
})
|
||||
rp := gpHTTP.NewReverseProxy(proxyURL, rr)
|
||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
}
|
||||
patchReverseProxy(middleware.name, rp, []*Middleware{mid})
|
||||
rp.ServeHTTP(w, req)
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
@ -117,9 +124,10 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
|||
return nil, E.From(err)
|
||||
}
|
||||
return &TestResult{
|
||||
RequestHeaders: rt.reqHeaders,
|
||||
RequestHeaders: rr.headers,
|
||||
ResponseHeaders: resp.Header,
|
||||
ResponseStatus: resp.StatusCode,
|
||||
RemoteAddr: rr.remoteAddr,
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
|
99
internal/net/http/middleware/trace.go
Normal file
99
internal/net/http/middleware/trace.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type Trace struct {
|
||||
Time string `json:"time,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Message string `json:"msg"`
|
||||
ReqHeaders http.Header `json:"req_headers,omitempty"`
|
||||
RespHeaders http.Header `json:"resp_headers,omitempty"`
|
||||
Additional map[string]any `json:"additional,omitempty"`
|
||||
}
|
||||
|
||||
type Traces []*Trace
|
||||
|
||||
var traces = Traces{}
|
||||
var tracesMu sync.Mutex
|
||||
|
||||
const MaxTraceNum = 1000
|
||||
|
||||
func GetAllTrace() []*Trace {
|
||||
return traces
|
||||
}
|
||||
|
||||
func (tr *Trace) WithRequest(req *Request) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
tr.URL = req.RequestURI
|
||||
tr.ReqHeaders = req.Header.Clone()
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) WithResponse(resp *Response) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
tr.URL = resp.Request.RequestURI
|
||||
tr.ReqHeaders = resp.Request.Header.Clone()
|
||||
tr.RespHeaders = resp.Header.Clone()
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) With(what string, additional any) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tr.Additional == nil {
|
||||
tr.Additional = map[string]any{}
|
||||
}
|
||||
tr.Additional[what] = additional
|
||||
return tr
|
||||
}
|
||||
|
||||
func (m *Middleware) EnableTrace() {
|
||||
m.trace = true
|
||||
for _, child := range m.children {
|
||||
child.parent = m
|
||||
child.EnableTrace()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
|
||||
if !m.trace {
|
||||
return nil
|
||||
}
|
||||
return addTrace(&Trace{
|
||||
Time: U.FormatTime(time.Now()),
|
||||
Caller: m.Fullname(),
|
||||
Message: fmt.Sprintf(msg, args...),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
|
||||
return m.AddTracef("%s", msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
|
||||
return m.AddTracef("%s", msg).WithResponse(resp)
|
||||
}
|
||||
|
||||
func addTrace(t *Trace) *Trace {
|
||||
tracesMu.Lock()
|
||||
defer tracesMu.Unlock()
|
||||
if len(traces) > MaxTraceNum {
|
||||
traces = traces[1:]
|
||||
}
|
||||
traces = append(traces, t)
|
||||
return t
|
||||
}
|
|
@ -2,6 +2,7 @@ package middleware
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -14,7 +15,7 @@ const (
|
|||
)
|
||||
|
||||
var SetXForwarded = &Middleware{
|
||||
rewrite: func(req *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
req.Header.Del("Forwarded")
|
||||
req.Header.Del(xForwardedFor)
|
||||
req.Header.Del(xForwardedHost)
|
||||
|
@ -23,7 +24,7 @@ var SetXForwarded = &Middleware{
|
|||
if err == nil {
|
||||
req.Header.Set(xForwardedFor, clientIP)
|
||||
} else {
|
||||
req.Header.Del(xForwardedFor)
|
||||
req.Header.Set(xForwardedFor, req.RemoteAddr)
|
||||
}
|
||||
req.Header.Set(xForwardedHost, req.Host)
|
||||
if req.TLS == nil {
|
||||
|
@ -31,14 +32,16 @@ var SetXForwarded = &Middleware{
|
|||
} else {
|
||||
req.Header.Set(xForwardedProto, "https")
|
||||
}
|
||||
next(w, req)
|
||||
},
|
||||
}
|
||||
|
||||
var HideXForwarded = &Middleware{
|
||||
rewrite: func(req *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
req.Header.Del("Forwarded")
|
||||
req.Header.Del(xForwardedFor)
|
||||
req.Header.Del(xForwardedHost)
|
||||
req.Header.Del(xForwardedProto)
|
||||
next(w, req)
|
||||
},
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
|||
rp := NewReverseProxy(entry.URL, trans)
|
||||
|
||||
if len(entry.Middlewares) > 0 {
|
||||
err := middleware.PatchReverseProxy(rp, entry.Middlewares)
|
||||
err := middleware.PatchReverseProxy(string(entry.Alias), rp, entry.Middlewares)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -32,3 +32,7 @@ func (cidr *CIDR) Contains(ip net.IP) bool {
|
|||
func (cidr *CIDR) String() string {
|
||||
return (*net.IPNet)(cidr).String()
|
||||
}
|
||||
|
||||
func (cidr *CIDR) Equals(other *CIDR) bool {
|
||||
return (*net.IPNet)(cidr).IP.Equal(other.IP) && cidr.Mask.String() == other.Mask.String()
|
||||
}
|
||||
|
|
|
@ -42,6 +42,10 @@ func FormatDuration(d time.Duration) string {
|
|||
return strings.Join(parts[:len(parts)-1], ", ") + " and " + parts[len(parts)-1]
|
||||
}
|
||||
|
||||
func FormatTime(t time.Time) string {
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
func ParseBool(s string) bool {
|
||||
switch strings.ToLower(s) {
|
||||
case "1", "true", "yes", "on":
|
||||
|
|
|
@ -1,19 +1,25 @@
|
|||
package functional
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Slice[T any] struct {
|
||||
s []T
|
||||
s []T
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewSlice[T any]() *Slice[T] {
|
||||
return &Slice[T]{make([]T, 0)}
|
||||
return &Slice[T]{s: make([]T, 0)}
|
||||
}
|
||||
|
||||
func NewSliceN[T any](n int) *Slice[T] {
|
||||
return &Slice[T]{make([]T, n)}
|
||||
return &Slice[T]{s: make([]T, n)}
|
||||
}
|
||||
|
||||
func NewSliceFrom[T any](s []T) *Slice[T] {
|
||||
return &Slice[T]{s}
|
||||
return &Slice[T]{s: s}
|
||||
}
|
||||
|
||||
func (s *Slice[T]) Size() int {
|
||||
|
@ -46,6 +52,30 @@ func (s *Slice[T]) AddRange(other *Slice[T]) *Slice[T] {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *Slice[T]) SafeAdd(e T) *Slice[T] {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.Add(e)
|
||||
}
|
||||
|
||||
func (s *Slice[T]) SafeAddRange(other *Slice[T]) *Slice[T] {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.AddRange(other)
|
||||
}
|
||||
|
||||
func (s *Slice[T]) Pop() T {
|
||||
v := s.s[len(s.s)-1]
|
||||
s.s = s.s[:len(s.s)-1]
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *Slice[T]) SafePop() T {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.Pop()
|
||||
}
|
||||
|
||||
func (s *Slice[T]) ForEach(do func(T)) {
|
||||
for _, v := range s.s {
|
||||
do(v)
|
||||
|
@ -57,7 +87,7 @@ func (s *Slice[T]) Map(m func(T) T) *Slice[T] {
|
|||
for i, v := range s.s {
|
||||
n[i] = m(v)
|
||||
}
|
||||
return &Slice[T]{n}
|
||||
return &Slice[T]{s: n}
|
||||
}
|
||||
|
||||
func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] {
|
||||
|
@ -67,5 +97,13 @@ func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] {
|
|||
n = append(n, v)
|
||||
}
|
||||
}
|
||||
return &Slice[T]{n}
|
||||
return &Slice[T]{s: n}
|
||||
}
|
||||
|
||||
func (s *Slice[T]) String() string {
|
||||
out, err := json.MarshalIndent(s.s, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
|
|
@ -2,10 +2,23 @@ package utils
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if common.IsTest {
|
||||
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
func IgnoreError[Result any](r Result, _ error) Result {
|
||||
return r
|
||||
}
|
||||
|
||||
func ExpectNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil && !reflect.ValueOf(err).IsNil() {
|
||||
|
|
Loading…
Add table
Reference in a new issue