fixed middleware implementation, added middleware tracing for easier debug

This commit is contained in:
yusing 2024-10-02 13:55:41 +08:00
parent d172552fb0
commit ba13b81b0e
31 changed files with 561 additions and 196 deletions

View file

@ -83,13 +83,14 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
### Commands line arguments
| Argument | Description | Example |
| ----------- | -------------------------------- | -------------------------- |
| empty | start proxy server | |
| `validate` | validate config and exit | |
| `reload` | trigger a force reload of config | |
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
| Argument | Description | Example |
| ----------------- | ---------------------------------------------------- | -------------------------------- |
| empty | start proxy server | |
| `validate` | validate config and exit | |
| `reload` | trigger a force reload of config | |
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
| `debug-ls-mtrace` | list middleware trace **(works only in debug mode)** | `go-proxy debug-ls-mtrace \| jq` |
**run with `docker exec go-proxy /app/go-proxy <command>`**

View file

@ -18,7 +18,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal"
"github.com/yusing/go-proxy/internal/api"
apiUtils "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/api/v1/query"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/docker"
@ -57,7 +57,7 @@ func main() {
}
if args.Command == common.CommandReload {
if err := apiUtils.ReloadServer(); err.HasError() {
if err := query.ReloadServer(); err.HasError() {
log.Fatal(err)
}
log.Print("ok")
@ -93,7 +93,7 @@ func main() {
printJSON(cfg.Value())
return
case common.CommandListRoutes:
routes, err := apiUtils.ListRoutes()
routes, err := query.ListRoutes()
if err.HasError() {
log.Printf("failed to connect to api server: %s", err)
log.Printf("falling back to config file")
@ -108,6 +108,12 @@ func main() {
case common.CommandDebugListProviders:
printJSON(cfg.DumpProviders())
return
case common.CommandDebugListMTrace:
trace, err := query.ListMiddlewareTraces()
if err.HasError() {
log.Fatal(err)
}
printJSON(trace)
}
if common.IsDebug {

View file

@ -25,10 +25,10 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
return
case route.Type() == R.RouteTypeReverseProxy:
ok = U.IsSiteHealthy(route.URL().String())
ok = IsSiteHealthy(route.URL().String())
case route.Type() == R.RouteTypeStream:
entry := route.Entry()
ok = U.IsStreamHealthy(
ok = IsStreamHealthy(
strings.Split(entry.Scheme, ":")[1], // target scheme
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
)

View file

@ -1,21 +1,22 @@
package utils
package v1
import (
"net"
"net/http"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
)
func IsSiteHealthy(url string) bool {
// try HEAD first
// if HEAD is not allowed, try GET
resp, err := httpClient.Head(url)
resp, err := U.Head(url)
if resp != nil {
resp.Body.Close()
}
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
_, err = httpClient.Get(url)
_, err = U.Get(url)
}
if resp != nil {
resp.Body.Close()

View file

@ -8,19 +8,28 @@ import (
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
const (
ListRoutes = "routes"
ListConfigFiles = "config_files"
ListMiddlewareTrace = "middleware_trace"
)
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
what := r.PathValue("what")
if what == "" {
what = "routes"
what = ListRoutes
}
switch what {
case "routes":
case ListRoutes:
listRoutes(cfg, w, r)
case "config_files":
case ListConfigFiles:
listConfigFiles(w, r)
case ListMiddlewareTrace:
listMiddlewareTrace(w, r)
default:
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
}
@ -59,3 +68,12 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
}
w.Write(resp)
}
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
resp, err := json.Marshal(middleware.GetAllTrace())
if err != nil {
U.HandleErr(w, r, err)
return
}
w.Write(resp)
}

View file

@ -1,4 +1,4 @@
package utils
package query
import (
"encoding/json"
@ -6,12 +6,15 @@ import (
"io"
"net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
func ReloadServer() E.NestedError {
resp, err := httpClient.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
if err != nil {
return E.From(err)
}
@ -32,7 +35,7 @@ func ReloadServer() E.NestedError {
}
func ListRoutes() (map[string]map[string]any, E.NestedError) {
resp, err := httpClient.Get(fmt.Sprintf("%s/v1/list/routes", common.APIHTTPURL))
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes))
if err != nil {
return nil, E.From(err)
}
@ -47,3 +50,20 @@ func ListRoutes() (map[string]map[string]any, E.NestedError) {
}
return routes, nil
}
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace))
if err != nil {
return nil, E.From(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode)
}
var traces middleware.Traces
err = json.NewDecoder(resp.Body).Decode(&traces)
if err != nil {
return nil, E.From(err)
}
return traces, nil
}

View file

@ -8,7 +8,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
)
var httpClient = &http.Client{
var HTTPClient = &http.Client{
Timeout: common.ConnectionTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
@ -21,3 +21,7 @@ var httpClient = &http.Client{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
var Get = HTTPClient.Get
var Post = HTTPClient.Post
var Head = HTTPClient.Head

View file

@ -2,9 +2,9 @@ package common
import (
"flag"
"fmt"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
)
type Args struct {
@ -20,6 +20,7 @@ const (
CommandReload = "reload"
CommandDebugListEntries = "debug-ls-entries"
CommandDebugListProviders = "debug-ls-providers"
CommandDebugListMTrace = "debug-ls-mtrace"
)
var ValidCommands = []string{
@ -31,23 +32,24 @@ var ValidCommands = []string{
CommandReload,
CommandDebugListEntries,
CommandDebugListProviders,
CommandDebugListMTrace,
}
func GetArgs() Args {
var args Args
flag.Parse()
args.Command = flag.Arg(0)
if err := validateArg(args.Command); err.HasError() {
if err := validateArg(args.Command); err != nil {
logrus.Fatal(err)
}
return args
}
func validateArg(arg string) E.NestedError {
func validateArg(arg string) error {
for _, v := range ValidCommands {
if arg == v {
return nil
}
}
return E.Invalid("argument", arg)
return fmt.Errorf("invalid command: %s", arg)
}

View file

@ -4,14 +4,14 @@ import (
"fmt"
"net"
"os"
"strings"
"github.com/sirupsen/logrus"
U "github.com/yusing/go-proxy/internal/utils"
)
var (
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", false)
IsTest = GetEnvBool("GOPROXY_TEST", false)
IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test")
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
ProxyHTTPAddr,
@ -35,7 +35,14 @@ func GetEnvBool(key string, defaultValue bool) bool {
if !ok || value == "" {
return defaultValue
}
return U.ParseBool(value)
switch strings.ToLower(value) {
case "true", "yes", "1":
return true
case "false", "no", "0":
return false
default:
return defaultValue
}
}
func GetEnv(key, defaultValue string) string {

View file

@ -7,6 +7,7 @@ import (
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type cidrWhitelist struct {
@ -19,7 +20,7 @@ type cidrWhitelistOpts struct {
StatusCode int
Message string
trustedAddr map[string]struct{} // cache for trusted IPs
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
var CIDRWhiteList = &cidrWhitelist{
@ -28,15 +29,16 @@ var CIDRWhiteList = &cidrWhitelist{
"allow": D.YamlStringListParser,
"statusCode": D.IntParser,
},
withOptions: NewCIDRWhitelist,
},
}
var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
return &cidrWhitelistOpts{
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
trustedAddr: make(map[string]struct{}),
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
cachedAddr: F.NewMapOf[string, bool](),
}
}
@ -57,23 +59,32 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
return wl.m, nil
}
func (wl *cidrWhitelist) checkIP(next http.Handler, w ResponseWriter, r *Request) {
var ok bool
if _, ok = wl.trustedAddr[r.RemoteAddr]; !ok {
ip := net.IP(r.RemoteAddr)
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
var allow, ok bool
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.cidrWhitelistOpts.Allow {
if cidr.Contains(ip) {
wl.trustedAddr[r.RemoteAddr] = struct{}{}
ok = true
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
break
}
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
}
}
if !ok {
if !allow {
w.WriteHeader(wl.StatusCode)
w.Write([]byte(wl.Message))
return
}
next.ServeHTTP(w, r)
next(w, r)
}

View file

@ -0,0 +1,42 @@
package middleware
import (
_ "embed"
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/cidr_whitelist_test.yml
var testCIDRWhitelistCompose []byte
var deny, accept *Middleware
func TestCIDRWhitelist(t *testing.T) {
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
if err != nil {
panic(err)
}
deny = mids["deny@file"]
accept = mids["accept@file"]
if deny == nil || accept == nil {
panic("bug occurred")
}
t.Run("deny", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
}
})
t.Run("accept", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
})
}

View file

@ -39,12 +39,13 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
cri := new(realIP)
cri.m = &Middleware{
impl: cri,
rewrite: func(r *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.From = cidrs
}
cri.setRealIP(r)
next(w, r)
},
}
cri.realIPOpts = &realIPOpts{

View file

@ -15,9 +15,9 @@ import (
)
var CustomErrorPage = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next.ServeHTTP(w, r)
next(w, r)
}
},
modifyResponse: func(resp *Response) error {

View file

@ -13,7 +13,6 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
@ -45,7 +44,6 @@ var ForwardAuth = func() *forwardAuth {
fa.m.withOptions = NewForwardAuthfunc
return fa
}()
var faLogger = logrus.WithField("middleware", "ForwardAuth")
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
faWithOpts := new(forwardAuth)
@ -80,7 +78,7 @@ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
return faWithOpts.m, nil
}
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
gpHTTP.RemoveHop(req.Header)
faReq, err := http.NewRequestWithContext(
@ -90,7 +88,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
nil,
)
if err != nil {
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
fa.m.AddTracef("new request err to %s", fa.Address).With("error", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -103,7 +101,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
faResp, err := fa.client.Do(faReq)
if err != nil {
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
fa.m.AddTracef("failed to call %s", fa.Address).With("error", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -111,7 +109,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
body, err := io.ReadAll(faResp.Body)
if err != nil {
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
fa.m.AddTracef("failed to read response body from %s", fa.Address).With("error", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -122,7 +120,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
redirectURL, err := faResp.Location()
if err != nil {
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
fa.m.AddTracef("failed to get location from %s", fa.Address).With("error", err)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
@ -132,7 +130,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
fa.m.AddTracef("failed to write response body from %s", fa.Address).With("error", err)
}
return
}

View file

@ -2,6 +2,7 @@ package middleware
import (
"encoding/json"
"errors"
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
@ -21,7 +22,7 @@ type (
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(resp *Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
@ -33,23 +34,38 @@ type (
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
transport http.RoundTripper
withOptions CloneWithOptFunc
labelParserMap D.ValueParserMap
impl any
parent *Middleware
children []*Middleware
trace bool
}
)
var Deserialize = U.Deserialize
func Rewrite(r RewriteFunc) BeforeFunc {
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
r(req)
next(w, req)
}
}
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) Fullname() string {
if m.parent != nil {
return m.parent.Fullname() + "." + m.name
}
return m.name
}
func (m *Middleware) String() string {
return m.name
}
@ -72,14 +88,21 @@ func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Nested
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
return &Middleware{
m.name,
m.before,
m.modifyResponse,
nil, nil,
m.impl,
m.parent,
m.children,
false,
}, nil
}
// TODO: check conflict or duplicates
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
befores := make([]BeforeFunc, 0, len(middlewares))
rewrites := make([]RewriteFunc, 0, len(middlewares))
modResps := make([]ModifyResponseFunc, 0, len(middlewares))
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
@ -88,7 +111,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
invalidM.To(&res)
}()
for name, opts := range middlewares {
for name, opts := range middlewaresMap {
m, ok := Get(name)
if !ok {
invalidM.Add(E.NotExist("middleware", name))
@ -100,56 +123,35 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
invalidOpts.Add(err.Subject(name))
continue
}
if m.before != nil {
befores = append(befores, m.before)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modResps = append(modResps, m.modifyResponse)
}
middlewares = append(middlewares, m)
}
if invalidM.HasError() {
return
}
origServeHTTP := rp.ServeHTTP
for i, before := range befores {
if i < len(befores)-1 {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(rp.ServeHTTP, w, r)
}
} else {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(origServeHTTP, w, r)
}
}
}
if len(rewrites) > 0 {
origServeHTTP = rp.ServeHTTP
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
for _, rewrite := range rewrites {
rewrite(r)
}
origServeHTTP(w, r)
}
}
if len(modResps) > 0 {
if rp.ModifyResponse != nil {
modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...)
}
rp.ModifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware ModifyResponse")
for _, mr := range modResps {
b.AddE(mr(res))
}
return b.Build().Error()
}
}
patchReverseProxy(rpName, rp, middlewares)
return
}
func patchReverseProxy(rpName string, rp *ReverseProxy, middlewares []*Middleware) {
mid := BuildMiddlewareFromChain(rpName, middlewares)
if mid.before != nil {
ori := rp.ServeHTTP
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
mid.before(ori, w, r)
}
}
if mid.modifyResponse != nil {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *http.Response) error {
return errors.Join(mid.modifyResponse(res), ori(res))
}
} else {
rp.ModifyResponse = mid.modifyResponse
}
}
}

View file

@ -1,9 +1,11 @@
package middleware
import (
"fmt"
"net/http"
"os"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
@ -23,7 +25,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("toml unmarshal", err))
b.Add(E.FailWith("yaml unmarshal", err))
return
}
middlewares = make(map[string]*Middleware)
@ -31,18 +33,22 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
chainErr := E.NewBuilder(name)
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"].(string) == "" {
chainErr.Add(E.Missing("use").Subjectf("%s.%d", name, i))
if def["use"] == nil || def["use"] == "" {
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
continue
}
baseName := def["use"].(string)
base, ok := Get(baseName)
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf("%s.%d", name, i))
continue
base, ok = middlewares[baseName]
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
continue
}
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
m.name = fmt.Sprintf("%s[%d]", name, i)
if err != nil {
chainErr.Add(err.Subjectf("item%d", i))
continue
@ -52,8 +58,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
if chainErr.HasError() {
b.Add(chainErr.Build())
} else {
name = name + "@file"
middlewares[name] = BuildMiddlewareFromChain(name, chain)
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
}
}
return
@ -61,47 +66,49 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
// TODO: check conflict or duplicates
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
var (
befores []BeforeFunc
rewrites []RewriteFunc
modResps []ModifyResponseFunc
)
for _, m := range chain {
if m.before != nil {
befores = append(befores, m.before)
m := &Middleware{name: name, children: chain}
var befores []*Middleware
var modResps []*Middleware
for _, comp := range chain {
if comp.before != nil {
befores = append(befores, comp)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modResps = append(modResps, m.modifyResponse)
if comp.modifyResponse != nil {
modResps = append(modResps, comp)
}
comp.parent = m
}
m := &Middleware{name: name}
if len(befores) > 0 {
m.before = func(next http.Handler, w ResponseWriter, r *Request) {
for _, before := range befores {
before(next, w, r)
}
}
}
if len(rewrites) > 0 {
m.rewrite = func(r *Request) {
for _, rewrite := range rewrites {
rewrite(r)
}
}
m.before = buildBefores(befores)
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware %s", name)
b := E.NewBuilder("errors in middleware")
for _, mr := range modResps {
b.AddE(mr(res))
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
}
return b.Build().Error()
}
}
if common.IsDebug {
m.EnableTrace()
m.AddTracef("middleware created")
}
return m
}
func buildBefores(befores []*Middleware) BeforeFunc {
if len(befores) == 1 {
return befores[0].before
}
nextBefores := buildBefores(befores[1:])
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
befores[0].before(func(w ResponseWriter, r *Request) {
nextBefores(next, w, r)
}, w, r)
}
}

