mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-21 20:04:03 +02:00
fixed middleware implementation, added middleware tracing for easier debug
This commit is contained in:
parent
d172552fb0
commit
ba13b81b0e
31 changed files with 561 additions and 196 deletions
15
README.md
15
README.md
|
@ -83,13 +83,14 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
||||||
|
|
||||||
### Commands line arguments
|
### Commands line arguments
|
||||||
|
|
||||||
| Argument | Description | Example |
|
| Argument | Description | Example |
|
||||||
| ----------- | -------------------------------- | -------------------------- |
|
| ----------------- | ---------------------------------------------------- | -------------------------------- |
|
||||||
| empty | start proxy server | |
|
| empty | start proxy server | |
|
||||||
| `validate` | validate config and exit | |
|
| `validate` | validate config and exit | |
|
||||||
| `reload` | trigger a force reload of config | |
|
| `reload` | trigger a force reload of config | |
|
||||||
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
|
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
|
||||||
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
|
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
|
||||||
|
| `debug-ls-mtrace` | list middleware trace **(works only in debug mode)** | `go-proxy debug-ls-mtrace \| jq` |
|
||||||
|
|
||||||
**run with `docker exec go-proxy /app/go-proxy <command>`**
|
**run with `docker exec go-proxy /app/go-proxy <command>`**
|
||||||
|
|
||||||
|
|
12
cmd/main.go
12
cmd/main.go
|
@ -18,7 +18,7 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/yusing/go-proxy/internal"
|
"github.com/yusing/go-proxy/internal"
|
||||||
"github.com/yusing/go-proxy/internal/api"
|
"github.com/yusing/go-proxy/internal/api"
|
||||||
apiUtils "github.com/yusing/go-proxy/internal/api/v1/utils"
|
"github.com/yusing/go-proxy/internal/api/v1/query"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/config"
|
"github.com/yusing/go-proxy/internal/config"
|
||||||
"github.com/yusing/go-proxy/internal/docker"
|
"github.com/yusing/go-proxy/internal/docker"
|
||||||
|
@ -57,7 +57,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.Command == common.CommandReload {
|
if args.Command == common.CommandReload {
|
||||||
if err := apiUtils.ReloadServer(); err.HasError() {
|
if err := query.ReloadServer(); err.HasError() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
log.Print("ok")
|
log.Print("ok")
|
||||||
|
@ -93,7 +93,7 @@ func main() {
|
||||||
printJSON(cfg.Value())
|
printJSON(cfg.Value())
|
||||||
return
|
return
|
||||||
case common.CommandListRoutes:
|
case common.CommandListRoutes:
|
||||||
routes, err := apiUtils.ListRoutes()
|
routes, err := query.ListRoutes()
|
||||||
if err.HasError() {
|
if err.HasError() {
|
||||||
log.Printf("failed to connect to api server: %s", err)
|
log.Printf("failed to connect to api server: %s", err)
|
||||||
log.Printf("falling back to config file")
|
log.Printf("falling back to config file")
|
||||||
|
@ -108,6 +108,12 @@ func main() {
|
||||||
case common.CommandDebugListProviders:
|
case common.CommandDebugListProviders:
|
||||||
printJSON(cfg.DumpProviders())
|
printJSON(cfg.DumpProviders())
|
||||||
return
|
return
|
||||||
|
case common.CommandDebugListMTrace:
|
||||||
|
trace, err := query.ListMiddlewareTraces()
|
||||||
|
if err.HasError() {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
printJSON(trace)
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.IsDebug {
|
if common.IsDebug {
|
||||||
|
|
|
@ -25,10 +25,10 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||||
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
|
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
|
||||||
return
|
return
|
||||||
case route.Type() == R.RouteTypeReverseProxy:
|
case route.Type() == R.RouteTypeReverseProxy:
|
||||||
ok = U.IsSiteHealthy(route.URL().String())
|
ok = IsSiteHealthy(route.URL().String())
|
||||||
case route.Type() == R.RouteTypeStream:
|
case route.Type() == R.RouteTypeStream:
|
||||||
entry := route.Entry()
|
entry := route.Entry()
|
||||||
ok = U.IsStreamHealthy(
|
ok = IsStreamHealthy(
|
||||||
strings.Split(entry.Scheme, ":")[1], // target scheme
|
strings.Split(entry.Scheme, ":")[1], // target scheme
|
||||||
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
|
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,21 +1,22 @@
|
||||||
package utils
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func IsSiteHealthy(url string) bool {
|
func IsSiteHealthy(url string) bool {
|
||||||
// try HEAD first
|
// try HEAD first
|
||||||
// if HEAD is not allowed, try GET
|
// if HEAD is not allowed, try GET
|
||||||
resp, err := httpClient.Head(url)
|
resp, err := U.Head(url)
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
}
|
}
|
||||||
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
|
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
|
||||||
_, err = httpClient.Get(url)
|
_, err = U.Get(url)
|
||||||
}
|
}
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
|
@ -8,19 +8,28 @@ import (
|
||||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/config"
|
"github.com/yusing/go-proxy/internal/config"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ListRoutes = "routes"
|
||||||
|
ListConfigFiles = "config_files"
|
||||||
|
ListMiddlewareTrace = "middleware_trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||||
what := r.PathValue("what")
|
what := r.PathValue("what")
|
||||||
if what == "" {
|
if what == "" {
|
||||||
what = "routes"
|
what = ListRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
switch what {
|
switch what {
|
||||||
case "routes":
|
case ListRoutes:
|
||||||
listRoutes(cfg, w, r)
|
listRoutes(cfg, w, r)
|
||||||
case "config_files":
|
case ListConfigFiles:
|
||||||
listConfigFiles(w, r)
|
listConfigFiles(w, r)
|
||||||
|
case ListMiddlewareTrace:
|
||||||
|
listMiddlewareTrace(w, r)
|
||||||
default:
|
default:
|
||||||
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
|
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
@ -59,3 +68,12 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
w.Write(resp)
|
w.Write(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp, err := json.Marshal(middleware.GetAllTrace())
|
||||||
|
if err != nil {
|
||||||
|
U.HandleErr(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(resp)
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package utils
|
package query
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -6,12 +6,15 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||||
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ReloadServer() E.NestedError {
|
func ReloadServer() E.NestedError {
|
||||||
resp, err := httpClient.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
|
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.From(err)
|
return E.From(err)
|
||||||
}
|
}
|
||||||
|
@ -32,7 +35,7 @@ func ReloadServer() E.NestedError {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
||||||
resp, err := httpClient.Get(fmt.Sprintf("%s/v1/list/routes", common.APIHTTPURL))
|
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.From(err)
|
return nil, E.From(err)
|
||||||
}
|
}
|
||||||
|
@ -47,3 +50,20 @@ func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
||||||
}
|
}
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
|
||||||
|
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace))
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.From(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode)
|
||||||
|
}
|
||||||
|
var traces middleware.Traces
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&traces)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.From(err)
|
||||||
|
}
|
||||||
|
return traces, nil
|
||||||
|
}
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpClient = &http.Client{
|
var HTTPClient = &http.Client{
|
||||||
Timeout: common.ConnectionTimeout,
|
Timeout: common.ConnectionTimeout,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
@ -21,3 +21,7 @@ var httpClient = &http.Client{
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var Get = HTTPClient.Get
|
||||||
|
var Post = HTTPClient.Post
|
||||||
|
var Head = HTTPClient.Head
|
||||||
|
|
|
@ -2,9 +2,9 @@ package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Args struct {
|
type Args struct {
|
||||||
|
@ -20,6 +20,7 @@ const (
|
||||||
CommandReload = "reload"
|
CommandReload = "reload"
|
||||||
CommandDebugListEntries = "debug-ls-entries"
|
CommandDebugListEntries = "debug-ls-entries"
|
||||||
CommandDebugListProviders = "debug-ls-providers"
|
CommandDebugListProviders = "debug-ls-providers"
|
||||||
|
CommandDebugListMTrace = "debug-ls-mtrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ValidCommands = []string{
|
var ValidCommands = []string{
|
||||||
|
@ -31,23 +32,24 @@ var ValidCommands = []string{
|
||||||
CommandReload,
|
CommandReload,
|
||||||
CommandDebugListEntries,
|
CommandDebugListEntries,
|
||||||
CommandDebugListProviders,
|
CommandDebugListProviders,
|
||||||
|
CommandDebugListMTrace,
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetArgs() Args {
|
func GetArgs() Args {
|
||||||
var args Args
|
var args Args
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
args.Command = flag.Arg(0)
|
args.Command = flag.Arg(0)
|
||||||
if err := validateArg(args.Command); err.HasError() {
|
if err := validateArg(args.Command); err != nil {
|
||||||
logrus.Fatal(err)
|
logrus.Fatal(err)
|
||||||
}
|
}
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateArg(arg string) E.NestedError {
|
func validateArg(arg string) error {
|
||||||
for _, v := range ValidCommands {
|
for _, v := range ValidCommands {
|
||||||
if arg == v {
|
if arg == v {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return E.Invalid("argument", arg)
|
return fmt.Errorf("invalid command: %s", arg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,14 +4,14 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", false)
|
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", false)
|
||||||
IsTest = GetEnvBool("GOPROXY_TEST", false)
|
IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test")
|
||||||
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
|
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
|
||||||
|
|
||||||
ProxyHTTPAddr,
|
ProxyHTTPAddr,
|
||||||
|
@ -35,7 +35,14 @@ func GetEnvBool(key string, defaultValue bool) bool {
|
||||||
if !ok || value == "" {
|
if !ok || value == "" {
|
||||||
return defaultValue
|
return defaultValue
|
||||||
}
|
}
|
||||||
return U.ParseBool(value)
|
switch strings.ToLower(value) {
|
||||||
|
case "true", "yes", "1":
|
||||||
|
return true
|
||||||
|
case "false", "no", "0":
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetEnv(key, defaultValue string) string {
|
func GetEnv(key, defaultValue string) string {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
D "github.com/yusing/go-proxy/internal/docker"
|
D "github.com/yusing/go-proxy/internal/docker"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/types"
|
"github.com/yusing/go-proxy/internal/types"
|
||||||
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
type cidrWhitelist struct {
|
type cidrWhitelist struct {
|
||||||
|
@ -19,7 +20,7 @@ type cidrWhitelistOpts struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
Message string
|
Message string
|
||||||
|
|
||||||
trustedAddr map[string]struct{} // cache for trusted IPs
|
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||||
}
|
}
|
||||||
|
|
||||||
var CIDRWhiteList = &cidrWhitelist{
|
var CIDRWhiteList = &cidrWhitelist{
|
||||||
|
@ -28,15 +29,16 @@ var CIDRWhiteList = &cidrWhitelist{
|
||||||
"allow": D.YamlStringListParser,
|
"allow": D.YamlStringListParser,
|
||||||
"statusCode": D.IntParser,
|
"statusCode": D.IntParser,
|
||||||
},
|
},
|
||||||
|
withOptions: NewCIDRWhitelist,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
|
var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
|
||||||
return &cidrWhitelistOpts{
|
return &cidrWhitelistOpts{
|
||||||
Allow: []*types.CIDR{},
|
Allow: []*types.CIDR{},
|
||||||
StatusCode: http.StatusForbidden,
|
StatusCode: http.StatusForbidden,
|
||||||
Message: "IP not allowed",
|
Message: "IP not allowed",
|
||||||
trustedAddr: make(map[string]struct{}),
|
cachedAddr: F.NewMapOf[string, bool](),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,23 +59,32 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||||
return wl.m, nil
|
return wl.m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wl *cidrWhitelist) checkIP(next http.Handler, w ResponseWriter, r *Request) {
|
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
var ok bool
|
var allow, ok bool
|
||||||
if _, ok = wl.trustedAddr[r.RemoteAddr]; !ok {
|
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
|
||||||
ip := net.IP(r.RemoteAddr)
|
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
ipStr = r.RemoteAddr
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(ipStr)
|
||||||
for _, cidr := range wl.cidrWhitelistOpts.Allow {
|
for _, cidr := range wl.cidrWhitelistOpts.Allow {
|
||||||
if cidr.Contains(ip) {
|
if cidr.Contains(ip) {
|
||||||
wl.trustedAddr[r.RemoteAddr] = struct{}{}
|
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||||
ok = true
|
allow = true
|
||||||
|
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !allow {
|
||||||
|
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||||
|
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !ok {
|
if !allow {
|
||||||
w.WriteHeader(wl.StatusCode)
|
w.WriteHeader(wl.StatusCode)
|
||||||
w.Write([]byte(wl.Message))
|
w.Write([]byte(wl.Message))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
|
|
42
internal/net/http/middleware/cidr_whitelist_test.go
Normal file
42
internal/net/http/middleware/cidr_whitelist_test.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed test_data/cidr_whitelist_test.yml
|
||||||
|
var testCIDRWhitelistCompose []byte
|
||||||
|
var deny, accept *Middleware
|
||||||
|
|
||||||
|
func TestCIDRWhitelist(t *testing.T) {
|
||||||
|
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
deny = mids["deny@file"]
|
||||||
|
accept = mids["accept@file"]
|
||||||
|
if deny == nil || accept == nil {
|
||||||
|
panic("bug occurred")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("deny", func(t *testing.T) {
|
||||||
|
for range 10 {
|
||||||
|
result, err := newMiddlewareTest(deny, nil)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
|
||||||
|
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("accept", func(t *testing.T) {
|
||||||
|
for range 10 {
|
||||||
|
result, err := newMiddlewareTest(accept, nil)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -39,12 +39,13 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
|
||||||
cri := new(realIP)
|
cri := new(realIP)
|
||||||
cri.m = &Middleware{
|
cri.m = &Middleware{
|
||||||
impl: cri,
|
impl: cri,
|
||||||
rewrite: func(r *Request) {
|
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
cidrs := tryFetchCFCIDR()
|
cidrs := tryFetchCFCIDR()
|
||||||
if cidrs != nil {
|
if cidrs != nil {
|
||||||
cri.From = cidrs
|
cri.From = cidrs
|
||||||
}
|
}
|
||||||
cri.setRealIP(r)
|
cri.setRealIP(r)
|
||||||
|
next(w, r)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cri.realIPOpts = &realIPOpts{
|
cri.realIPOpts = &realIPOpts{
|
||||||
|
|
|
@ -15,9 +15,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var CustomErrorPage = &Middleware{
|
var CustomErrorPage = &Middleware{
|
||||||
before: func(next http.Handler, w ResponseWriter, r *Request) {
|
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
if !ServeStaticErrorPageFile(w, r) {
|
if !ServeStaticErrorPageFile(w, r) {
|
||||||
next.ServeHTTP(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
modifyResponse: func(resp *Response) error {
|
modifyResponse: func(resp *Response) error {
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
D "github.com/yusing/go-proxy/internal/docker"
|
D "github.com/yusing/go-proxy/internal/docker"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
|
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
|
||||||
|
@ -45,7 +44,6 @@ var ForwardAuth = func() *forwardAuth {
|
||||||
fa.m.withOptions = NewForwardAuthfunc
|
fa.m.withOptions = NewForwardAuthfunc
|
||||||
return fa
|
return fa
|
||||||
}()
|
}()
|
||||||
var faLogger = logrus.WithField("middleware", "ForwardAuth")
|
|
||||||
|
|
||||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||||
faWithOpts := new(forwardAuth)
|
faWithOpts := new(forwardAuth)
|
||||||
|
@ -80,7 +78,7 @@ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||||
return faWithOpts.m, nil
|
return faWithOpts.m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
|
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||||
gpHTTP.RemoveHop(req.Header)
|
gpHTTP.RemoveHop(req.Header)
|
||||||
|
|
||||||
faReq, err := http.NewRequestWithContext(
|
faReq, err := http.NewRequestWithContext(
|
||||||
|
@ -90,7 +88,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
|
fa.m.AddTracef("new request err to %s", fa.Address).With("error", err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -103,7 +101,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||||
|
|
||||||
faResp, err := fa.client.Do(faReq)
|
faResp, err := fa.client.Do(faReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
|
fa.m.AddTracef("failed to call %s", fa.Address).With("error", err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -111,7 +109,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||||
|
|
||||||
body, err := io.ReadAll(faResp.Body)
|
body, err := io.ReadAll(faResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
|
fa.m.AddTracef("failed to read response body from %s", fa.Address).With("error", err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -122,7 +120,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||||
|
|
||||||
redirectURL, err := faResp.Location()
|
redirectURL, err := faResp.Location()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
|
fa.m.AddTracef("failed to get location from %s", fa.Address).With("error", err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
} else if redirectURL.String() != "" {
|
} else if redirectURL.String() != "" {
|
||||||
|
@ -132,7 +130,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||||
w.WriteHeader(faResp.StatusCode)
|
w.WriteHeader(faResp.StatusCode)
|
||||||
|
|
||||||
if _, err = w.Write(body); err != nil {
|
if _, err = w.Write(body); err != nil {
|
||||||
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
|
fa.m.AddTracef("failed to write response body from %s", fa.Address).With("error", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
D "github.com/yusing/go-proxy/internal/docker"
|
D "github.com/yusing/go-proxy/internal/docker"
|
||||||
|
@ -21,7 +22,7 @@ type (
|
||||||
Header = http.Header
|
Header = http.Header
|
||||||
Cookie = http.Cookie
|
Cookie = http.Cookie
|
||||||
|
|
||||||
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
|
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
||||||
RewriteFunc func(req *Request)
|
RewriteFunc func(req *Request)
|
||||||
ModifyResponseFunc func(resp *Response) error
|
ModifyResponseFunc func(resp *Response) error
|
||||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
|
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
|
||||||
|
@ -33,23 +34,38 @@ type (
|
||||||
name string
|
name string
|
||||||
|
|
||||||
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
||||||
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
|
|
||||||
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
|
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
|
||||||
|
|
||||||
transport http.RoundTripper
|
|
||||||
|
|
||||||
withOptions CloneWithOptFunc
|
withOptions CloneWithOptFunc
|
||||||
labelParserMap D.ValueParserMap
|
labelParserMap D.ValueParserMap
|
||||||
impl any
|
impl any
|
||||||
|
|
||||||
|
parent *Middleware
|
||||||
|
children []*Middleware
|
||||||
|
trace bool
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var Deserialize = U.Deserialize
|
var Deserialize = U.Deserialize
|
||||||
|
|
||||||
|
func Rewrite(r RewriteFunc) BeforeFunc {
|
||||||
|
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||||
|
r(req)
|
||||||
|
next(w, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Middleware) Name() string {
|
func (m *Middleware) Name() string {
|
||||||
return m.name
|
return m.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) Fullname() string {
|
||||||
|
if m.parent != nil {
|
||||||
|
return m.parent.Fullname() + "." + m.name
|
||||||
|
}
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Middleware) String() string {
|
func (m *Middleware) String() string {
|
||||||
return m.name
|
return m.name
|
||||||
}
|
}
|
||||||
|
@ -72,14 +88,21 @@ func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Nested
|
||||||
|
|
||||||
// WithOptionsClone is called only once
|
// WithOptionsClone is called only once
|
||||||
// set withOptions and labelParser will not be used after that
|
// set withOptions and labelParser will not be used after that
|
||||||
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
|
return &Middleware{
|
||||||
|
m.name,
|
||||||
|
m.before,
|
||||||
|
m.modifyResponse,
|
||||||
|
nil, nil,
|
||||||
|
m.impl,
|
||||||
|
m.parent,
|
||||||
|
m.children,
|
||||||
|
false,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check conflict or duplicates
|
// TODO: check conflict or duplicates
|
||||||
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
|
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
|
||||||
befores := make([]BeforeFunc, 0, len(middlewares))
|
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||||
rewrites := make([]RewriteFunc, 0, len(middlewares))
|
|
||||||
modResps := make([]ModifyResponseFunc, 0, len(middlewares))
|
|
||||||
|
|
||||||
invalidM := E.NewBuilder("invalid middlewares")
|
invalidM := E.NewBuilder("invalid middlewares")
|
||||||
invalidOpts := E.NewBuilder("invalid options")
|
invalidOpts := E.NewBuilder("invalid options")
|
||||||
|
@ -88,7 +111,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
|
||||||
invalidM.To(&res)
|
invalidM.To(&res)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for name, opts := range middlewares {
|
for name, opts := range middlewaresMap {
|
||||||
m, ok := Get(name)
|
m, ok := Get(name)
|
||||||
if !ok {
|
if !ok {
|
||||||
invalidM.Add(E.NotExist("middleware", name))
|
invalidM.Add(E.NotExist("middleware", name))
|
||||||
|
@ -100,56 +123,35 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
|
||||||
invalidOpts.Add(err.Subject(name))
|
invalidOpts.Add(err.Subject(name))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if m.before != nil {
|
middlewares = append(middlewares, m)
|
||||||
befores = append(befores, m.before)
|
|
||||||
}
|
|
||||||
if m.rewrite != nil {
|
|
||||||
rewrites = append(rewrites, m.rewrite)
|
|
||||||
}
|
|
||||||
if m.modifyResponse != nil {
|
|
||||||
modResps = append(modResps, m.modifyResponse)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if invalidM.HasError() {
|
if invalidM.HasError() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
origServeHTTP := rp.ServeHTTP
|
patchReverseProxy(rpName, rp, middlewares)
|
||||||
for i, before := range befores {
|
|
||||||
if i < len(befores)-1 {
|
|
||||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
|
||||||
before(rp.ServeHTTP, w, r)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
|
||||||
before(origServeHTTP, w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rewrites) > 0 {
|
|
||||||
origServeHTTP = rp.ServeHTTP
|
|
||||||
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
for _, rewrite := range rewrites {
|
|
||||||
rewrite(r)
|
|
||||||
}
|
|
||||||
origServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(modResps) > 0 {
|
|
||||||
if rp.ModifyResponse != nil {
|
|
||||||
modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...)
|
|
||||||
}
|
|
||||||
rp.ModifyResponse = func(res *Response) error {
|
|
||||||
b := E.NewBuilder("errors in middleware ModifyResponse")
|
|
||||||
for _, mr := range modResps {
|
|
||||||
b.AddE(mr(res))
|
|
||||||
}
|
|
||||||
return b.Build().Error()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func patchReverseProxy(rpName string, rp *ReverseProxy, middlewares []*Middleware) {
|
||||||
|
mid := BuildMiddlewareFromChain(rpName, middlewares)
|
||||||
|
|
||||||
|
if mid.before != nil {
|
||||||
|
ori := rp.ServeHTTP
|
||||||
|
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mid.before(ori, w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if mid.modifyResponse != nil {
|
||||||
|
if rp.ModifyResponse != nil {
|
||||||
|
ori := rp.ModifyResponse
|
||||||
|
rp.ModifyResponse = func(res *http.Response) error {
|
||||||
|
return errors.Join(mid.modifyResponse(res), ori(res))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rp.ModifyResponse = mid.modifyResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
@ -23,7 +25,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||||
var rawMap map[string][]map[string]any
|
var rawMap map[string][]map[string]any
|
||||||
err := yaml.Unmarshal(data, &rawMap)
|
err := yaml.Unmarshal(data, &rawMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Add(E.FailWith("toml unmarshal", err))
|
b.Add(E.FailWith("yaml unmarshal", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
middlewares = make(map[string]*Middleware)
|
middlewares = make(map[string]*Middleware)
|
||||||
|
@ -31,18 +33,22 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||||
chainErr := E.NewBuilder(name)
|
chainErr := E.NewBuilder(name)
|
||||||
chain := make([]*Middleware, 0, len(defs))
|
chain := make([]*Middleware, 0, len(defs))
|
||||||
for i, def := range defs {
|
for i, def := range defs {
|
||||||
if def["use"] == nil || def["use"].(string) == "" {
|
if def["use"] == nil || def["use"] == "" {
|
||||||
chainErr.Add(E.Missing("use").Subjectf("%s.%d", name, i))
|
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
baseName := def["use"].(string)
|
baseName := def["use"].(string)
|
||||||
base, ok := Get(baseName)
|
base, ok := Get(baseName)
|
||||||
if !ok {
|
if !ok {
|
||||||
chainErr.Add(E.NotExist("middleware", baseName).Subjectf("%s.%d", name, i))
|
base, ok = middlewares[baseName]
|
||||||
continue
|
if !ok {
|
||||||
|
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
delete(def, "use")
|
delete(def, "use")
|
||||||
m, err := base.WithOptionsClone(def)
|
m, err := base.WithOptionsClone(def)
|
||||||
|
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
chainErr.Add(err.Subjectf("item%d", i))
|
chainErr.Add(err.Subjectf("item%d", i))
|
||||||
continue
|
continue
|
||||||
|
@ -52,8 +58,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||||
if chainErr.HasError() {
|
if chainErr.HasError() {
|
||||||
b.Add(chainErr.Build())
|
b.Add(chainErr.Build())
|
||||||
} else {
|
} else {
|
||||||
name = name + "@file"
|
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
|
||||||
middlewares[name] = BuildMiddlewareFromChain(name, chain)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -61,47 +66,49 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||||
|
|
||||||
// TODO: check conflict or duplicates
|
// TODO: check conflict or duplicates
|
||||||
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
||||||
var (
|
m := &Middleware{name: name, children: chain}
|
||||||
befores []BeforeFunc
|
|
||||||
rewrites []RewriteFunc
|
var befores []*Middleware
|
||||||
modResps []ModifyResponseFunc
|
var modResps []*Middleware
|
||||||
)
|
|
||||||
for _, m := range chain {
|
for _, comp := range chain {
|
||||||
if m.before != nil {
|
if comp.before != nil {
|
||||||
befores = append(befores, m.before)
|
befores = append(befores, comp)
|
||||||
}
|
}
|
||||||
if m.rewrite != nil {
|
if comp.modifyResponse != nil {
|
||||||
rewrites = append(rewrites, m.rewrite)
|
modResps = append(modResps, comp)
|
||||||
}
|
|
||||||
if m.modifyResponse != nil {
|
|
||||||
modResps = append(modResps, m.modifyResponse)
|
|
||||||
}
|
}
|
||||||
|
comp.parent = m
|
||||||
}
|
}
|
||||||
|
|
||||||
m := &Middleware{name: name}
|
|
||||||
if len(befores) > 0 {
|
if len(befores) > 0 {
|
||||||
m.before = func(next http.Handler, w ResponseWriter, r *Request) {
|
m.before = buildBefores(befores)
|
||||||
for _, before := range befores {
|
|
||||||
before(next, w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(rewrites) > 0 {
|
|
||||||
m.rewrite = func(r *Request) {
|
|
||||||
for _, rewrite := range rewrites {
|
|
||||||
rewrite(r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if len(modResps) > 0 {
|
if len(modResps) > 0 {
|
||||||
m.modifyResponse = func(res *Response) error {
|
m.modifyResponse = func(res *Response) error {
|
||||||
b := E.NewBuilder("errors in middleware %s", name)
|
b := E.NewBuilder("errors in middleware")
|
||||||
for _, mr := range modResps {
|
for _, mr := range modResps {
|
||||||
b.AddE(mr(res))
|
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
|
||||||
}
|
}
|
||||||
return b.Build().Error()
|
return b.Build().Error()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if common.IsDebug {
|
||||||
|
m.EnableTrace()
|
||||||
|
m.AddTracef("middleware created")
|
||||||
|
}
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildBefores(befores []*Middleware) BeforeFunc {
|
||||||
|
if len(befores) == 1 {
|
||||||
|
return befores[0].before
|
||||||
|
}
|
||||||
|
nextBefores := buildBefores(befores[1:])
|
||||||
|
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
|
befores[0].before(func(w ResponseWriter, r *Request) {
|
||||||
|
nextBefores(next, w, r)
|
||||||
|
}, w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -67,10 +67,10 @@ func LoadComposeFiles() {
|
||||||
b.Add(E.Duplicated("middleware", name))
|
b.Add(E.Duplicated("middleware", name))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
middlewares[name] = m
|
middlewares[U.ToLowerNoSnake(name)] = m
|
||||||
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
|
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
|
||||||
}
|
}
|
||||||
b.Add(err.Subject(defFile))
|
b.Add(err.Subject(path.Base(defFile)))
|
||||||
}
|
}
|
||||||
if b.HasError() {
|
if b.HasError() {
|
||||||
logger.Error(b.Build())
|
logger.Error(b.Build())
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
D "github.com/yusing/go-proxy/internal/docker"
|
D "github.com/yusing/go-proxy/internal/docker"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
)
|
)
|
||||||
|
@ -32,9 +33,15 @@ var ModifyRequest = func() *modifyRequest {
|
||||||
|
|
||||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||||
mr := new(modifyRequest)
|
mr := new(modifyRequest)
|
||||||
|
var mrFunc RewriteFunc
|
||||||
|
if common.IsDebug {
|
||||||
|
mrFunc = mr.modifyRequestWithTrace
|
||||||
|
} else {
|
||||||
|
mrFunc = mr.modifyRequest
|
||||||
|
}
|
||||||
mr.m = &Middleware{
|
mr.m = &Middleware{
|
||||||
impl: mr,
|
impl: mr,
|
||||||
rewrite: mr.modifyRequest,
|
before: Rewrite(mrFunc),
|
||||||
}
|
}
|
||||||
mr.modifyRequestOpts = new(modifyRequestOpts)
|
mr.modifyRequestOpts = new(modifyRequestOpts)
|
||||||
err := Deserialize(optsRaw, mr.modifyRequestOpts)
|
err := Deserialize(optsRaw, mr.modifyRequestOpts)
|
||||||
|
@ -55,3 +62,9 @@ func (mr *modifyRequest) modifyRequest(req *Request) {
|
||||||
req.Header.Del(k)
|
req.Header.Del(k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mr *modifyRequest) modifyRequestWithTrace(req *Request) {
|
||||||
|
mr.m.AddTraceRequest("before modify request", req)
|
||||||
|
mr.modifyRequest(req)
|
||||||
|
mr.m.AddTraceRequest("after modify request", req)
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
D "github.com/yusing/go-proxy/internal/docker"
|
D "github.com/yusing/go-proxy/internal/docker"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
)
|
)
|
||||||
|
@ -34,9 +35,11 @@ var ModifyResponse = func() (mr *modifyResponse) {
|
||||||
|
|
||||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||||
mr := new(modifyResponse)
|
mr := new(modifyResponse)
|
||||||
mr.m = &Middleware{
|
mr.m = &Middleware{impl: mr}
|
||||||
impl: mr,
|
if common.IsDebug {
|
||||||
modifyResponse: mr.modifyResponse,
|
mr.m.modifyResponse = mr.modifyResponseWithTrace
|
||||||
|
} else {
|
||||||
|
mr.m.modifyResponse = mr.modifyResponse
|
||||||
}
|
}
|
||||||
mr.modifyResponseOpts = new(modifyResponseOpts)
|
mr.modifyResponseOpts = new(modifyResponseOpts)
|
||||||
err := Deserialize(optsRaw, mr.modifyResponseOpts)
|
err := Deserialize(optsRaw, mr.modifyResponseOpts)
|
||||||
|
@ -58,3 +61,10 @@ func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error {
|
||||||
|
mr.m.AddTraceResponse("before modify response", resp)
|
||||||
|
err := mr.modifyResponse(resp)
|
||||||
|
mr.m.AddTraceResponse("after modify response", resp)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -2,8 +2,8 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
D "github.com/yusing/go-proxy/internal/docker"
|
D "github.com/yusing/go-proxy/internal/docker"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/types"
|
"github.com/yusing/go-proxy/internal/types"
|
||||||
|
@ -49,13 +49,14 @@ var realIPOptsDefault = func() *realIPOpts {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var realIPLogger = logrus.WithField("middleware", "RealIP")
|
|
||||||
|
|
||||||
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
|
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||||
riWithOpts := new(realIP)
|
riWithOpts := new(realIP)
|
||||||
riWithOpts.m = &Middleware{
|
riWithOpts.m = &Middleware{
|
||||||
impl: riWithOpts,
|
impl: riWithOpts,
|
||||||
rewrite: riWithOpts.setRealIP,
|
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
|
riWithOpts.setRealIP(r)
|
||||||
|
next(w, r)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
riWithOpts.realIPOpts = realIPOptsDefault()
|
riWithOpts.realIPOpts = realIPOptsDefault()
|
||||||
err := Deserialize(opts, riWithOpts.realIPOpts)
|
err := Deserialize(opts, riWithOpts.realIPOpts)
|
||||||
|
@ -78,7 +79,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
||||||
func (ri *realIP) setRealIP(req *Request) {
|
func (ri *realIP) setRealIP(req *Request) {
|
||||||
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
realIPLogger.Debugf("failed to split host port %s", err)
|
clientIPStr = req.RemoteAddr
|
||||||
}
|
}
|
||||||
clientIP := net.ParseIP(clientIPStr)
|
clientIP := net.ParseIP(clientIPStr)
|
||||||
|
|
||||||
|
@ -90,7 +91,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !isTrusted {
|
if !isTrusted {
|
||||||
realIPLogger.Debugf("client ip %s is not trusted", clientIP)
|
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +99,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||||
var lastNonTrustedIP string
|
var lastNonTrustedIP string
|
||||||
|
|
||||||
if len(realIPs) == 0 {
|
if len(realIPs) == 0 {
|
||||||
realIPLogger.Debugf("no real ip found in header %q", ri.Header)
|
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,14 +111,16 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||||
lastNonTrustedIP = r
|
lastNonTrustedIP = r
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if lastNonTrustedIP == "" {
|
}
|
||||||
realIPLogger.Debugf("no non-trusted ip found in header %q", ri.Header)
|
|
||||||
return
|
if lastNonTrustedIP == "" {
|
||||||
}
|
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.RemoteAddr = lastNonTrustedIP
|
req.RemoteAddr = lastNonTrustedIP
|
||||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||||
req.Header.Set("X-Real-IP", lastNonTrustedIP)
|
req.Header.Set("X-Real-IP", lastNonTrustedIP)
|
||||||
req.Header.Set("X-Forwarded-For", lastNonTrustedIP)
|
req.Header.Set(xForwardedFor, lastNonTrustedIP)
|
||||||
|
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,13 +2,15 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/types"
|
"github.com/yusing/go-proxy/internal/types"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSetRealIP(t *testing.T) {
|
func TestSetRealIPOpts(t *testing.T) {
|
||||||
opts := OptionsRaw{
|
opts := OptionsRaw{
|
||||||
"header": "X-Real-IP",
|
"header": "X-Real-IP",
|
||||||
"from": []string{
|
"from": []string{
|
||||||
|
@ -37,13 +39,39 @@ func TestSetRealIP(t *testing.T) {
|
||||||
Recursive: true,
|
Recursive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("set_options", func(t *testing.T) {
|
ri, err := NewRealIP(opts)
|
||||||
ri, err := RealIP.m.WithOptionsClone(opts)
|
ExpectNoError(t, err.Error())
|
||||||
ExpectNoError(t, err.Error())
|
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||||
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||||
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
|
for i, CIDR := range ri.impl.(*realIP).From {
|
||||||
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
|
||||||
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
|
}
|
||||||
})
|
}
|
||||||
// TODO test
|
|
||||||
|
func TestSetRealIP(t *testing.T) {
|
||||||
|
const (
|
||||||
|
testHeader = "X-Real-IP"
|
||||||
|
testRealIP = "192.168.1.1"
|
||||||
|
)
|
||||||
|
opts := OptionsRaw{
|
||||||
|
"header": testHeader,
|
||||||
|
"from": []string{"0.0.0.0/0"},
|
||||||
|
}
|
||||||
|
optsMr := OptionsRaw{
|
||||||
|
"set_headers": map[string]string{testHeader: testRealIP},
|
||||||
|
}
|
||||||
|
realip, err := NewRealIP(opts)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
|
||||||
|
mr, err := NewModifyRequest(optsMr)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
|
||||||
|
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
|
||||||
|
|
||||||
|
result, err := newMiddlewareTest(mid, nil)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
t.Log(traces)
|
||||||
|
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||||
|
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
||||||
|
ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,13 +7,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var RedirectHTTP = &Middleware{
|
var RedirectHTTP = &Middleware{
|
||||||
before: func(next http.Handler, w ResponseWriter, r *Request) {
|
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
if r.TLS == nil {
|
if r.TLS == nil {
|
||||||
r.URL.Scheme = "https"
|
r.URL.Scheme = "https"
|
||||||
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
|
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
|
||||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r)
|
next(w, r)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
deny:
|
||||||
|
- use: ModifyRequest
|
||||||
|
setHeaders:
|
||||||
|
X-Real-IP: 192.168.1.1:1234
|
||||||
|
- use: RealIP
|
||||||
|
header: X-Real-IP
|
||||||
|
from:
|
||||||
|
- 0.0.0.0/0
|
||||||
|
- use: CIDRWhitelist
|
||||||
|
allow:
|
||||||
|
- 192.168.0.0/24
|
||||||
|
accept:
|
||||||
|
- use: ModifyRequest
|
||||||
|
setHeaders:
|
||||||
|
X-Real-IP: 192.168.0.1:1234
|
||||||
|
- use: RealIP
|
||||||
|
header: X-Real-IP
|
||||||
|
from:
|
||||||
|
- 0.0.0.0/0
|
||||||
|
- use: CIDRWhitelist
|
||||||
|
allow:
|
||||||
|
- 192.168.0.0/24
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
|
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
|
||||||
)
|
)
|
||||||
|
@ -20,6 +21,9 @@ var testHeaders http.Header
|
||||||
const testHost = "example.com"
|
const testHost = "example.com"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
if !common.IsTest {
|
||||||
|
return
|
||||||
|
}
|
||||||
tmp := map[string]string{}
|
tmp := map[string]string{}
|
||||||
err := json.Unmarshal(testHeadersRaw, &tmp)
|
err := json.Unmarshal(testHeadersRaw, &tmp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -31,13 +35,15 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type requestHeaderRecorder struct {
|
type requestRecorder struct {
|
||||||
parent http.RoundTripper
|
parent http.RoundTripper
|
||||||
reqHeaders http.Header
|
headers http.Header
|
||||||
|
remoteAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
rt.reqHeaders = req.Header
|
rt.headers = req.Header
|
||||||
|
rt.remoteAddr = req.RemoteAddr
|
||||||
if rt.parent != nil {
|
if rt.parent != nil {
|
||||||
return rt.parent.RoundTrip(req)
|
return rt.parent.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
@ -46,6 +52,7 @@ func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, e
|
||||||
Header: testHeaders,
|
Header: testHeaders,
|
||||||
Body: io.NopCloser(bytes.NewBufferString("OK")),
|
Body: io.NopCloser(bytes.NewBufferString("OK")),
|
||||||
Request: req,
|
Request: req,
|
||||||
|
TLS: req.TLS,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,6 +60,7 @@ type TestResult struct {
|
||||||
RequestHeaders http.Header
|
RequestHeaders http.Header
|
||||||
ResponseHeaders http.Header
|
ResponseHeaders http.Header
|
||||||
ResponseStatus int
|
ResponseStatus int
|
||||||
|
RemoteAddr string
|
||||||
Data []byte
|
Data []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +73,7 @@ type testArgs struct {
|
||||||
|
|
||||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
|
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
|
||||||
var body io.Reader
|
var body io.Reader
|
||||||
var rt = new(requestHeaderRecorder)
|
var rr = new(requestRecorder)
|
||||||
var proxyURL *url.URL
|
var proxyURL *url.URL
|
||||||
var requestTarget string
|
var requestTarget string
|
||||||
var err error
|
var err error
|
||||||
|
@ -98,17 +106,16 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.From(err)
|
return nil, E.From(err)
|
||||||
}
|
}
|
||||||
rt.parent = http.DefaultTransport
|
rr.parent = http.DefaultTransport
|
||||||
} else {
|
} else {
|
||||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||||
}
|
}
|
||||||
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
|
rp := gpHTTP.NewReverseProxy(proxyURL, rr)
|
||||||
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
|
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||||
middleware.name: args.middlewareOpt,
|
|
||||||
})
|
|
||||||
if setOptErr != nil {
|
if setOptErr != nil {
|
||||||
return nil, setOptErr
|
return nil, setOptErr
|
||||||
}
|
}
|
||||||
|
patchReverseProxy(middleware.name, rp, []*Middleware{mid})
|
||||||
rp.ServeHTTP(w, req)
|
rp.ServeHTTP(w, req)
|
||||||
resp := w.Result()
|
resp := w.Result()
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
@ -117,9 +124,10 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
||||||
return nil, E.From(err)
|
return nil, E.From(err)
|
||||||
}
|
}
|
||||||
return &TestResult{
|
return &TestResult{
|
||||||
RequestHeaders: rt.reqHeaders,
|
RequestHeaders: rr.headers,
|
||||||
ResponseHeaders: resp.Header,
|
ResponseHeaders: resp.Header,
|
||||||
ResponseStatus: resp.StatusCode,
|
ResponseStatus: resp.StatusCode,
|
||||||
|
RemoteAddr: rr.remoteAddr,
|
||||||
Data: data,
|
Data: data,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
99
internal/net/http/middleware/trace.go
Normal file
99
internal/net/http/middleware/trace.go
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
U "github.com/yusing/go-proxy/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Trace struct {
|
||||||
|
Time string `json:"time,omitempty"`
|
||||||
|
Caller string `json:"caller,omitempty"`
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
Message string `json:"msg"`
|
||||||
|
ReqHeaders http.Header `json:"req_headers,omitempty"`
|
||||||
|
RespHeaders http.Header `json:"resp_headers,omitempty"`
|
||||||
|
Additional map[string]any `json:"additional,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Traces []*Trace
|
||||||
|
|
||||||
|
var traces = Traces{}
|
||||||
|
var tracesMu sync.Mutex
|
||||||
|
|
||||||
|
const MaxTraceNum = 1000
|
||||||
|
|
||||||
|
func GetAllTrace() []*Trace {
|
||||||
|
return traces
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *Trace) WithRequest(req *Request) *Trace {
|
||||||
|
if tr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tr.URL = req.RequestURI
|
||||||
|
tr.ReqHeaders = req.Header.Clone()
|
||||||
|
return tr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *Trace) WithResponse(resp *Response) *Trace {
|
||||||
|
if tr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tr.URL = resp.Request.RequestURI
|
||||||
|
tr.ReqHeaders = resp.Request.Header.Clone()
|
||||||
|
tr.RespHeaders = resp.Header.Clone()
|
||||||
|
return tr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *Trace) With(what string, additional any) *Trace {
|
||||||
|
if tr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if tr.Additional == nil {
|
||||||
|
tr.Additional = map[string]any{}
|
||||||
|
}
|
||||||
|
tr.Additional[what] = additional
|
||||||
|
return tr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) EnableTrace() {
|
||||||
|
m.trace = true
|
||||||
|
for _, child := range m.children {
|
||||||
|
child.parent = m
|
||||||
|
child.EnableTrace()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
|
||||||
|
if !m.trace {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return addTrace(&Trace{
|
||||||
|
Time: U.FormatTime(time.Now()),
|
||||||
|
Caller: m.Fullname(),
|
||||||
|
Message: fmt.Sprintf(msg, args...),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
|
||||||
|
return m.AddTracef("%s", msg).WithRequest(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
|
||||||
|
return m.AddTracef("%s", msg).WithResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addTrace(t *Trace) *Trace {
|
||||||
|
tracesMu.Lock()
|
||||||
|
defer tracesMu.Unlock()
|
||||||
|
if len(traces) > MaxTraceNum {
|
||||||
|
traces = traces[1:]
|
||||||
|
}
|
||||||
|
traces = append(traces, t)
|
||||||
|
return t
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -14,7 +15,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var SetXForwarded = &Middleware{
|
var SetXForwarded = &Middleware{
|
||||||
rewrite: func(req *Request) {
|
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||||
req.Header.Del("Forwarded")
|
req.Header.Del("Forwarded")
|
||||||
req.Header.Del(xForwardedFor)
|
req.Header.Del(xForwardedFor)
|
||||||
req.Header.Del(xForwardedHost)
|
req.Header.Del(xForwardedHost)
|
||||||
|
@ -23,7 +24,7 @@ var SetXForwarded = &Middleware{
|
||||||
if err == nil {
|
if err == nil {
|
||||||
req.Header.Set(xForwardedFor, clientIP)
|
req.Header.Set(xForwardedFor, clientIP)
|
||||||
} else {
|
} else {
|
||||||
req.Header.Del(xForwardedFor)
|
req.Header.Set(xForwardedFor, req.RemoteAddr)
|
||||||
}
|
}
|
||||||
req.Header.Set(xForwardedHost, req.Host)
|
req.Header.Set(xForwardedHost, req.Host)
|
||||||
if req.TLS == nil {
|
if req.TLS == nil {
|
||||||
|
@ -31,14 +32,16 @@ var SetXForwarded = &Middleware{
|
||||||
} else {
|
} else {
|
||||||
req.Header.Set(xForwardedProto, "https")
|
req.Header.Set(xForwardedProto, "https")
|
||||||
}
|
}
|
||||||
|
next(w, req)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var HideXForwarded = &Middleware{
|
var HideXForwarded = &Middleware{
|
||||||
rewrite: func(req *Request) {
|
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||||
req.Header.Del("Forwarded")
|
req.Header.Del("Forwarded")
|
||||||
req.Header.Del(xForwardedFor)
|
req.Header.Del(xForwardedFor)
|
||||||
req.Header.Del(xForwardedHost)
|
req.Header.Del(xForwardedHost)
|
||||||
req.Header.Del(xForwardedProto)
|
req.Header.Del(xForwardedProto)
|
||||||
|
next(w, req)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
||||||
rp := NewReverseProxy(entry.URL, trans)
|
rp := NewReverseProxy(entry.URL, trans)
|
||||||
|
|
||||||
if len(entry.Middlewares) > 0 {
|
if len(entry.Middlewares) > 0 {
|
||||||
err := middleware.PatchReverseProxy(rp, entry.Middlewares)
|
err := middleware.PatchReverseProxy(string(entry.Alias), rp, entry.Middlewares)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,3 +32,7 @@ func (cidr *CIDR) Contains(ip net.IP) bool {
|
||||||
func (cidr *CIDR) String() string {
|
func (cidr *CIDR) String() string {
|
||||||
return (*net.IPNet)(cidr).String()
|
return (*net.IPNet)(cidr).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cidr *CIDR) Equals(other *CIDR) bool {
|
||||||
|
return (*net.IPNet)(cidr).IP.Equal(other.IP) && cidr.Mask.String() == other.Mask.String()
|
||||||
|
}
|
||||||
|
|
|
@ -42,6 +42,10 @@ func FormatDuration(d time.Duration) string {
|
||||||
return strings.Join(parts[:len(parts)-1], ", ") + " and " + parts[len(parts)-1]
|
return strings.Join(parts[:len(parts)-1], ", ") + " and " + parts[len(parts)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FormatTime(t time.Time) string {
|
||||||
|
return t.Format("2006-01-02 15:04:05")
|
||||||
|
}
|
||||||
|
|
||||||
func ParseBool(s string) bool {
|
func ParseBool(s string) bool {
|
||||||
switch strings.ToLower(s) {
|
switch strings.ToLower(s) {
|
||||||
case "1", "true", "yes", "on":
|
case "1", "true", "yes", "on":
|
||||||
|
|
|
@ -1,19 +1,25 @@
|
||||||
package functional
|
package functional
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
type Slice[T any] struct {
|
type Slice[T any] struct {
|
||||||
s []T
|
s []T
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSlice[T any]() *Slice[T] {
|
func NewSlice[T any]() *Slice[T] {
|
||||||
return &Slice[T]{make([]T, 0)}
|
return &Slice[T]{s: make([]T, 0)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSliceN[T any](n int) *Slice[T] {
|
func NewSliceN[T any](n int) *Slice[T] {
|
||||||
return &Slice[T]{make([]T, n)}
|
return &Slice[T]{s: make([]T, n)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSliceFrom[T any](s []T) *Slice[T] {
|
func NewSliceFrom[T any](s []T) *Slice[T] {
|
||||||
return &Slice[T]{s}
|
return &Slice[T]{s: s}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Slice[T]) Size() int {
|
func (s *Slice[T]) Size() int {
|
||||||
|
@ -46,6 +52,30 @@ func (s *Slice[T]) AddRange(other *Slice[T]) *Slice[T] {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Slice[T]) SafeAdd(e T) *Slice[T] {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.Add(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Slice[T]) SafeAddRange(other *Slice[T]) *Slice[T] {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.AddRange(other)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Slice[T]) Pop() T {
|
||||||
|
v := s.s[len(s.s)-1]
|
||||||
|
s.s = s.s[:len(s.s)-1]
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Slice[T]) SafePop() T {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.Pop()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Slice[T]) ForEach(do func(T)) {
|
func (s *Slice[T]) ForEach(do func(T)) {
|
||||||
for _, v := range s.s {
|
for _, v := range s.s {
|
||||||
do(v)
|
do(v)
|
||||||
|
@ -57,7 +87,7 @@ func (s *Slice[T]) Map(m func(T) T) *Slice[T] {
|
||||||
for i, v := range s.s {
|
for i, v := range s.s {
|
||||||
n[i] = m(v)
|
n[i] = m(v)
|
||||||
}
|
}
|
||||||
return &Slice[T]{n}
|
return &Slice[T]{s: n}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] {
|
func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] {
|
||||||
|
@ -67,5 +97,13 @@ func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] {
|
||||||
n = append(n, v)
|
n = append(n, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &Slice[T]{n}
|
return &Slice[T]{s: n}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Slice[T]) String() string {
|
||||||
|
out, err := json.MarshalIndent(s.s, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,10 +2,23 @@ package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if common.IsTest {
|
||||||
|
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IgnoreError[Result any](r Result, _ error) Result {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func ExpectNoError(t *testing.T, err error) {
|
func ExpectNoError(t *testing.T, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if err != nil && !reflect.ValueOf(err).IsNil() {
|
if err != nil && !reflect.ValueOf(err).IsNil() {
|
||||||
|
|
Loading…
Add table
Reference in a new issue