View file

@ -67,10 +67,10 @@ func LoadComposeFiles() {
b.Add(E.Duplicated("middleware", name))
continue
}
middlewares[name] = m
middlewares[U.ToLowerNoSnake(name)] = m
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
}
b.Add(err.Subject(defFile))
b.Add(err.Subject(path.Base(defFile)))
}
if b.HasError() {
logger.Error(b.Build())

View file

@ -1,6 +1,7 @@
package middleware
import (
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
@ -32,9 +33,15 @@ var ModifyRequest = func() *modifyRequest {
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
mr := new(modifyRequest)
var mrFunc RewriteFunc
if common.IsDebug {
mrFunc = mr.modifyRequestWithTrace
} else {
mrFunc = mr.modifyRequest
}
mr.m = &Middleware{
impl: mr,
rewrite: mr.modifyRequest,
impl: mr,
before: Rewrite(mrFunc),
}
mr.modifyRequestOpts = new(modifyRequestOpts)
err := Deserialize(optsRaw, mr.modifyRequestOpts)
@ -55,3 +62,9 @@ func (mr *modifyRequest) modifyRequest(req *Request) {
req.Header.Del(k)
}
}
func (mr *modifyRequest) modifyRequestWithTrace(req *Request) {
mr.m.AddTraceRequest("before modify request", req)
mr.modifyRequest(req)
mr.m.AddTraceRequest("after modify request", req)
}

View file

@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
@ -34,9 +35,11 @@ var ModifyResponse = func() (mr *modifyResponse) {
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
mr := new(modifyResponse)
mr.m = &Middleware{
impl: mr,
modifyResponse: mr.modifyResponse,
mr.m = &Middleware{impl: mr}
if common.IsDebug {
mr.m.modifyResponse = mr.modifyResponseWithTrace
} else {
mr.m.modifyResponse = mr.modifyResponse
}
mr.modifyResponseOpts = new(modifyResponseOpts)
err := Deserialize(optsRaw, mr.modifyResponseOpts)
@ -58,3 +61,10 @@ func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
}
return nil
}
func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error {
mr.m.AddTraceResponse("before modify response", resp)
err := mr.modifyResponse(resp)
mr.m.AddTraceResponse("after modify response", resp)
return err
}

View file

@ -2,8 +2,8 @@ package middleware
import (
"net"
"net/http"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
@ -49,13 +49,14 @@ var realIPOptsDefault = func() *realIPOpts {
}
}
var realIPLogger = logrus.WithField("middleware", "RealIP")
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
riWithOpts := new(realIP)
riWithOpts.m = &Middleware{
impl: riWithOpts,
rewrite: riWithOpts.setRealIP,
impl: riWithOpts,
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
riWithOpts.setRealIP(r)
next(w, r)
},
}
riWithOpts.realIPOpts = realIPOptsDefault()
err := Deserialize(opts, riWithOpts.realIPOpts)
@ -78,7 +79,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
func (ri *realIP) setRealIP(req *Request) {
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
realIPLogger.Debugf("failed to split host port %s", err)
clientIPStr = req.RemoteAddr
}
clientIP := net.ParseIP(clientIPStr)
@ -90,7 +91,7 @@ func (ri *realIP) setRealIP(req *Request) {
}
}
if !isTrusted {
realIPLogger.Debugf("client ip %s is not trusted", clientIP)
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
return
}
@ -98,7 +99,7 @@ func (ri *realIP) setRealIP(req *Request) {
var lastNonTrustedIP string
if len(realIPs) == 0 {
realIPLogger.Debugf("no real ip found in header %q", ri.Header)
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
return
}
@ -110,14 +111,16 @@ func (ri *realIP) setRealIP(req *Request) {
lastNonTrustedIP = r
}
}
if lastNonTrustedIP == "" {
realIPLogger.Debugf("no non-trusted ip found in header %q", ri.Header)
return
}
}
if lastNonTrustedIP == "" {
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
return
}
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set("X-Real-IP", lastNonTrustedIP)
req.Header.Set("X-Forwarded-For", lastNonTrustedIP)
req.Header.Set(xForwardedFor, lastNonTrustedIP)
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
}

View file

@ -2,13 +2,15 @@ package middleware
import (
"net"
"net/http"
"strings"
"testing"
"github.com/yusing/go-proxy/internal/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIP(t *testing.T) {
func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{
"header": "X-Real-IP",
"from": []string{
@ -37,13 +39,39 @@ func TestSetRealIP(t *testing.T) {
Recursive: true,
}
t.Run("set_options", func(t *testing.T) {
ri, err := RealIP.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
})
// TODO test
ri, err := NewRealIP(opts)
ExpectNoError(t, err.Error())
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
for i, CIDR := range ri.impl.(*realIP).From {
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
}
}
func TestSetRealIP(t *testing.T) {
const (
testHeader = "X-Real-IP"
testRealIP = "192.168.1.1"
)
opts := OptionsRaw{
"header": testHeader,
"from": []string{"0.0.0.0/0"},
}
optsMr := OptionsRaw{
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := NewRealIP(opts)
ExpectNoError(t, err.Error())
mr, err := NewModifyRequest(optsMr)
ExpectNoError(t, err.Error())
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err.Error())
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP)
}

View file

@ -7,13 +7,13 @@ import (
)
var RedirectHTTP = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if r.TLS == nil {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
next.ServeHTTP(w, r)
next(w, r)
},
}

View file

@ -0,0 +1,22 @@
deny:
- use: ModifyRequest
setHeaders:
X-Real-IP: 192.168.1.1:1234
- use: RealIP
header: X-Real-IP
from:
- 0.0.0.0/0
- use: CIDRWhitelist
allow:
- 192.168.0.0/24
accept:
- use: ModifyRequest
setHeaders:
X-Real-IP: 192.168.0.1:1234
- use: RealIP
header: X-Real-IP
from:
- 0.0.0.0/0
- use: CIDRWhitelist
allow:
- 192.168.0.0/24

View file

@ -9,6 +9,7 @@ import (
"net/http/httptest"
"net/url"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
)
@ -20,6 +21,9 @@ var testHeaders http.Header
const testHost = "example.com"
func init() {
if !common.IsTest {
return
}
tmp := map[string]string{}
err := json.Unmarshal(testHeadersRaw, &tmp)
if err != nil {
@ -31,13 +35,15 @@ func init() {
}
}
type requestHeaderRecorder struct {
type requestRecorder struct {
parent http.RoundTripper
reqHeaders http.Header
headers http.Header
remoteAddr string
}
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
rt.reqHeaders = req.Header
func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
rt.headers = req.Header
rt.remoteAddr = req.RemoteAddr
if rt.parent != nil {
return rt.parent.RoundTrip(req)
}
@ -46,6 +52,7 @@ func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, e
Header: testHeaders,
Body: io.NopCloser(bytes.NewBufferString("OK")),
Request: req,
TLS: req.TLS,
}, nil
}
@ -53,6 +60,7 @@ type TestResult struct {
RequestHeaders http.Header
ResponseHeaders http.Header
ResponseStatus int
RemoteAddr string
Data []byte
}
@ -65,7 +73,7 @@ type testArgs struct {
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
var body io.Reader
var rt = new(requestHeaderRecorder)
var rr = new(requestRecorder)
var proxyURL *url.URL
var requestTarget string
var err error
@ -98,17 +106,16 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
if err != nil {
return nil, E.From(err)
}
rt.parent = http.DefaultTransport
rr.parent = http.DefaultTransport
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
middleware.name: args.middlewareOpt,
})
rp := gpHTTP.NewReverseProxy(proxyURL, rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}
patchReverseProxy(middleware.name, rp, []*Middleware{mid})
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
@ -117,9 +124,10 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
return nil, E.From(err)
}
return &TestResult{
RequestHeaders: rt.reqHeaders,
RequestHeaders: rr.headers,
ResponseHeaders: resp.Header,
ResponseStatus: resp.StatusCode,
RemoteAddr: rr.remoteAddr,
Data: data,
}, nil
}

View file

@ -0,0 +1,99 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
U "github.com/yusing/go-proxy/internal/utils"
)
type Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders http.Header `json:"req_headers,omitempty"`
RespHeaders http.Header `json:"resp_headers,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}
type Traces []*Trace
var traces = Traces{}
var tracesMu sync.Mutex
const MaxTraceNum = 1000
func GetAllTrace() []*Trace {
return traces
}
func (tr *Trace) WithRequest(req *Request) *Trace {
if tr == nil {
return nil
}
tr.URL = req.RequestURI
tr.ReqHeaders = req.Header.Clone()
return tr
}
func (tr *Trace) WithResponse(resp *Response) *Trace {
if tr == nil {
return nil
}
tr.URL = resp.Request.RequestURI
tr.ReqHeaders = resp.Request.Header.Clone()
tr.RespHeaders = resp.Header.Clone()
return tr
}
func (tr *Trace) With(what string, additional any) *Trace {
if tr == nil {
return nil
}
if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional[what] = additional
return tr
}
func (m *Middleware) EnableTrace() {
m.trace = true
for _, child := range m.children {
child.parent = m
child.EnableTrace()
}
}
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
if !m.trace {
return nil
}
return addTrace(&Trace{
Time: U.FormatTime(time.Now()),
Caller: m.Fullname(),
Message: fmt.Sprintf(msg, args...),
})
}
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
return m.AddTracef("%s", msg).WithRequest(req)
}
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
return m.AddTracef("%s", msg).WithResponse(resp)
}
func addTrace(t *Trace) *Trace {
tracesMu.Lock()
defer tracesMu.Unlock()
if len(traces) > MaxTraceNum {
traces = traces[1:]
}
traces = append(traces, t)
return t
}

View file

@ -2,6 +2,7 @@ package middleware
import (
"net"
"net/http"
)
const (
@ -14,7 +15,7 @@ const (
)
var SetXForwarded = &Middleware{
rewrite: func(req *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
req.Header.Del("Forwarded")
req.Header.Del(xForwardedFor)
req.Header.Del(xForwardedHost)
@ -23,7 +24,7 @@ var SetXForwarded = &Middleware{
if err == nil {
req.Header.Set(xForwardedFor, clientIP)
} else {
req.Header.Del(xForwardedFor)
req.Header.Set(xForwardedFor, req.RemoteAddr)
}
req.Header.Set(xForwardedHost, req.Host)
if req.TLS == nil {
@ -31,14 +32,16 @@ var SetXForwarded = &Middleware{
} else {
req.Header.Set(xForwardedProto, "https")
}
next(w, req)
},
}
var HideXForwarded = &Middleware{
rewrite: func(req *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
req.Header.Del("Forwarded")
req.Header.Del(xForwardedFor)
req.Header.Del(xForwardedHost)
req.Header.Del(xForwardedProto)
next(w, req)
},
}

View file

@ -68,7 +68,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
rp := NewReverseProxy(entry.URL, trans)
if len(entry.Middlewares) > 0 {
err := middleware.PatchReverseProxy(rp, entry.Middlewares)
err := middleware.PatchReverseProxy(string(entry.Alias), rp, entry.Middlewares)
if err != nil {
return nil, err
}

View file

@ -32,3 +32,7 @@ func (cidr *CIDR) Contains(ip net.IP) bool {
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
}
func (cidr *CIDR) Equals(other *CIDR) bool {
return (*net.IPNet)(cidr).IP.Equal(other.IP) && cidr.Mask.String() == other.Mask.String()
}

View file

@ -42,6 +42,10 @@ func FormatDuration(d time.Duration) string {
return strings.Join(parts[:len(parts)-1], ", ") + " and " + parts[len(parts)-1]
}
func FormatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}
func ParseBool(s string) bool {
switch strings.ToLower(s) {
case "1", "true", "yes", "on":

View file

@ -1,19 +1,25 @@
package functional
import (
"encoding/json"
"sync"
)
type Slice[T any] struct {
s []T
s []T
mu sync.Mutex
}
func NewSlice[T any]() *Slice[T] {
return &Slice[T]{make([]T, 0)}
return &Slice[T]{s: make([]T, 0)}
}
func NewSliceN[T any](n int) *Slice[T] {
return &Slice[T]{make([]T, n)}
return &Slice[T]{s: make([]T, n)}
}
func NewSliceFrom[T any](s []T) *Slice[T] {
return &Slice[T]{s}
return &Slice[T]{s: s}
}
func (s *Slice[T]) Size() int {
@ -46,6 +52,30 @@ func (s *Slice[T]) AddRange(other *Slice[T]) *Slice[T] {
return s
}
func (s *Slice[T]) SafeAdd(e T) *Slice[T] {
s.mu.Lock()
defer s.mu.Unlock()
return s.Add(e)
}
func (s *Slice[T]) SafeAddRange(other *Slice[T]) *Slice[T] {
s.mu.Lock()
defer s.mu.Unlock()
return s.AddRange(other)
}
func (s *Slice[T]) Pop() T {
v := s.s[len(s.s)-1]
s.s = s.s[:len(s.s)-1]
return v
}
func (s *Slice[T]) SafePop() T {
s.mu.Lock()
defer s.mu.Unlock()
return s.Pop()
}
func (s *Slice[T]) ForEach(do func(T)) {
for _, v := range s.s {
do(v)
@ -57,7 +87,7 @@ func (s *Slice[T]) Map(m func(T) T) *Slice[T] {
for i, v := range s.s {
n[i] = m(v)
}
return &Slice[T]{n}
return &Slice[T]{s: n}
}
func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] {
@ -67,5 +97,13 @@ func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] {
n = append(n, v)
}
}
return &Slice[T]{n}
return &Slice[T]{s: n}
}
func (s *Slice[T]) String() string {
out, err := json.MarshalIndent(s.s, "", " ")
if err != nil {
panic(err)
}
return string(out)
}

View file

@ -2,10 +2,23 @@ package utils
import (
"errors"
"os"
"reflect"
"testing"
"github.com/yusing/go-proxy/internal/common"
)
func init() {
if common.IsTest {
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
}
}
func IgnoreError[Result any](r Result, _ error) Result {
return r
}
func ExpectNoError(t *testing.T, err error) {
t.Helper()
if err != nil && !reflect.ValueOf(err).IsNil() {