fix: rules escaped backslash

This commit is contained in:
yusing 2025-01-09 19:59:53 +08:00
parent 694219c50a
commit 2639c2a836
62 changed files with 2620 additions and 496 deletions

View file

@ -23,10 +23,10 @@ lint:
enabled: enabled:
- hadolint@2.12.1-beta - hadolint@2.12.1-beta
- actionlint@1.7.6 - actionlint@1.7.6
- checkov@3.2.347 - checkov@3.2.350
- git-diff-check - git-diff-check
- gofmt@1.20.4 - gofmt@1.20.4
- golangci-lint@1.62.2 - golangci-lint@1.63.4
- osv-scanner@1.9.2 - osv-scanner@1.9.2
- oxipng@9.1.3 - oxipng@9.1.3
- prettier@3.4.2 - prettier@3.4.2

View file

@ -87,8 +87,10 @@ Setup DNS Records point to machine which runs `GoDoxy`, e.g.
- change username and password for WebUI authentication - change username and password for WebUI authentication
```shell ```shell
sed -i "s|API_USERNAME=.*|API_USERNAME=admin|g" .env USERNAME=admin
sed -i "s|API_PASSWORD=.*|API_PASSWORD=some-strong-password|g" .env PASSWORD=some-password
sed -i "s|API_USERNAME=.*|API_USERNAME=${USERNAME}|g" .env
sed -i "s|API_PASSWORD=.*|API_PASSWORD=${PASSWORD}|g" .env
``` ```
4. _(Optional)_ setup `docker-socket-proxy` other docker nodes (see [Multi docker nodes setup](https://github.com/yusing/go-proxy/wiki/Configurations#multi-docker-nodes-setup)) then add them inside `config.yml` 4. _(Optional)_ setup `docker-socket-proxy` other docker nodes (see [Multi docker nodes setup](https://github.com/yusing/go-proxy/wiki/Configurations#multi-docker-nodes-setup)) then add them inside `config.yml`

View file

@ -87,8 +87,10 @@ _加入我們的 [Discord](https://discord.gg/umReR62nRd) 獲取幫助和討論_
- 更改網頁介面認證的使用者名稱和密碼 - 更改網頁介面認證的使用者名稱和密碼
```shell ```shell
sed -i "s|API_USERNAME=.*|API_USERNAME=admin|g" .env USERNAME=admin
sed -i "s|API_PASSWORD=.*|API_PASSWORD=some-strong-password|g" .env PASSWORD=some-password
sed -i "s|API_USERNAME=.*|API_USERNAME=${USERNAME}|g" .env
sed -i "s|API_PASSWORD=.*|API_PASSWORD=${PASSWORD}|g" .env
``` ```
4. _可選_ 設置其他 Docker 節點的 `docker-socket-proxy`(參見 [多 Docker 節點設置](https://github.com/yusing/go-proxy/wiki/Configurations#multi-docker-nodes-setup)),然後在 `config.yml` 中添加它們 4. _可選_ 設置其他 Docker 節點的 `docker-socket-proxy`(參見 [多 Docker 節點設置](https://github.com/yusing/go-proxy/wiki/Configurations#multi-docker-nodes-setup)),然後在 `config.yml` 中添加它們

View file

@ -3,24 +3,19 @@ package main
import ( import (
"encoding/json" "encoding/json"
"log" "log"
"net/http"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time" "time"
"github.com/yusing/go-proxy/internal" "github.com/yusing/go-proxy/internal"
"github.com/yusing/go-proxy/internal/api"
"github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/api/v1/query"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/entrypoint"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/metrics"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/net/http/server" "github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/pkg" "github.com/yusing/go-proxy/pkg"
) )
@ -98,16 +93,16 @@ func main() {
switch args.Command { switch args.Command {
case common.CommandListRoutes: case common.CommandListRoutes:
cfg.StartProxyProviders() cfg.StartProxyProviders()
printJSON(config.RoutesByAlias()) printJSON(routes.RoutesByAlias())
return return
case common.CommandListConfigs: case common.CommandListConfigs:
printJSON(config.Value()) printJSON(cfg.Value())
return return
case common.CommandDebugListEntries: case common.CommandDebugListEntries:
printJSON(config.DumpEntries()) printJSON(cfg.DumpEntries())
return return
case common.CommandDebugListProviders: case common.CommandDebugListProviders:
printJSON(config.DumpProviders()) printJSON(cfg.DumpProviders())
return return
} }
@ -115,58 +110,25 @@ func main() {
logging.Warn().Msg("API JWT secret is empty, authentication is disabled") logging.Warn().Msg("API JWT secret is empty, authentication is disabled")
} }
cfg.StartProxyProviders() cfg.Start()
config.WatchChanges() config.WatchChanges()
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT)
signal.Notify(sig, syscall.SIGTERM)
signal.Notify(sig, syscall.SIGHUP)
autocert := config.GetAutoCertProvider()
if autocert != nil {
if err := autocert.Setup(); err != nil {
E.LogFatal("autocert setup error", err)
}
} else {
logging.Info().Msg("autocert not configured")
}
server.StartServer(server.Options{
Name: "proxy",
CertProvider: autocert,
HTTPAddr: common.ProxyHTTPAddr,
HTTPSAddr: common.ProxyHTTPSAddr,
Handler: http.HandlerFunc(entrypoint.Handler),
})
// Initialize authentication providers // Initialize authentication providers
if err := auth.Initialize(); err != nil { if err := auth.Initialize(); err != nil {
logging.Warn().Err(err).Msg("Failed to initialize authentication providers") logging.Warn().Err(err).Msg("Failed to initialize authentication providers")
} }
server.StartServer(server.Options{ sig := make(chan os.Signal, 1)
Name: "api", signal.Notify(sig, syscall.SIGINT)
CertProvider: autocert, signal.Notify(sig, syscall.SIGTERM)
HTTPAddr: common.APIHTTPAddr, signal.Notify(sig, syscall.SIGHUP)
Handler: api.NewHandler(),
})
if common.PrometheusEnabled {
server.StartServer(server.Options{
Name: "metrics",
CertProvider: autocert,
HTTPAddr: common.MetricsHTTPAddr,
Handler: metrics.NewHandler(),
})
}
// wait for signal // wait for signal
<-sig <-sig
// gracefully shutdown // grafully shutdown
logging.Info().Msg("shutting down") logging.Info().Msg("shutting down")
_ = task.GracefulShutdown(time.Second * time.Duration(config.Value().TimeoutShutdown)) _ = task.GracefulShutdown(time.Second * time.Duration(cfg.Value().TimeoutShutdown))
} }
func prepareDirectory(dir string) { func prepareDirectory(dir string) {

1
go.mod
View file

@ -69,6 +69,7 @@ require (
go.opentelemetry.io/otel/trace v1.33.0 // indirect go.opentelemetry.io/otel/trace v1.33.0 // indirect
golang.org/x/crypto v0.32.0 // indirect golang.org/x/crypto v0.32.0 // indirect
golang.org/x/mod v0.22.0 // indirect golang.org/x/mod v0.22.0 // indirect
golang.org/x/oauth2 v0.25.0 // indirect
golang.org/x/sync v0.10.0 // indirect golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.29.0 // indirect golang.org/x/sys v0.29.0 // indirect
golang.org/x/tools v0.29.0 // indirect golang.org/x/tools v0.29.0 // indirect

View file

@ -8,20 +8,17 @@ import (
"github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/auth"
. "github.com/yusing/go-proxy/internal/api/v1/utils" . "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types"
) )
type ServeMux struct{ *http.ServeMux } type ServeMux struct{ *http.ServeMux }
func NewServeMux() ServeMux {
return ServeMux{http.NewServeMux()}
}
func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) { func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler)) mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler))
} }
func NewHandler() http.Handler { func NewHandler(cfg config.ConfigInstance) http.Handler {
mux := NewServeMux() mux := ServeMux{http.NewServeMux()}
mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("POST", "/v1/login", auth.LoginHandler) mux.HandleFunc("POST", "/v1/login", auth.LoginHandler)
@ -30,19 +27,25 @@ func NewHandler() http.Handler {
mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler) mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler)
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/reload", v1.Reload) mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(v1.List)) mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(v1.List)) mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(v1.List)) mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent)) mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent))
mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile) mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile)
mux.HandleFunc("GET", "/v1/stats", v1.Stats) mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats))
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS) mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS))
return mux return mux
} }
func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(cfg, w, r)
}
}
// allow only requests to API server with localhost. // allow only requests to API server with localhost.
func checkHost(f http.HandlerFunc) http.HandlerFunc { func checkHost(f http.HandlerFunc) http.HandlerFunc {
if common.IsDebug { if common.IsDebug {

View file

@ -9,7 +9,7 @@ import (
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" config "github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/route/provider"

View file

@ -6,9 +6,10 @@ import (
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route/routes"
route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -24,7 +25,7 @@ const (
ListTasks = "tasks" ListTasks = "tasks"
) )
func List(w http.ResponseWriter, r *http.Request) { func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
what := r.PathValue("what") what := r.PathValue("what")
if what == "" { if what == "" {
what = ListRoutes what = ListRoutes
@ -40,7 +41,7 @@ func List(w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, route) U.RespondJSON(w, r, route)
} }
case ListRoutes: case ListRoutes:
U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type")))) U.RespondJSON(w, r, routes.RoutesByAlias(route.RouteType(r.FormValue("type"))))
case ListFiles: case ListFiles:
listFiles(w, r) listFiles(w, r)
case ListMiddlewares: case ListMiddlewares:
@ -48,9 +49,9 @@ func List(w http.ResponseWriter, r *http.Request) {
case ListMiddlewareTraces: case ListMiddlewareTraces:
U.RespondJSON(w, r, middleware.GetAllTrace()) U.RespondJSON(w, r, middleware.GetAllTrace())
case ListMatchDomains: case ListMatchDomains:
U.RespondJSON(w, r, config.Value().MatchDomains) U.RespondJSON(w, r, cfg.Value().MatchDomains)
case ListHomepageConfig: case ListHomepageConfig:
U.RespondJSON(w, r, config.HomepageConfig()) U.RespondJSON(w, r, routes.HomepageConfig(cfg.Value().Homepage.UseDefaultCategories))
case ListTasks: case ListTasks:
U.RespondJSON(w, r, task.DebugTaskList()) U.RespondJSON(w, r, task.DebugTaskList())
default: default:
@ -60,9 +61,9 @@ func List(w http.ResponseWriter, r *http.Request) {
func listRoute(which string) any { func listRoute(which string) any {
if which == "" || which == "all" { if which == "" || which == "all" {
return config.RoutesByAlias() return routes.RoutesByAlias()
} }
routes := config.RoutesByAlias() routes := routes.RoutesByAlias()
route, ok := routes[which] route, ok := routes[which]
if !ok { if !ok {
return nil return nil

View file

@ -4,11 +4,11 @@ import (
"net/http" "net/http"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config" config "github.com/yusing/go-proxy/internal/config/types"
) )
func Reload(w http.ResponseWriter, r *http.Request) { func Reload(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
if err := config.Reload(); err != nil { if err := cfg.Reload(); err != nil {
U.HandleErr(w, r, err) U.HandleErr(w, r, err)
return return
} }

View file

@ -9,25 +9,25 @@ import (
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
func Stats(w http.ResponseWriter, r *http.Request) { func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, getStats()) U.RespondJSON(w, r, getStats(cfg))
} }
func StatsWS(w http.ResponseWriter, r *http.Request) { func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
var originPats []string var originPats []string
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
if len(config.Value().MatchDomains) == 0 { if len(cfg.Value().MatchDomains) == 0 {
U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins") U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
originPats = []string{"*"} originPats = []string{"*"}
} else { } else {
originPats = make([]string, len(config.Value().MatchDomains)) originPats = make([]string, len(cfg.Value().MatchDomains))
for i, domain := range config.Value().MatchDomains { for i, domain := range cfg.Value().MatchDomains {
originPats[i] = "*" + domain originPats[i] = "*" + domain
} }
originPats = append(originPats, localAddresses...) originPats = append(originPats, localAddresses...)
@ -52,7 +52,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
stats := getStats() stats := getStats(cfg)
if err := wsjson.Write(ctx, conn, stats); err != nil { if err := wsjson.Write(ctx, conn, stats); err != nil {
U.LogError(r).Msg("failed to write JSON") U.LogError(r).Msg("failed to write JSON")
return return
@ -62,9 +62,9 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
var startTime = time.Now() var startTime = time.Now()
func getStats() map[string]any { func getStats(cfg config.ConfigInstance) map[string]any {
return map[string]any{ return map[string]any{
"proxies": config.Statistics(), "proxies": cfg.Statistics(),
"uptime": strutils.FormatDuration(time.Since(startTime)), "uptime": strutils.FormatDuration(time.Since(startTime)),
} }
} }

View file

@ -7,6 +7,7 @@ import (
"os" "os"
"path" "path"
"reflect" "reflect"
"runtime"
"sort" "sort"
"time" "time"
@ -148,28 +149,40 @@ func (p *Provider) ShouldRenewOn() time.Time {
panic("no certificate available") panic("no certificate available")
} }
func (p *Provider) ScheduleRenewal() { func (p *Provider) ScheduleRenewal(parent task.Parent) {
if p.GetName() == ProviderLocal { if p.GetName() == ProviderLocal {
return return
} }
go func() { go func() {
task := task.RootTask("cert-renew-scheduler", true) lastErrOn := time.Time{}
renewalTime := p.ShouldRenewOn()
timer := time.NewTimer(time.Until(renewalTime))
defer timer.Stop()
task := parent.Subtask("cert-renew-scheduler")
defer task.Finish(nil) defer task.Finish(nil)
for { for {
renewalTime := p.ShouldRenewOn()
timer := time.NewTimer(time.Until(renewalTime))
select { select {
case <-task.Context().Done(): case <-task.Context().Done():
timer.Stop()
return return
case <-timer.C: case <-timer.C:
// Retry after 1 hour on failure
if time.Now().Before(lastErrOn.Add(time.Hour)) {
continue
}
if err := p.renewIfNeeded(); err != nil { if err := p.renewIfNeeded(); err != nil {
E.LogWarn("cert renew failed", err, &logger) E.LogWarn("cert renew failed", err, &logger)
// Retry after 1 hour on failure lastErrOn = time.Now()
time.Sleep(time.Hour) continue
} }
// Reset on success
lastErrOn = time.Time{}
renewalTime = p.ShouldRenewOn()
timer.Reset(time.Until(renewalTime))
default:
// Allow other tasks to run
runtime.Gosched()
} }
} }
}() }()

View file

@ -18,8 +18,6 @@ func (p *Provider) Setup() (err E.Error) {
} }
} }
p.ScheduleRenewal()
for _, expiry := range p.GetExpiries() { for _, expiry := range p.GetExpiries() {
logger.Info().Msg("certificate expire on " + strutils.FormatTime(expiry)) logger.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
break break

View file

@ -11,6 +11,105 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
var (
prefixes = []string{"GODOXY_", "GOPROXY_", ""}
IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test")
IsDebug = GetEnvBool("DEBUG", IsTest)
IsDebugSkipAuth = GetEnvBool("DEBUG_SKIP_AUTH", false)
IsTrace = GetEnvBool("TRACE", false) && IsDebug
IsProduction = !IsTest && !IsDebug
ProxyHTTPAddr,
ProxyHTTPHost,
ProxyHTTPPort,
ProxyHTTPURL = GetAddrEnv("HTTP_ADDR", ":80", "http")
ProxyHTTPSAddr,
ProxyHTTPSHost,
ProxyHTTPSPort,
ProxyHTTPSURL = GetAddrEnv("HTTPS_ADDR", ":443", "https")
APIHTTPAddr,
APIHTTPHost,
APIHTTPPort,
APIHTTPURL = GetAddrEnv("API_ADDR", "127.0.0.1:8888", "http")
MetricsHTTPAddr,
MetricsHTTPHost,
MetricsHTTPPort,
MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http")
PrometheusEnabled = MetricsHTTPURL != ""
APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", ""))
APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour)
APIUser = GetEnvString("API_USER", "admin")
APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password"))
)
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
var value string
var ok bool
for _, prefix := range prefixes {
value, ok = os.LookupEnv(prefix + key)
if ok && value != "" {
break
}
}
if !ok || value == "" {
return defaultValue
}
parsed, err := parser(value)
if err == nil {
return parsed
}
log.Fatal().Err(err).Msgf("env %s: invalid %T value: %s", key, parsed, value)
return defaultValue
}
func GetEnvString(key string, defaultValue string) string {
return GetEnv(key, defaultValue, func(s string) (string, error) {
return s, nil
})
}
func GetEnvBool(key string, defaultValue bool) bool {
return GetEnv(key, defaultValue, strconv.ParseBool)
}
func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL string) {
addr = GetEnvString(key, defaultValue)
if addr == "" {
return
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
log.Fatal().Msgf("env %s: invalid address: %s", key, addr)
}
if host == "" {
host = "localhost"
}
fullURL = fmt.Sprintf("%s://%s:%s", scheme, host, port)
return
}
func GetDurationEnv(key string, defaultValue time.Duration) time.Duration {
return GetEnv(key, defaultValue, time.ParseDuration)
}
package common
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
)
var ( var (
prefixes = []string{"GODOXY_", "GOPROXY_", ""} prefixes = []string{"GODOXY_", "GOPROXY_", ""}

View file

@ -7,12 +7,15 @@ import (
"sync" "sync"
"time" "time"
"github.com/yusing/go-proxy/internal/api"
"github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/entrypoint" "github.com/yusing/go-proxy/internal/entrypoint"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/metrics"
"github.com/yusing/go-proxy/internal/net/http/server"
"github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/notif"
proxy "github.com/yusing/go-proxy/internal/route/provider" proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
@ -26,7 +29,9 @@ type Config struct {
value *types.Config value *types.Config
providers F.Map[string, *proxy.Provider] providers F.Map[string, *proxy.Provider]
autocertProvider *autocert.Provider autocertProvider *autocert.Provider
task *task.Task entrypoint *entrypoint.Entrypoint
task *task.Task
} }
var ( var (
@ -45,15 +50,18 @@ Make sure you rename it back before next time you start.`
You may run "ls-config" to show or dump the current config.` You may run "ls-config" to show or dump the current config.`
) )
var Validate = types.Validate
func GetInstance() *Config { func GetInstance() *Config {
return instance return instance
} }
func newConfig() *Config { func newConfig() *Config {
return &Config{ return &Config{
value: types.DefaultConfig(), value: types.DefaultConfig(),
providers: F.NewMapOf[string, *proxy.Provider](), providers: F.NewMapOf[string, *proxy.Provider](),
task: task.RootTask("config", false), entrypoint: entrypoint.NewEntrypoint(),
task: task.RootTask("config", false),
} }
} }
@ -66,11 +74,6 @@ func Load() (*Config, E.Error) {
return instance, instance.load() return instance, instance.load()
} }
func Validate(data []byte) E.Error {
var model types.Config
return utils.DeserializeYAML(data, &model)
}
func MatchDomains() []string { func MatchDomains() []string {
return instance.value.MatchDomains return instance.value.MatchDomains
} }
@ -101,6 +104,7 @@ func OnConfigChange(ev []events.Event) {
} }
if err := Reload(); err != nil { if err := Reload(); err != nil {
logger.Warn().Msg("using last config")
// recovered in event queue // recovered in event queue
panic(err) panic(err)
} }
@ -122,15 +126,19 @@ func Reload() E.Error {
// -> replace config -> start new subtasks // -> replace config -> start new subtasks
instance.task.Finish("config changed") instance.task.Finish("config changed")
instance = newCfg instance = newCfg
instance.StartProxyProviders() instance.Start()
return nil return nil
} }
func Value() types.Config { func (cfg *Config) Value() *types.Config {
return *instance.value return instance.value
} }
func GetAutoCertProvider() *autocert.Provider { func (cfg *Config) Reload() E.Error {
return Reload()
}
func (cfg *Config) AutoCertProvider() *autocert.Provider {
return instance.autocertProvider return instance.autocertProvider
} }
@ -138,6 +146,26 @@ func (cfg *Config) Task() *task.Task {
return cfg.task return cfg.task
} }
func (cfg *Config) Start() {
cfg.StartAutoCert()
cfg.StartProxyProviders()
cfg.StartServers()
}
func (cfg *Config) StartAutoCert() {
autocert := cfg.autocertProvider
if autocert == nil {
logging.Info().Msg("autocert not configured")
return
}
if err := autocert.Setup(); err != nil {
E.LogFatal("autocert setup error", err)
} else {
autocert.ScheduleRenewal(cfg.task)
}
}
func (cfg *Config) StartProxyProviders() { func (cfg *Config) StartProxyProviders() {
errs := cfg.providers.CollectErrorsParallel( errs := cfg.providers.CollectErrorsParallel(
func(_ string, p *proxy.Provider) error { func(_ string, p *proxy.Provider) error {
@ -149,6 +177,30 @@ func (cfg *Config) StartProxyProviders() {
} }
} }
func (cfg *Config) StartServers() {
server.StartServer(cfg.task, server.Options{
Name: "proxy",
CertProvider: cfg.AutoCertProvider(),
HTTPAddr: common.ProxyHTTPAddr,
HTTPSAddr: common.ProxyHTTPSAddr,
Handler: cfg.entrypoint,
})
server.StartServer(cfg.task, server.Options{
Name: "api",
CertProvider: cfg.AutoCertProvider(),
HTTPAddr: common.APIHTTPAddr,
Handler: api.NewHandler(cfg),
})
if common.PrometheusEnabled {
server.StartServer(cfg.task, server.Options{
Name: "metrics",
CertProvider: cfg.AutoCertProvider(),
HTTPAddr: common.MetricsHTTPAddr,
Handler: metrics.NewHandler(),
})
}
}
func (cfg *Config) load() E.Error { func (cfg *Config) load() E.Error {
const errMsg = "config load error" const errMsg = "config load error"
@ -164,8 +216,8 @@ func (cfg *Config) load() E.Error {
// errors are non fatal below // errors are non fatal below
errs := E.NewBuilder(errMsg) errs := E.NewBuilder(errMsg)
errs.Add(entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
errs.Add(entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
errs.Add(cfg.initNotification(model.Providers.Notification)) errs.Add(cfg.initNotification(model.Providers.Notification))
errs.Add(cfg.initAutoCert(model.AutoCert)) errs.Add(cfg.initAutoCert(model.AutoCert))
errs.Add(cfg.loadRouteProviders(&model.Providers)) errs.Add(cfg.loadRouteProviders(&model.Providers))
@ -176,7 +228,8 @@ func (cfg *Config) load() E.Error {
model.MatchDomains[i] = "." + domain model.MatchDomains[i] = "." + domain
} }
} }
entrypoint.SetFindRouteDomains(model.MatchDomains) cfg.entrypoint.SetFindRouteDomains(model.MatchDomains)
return errs.Error() return errs.Error()
} }

View file

@ -1,20 +1,14 @@
package config package config
import ( import (
"strings"
"github.com/yusing/go-proxy/internal/homepage"
route "github.com/yusing/go-proxy/internal/route" route "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/entry" "github.com/yusing/go-proxy/internal/route/provider"
proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
func DumpEntries() map[string]*types.RawEntry { func (cfg *Config) DumpEntries() map[string]*types.RawEntry {
entries := make(map[string]*types.RawEntry) entries := make(map[string]*types.RawEntry)
instance.providers.RangeAll(func(_ string, p *proxy.Provider) { cfg.providers.RangeAll(func(_ string, p *provider.Provider) {
p.RangeRoutes(func(alias string, r *route.Route) { p.RangeRoutes(func(alias string, r *route.Route) {
entries[alias] = r.Entry entries[alias] = r.Entry
}) })
@ -22,107 +16,20 @@ func DumpEntries() map[string]*types.RawEntry {
return entries return entries
} }
func DumpProviders() map[string]*proxy.Provider { func (cfg *Config) DumpProviders() map[string]*provider.Provider {
entries := make(map[string]*proxy.Provider) entries := make(map[string]*provider.Provider)
instance.providers.RangeAll(func(name string, p *proxy.Provider) { cfg.providers.RangeAll(func(name string, p *provider.Provider) {
entries[name] = p entries[name] = p
}) })
return entries return entries
} }
func HomepageConfig() homepage.Config { func (cfg *Config) Statistics() map[string]any {
hpCfg := homepage.NewHomePageConfig()
routes.GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) {
en := r.RawEntry()
item := en.Homepage
if item == nil {
item = new(homepage.Item)
item.Show = true
}
if !item.IsEmpty() {
item.Show = true
}
if !item.Show {
return
}
item.Alias = alias
if item.Name == "" {
item.Name = strutils.Title(
strings.ReplaceAll(
strings.ReplaceAll(alias, "-", " "),
"_", " ",
),
)
}
if instance.value.Homepage.UseDefaultCategories {
if en.Container != nil && item.Category == "" {
if category, ok := homepage.PredefinedCategories[en.Container.ImageName]; ok {
item.Category = category
}
}
if item.Category == "" {
if category, ok := homepage.PredefinedCategories[strings.ToLower(alias)]; ok {
item.Category = category
}
}
}
switch {
case entry.IsDocker(r):
if item.Category == "" {
item.Category = "Docker"
}
item.SourceType = string(proxy.ProviderTypeDocker)
case entry.UseLoadBalance(r):
if item.Category == "" {
item.Category = "Load-balanced"
}
item.SourceType = "loadbalancer"
default:
if item.Category == "" {
item.Category = "Others"
}
item.SourceType = string(proxy.ProviderTypeFile)
}
item.AltURL = r.TargetURL().String()
hpCfg.Add(item)
})
return hpCfg
}
func RoutesByAlias(typeFilter ...route.RouteType) map[string]any {
rts := make(map[string]any)
if len(typeFilter) == 0 || typeFilter[0] == "" {
typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream}
}
for _, t := range typeFilter {
switch t {
case route.RouteTypeReverseProxy:
routes.GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) {
rts[alias] = r
})
case route.RouteTypeStream:
routes.GetStreamRoutes().RangeAll(func(alias string, r types.StreamRoute) {
rts[alias] = r
})
}
}
return rts
}
func Statistics() map[string]any {
nTotalStreams := 0 nTotalStreams := 0
nTotalRPs := 0 nTotalRPs := 0
providerStats := make(map[string]proxy.ProviderStats) providerStats := make(map[string]provider.ProviderStats)
instance.providers.RangeAll(func(name string, p *proxy.Provider) { cfg.providers.RangeAll(func(name string, p *provider.Provider) {
stats := p.Statistics() stats := p.Statistics()
providerStats[name] = stats providerStats[name] = stats

View file

@ -3,6 +3,8 @@ package types
import ( import (
"github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
E "github.com/yusing/go-proxy/internal/error"
) )
type ( type (
@ -24,6 +26,12 @@ type (
AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"` AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"`
} }
NotificationConfig map[string]any NotificationConfig map[string]any
ConfigInstance interface {
Value() *Config
Reload() E.Error
Statistics() map[string]any
}
) )
func DefaultConfig() *Config { func DefaultConfig() *Config {
@ -35,6 +43,11 @@ func DefaultConfig() *Config {
} }
} }
func Validate(data []byte) E.Error {
var model Config
return utils.DeserializeYAML(data, &model)
}
func init() { func init() {
utils.RegisterDefaultValueFactory(DefaultConfig) utils.RegisterDefaultValueFactory(DefaultConfig)
} }

View file

@ -28,16 +28,17 @@ type (
PrivateIP string `json:"private_ip"` PrivateIP string `json:"private_ip"`
NetworkMode string `json:"network_mode"` NetworkMode string `json:"network_mode"`
Aliases []string `json:"aliases"` Aliases []string `json:"aliases"`
IsExcluded bool `json:"is_excluded"` IsExcluded bool `json:"is_excluded"`
IsExplicit bool `json:"is_explicit"` IsExplicit bool `json:"is_explicit"`
IsDatabase bool `json:"is_database"` IsDatabase bool `json:"is_database"`
IdleTimeout string `json:"idle_timeout,omitempty"` IdleTimeout string `json:"idle_timeout,omitempty"`
WakeTimeout string `json:"wake_timeout,omitempty"` WakeTimeout string `json:"wake_timeout,omitempty"`
StopMethod string `json:"stop_method,omitempty"` StopMethod string `json:"stop_method,omitempty"`
StopTimeout string `json:"stop_timeout,omitempty"` // stop_method = "stop" only StopTimeout string `json:"stop_timeout,omitempty"` // stop_method = "stop" only
StopSignal string `json:"stop_signal,omitempty"` // stop_method = "stop" | "kill" only StopSignal string `json:"stop_signal,omitempty"` // stop_method = "stop" | "kill" only
Running bool `json:"running"` StartEndpoint string `json:"start_endpoint,omitempty"`
Running bool `json:"running"`
} }
) )
@ -58,16 +59,17 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) {
PrivatePortMapping: helper.getPrivatePortMapping(), PrivatePortMapping: helper.getPrivatePortMapping(),
NetworkMode: c.HostConfig.NetworkMode, NetworkMode: c.HostConfig.NetworkMode,
Aliases: helper.getAliases(), Aliases: helper.getAliases(),
IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)), IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)),
IsExplicit: isExplicit, IsExplicit: isExplicit,
IsDatabase: helper.isDatabase(), IsDatabase: helper.isDatabase(),
IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout), IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout),
WakeTimeout: helper.getDeleteLabel(LabelWakeTimeout), WakeTimeout: helper.getDeleteLabel(LabelWakeTimeout),
StopMethod: helper.getDeleteLabel(LabelStopMethod), StopMethod: helper.getDeleteLabel(LabelStopMethod),
StopTimeout: helper.getDeleteLabel(LabelStopTimeout), StopTimeout: helper.getDeleteLabel(LabelStopTimeout),
StopSignal: helper.getDeleteLabel(LabelStopSignal), StopSignal: helper.getDeleteLabel(LabelStopSignal),
Running: c.Status == "running" || c.State == "running", StartEndpoint: helper.getDeleteLabel(LabelStartEndpoint),
Running: c.Status == "running" || c.State == "running",
} }
res.setPrivateIP(helper) res.setPrivateIP(helper)
res.setPublicIP() res.setPublicIP()

View file

@ -2,6 +2,8 @@ package types
import ( import (
"errors" "errors"
"net/url"
"strings"
"time" "time"
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
@ -10,11 +12,12 @@ import (
type ( type (
Config struct { Config struct {
IdleTimeout time.Duration `json:"idle_timeout,omitempty"` IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout,omitempty"` WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument
StopMethod StopMethod `json:"stop_method,omitempty"` StopMethod StopMethod `json:"stop_method,omitempty"`
StopSignal Signal `json:"stop_signal,omitempty"` StopSignal Signal `json:"stop_signal,omitempty"`
StartEndpoint string `json:"start_endpoint,omitempty"` // Optional path that must be hit to start container
DockerHost string `json:"docker_host,omitempty"` DockerHost string `json:"docker_host,omitempty"`
ContainerName string `json:"container_name,omitempty"` ContainerName string `json:"container_name,omitempty"`
@ -58,17 +61,19 @@ func ValidateConfig(cont *docker.Container) (*Config, E.Error) {
stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout) stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout)
stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod) stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod)
signal := E.Collect(errs, validateSignal, cont.StopSignal) signal := E.Collect(errs, validateSignal, cont.StopSignal)
startEndpoint := E.Collect(errs, validateStartEndpoint, cont.StartEndpoint)
if errs.HasError() { if errs.HasError() {
return nil, errs.Error() return nil, errs.Error()
} }
return &Config{ return &Config{
IdleTimeout: idleTimeout, IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout, WakeTimeout: wakeTimeout,
StopTimeout: int(stopTimeout.Seconds()), StopTimeout: int(stopTimeout.Seconds()),
StopMethod: stopMethod, StopMethod: stopMethod,
StopSignal: signal, StopSignal: signal,
StartEndpoint: startEndpoint,
DockerHost: cont.DockerHost, DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName, ContainerName: cont.ContainerName,
@ -104,3 +109,21 @@ func validateStopMethod(s string) (StopMethod, error) {
return "", errors.New("invalid stop method " + s) return "", errors.New("invalid stop method " + s)
} }
} }
func validateStartEndpoint(s string) (string, error) {
if s == "" {
return "", nil
}
// checks needed as of Go 1.6 because of change https://github.com/golang/go/commit/617c93ce740c3c3cc28cdd1a0d712be183d0b328#diff-6c2d018290e298803c0c9419d8739885L195
// emulate browser and strip the '#' suffix prior to validation. see issue-#237
if i := strings.Index(s, "#"); i > -1 {
s = s[:i]
}
if len(s) == 0 {
return "", errors.New("start endpoint must not be empty if defined")
}
if _, err := url.ParseRequestURI(s); err != nil {
return "", err
}
return s, nil
}

View file

@ -0,0 +1,47 @@
package types
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestValidateStartEndpoint(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{
name: "valid",
input: "/start",
wantErr: false,
},
{
name: "invalid",
input: "../foo",
wantErr: true,
},
{
name: "single fragment",
input: "#",
wantErr: true,
},
{
name: "empty",
input: "",
wantErr: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s, err := validateStartEndpoint(tc.input)
if err == nil {
ExpectEqual(t, s, tc.input)
}
if (err != nil) != tc.wantErr {
t.Errorf("validateStartEndpoint() error = %v, wantErr %t", err, tc.wantErr)
}
})
}
}

View file

@ -8,7 +8,7 @@ import (
"github.com/yusing/go-proxy/internal/docker/idlewatcher/types" "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/metrics" "github.com/yusing/go-proxy/internal/metrics"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/reverseproxy"
net "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"
route "github.com/yusing/go-proxy/internal/route/types" route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
@ -22,7 +22,7 @@ type (
waker struct { waker struct {
_ U.NoCopy _ U.NoCopy
rp *gphttp.ReverseProxy rp *reverseproxy.ReverseProxy
stream net.Stream stream net.Stream
hc health.HealthChecker hc health.HealthChecker
metric *metrics.Gauge metric *metrics.Gauge
@ -38,7 +38,7 @@ const (
// TODO: support stream // TODO: support stream
func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { func newWaker(parent task.Parent, entry route.Entry, rp *reverseproxy.ReverseProxy, stream net.Stream) (Waker, E.Error) {
hcCfg := entry.RawEntry().HealthCheck hcCfg := entry.RawEntry().HealthCheck
hcCfg.Timeout = idleWakerCheckTimeout hcCfg.Timeout = idleWakerCheckTimeout
@ -71,7 +71,7 @@ func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, st
} }
// lifetime should follow route provider. // lifetime should follow route provider.
func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *reverseproxy.ReverseProxy) (Waker, E.Error) {
return newWaker(parent, entry, rp, nil) return newWaker(parent, entry, rp, nil)
} }

View file

@ -34,6 +34,12 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
return true return true
} }
// Check if start endpoint is configured and request path matches
if w.StartEndpoint != "" && r.URL.Path != w.StartEndpoint {
http.Error(rw, "Forbidden: Container can only be started via configured start endpoint", http.StatusForbidden)
return false
}
if r.Body != nil { if r.Body != nil {
defer r.Body.Close() defer r.Body.Close()
} }

View file

@ -5,11 +5,12 @@ const (
NSProxy = "proxy" NSProxy = "proxy"
LabelAliases = NSProxy + ".aliases" LabelAliases = NSProxy + ".aliases"
LabelExclude = NSProxy + ".exclude" LabelExclude = NSProxy + ".exclude"
LabelIdleTimeout = NSProxy + ".idle_timeout" LabelIdleTimeout = NSProxy + ".idle_timeout"
LabelWakeTimeout = NSProxy + ".wake_timeout" LabelWakeTimeout = NSProxy + ".wake_timeout"
LabelStopMethod = NSProxy + ".stop_method" LabelStopMethod = NSProxy + ".stop_method"
LabelStopTimeout = NSProxy + ".stop_timeout" LabelStopTimeout = NSProxy + ".stop_timeout"
LabelStopSignal = NSProxy + ".stop_signal" LabelStopSignal = NSProxy + ".stop_signal"
LabelStartEndpoint = NSProxy + ".start_endpoint"
) )

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"sync"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/http/accesslog"
@ -17,32 +16,31 @@ import (
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
var findRouteFunc = findRouteAnyDomain type Entrypoint struct {
middleware *middleware.Middleware
var ( accessLogger *accesslog.AccessLogger
epMiddleware *middleware.Middleware findRouteFunc func(host string) (route.HTTPRoute, error)
epMiddlewareMu sync.Mutex }
epAccessLogger *accesslog.AccessLogger
epAccessLoggerMu sync.Mutex
)
var ErrNoSuchRoute = errors.New("no such route") var ErrNoSuchRoute = errors.New("no such route")
func SetFindRouteDomains(domains []string) { func NewEntrypoint() *Entrypoint {
if len(domains) == 0 { return &Entrypoint{
findRouteFunc = findRouteAnyDomain findRouteFunc: findRouteAnyDomain,
} else {
findRouteFunc = findRouteByDomains(domains)
} }
} }
func SetMiddlewares(mws []map[string]any) error { func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
epMiddlewareMu.Lock() if len(domains) == 0 {
defer epMiddlewareMu.Unlock() ep.findRouteFunc = findRouteAnyDomain
} else {
ep.findRouteFunc = findRouteByDomains(domains)
}
}
func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
if len(mws) == 0 { if len(mws) == 0 {
epMiddleware = nil ep.middleware = nil
return nil return nil
} }
@ -50,22 +48,19 @@ func SetMiddlewares(mws []map[string]any) error {
if err != nil { if err != nil {
return err return err
} }
epMiddleware = mid ep.middleware = mid
logger.Debug().Msg("entrypoint middleware loaded") logger.Debug().Msg("entrypoint middleware loaded")
return nil return nil
} }
func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
epAccessLoggerMu.Lock()
defer epAccessLoggerMu.Unlock()
if cfg == nil { if cfg == nil {
epAccessLogger = nil ep.accessLogger = nil
return return
} }
epAccessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg)
if err != nil { if err != nil {
return return
} }
@ -73,28 +68,18 @@ func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
return return
} }
func Handler(w http.ResponseWriter, r *http.Request) { func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mux, err := findRouteFunc(r.Host) mux, err := ep.findRouteFunc(r.Host)
if err == nil { if err == nil {
if epAccessLogger != nil { if ep.accessLogger != nil {
epMiddlewareMu.Lock() w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
if epAccessLogger != nil { ep.accessLogger.Log(r, resp)
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { return nil
epAccessLogger.Log(r, resp) })
return nil
})
}
epMiddlewareMu.Unlock()
} }
if epMiddleware != nil { if ep.middleware != nil {
epMiddlewareMu.Lock() ep.middleware.ServeHTTP(mux.ServeHTTP, w, r)
if epMiddleware != nil { return
mid := epMiddleware
epMiddlewareMu.Unlock()
mid.ServeHTTP(mux.ServeHTTP, w, r)
return
}
epMiddlewareMu.Unlock()
} }
mux.ServeHTTP(w, r) mux.ServeHTTP(w, r)
return return

View file

@ -8,18 +8,19 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
var r route.HTTPRoute var (
r route.HTTPRoute
ep = NewEntrypoint()
)
func run(t *testing.T, match []string, noMatch []string) { func run(t *testing.T, match []string, noMatch []string) {
t.Helper() t.Helper()
t.Cleanup(routes.TestClear) t.Cleanup(routes.TestClear)
t.Cleanup(func() { t.Cleanup(func() { ep.SetFindRouteDomains(nil) })
SetFindRouteDomains(nil)
})
for _, test := range match { for _, test := range match {
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
found, err := findRouteFunc(test) found, err := ep.findRouteFunc(test)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectTrue(t, found == &r) ExpectTrue(t, found == &r)
}) })
@ -27,7 +28,7 @@ func run(t *testing.T, match []string, noMatch []string) {
for _, test := range noMatch { for _, test := range noMatch {
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
_, err := findRouteFunc(test) _, err := ep.findRouteFunc(test)
ExpectError(t, ErrNoSuchRoute, err) ExpectError(t, ErrNoSuchRoute, err)
}) })
} }
@ -72,7 +73,7 @@ func TestFindRouteExactHostMatch(t *testing.T) {
} }
func TestFindRouteByDomains(t *testing.T) { func TestFindRouteByDomains(t *testing.T) {
SetFindRouteDomains([]string{ ep.SetFindRouteDomains([]string{
".domain.com", ".domain.com",
".sub.domain.com", ".sub.domain.com",
}) })
@ -97,7 +98,7 @@ func TestFindRouteByDomains(t *testing.T) {
} }
func TestFindRouteByDomainsExactMatch(t *testing.T) { func TestFindRouteByDomainsExactMatch(t *testing.T) {
SetFindRouteDomains([]string{ ep.SetFindRouteDomains([]string{
".domain.com", ".domain.com",
".sub.domain.com", ".sub.domain.com",
}) })

View file

@ -8,8 +8,8 @@ import (
//nolint:errname //nolint:errname
type withSubject struct { type withSubject struct {
Subject string `json:"subject"` Subjects []string `json:"subjects"`
Err error `json:"err"` Err error `json:"err"`
} }
const subjectSep = " > " const subjectSep = " > "
@ -30,13 +30,18 @@ func PrependSubject(subject string, err error) error {
case Error: case Error:
return err.Subject(subject) return err.Subject(subject)
} }
return &withSubject{subject, err} return &withSubject{[]string{subject}, err}
} }
func (err *withSubject) Prepend(subject string) *withSubject { func (err *withSubject) Prepend(subject string) *withSubject {
clone := *err clone := *err
if subject != "" { if subject != "" {
clone.Subject = subject + subjectSep + clone.Subject switch subject[0] {
case '[', '(', '{':
clone.Subjects[len(clone.Subjects)-1] += subject
default:
clone.Subjects = append(clone.Subjects, subject)
}
} }
return &clone return &clone
} }
@ -50,7 +55,22 @@ func (err *withSubject) Unwrap() error {
} }
func (err *withSubject) Error() string { func (err *withSubject) Error() string {
subjects := strings.Split(err.Subject, subjectSep) // subject is in reversed order
subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1]) n := len(err.Subjects)
return strings.Join(subjects, subjectSep) + ": " + err.Err.Error() size := 0
errStr := err.Err.Error()
var sb strings.Builder
for _, s := range err.Subjects {
size += len(s)
}
sb.Grow(size + 2 + n*len(subjectSep) + len(errStr))
for i := n - 1; i > 0; i-- {
sb.WriteString(err.Subjects[i])
sb.WriteString(subjectSep)
}
sb.WriteString(highlight(err.Subjects[0]))
sb.WriteString(": ")
sb.WriteString(errStr)
return sb.String()
} }

View file

@ -129,7 +129,6 @@ func (l *AccessLogger) Flush(force bool) {
l.write(l.buf.Bytes()) l.write(l.buf.Bytes())
l.buf.Reset() l.buf.Reset()
l.bufMu.Unlock() l.bufMu.Unlock()
logger.Debug().Msg("access log flushed to " + l.io.Name())
} }
} }
@ -170,5 +169,7 @@ func (l *AccessLogger) write(data []byte) {
l.io.Unlock() l.io.Unlock()
if err != nil { if err != nil {
l.handleErr(err) l.handleErr(err)
} else {
logger.Debug().Msg("access log flushed to " + l.io.Name())
} }
} }

View file

@ -3,36 +3,66 @@ package accesslog
import ( import (
"fmt" "fmt"
"os" "os"
"path"
"sync" "sync"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
) )
type File struct { type File struct {
*os.File *os.File
sync.Mutex sync.Mutex
// os.File.Name() may not equal to key of `openedFiles`.
// Store it for later delete from `openedFiles`.
path string
refCount *utils.RefCount
} }
var ( var (
openedFiles = make(map[string]AccessLogIO) openedFiles = make(map[string]*File)
openedFilesMu sync.Mutex openedFilesMu sync.Mutex
) )
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
openedFilesMu.Lock() openedFilesMu.Lock()
var io AccessLogIO var file *File
if opened, ok := openedFiles[cfg.Path]; ok { path := path.Clean(cfg.Path)
io = opened if opened, ok := openedFiles[path]; ok {
opened.refCount.Add()
file = opened
} else { } else {
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
if err != nil { if err != nil {
openedFilesMu.Unlock()
return nil, fmt.Errorf("access log open error: %w", err) return nil, fmt.Errorf("access log open error: %w", err)
} }
io = &File{File: f} file = &File{File: f, path: path, refCount: utils.NewRefCounter()}
openedFiles[cfg.Path] = io openedFiles[path] = file
go file.closeOnZero()
} }
openedFilesMu.Unlock() openedFilesMu.Unlock()
return NewAccessLogger(parent, io, cfg), nil return NewAccessLogger(parent, file, cfg), nil
}
func (f *File) Close() error {
f.refCount.Sub()
return nil
}
func (f *File) closeOnZero() {
defer logger.Debug().
Str("path", f.path).
Msg("access log closed")
<-f.refCount.Zero()
openedFilesMu.Lock()
delete(openedFiles, f.path)
openedFilesMu.Unlock()
f.File.Close()
} }

View file

@ -2,6 +2,10 @@ package http
import ( import (
"net/http" "net/http"
"net/textproto"
"github.com/yusing/go-proxy/internal/utils/strutils"
"golang.org/x/net/http/httpguts"
) )
const ( const (
@ -22,6 +26,48 @@ const (
HeaderContentLength = "Content-Length" HeaderContentLength = "Content-Length"
) )
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
func UpgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
// RemoveHopByHopHeaders removes hop-by-hop headers.
func RemoveHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
for _, f := range h["Connection"] {
for _, sf := range strutils.SplitComma(f) {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
for _, f := range hopHeaders {
h.Del(f)
}
}
func RemoveHop(h http.Header) { func RemoveHop(h http.Header) {
reqUpType := UpgradeType(h) reqUpType := UpgradeType(h)
RemoveHopByHopHeaders(h) RemoveHopByHopHeaders(h)

View file

@ -0,0 +1,20 @@
package http
import "net/http"
var validMethods = map[string]struct{}{
http.MethodGet: {},
http.MethodHead: {},
http.MethodPost: {},
http.MethodPut: {},
http.MethodPatch: {},
http.MethodDelete: {},
http.MethodConnect: {},
http.MethodOptions: {},
http.MethodTrace: {},
}
func IsMethodValid(method string) bool {
_, ok := validMethods[method]
return ok
}

View file

@ -4,7 +4,10 @@ import (
"net" "net"
"net/http" "net/http"
"github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
) )
@ -16,7 +19,7 @@ type (
} }
CIDRWhitelistOpts struct { CIDRWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"` Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"` StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
Message string Message string
} }
) )
@ -30,6 +33,13 @@ var (
} }
) )
func init() {
utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
statusCode := fl.Field().Int()
return gphttp.IsStatusCodeValid(int(statusCode))
})
}
// setup implements MiddlewareWithSetup. // setup implements MiddlewareWithSetup.
func (wl *cidrWhitelist) setup() { func (wl *cidrWhitelist) setup() {
wl.CIDRWhitelistOpts = cidrWhitelistDefaults wl.CIDRWhitelistOpts = cidrWhitelistDefaults

View file

@ -24,6 +24,18 @@ func TestCIDRWhitelistValidation(t *testing.T) {
"message": testMessage, "message": testMessage,
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
_, err = CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
"status": 403,
})
ExpectNoError(t, err)
_, err = CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
"status_code": 403,
})
ExpectNoError(t, err)
}) })
t.Run("missing allow", func(t *testing.T) { t.Run("missing allow", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{ _, err := CIDRWhiteList.New(OptionsRaw{

View file

@ -9,14 +9,15 @@ import (
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
type ( type (
Error = E.Error Error = E.Error
ReverseProxy = gphttp.ReverseProxy ReverseProxy = reverseproxy.ReverseProxy
ProxyRequest = gphttp.ProxyRequest ProxyRequest = reverseproxy.ProxyRequest
ImplNewFunc = func() any ImplNewFunc = func() any
OptionsRaw = map[string]any OptionsRaw = map[string]any
@ -93,9 +94,9 @@ func (m *Middleware) finalize() {
} }
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.construct == nil { if m.construct == nil { // likely a middleware from compose
if optsRaw != nil { if len(optsRaw) != 0 {
panic("bug: middleware already constructed") return nil, E.New("additional options not allowed for middleware ").Subject(m.name)
} }
return m, nil return m, nil
} }

View file

@ -61,17 +61,38 @@ func LoadComposeFiles() {
logger.Err(err).Msg("failed to list middleware definitions") logger.Err(err).Msg("failed to list middleware definitions")
return return
} }
for _, defFile := range middlewareDefs {
voidErrs := E.NewBuilder("") // ignore these errors, will be added in next step
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
if len(mws) == 0 {
continue
}
for name, m := range mws {
name = strutils.ToLowerNoSnake(name)
if _, ok := allMiddlewares[name]; ok {
errs.Add(ErrDuplicatedMiddleware.Subject(name))
continue
}
allMiddlewares[name] = m
logger.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")
}
}
// build again to resolve cross references
for _, defFile := range middlewareDefs { for _, defFile := range middlewareDefs {
mws := BuildMiddlewaresFromComposeFile(defFile, errs) mws := BuildMiddlewaresFromComposeFile(defFile, errs)
if len(mws) == 0 { if len(mws) == 0 {
continue continue
} }
for name, m := range mws { for name, m := range mws {
name = strutils.ToLowerNoSnake(name)
if _, ok := allMiddlewares[name]; ok { if _, ok := allMiddlewares[name]; ok {
errs.Add(ErrDuplicatedMiddleware.Subject(name)) // already loaded above
continue continue
} }
allMiddlewares[strutils.ToLowerNoSnake(name)] = m allMiddlewares[name] = m
logger.Info(). logger.Info().
Str("src", path.Base(defFile)). Str("src", path.Base(defFile)).
Str("name", name). Str("name", name).

View file

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
) )
// internal use only. // internal use only.
@ -13,7 +14,7 @@ type setUpstreamHeaders struct {
var suh = NewMiddleware[setUpstreamHeaders]() var suh = NewMiddleware[setUpstreamHeaders]()
func newSetUpstreamHeaders(rp *gphttp.ReverseProxy) *Middleware { func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
m, err := suh.New(OptionsRaw{ m, err := suh.New(OptionsRaw{
"name": rp.TargetName, "name": rp.TargetName,
"scheme": rp.TargetURL.Scheme, "scheme": rp.TargetURL.Scheme,

View file

@ -10,7 +10,7 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
) )
@ -139,7 +139,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
rr.parent = http.DefaultTransport rr.parent = http.DefaultTransport
} }
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr) rp := reverseproxy.NewReverseProxy(middleware.name, args.upstreamURL, rr)
mid, setOptErr := middleware.New(args.middlewareOpt) mid, setOptErr := middleware.New(args.middlewareOpt)
if setOptErr != nil { if setOptErr != nil {

View file

@ -0,0 +1,577 @@
// Copyright 2011 The Go Authors.
// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go)
// https://cs.opensource.google/go/go/+/master:LICENSE
package reverseproxy
// This is a small mod on net/http/httputil/reverseproxy.go
// that boosts performance in some cases
// and compatible to other modules of this project
// Copyright (c) 2024 yusing
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/textproto"
"net/url"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/metrics"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts"
)
// A ProxyRequest contains a request to be rewritten by a [ReverseProxy].
type ProxyRequest struct {
// In is the request received by the proxy.
// The Rewrite function must not modify In.
In *http.Request
// Out is the request which will be sent by the proxy.
// The Rewrite function may modify or replace this request.
// Hop-by-hop headers are removed from this request
// before Rewrite is called.
Out *http.Request
}
// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
// X-Forwarded-Proto headers of the outbound request.
//
// - The X-Forwarded-For header is set to the client IP address.
// - The X-Forwarded-Host header is set to the host name requested
// by the client.
// - The X-Forwarded-Proto header is set to "http" or "https", depending
// on whether the inbound request was made on a TLS-enabled connection.
//
// If the outbound request contains an existing X-Forwarded-For header,
// SetXForwarded appends the client IP address to it. To append to the
// inbound request's X-Forwarded-For header (the default behavior of
// [ReverseProxy] when using a Director function), copy the header
// from the inbound request before calling SetXForwarded:
//
// rewriteFunc := func(r *httputil.ProxyRequest) {
// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
// r.SetXForwarded()
// }
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
//
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
zerolog.Logger
// The transport used to perform proxy requests.
Transport http.RoundTripper
// ModifyResponse is an optional function that modifies the
// Response from the backend. It is called if the backend
// returns a response at all, with any HTTP status code.
// If the backend is unreachable, the optional ErrorHandler is
// called before ModifyResponse.
//
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*http.Response) error
AccessLogger *accesslog.AccessLogger
HandlerFunc http.HandlerFunc
TargetName string
TargetURL types.URL
}
type httpMetricLogger struct {
http.ResponseWriter
timestamp time.Time
labels *metrics.HTTPRouteMetricLabels
}
var logger = logging.With().Str("module", "reverse_proxy").Logger()
// WriteHeader implements http.ResponseWriter.
func (l *httpMetricLogger) WriteHeader(status int) {
l.ResponseWriter.WriteHeader(status)
duration := time.Since(l.timestamp)
go func() {
m := metrics.GetRouteMetrics()
m.HTTPReqTotal.Inc()
m.HTTPReqElapsed.With(l.labels).Set(float64(duration.Milliseconds()))
// ignore 1xx
switch {
case status >= 500:
m.HTTP5xx.With(l.labels).Inc()
case status >= 400:
m.HTTP4xx.With(l.labels).Inc()
case status >= 200:
m.HTTP2xx3xx.With(l.labels).Inc()
}
}()
}
func (l *httpMetricLogger) Unwrap() http.ResponseWriter {
return l.ResponseWriter
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func joinURLPath(a, b *url.URL) (path, rawpath string) {
if a.RawPath == "" && b.RawPath == "" {
return singleJoiningSlash(a.Path, b.Path), ""
}
// Same as singleJoiningSlash, but uses EscapedPath to determine
// whether a slash should be added
apath := a.EscapedPath()
bpath := b.EscapedPath()
aslash := strings.HasSuffix(apath, "/")
bslash := strings.HasPrefix(bpath, "/")
switch {
case aslash && bslash:
return a.Path + b.Path[1:], apath + bpath[1:]
case !aslash && !bslash:
return a.Path + "/" + b.Path, apath + "/" + bpath
}
return a.Path + b.Path, apath + bpath
}
// NewReverseProxy returns a new [ReverseProxy] that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
func NewReverseProxy(name string, target types.URL, transport http.RoundTripper) *ReverseProxy {
if transport == nil {
panic("nil transport")
}
rp := &ReverseProxy{
Logger: logger.With().Str("name", name).Logger(),
Transport: transport,
TargetName: name,
TargetURL: target,
}
rp.HandlerFunc = rp.handler
return rp
}
func (p *ReverseProxy) UnregisterMetrics() {
metrics.GetRouteMetrics().UnregisterService(p.TargetName)
}
func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
targetQuery := p.TargetURL.RawQuery
req.URL.Scheme = p.TargetURL.Scheme
req.URL.Host = p.TargetURL.Host
req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL.URL, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) {
switch {
case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF):
logger.Debug().Err(err).Str("url", r.URL.String()).Msg("http proxy error")
default:
logger.Err(err).Str("url", r.URL.String()).Msg("http proxy error")
}
if writeHeader {
rw.WriteHeader(http.StatusInternalServerError)
}
if p.AccessLogger != nil {
p.AccessLogger.LogError(r, err)
}
}
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
res.Request = origReq
err := p.ModifyResponse(res)
res.Request = req
if err != nil {
res.Body.Close()
p.errorHandler(rw, req, err, true)
return false
}
return true
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.HandlerFunc(rw, req)
}
func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
visitorIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
visitorIP = req.RemoteAddr
}
if common.PrometheusEnabled {
t := time.Now()
// req.RemoteAddr had been modified by middleware (if any)
lbls := &metrics.HTTPRouteMetricLabels{
Service: p.TargetName,
Method: req.Method,
Host: req.Host,
Visitor: visitorIP,
Path: req.URL.Path,
}
rw = &httpMetricLogger{
ResponseWriter: rw,
timestamp: t,
labels: lbls,
}
}
transport := p.Transport
ctx := req.Context()
/* trunk-ignore(golangci-lint/revive) */
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
// a Context that carries a cancellation signal, don't
// bother spinning up a goroutine to watch the CloseNotify
// channel (if any).
//
// If the request Context has a nil Done channel (which
// means it is either context.Background, or a custom
// Context implementation with no cancellation signal),
// then consult the CloseNotifier if available.
} else if cn, ok := rw.(http.CloseNotifier); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
}
outreq := req.Clone(ctx)
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Body != nil {
// Reading from the request body after returning from a handler is not
// allowed, and the RoundTrip goroutine that reads the Body can outlive
// this handler. This can lead to a crash if the handler panics (see
// Issue 46866). Although calling Close doesn't guarantee there isn't
// any Read in flight after the handle returns, in practice it's safe to
// read after closing it.
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
p.rewriteRequestURL(outreq)
outreq.Close = false
reqUpType := gphttp.UpgradeType(outreq.Header)
if !IsPrint(reqUpType) {
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true)
return
}
req.Header.Del("Forwarded")
gphttp.RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
if strings.EqualFold(reqUpType, "websocket") {
cleanWebsocketHeaders(outreq)
}
}
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := outreq.Header[gphttp.HeaderXForwardedFor]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
xff := visitorIP
if len(prior) > 0 {
xff = strings.Join(prior, ", ") + ", " + xff
}
if !omit {
outreq.Header.Set(gphttp.HeaderXForwardedFor, xff)
}
var reqScheme string
if req.TLS != nil {
reqScheme = "https"
} else {
reqScheme = "http"
}
outreq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
outreq.Header.Set(gphttp.HeaderXForwardedProto, reqScheme)
outreq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
outreq.Header.Set(gphttp.HeaderXForwardedURI, req.RequestURI)
if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set,
// don't send the default Go HTTP client User-Agent.
outreq.Header.Set("User-Agent", "")
}
var (
roundTripMutex sync.Mutex
roundTripDone bool
)
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
roundTripMutex.Lock()
defer roundTripMutex.Unlock()
if roundTripDone {
// If RoundTrip has returned, don't try to further modify
// the ResponseWriter's header map.
return nil
}
h := rw.Header()
copyHeader(h, http.Header(header))
rw.WriteHeader(code)
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
clear(h)
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
roundTripMutex.Lock()
roundTripDone = true
roundTripMutex.Unlock()
if err != nil {
p.errorHandler(rw, outreq, err, false)
res = &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
Request: req,
TLS: req.TLS,
}
}
if p.AccessLogger != nil {
defer func() {
p.AccessLogger.Log(req, res)
}()
}
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
if !p.modifyResponse(rw, res, req, outreq) {
return
}
p.handleUpgradeResponse(rw, outreq, res)
return
}
gphttp.RemoveHopByHopHeaders(res.Header)
if !p.modifyResponse(rw, res, req, outreq) {
return
}
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
_, err = io.Copy(rw, res.Body)
if err != nil {
if !errors.Is(err, context.Canceled) {
p.errorHandler(rw, req, err, true)
}
res.Body.Close()
return
}
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
http.NewResponseController(rw).Flush()
}
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
return
}
for k, vv := range res.Trailer {
k = http.TrailerPrefix + k
for _, v := range vv {
rw.Header().Add(k, v)
}
}
}
// reference: https://github.com/traefik/traefik/blob/master/pkg/proxy/httputil/proxy.go
// https://tools.ietf.org/html/rfc6455#page-20
func cleanWebsocketHeaders(req *http.Request) {
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
delete(req.Header, "Sec-Websocket-Key")
req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"]
delete(req.Header, "Sec-Websocket-Extensions")
req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"]
delete(req.Header, "Sec-Websocket-Accept")
req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"]
delete(req.Header, "Sec-Websocket-Protocol")
req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"]
delete(req.Header, "Sec-Websocket-Version")
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := gphttp.UpgradeType(req.Header)
resUpType := gphttp.UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
return
}
if !strings.EqualFold(reqUpType, resUpType) {
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true)
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.errorHandler(rw, req, errors.New("internal error: 101 switching protocols response with non-writable body"), true)
return
}
rc := http.NewResponseController(rw)
conn, brw, hijackErr := rc.Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
p.errorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw), true)
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
if hijackErr != nil {
p.errorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", hijackErr), true)
return
}
defer conn.Close()
copyHeader(rw.Header(), res.Header)
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return
}
if err := brw.Flush(); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return
}
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
/* trunk-ignore(golangci-lint/errcheck) */
bdp.Start()
}
func IsPrint(s string) bool {
for _, r := range s {
if r < ' ' || r > '~' {
return false
}
}
return true
}

View file

@ -35,9 +35,9 @@ type Options struct {
Handler http.Handler Handler http.Handler
} }
func StartServer(opt Options) (s *Server) { func StartServer(parent task.Parent, opt Options) (s *Server) {
s = NewServer(opt) s = NewServer(opt)
s.Start() s.Start(parent)
return s return s
} }
@ -83,11 +83,13 @@ func NewServer(opt Options) (s *Server) {
// If both are not set, this does nothing. // If both are not set, this does nothing.
// //
// Start() is non-blocking. // Start() is non-blocking.
func (s *Server) Start() { func (s *Server) Start(parent task.Parent) {
if s.http == nil && s.https == nil { if s.http == nil && s.https == nil {
return return
} }
task := parent.Subtask("server."+s.Name, false)
s.startTime = time.Now() s.startTime = time.Now()
if s.http != nil { if s.http != nil {
go func() { go func() {
@ -105,7 +107,7 @@ func (s *Server) Start() {
s.l.Info().Str("addr", s.https.Addr).Msgf("server started") s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
} }
task.OnProgramExit("server."+s.Name+".stop", s.stop) task.OnCancel("stop", s.stop)
} }
func (s *Server) stop() { func (s *Server) stop() {
@ -113,14 +115,19 @@ func (s *Server) stop() {
return return
} }
ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second)
defer cancel()
if s.http != nil && s.httpStarted { if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(task.RootContext())) s.handleErr("http", s.http.Shutdown(ctx))
s.httpStarted = false s.httpStarted = false
s.l.Info().Str("addr", s.http.Addr).Msgf("server stopped")
} }
if s.https != nil && s.httpsStarted { if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(task.RootContext())) s.handleErr("https", s.https.Shutdown(ctx))
s.httpsStarted = false s.httpsStarted = false
s.l.Info().Str("addr", s.https.Addr).Msgf("server stopped")
} }
} }

View file

@ -5,3 +5,7 @@ import "net/http"
func IsSuccess(status int) bool { func IsSuccess(status int) bool {
return status >= http.StatusOK && status < http.StatusMultipleChoices return status >= http.StatusOK && status < http.StatusMultipleChoices
} }
func IsStatusCodeValid(status int) bool {
return http.StatusText(status) != ""
}

View file

@ -8,6 +8,11 @@ import (
//nolint:recvcheck //nolint:recvcheck
type CIDR net.IPNet type CIDR net.IPNet
func ParseCIDR(v string) (cidr CIDR, err error) {
err = cidr.Parse(v)
return
}
func (cidr *CIDR) Parse(v string) error { func (cidr *CIDR) Parse(v string) error {
if !strings.Contains(v, "/") { if !strings.Contains(v, "/") {
v += "/32" // single IP v += "/32" // single IP

View file

@ -49,10 +49,7 @@ func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool {
func init() { func init() {
utils.RegisterDefaultValueFactory(DefaultValue) utils.RegisterDefaultValueFactory(DefaultValue)
err := utils.Validator().RegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed) utils.MustRegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed)
if err != nil {
panic(err)
}
} }
// Name implements Provider. // Name implements Provider.

View file

@ -13,6 +13,7 @@ import (
"github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/loadbalancer"
loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/route/entry" "github.com/yusing/go-proxy/internal/route/entry"
"github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/route/routes"
route "github.com/yusing/go-proxy/internal/route/types" route "github.com/yusing/go-proxy/internal/route/types"
@ -30,7 +31,7 @@ type (
loadBalancer *loadbalancer.LoadBalancer loadBalancer *loadbalancer.LoadBalancer
server *loadbalancer.Server server *loadbalancer.Server
handler http.Handler handler http.Handler
rp *gphttp.ReverseProxy rp *reverseproxy.ReverseProxy
task *task.Task task *task.Task
@ -49,7 +50,7 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
} }
service := entry.TargetName() service := entry.TargetName()
rp := gphttp.NewReverseProxy(service, entry.URL, trans) rp := reverseproxy.NewReverseProxy(service, entry.URL, trans)
if len(entry.Raw.Middlewares) > 0 { if len(entry.Raw.Middlewares) > 0 {
err := middleware.PatchReverseProxy(rp, entry.Raw.Middlewares) err := middleware.PatchReverseProxy(rp, entry.Raw.Middlewares)
@ -138,7 +139,7 @@ func (r *HTTPRoute) Start(parent task.Parent) E.Error {
} }
if len(r.Raw.Rules) > 0 { if len(r.Raw.Rules) > 0 {
r.handler = r.Raw.Rules.BuildHandler(r.rp) r.handler = r.Raw.Rules.BuildHandler(r.handler)
} }
if r.HealthMon != nil { if r.HealthMon != nil {

View file

@ -72,9 +72,9 @@ proxy.app1.host: 10.0.0.254
proxy.app1.port: 80 proxy.app1.port: 80
proxy.app1.path_patterns: proxy.app1.path_patterns:
| # Check https://pkg.go.dev/net/http#hdr-Patterns-ServeMux for syntax | # Check https://pkg.go.dev/net/http#hdr-Patterns-ServeMux for syntax
GET / # accept any GET request - GET / # accept any GET request
POST /auth # for /auth and /auth/* accept only POST - POST /auth # for /auth and /auth/* accept only POST
GET /home/{$} # for exactly /home - GET /home/{$} # for exactly /home
proxy.app1.healthcheck.disabled: false proxy.app1.healthcheck.disabled: false
proxy.app1.healthcheck.path: / proxy.app1.healthcheck.path: /
proxy.app1.healthcheck.interval: 5s proxy.app1.healthcheck.interval: 5s

View file

@ -5,6 +5,7 @@ import (
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/entry" "github.com/yusing/go-proxy/internal/route/entry"
"github.com/yusing/go-proxy/internal/route/provider/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher"
) )
@ -87,10 +88,10 @@ func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
switch handler.provider.GetType() { switch handler.provider.GetType() {
case ProviderTypeDocker: case types.ProviderTypeDocker:
return route.Entry.Container.ContainerID == event.ActorID || return route.Entry.Container.ContainerID == event.ActorID ||
route.Entry.Container.ContainerName == event.ActorName route.Entry.Container.ContainerName == event.ActorName
case ProviderTypeFile: case types.ProviderTypeFile:
return true return true
} }
// should never happen // should never happen

View file

@ -10,6 +10,8 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/provider/types"
route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events" "github.com/yusing/go-proxy/internal/watcher/events"
@ -20,7 +22,7 @@ type (
ProviderImpl `json:"-"` ProviderImpl `json:"-"`
name string name string
t ProviderType t types.ProviderType
routes R.Routes routes R.Routes
watcher W.Watcher watcher W.Watcher
@ -31,24 +33,20 @@ type (
NewWatcher() W.Watcher NewWatcher() W.Watcher
Logger() *zerolog.Logger Logger() *zerolog.Logger
} }
ProviderType string
ProviderStats struct { ProviderStats struct {
NumRPs int `json:"num_reverse_proxies"` NumRPs int `json:"num_reverse_proxies"`
NumStreams int `json:"num_streams"` NumStreams int `json:"num_streams"`
Type ProviderType `json:"type"` Type types.ProviderType `json:"type"`
} }
) )
const ( const (
ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file"
providerEventFlushInterval = 300 * time.Millisecond providerEventFlushInterval = 300 * time.Millisecond
) )
var ErrEmptyProviderName = errors.New("empty provider name") var ErrEmptyProviderName = errors.New("empty provider name")
func newProvider(name string, t ProviderType) *Provider { func newProvider(name string, t types.ProviderType) *Provider {
return &Provider{ return &Provider{
name: name, name: name,
t: t, t: t,
@ -61,7 +59,7 @@ func NewFileProvider(filename string) (p *Provider, err error) {
if name == "" { if name == "" {
return nil, ErrEmptyProviderName return nil, ErrEmptyProviderName
} }
p = newProvider(strings.ReplaceAll(name, ".", "_"), ProviderTypeFile) p = newProvider(strings.ReplaceAll(name, ".", "_"), types.ProviderTypeFile)
p.ProviderImpl, err = FileProviderImpl(filename) p.ProviderImpl, err = FileProviderImpl(filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,7 +73,7 @@ func NewDockerProvider(name string, dockerHost string) (p *Provider, err error)
return nil, ErrEmptyProviderName return nil, ErrEmptyProviderName
} }
p = newProvider(name, ProviderTypeDocker) p = newProvider(name, types.ProviderTypeDocker)
p.ProviderImpl, err = DockerProviderImpl(name, dockerHost, p.IsExplicitOnly()) p.ProviderImpl, err = DockerProviderImpl(name, dockerHost, p.IsExplicitOnly())
if err != nil { if err != nil {
return nil, err return nil, err
@ -92,7 +90,7 @@ func (p *Provider) GetName() string {
return p.name return p.name
} }
func (p *Provider) GetType() ProviderType { func (p *Provider) GetType() types.ProviderType {
return p.t return p.t
} }
@ -111,7 +109,7 @@ func (p *Provider) startRoute(parent task.Parent, r *R.Route) E.Error {
return nil return nil
} }
// Start implements*task.TaskStarter. // Start implements task.TaskStarter.
func (p *Provider) Start(parent task.Parent) E.Error { func (p *Provider) Start(parent task.Parent) E.Error {
t := parent.Subtask("provider."+p.name, false) t := parent.Subtask("provider."+p.name, false)
@ -171,9 +169,9 @@ func (p *Provider) Statistics() ProviderStats {
numStreams := 0 numStreams := 0
p.routes.RangeAll(func(_ string, r *R.Route) { p.routes.RangeAll(func(_ string, r *R.Route) {
switch r.Type { switch r.Type {
case R.RouteTypeReverseProxy: case route.RouteTypeReverseProxy:
numRPs++ numRPs++
case R.RouteTypeStream: case route.RouteTypeStream:
numStreams++ numStreams++
} }
}) })

View file

@ -0,0 +1,8 @@
package types
type ProviderType string
const (
ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file"
)

View file

@ -14,11 +14,10 @@ import (
) )
type ( type (
RouteType string Route struct {
Route struct {
_ U.NoCopy _ U.NoCopy
impl impl
Type RouteType Type types.RouteType
Entry *RawEntry Entry *RawEntry
} }
Routes = F.Map[string, *Route] Routes = F.Map[string, *Route]
@ -34,11 +33,6 @@ type (
RawEntries = types.RawEntries RawEntries = types.RawEntries
) )
const (
RouteTypeStream RouteType = "stream"
RouteTypeReverseProxy RouteType = "reverse_proxy"
)
// function alias. // function alias.
var ( var (
NewRoutes = F.NewMap[Routes] NewRoutes = F.NewMap[Routes]
@ -59,15 +53,15 @@ func NewRoute(raw *RawEntry) (*Route, E.Error) {
return nil, err return nil, err
} }
var t RouteType var t types.RouteType
var rt impl var rt impl
switch e := en.(type) { switch e := en.(type) {
case *entry.StreamEntry: case *entry.StreamEntry:
t = RouteTypeStream t = types.RouteTypeStream
rt, err = NewStreamRoute(e) rt, err = NewStreamRoute(e)
case *entry.ReverseProxyEntry: case *entry.ReverseProxyEntry:
t = RouteTypeReverseProxy t = types.RouteTypeReverseProxy
rt, err = NewHTTPRoute(e) rt, err = NewHTTPRoute(e)
default: default:
panic("bug: should not reach here") panic("bug: should not reach here")

View file

@ -0,0 +1,99 @@
package routes
import (
"strings"
"github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/route/entry"
provider "github.com/yusing/go-proxy/internal/route/provider/types"
"github.com/yusing/go-proxy/internal/route/types"
route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
func HomepageConfig(useDefaultCategories bool) homepage.Config {
hpCfg := homepage.NewHomePageConfig()
GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) {
en := r.RawEntry()
item := en.Homepage
if item == nil {
item = new(homepage.Item)
item.Show = true
}
if !item.IsEmpty() {
item.Show = true
}
if !item.Show {
return
}
item.Alias = alias
if item.Name == "" {
item.Name = strutils.Title(
strings.ReplaceAll(
strings.ReplaceAll(alias, "-", " "),
"_", " ",
),
)
}
if useDefaultCategories {
if en.Container != nil && item.Category == "" {
if category, ok := homepage.PredefinedCategories[en.Container.ImageName]; ok {
item.Category = category
}
}
if item.Category == "" {
if category, ok := homepage.PredefinedCategories[strings.ToLower(alias)]; ok {
item.Category = category
}
}
}
switch {
case entry.IsDocker(r):
if item.Category == "" {
item.Category = "Docker"
}
item.SourceType = string(provider.ProviderTypeDocker)
case entry.UseLoadBalance(r):
if item.Category == "" {
item.Category = "Load-balanced"
}
item.SourceType = "loadbalancer"
default:
if item.Category == "" {
item.Category = "Others"
}
item.SourceType = string(provider.ProviderTypeFile)
}
item.AltURL = r.TargetURL().String()
hpCfg.Add(item)
})
return hpCfg
}
func RoutesByAlias(typeFilter ...route.RouteType) map[string]any {
rts := make(map[string]any)
if len(typeFilter) == 0 || typeFilter[0] == "" {
typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream}
}
for _, t := range typeFilter {
switch t {
case route.RouteTypeReverseProxy:
GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) {
rts[alias] = r
})
case route.RouteTypeStream:
GetStreamRoutes().RangeAll(func(alias string, r types.StreamRoute) {
rts[alias] = r
})
}
}
return rts
}

248
internal/route/rules/do.go Normal file
View file

@ -0,0 +1,248 @@
package rules
import (
"net/http"
"path"
"strconv"
"strings"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
Command struct {
raw string
exec *CommandExecutor
}
CommandExecutor struct {
directive string
http.HandlerFunc
proceed bool
}
)
const (
CommandRewrite = "rewrite"
CommandServe = "serve"
CommandProxy = "proxy"
CommandRedirect = "redirect"
CommandError = "error"
CommandBypass = "bypass"
)
var commands = map[string]struct {
help Help
validate ValidateFunc
build func(args any) *CommandExecutor
}{
CommandRewrite: {
help: Help{
command: CommandRewrite,
args: map[string]string{
"from": "the path to rewrite, must start with /",
"to": "the path to rewrite to, must start with /",
},
},
validate: func(args []string) (any, E.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
return validateURLPaths(args)
},
build: func(args any) *CommandExecutor {
a := args.([]string)
orig, repl := a[0], a[1]
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if len(path) > 0 && path[0] != '/' {
path = "/" + path
}
if !strings.HasPrefix(path, orig) {
return
}
path = repl + path[len(orig):]
r.URL.Path = path
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.RequestURI()
},
proceed: true,
}
},
},
CommandServe: {
help: Help{
command: CommandServe,
args: map[string]string{
"root": "the file system path to serve, must be an existing directory",
},
},
validate: validateFSPath,
build: func(args any) *CommandExecutor {
root := args.(string)
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
},
proceed: false,
}
},
},
CommandRedirect: {
help: Help{
command: CommandRedirect,
args: map[string]string{
"to": "the url to redirect to, can be relative or absolute URL",
},
},
validate: validateURL,
build: func(args any) *CommandExecutor {
target := args.(types.URL).String()
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
},
proceed: false,
}
},
},
CommandError: {
help: Help{
command: CommandError,
args: map[string]string{
"code": "the http status code to return",
"text": "the error message to return",
},
},
validate: func(args []string) (any, E.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
codeStr, text := args[0], args[1]
code, err := strconv.Atoi(codeStr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
if !gphttp.IsStatusCodeValid(code) {
return nil, ErrInvalidArguments.Subject(codeStr)
}
return []any{code, text}, nil
},
build: func(args any) *CommandExecutor {
a := args.([]any)
code, text := a[0].(int), a[1].(string)
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Error(w, text, code)
},
proceed: false,
}
},
},
CommandProxy: {
help: Help{
command: CommandProxy,
args: map[string]string{
"to": "the url to proxy to, must be an absolute URL",
},
},
validate: validateAbsoluteURL,
build: func(args any) *CommandExecutor {
target := args.(types.URL)
if target.Scheme == "" {
target.Scheme = "http"
}
rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport)
return &CommandExecutor{
HandlerFunc: rp.ServeHTTP,
proceed: false,
}
},
},
}
// Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error {
cmd.raw = v
lines := strutils.SplitLine(v)
if len(lines) == 0 {
return nil
}
executors := make([]*CommandExecutor, 0, len(lines))
for _, line := range lines {
if line == "" {
continue
}
directive, args, err := parse(line)
if err != nil {
return err
}
if directive == CommandBypass {
if len(args) != 0 {
return ErrInvalidArguments.Subject(directive)
}
return nil
}
builder, ok := commands[directive]
if !ok {
return ErrUnknownDirective.Subject(directive)
}
validArgs, err := builder.validate(args)
if err != nil {
return err.Subject(directive).Withf("%s", builder.help.String())
}
exec := builder.build(validArgs)
exec.directive = directive
executors = append(executors, exec)
}
exec, err := buildCmd(executors)
if err != nil {
return err
}
cmd.exec = exec
return nil
}
func buildCmd(executors []*CommandExecutor) (*CommandExecutor, error) {
for i, exec := range executors {
if !exec.proceed && i != len(executors)-1 {
return nil, ErrInvalidCommandSequence.
Withf("%s cannot follow %s", exec, executors[i+1])
}
}
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
for _, exec := range executors {
exec.HandlerFunc(w, r)
}
},
proceed: executors[len(executors)-1].proceed,
}, nil
}
func (cmd *Command) isBypass() bool {
return cmd.exec == nil
}
func (cmd *Command) String() string {
return cmd.raw
}
func (cmd *Command) MarshalJSON() ([]byte, error) {
return []byte("\"" + cmd.String() + "\""), nil
}
func (exec *CommandExecutor) String() string {
return exec.directive
}

View file

@ -0,0 +1,15 @@
package rules
import E "github.com/yusing/go-proxy/internal/error"
var (
ErrUnterminatedQuotes = E.New("unterminated quotes")
ErrUnsupportedEscapeChar = E.New("unsupported escape char")
ErrUnknownDirective = E.New("unknown directive")
ErrInvalidArguments = E.New("invalid arguments")
ErrInvalidOnTarget = E.New("invalid `rule.on` target")
ErrInvalidCommandSequence = E.New("invalid command sequence")
ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg")
ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args")
)

View file

@ -0,0 +1,41 @@
package rules
import "strings"
type Help struct {
command string
description string
args map[string]string // args[arg] -> description
}
/*
Generate help string, e.g.
rewrite <from> <to>
from: the path to rewrite, must start with /
to: the path to rewrite to, must start with /
*/
func (h *Help) String() string {
var sb strings.Builder
sb.WriteString(h.command)
sb.WriteString(" ")
for arg := range h.args {
sb.WriteRune('<')
sb.WriteString(arg)
sb.WriteString("> ")
}
if h.description != "" {
sb.WriteString("\n\t")
sb.WriteString(h.description)
sb.WriteRune('\n')
}
sb.WriteRune('\n')
for arg, desc := range h.args {
sb.WriteRune('\t')
sb.WriteString(arg)
sb.WriteString(": ")
sb.WriteString(desc)
sb.WriteRune('\n')
}
return sb.String()
}

254
internal/route/rules/on.go Normal file
View file

@ -0,0 +1,254 @@
package rules
import (
"net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
RuleOn struct {
raw string
check CheckFulfill
}
CheckFulfill func(r *http.Request) bool
Checkers []CheckFulfill
)
const (
OnHeader = "header"
OnQuery = "query"
OnCookie = "cookie"
OnForm = "form"
OnPostForm = "postform"
OnMethod = "method"
OnPath = "path"
OnRemote = "remote"
)
var checkers = map[string]struct {
help Help
validate ValidateFunc
check func(r *http.Request, args any) bool
}{
OnHeader: {
help: Help{
command: OnHeader,
args: map[string]string{
"key": "the header key",
"value": "the header value",
},
},
validate: toStrTuple,
check: func(r *http.Request, args any) bool {
return r.Header.Get(args.(StrTuple).First) == args.(StrTuple).Second
},
},
OnQuery: {
help: Help{
command: OnQuery,
args: map[string]string{
"key": "the query key",
"value": "the query value",
},
},
validate: toStrTuple,
check: func(r *http.Request, args any) bool {
return r.URL.Query().Get(args.(StrTuple).First) == args.(StrTuple).Second
},
},
OnCookie: {
help: Help{
command: OnCookie,
args: map[string]string{
"key": "the cookie key",
"value": "the cookie value",
},
},
validate: toStrTuple,
check: func(r *http.Request, args any) bool {
cookies := r.CookiesNamed(args.(StrTuple).First)
for _, cookie := range cookies {
if cookie.Value == args.(StrTuple).Second {
return true
}
}
return false
},
},
OnForm: {
help: Help{
command: OnForm,
args: map[string]string{
"key": "the form key",
"value": "the form value",
},
},
validate: toStrTuple,
check: func(r *http.Request, args any) bool {
return r.FormValue(args.(StrTuple).First) == args.(StrTuple).Second
},
},
OnPostForm: {
help: Help{
command: OnPostForm,
args: map[string]string{
"key": "the form key",
"value": "the form value",
},
},
validate: toStrTuple,
check: func(r *http.Request, args any) bool {
return r.PostFormValue(args.(StrTuple).First) == args.(StrTuple).Second
},
},
OnMethod: {
help: Help{
command: OnMethod,
args: map[string]string{
"method": "the http method",
},
},
validate: validateMethod,
check: func(r *http.Request, method any) bool {
return r.Method == method.(string)
},
},
OnPath: {
help: Help{
command: OnPath,
description: `The path can be a glob pattern, e.g.:
/path/to
/path/to/*`,
args: map[string]string{
"path": "the request path, must start with /",
},
},
validate: validateURLPath,
check: func(r *http.Request, globPath any) bool {
reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath
}
return strutils.GlobMatch(globPath.(string), reqPath)
},
},
OnRemote: {
help: Help{
command: OnRemote,
args: map[string]string{
"ip|cidr": "the remote ip or cidr",
},
},
validate: validateCIDR,
check: func(r *http.Request, cidr any) bool {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
host = r.RemoteAddr
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return cidr.(*net.IPNet).Contains(ip)
},
},
}
// Parse implements strutils.Parser.
func (on *RuleOn) Parse(v string) error {
on.raw = v
lines := strutils.SplitLine(v)
checks := make(Checkers, 0, len(lines))
errs := E.NewBuilder("rule.on syntax errors")
for i, line := range lines {
if line == "" {
continue
}
parsed, err := parseOn(line)
if err != nil {
errs.Add(err.Subjectf("line %d", i+1))
continue
}
checks = append(checks, parsed.matchOne())
}
on.check = checks.matchAll()
return errs.Error()
}
func (on *RuleOn) String() string {
return on.raw
}
func (on *RuleOn) MarshalJSON() ([]byte, error) {
return []byte("\"" + on.String() + "\""), nil
}
func parseOn(line string) (Checkers, E.Error) {
ors := strutils.SplitRune(line, '|')
if len(ors) > 1 {
errs := E.NewBuilder("rule.on syntax errors")
checks := make([]CheckFulfill, len(ors))
for i, or := range ors {
curCheckers, err := parseOn(or)
if err != nil {
errs.Add(err)
continue
}
checks[i] = curCheckers[0]
}
if err := errs.Error(); err != nil {
return nil, err
}
return checks, nil
}
subject, args, err := parse(line)
if err != nil {
return nil, err
}
checker, ok := checkers[subject]
if !ok {
return nil, ErrInvalidOnTarget.Subject(subject)
}
validArgs, err := checker.validate(args)
if err != nil {
return nil, err.Subject(subject).Withf("%s", checker.help.String())
}
return Checkers{
func(r *http.Request) bool {
return checker.check(r, validArgs)
},
}, nil
}
func (checkers Checkers) matchOne() CheckFulfill {
return func(r *http.Request) bool {
for _, checker := range checkers {
if checker(r) {
return true
}
}
return false
}
}
func (checkers Checkers) matchAll() CheckFulfill {
return func(r *http.Request) bool {
for _, checker := range checkers {
if !checker(r) {
return false
}
}
return true
}
}

View file

@ -0,0 +1,79 @@
package rules
import (
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
var escapedChars = map[rune]rune{
'n': '\n',
't': '\t',
'r': '\r',
'\'': '\'',
'"': '"',
'\\': '\\',
' ': ' ',
}
// parse expression to subject and args
// with support for quotes and escaped chars, e.g.
//
// error 403 "Forbidden 'foo' 'bar'"
// error 403 Forbidden\ \"foo\"\ \"bar\".
func parse(v string) (subject string, args []string, err E.Error) {
v = strings.TrimSpace(v)
var buf strings.Builder
escaped := false
quotes := make([]rune, 0, 4)
flush := func() {
if subject == "" {
subject = buf.String()
} else {
args = append(args, buf.String())
}
buf.Reset()
}
for _, r := range v {
if escaped {
if ch, ok := escapedChars[r]; ok {
buf.WriteRune(ch)
} else {
err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r)
return
}
escaped = false
continue
}
switch r {
case '\\':
escaped = true
continue
case '"', '\'':
switch {
case len(quotes) > 0 && quotes[len(quotes)-1] == r:
quotes = quotes[:len(quotes)-1]
if len(quotes) == 0 {
flush()
} else {
buf.WriteRune(r)
}
case len(quotes) == 0:
quotes = append(quotes, r)
default:
buf.WriteRune(r)
}
case ' ':
flush()
default:
buf.WriteRune(r)
}
}
if len(quotes) > 0 {
err = ErrUnterminatedQuotes
} else {
flush()
}
return
}

View file

@ -0,0 +1,103 @@
package rules
import (
"net/http"
)
type (
/*
Example:
proxy.app1.rules: |
- name: default
do: |
rewrite / /index.html
serve /var/www/goaccess
- name: ws
on: |
header Connection Upgrade
header Upgrade websocket
do: bypass
proxy.app2.rules: |
- name: default
do: bypass
- name: block POST and PUT
on: method POST | method PUT
do: error 403 Forbidden
*/
Rules []Rule
/*
Rule is a rule for a reverse proxy.
It do `Do` when `On` matches.
A rule can have multiple lines of on.
All lines of on must match,
but each line can have multiple checks that
one match means this line is matched.
*/
Rule struct {
Name string `json:"name" validate:"required"`
On RuleOn `json:"on"`
Do Command `json:"do"`
}
)
// BuildHandler returns a http.HandlerFunc that implements the rules.
//
// if a bypass rule matches,
// the request is passed to the upstream and no more rules are executed.
//
// if no rule matches, the default rule is executed
// if no rule matches and default rule is not set,
// the request is passed to the upstream.
func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
var (
defaultRule Rule
defaultRuleIndex int
)
for i, rule := range rules {
if rule.Name == "default" {
defaultRule = rule
defaultRuleIndex = i
break
}
}
rules = append(rules[:defaultRuleIndex], rules[defaultRuleIndex+1:]...)
// free allocated empty slices
// before encapsulating them into the handlerFunc.
if len(rules) == 0 {
if defaultRule.Do.isBypass() {
return up.ServeHTTP
}
rules = []Rule{}
}
return func(w http.ResponseWriter, r *http.Request) {
hasMatch := false
for _, rule := range rules {
if rule.On.check(r) {
if rule.Do.isBypass() {
up.ServeHTTP(w, r)
return
}
rule.Do.exec.HandlerFunc(w, r)
if !rule.Do.exec.proceed {
return
}
hasMatch = true
}
}
if hasMatch || defaultRule.Do.isBypass() {
up.ServeHTTP(w, r)
return
}
defaultRule.Do.exec.HandlerFunc(w, r)
}
}

View file

@ -0,0 +1,251 @@
package rules
import (
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseSubjectArgs(t *testing.T) {
t.Run("basic", func(t *testing.T) {
subject, args, err := parse("rewrite / /foo/bar")
ExpectNoError(t, err)
ExpectEqual(t, subject, "rewrite")
ExpectDeepEqual(t, args, []string{"/", "/foo/bar"})
})
t.Run("with quotes", func(t *testing.T) {
subject, args, err := parse(`error 403 "Forbidden 'foo' 'bar'."`)
ExpectNoError(t, err)
ExpectEqual(t, subject, "error")
ExpectDeepEqual(t, args, []string{"403", "Forbidden 'foo' 'bar'."})
})
t.Run("with escaped", func(t *testing.T) {
subject, args, err := parse(`error 403 Forbidden\ \"foo\"\ \"bar\".`)
ExpectNoError(t, err)
ExpectEqual(t, subject, "error")
ExpectDeepEqual(t, args, []string{"403", "Forbidden \"foo\" \"bar\"."})
})
}
func TestParseCommands(t *testing.T) {
tests := []struct {
name string
input string
wantErr error
}{
// bypass tests
{
name: "bypass_valid",
input: "bypass",
wantErr: nil,
},
{
name: "bypass_invalid_with_args",
input: "bypass /",
wantErr: ErrInvalidArguments,
},
// rewrite tests
{
name: "rewrite_valid",
input: "rewrite / /foo/bar",
wantErr: nil,
},
{
name: "rewrite_missing_target",
input: "rewrite /",
wantErr: ErrInvalidArguments,
},
{
name: "rewrite_too_many_args",
input: "rewrite / / /",
wantErr: ErrInvalidArguments,
},
{
name: "rewrite_no_leading_slash",
input: "rewrite abc /",
wantErr: ErrInvalidArguments,
},
// serve tests
{
name: "serve_valid",
input: "serve /var/www",
wantErr: nil,
},
{
name: "serve_missing_path",
input: "serve ",
wantErr: ErrInvalidArguments,
},
{
name: "serve_too_many_args",
input: "serve / / /",
wantErr: ErrInvalidArguments,
},
// redirect tests
{
name: "redirect_valid",
input: "redirect /",
wantErr: nil,
},
{
name: "redirect_too_many_args",
input: "redirect / /",
wantErr: ErrInvalidArguments,
},
// error directive tests
{
name: "error_valid",
input: "error 404 Not\\ Found",
wantErr: nil,
},
{
name: "error_missing_status_code",
input: "error Not\\ Found",
wantErr: ErrInvalidArguments,
},
{
name: "error_too_many_args",
input: "error 404 Not\\ Found extra",
wantErr: ErrInvalidArguments,
},
{
name: "error_no_escaped_space",
input: "error 404 Not Found",
wantErr: ErrInvalidArguments,
},
{
name: "error_invalid_status_code",
input: "error 123 abc",
wantErr: ErrInvalidArguments,
},
// proxy directive tests
{
name: "proxy_valid",
input: "proxy localhost:8080",
wantErr: nil,
},
{
name: "proxy_missing_target",
input: "proxy",
wantErr: ErrInvalidArguments,
},
{
name: "proxy_too_many_args",
input: "proxy localhost:8080 extra",
wantErr: ErrInvalidArguments,
},
{
name: "proxy_invalid_url",
input: "proxy :invalid_url",
wantErr: ErrInvalidArguments,
},
// unknown directive test
{
name: "unknown_directive",
input: "unknown /",
wantErr: ErrUnknownDirective,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := Command{}
err := cmd.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}
func TestParseOn(t *testing.T) {
tests := []struct {
name string
input string
wantErr E.Error
}{
// header
{
name: "header_valid",
input: "header Connection Upgrade",
wantErr: nil,
},
{
name: "header_invalid",
input: "header Connection",
wantErr: ErrInvalidArguments,
},
// query
{
name: "query_valid",
input: "query key value",
wantErr: nil,
},
{
name: "query_invalid",
input: "query key",
wantErr: ErrInvalidArguments,
},
// method
{
name: "method_valid",
input: "method GET",
wantErr: nil,
},
{
name: "method_invalid",
input: "method",
wantErr: ErrInvalidArguments,
},
// path
{
name: "path_valid",
input: "path /home",
wantErr: nil,
},
{
name: "path_invalid",
input: "path",
wantErr: ErrInvalidArguments,
},
// remote
{
name: "remote_valid",
input: "remote 127.0.0.1",
wantErr: nil,
},
{
name: "remote_invalid",
input: "remote",
wantErr: ErrInvalidArguments,
},
{
name: "unknown_target",
input: "unknown",
wantErr: ErrInvalidOnTarget,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}
func TestParseRule(t *testing.T) {
// test := map[string]any{
// "name": "test",
// "on": "method GET",
// "do": "bypass",
// }
}

View file

@ -0,0 +1,125 @@
package rules
import (
"os"
"path"
"strings"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
)
type (
ValidateFunc func(args []string) (any, E.Error)
StrTuple struct {
First, Second string
}
)
func toStrTuple(args []string) (any, E.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
return StrTuple{args[0], args[1]}, nil
}
func validateURL(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
u, err := types.ParseURL(args[0])
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return u, nil
}
func validateAbsoluteURL(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
u, err := types.ParseURL(args[0])
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
if u.Scheme == "" {
u.Scheme = "http"
}
if u.Host == "" {
return nil, ErrInvalidArguments.Withf("missing host")
}
return u, nil
}
func validateCIDR(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
if !strings.Contains(args[0], "/") {
args[0] += "/32"
}
cidr, err := types.ParseCIDR(args[0])
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return cidr, nil
}
func validateURLPath(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
p := args[0]
trailingSlash := len(p) > 1 && p[len(p)-1] == '/'
p, _, _ = strings.Cut(p, "#")
p = path.Clean(p)
if len(p) == 0 {
return nil, ErrInvalidArguments.Withf("empty path")
}
if trailingSlash {
p += "/"
}
if p[0] != '/' {
return nil, ErrInvalidArguments.Withf("must start with /")
}
return p, nil
}
func validateURLPaths(paths []string) (any, E.Error) {
errs := E.NewBuilder("invalid url paths")
for i, p := range paths {
val, err := validateURLPath([]string{p})
if err != nil {
errs.Add(err.Subject(p))
continue
}
paths[i] = val.(string)
}
if err := errs.Error(); err != nil {
return nil, err
}
return paths, nil
}
func validateFSPath(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
p := path.Clean(args[0])
if _, err := os.Stat(p); err != nil {
return nil, ErrInvalidArguments.With(err)
}
return p, nil
}
func validateMethod(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
method := strings.ToUpper(args[0])
if !gphttp.IsMethodValid(method) {
return nil, ErrInvalidArguments.Subject(method)
}
return method, nil
}

View file

@ -12,6 +12,7 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/http/accesslog"
loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/rules"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
@ -30,7 +31,7 @@ type (
Port string `json:"port,omitempty"` Port string `json:"port,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"` NoTLSVerify bool `json:"no_tls_verify,omitempty"`
PathPatterns []string `json:"path_patterns,omitempty"` PathPatterns []string `json:"path_patterns,omitempty"`
Rules Rules `json:"rules,omitempty"` Rules rules.Rules `json:"rules,omitempty" validate:"omitempty,unique=Name"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
LoadBalance *loadbalance.Config `json:"load_balance,omitempty"` LoadBalance *loadbalance.Config `json:"load_balance,omitempty"`
Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"` Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"`

View file

@ -0,0 +1,8 @@
package types
type RouteType string
const (
RouteTypeStream RouteType = "stream"
RouteTypeReverseProxy RouteType = "reverse_proxy"
)

View file

@ -9,7 +9,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"unicode"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -48,82 +47,6 @@ func New(t reflect.Type) reflect.Value {
return reflect.New(t) return reflect.New(t)
} }
// Serialize converts the given data into a map[string]any representation.
//
// It uses reflection to inspect the data type and handle different kinds of data.
// For a struct, it extracts the fields using the json tag if present, or the field name if not.
// For an embedded struct, it recursively converts its fields into the result map.
// For any other type, it returns an error.
//
// Parameters:
// - data: The data to be converted into a map.
//
// Returns:
// - result: The resulting map[string]any representation of the data.
// - error: An error if the data type is unsupported or if there is an error during conversion.
func Serialize(data any) (SerializedObject, error) {
result := make(map[string]any)
// Use reflection to inspect the data type
value := reflect.ValueOf(data)
// Check if the value is valid
if !value.IsValid() {
return nil, ErrInvalidType.Subjectf("%T", data)
}
// Dereference pointers if necessary
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
// Handle different kinds of data
switch value.Kind() {
case reflect.Map:
for _, key := range value.MapKeys() {
result[key.String()] = value.MapIndex(key).Interface()
}
case reflect.Struct:
for i := range value.NumField() {
field := value.Type().Field(i)
if !field.IsExported() {
continue
}
jsonTag := field.Tag.Get("json") // Get the json tag
if jsonTag == "-" {
continue // Ignore this field if the tag is "-"
}
if strings.Contains(jsonTag, ",omitempty") {
if value.Field(i).IsZero() {
continue
}
jsonTag = strings.Replace(jsonTag, ",omitempty", "", 1)
}
// If the json tag is not empty, use it as the key
switch {
case jsonTag != "":
result[jsonTag] = value.Field(i).Interface()
case field.Anonymous:
// If the field is an embedded struct, add its fields to the result
fieldMap, err := Serialize(value.Field(i).Interface())
if err != nil {
return nil, err
}
for k, v := range fieldMap {
result[k] = v
}
default:
result[field.Name] = value.Field(i).Interface()
}
}
default:
return nil, errors.New("serialize: unsupported data type " + value.Kind().String())
}
return result, nil
}
func extractFields(t reflect.Type) []reflect.StructField { func extractFields(t reflect.Type) []reflect.StructField {
for t.Kind() == reflect.Ptr { for t.Kind() == reflect.Ptr {
t = t.Elem() t = t.Elem()
@ -203,9 +126,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
mapping[key] = dstV.FieldByName(field.Name) mapping[key] = dstV.FieldByName(field.Name)
fieldName[field.Name] = key fieldName[field.Name] = key
_, ok := field.Tag.Lookup("validate") if !needValidate {
if ok { _, needValidate = field.Tag.Lookup("validate")
needValidate = true
} }
aliases, ok := field.Tag.Lookup("aliases") aliases, ok := field.Tag.Lookup("aliases")
@ -258,7 +180,7 @@ func Deserialize(src SerializedObject, dst any) E.Error {
} }
return errs.Error() return errs.Error()
default: default:
return ErrUnsupportedConversion.Subject("deserialize to " + dstT.String()) return ErrUnsupportedConversion.Subject("mapping to " + dstT.String())
} }
} }
@ -355,7 +277,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
if dstT.Kind() != reflect.Slice { if dstT.Kind() != reflect.Slice {
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String()) return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
} }
newSlice := reflect.MakeSlice(dstT, 0, src.Len()) newSlice := reflect.MakeSlice(dstT, src.Len(), src.Len())
i := 0 i := 0
for _, v := range src.Seq2() { for _, v := range src.Seq2() {
tmp := New(dstT.Elem()).Elem() tmp := New(dstT.Elem()).Elem()
@ -363,7 +285,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
if err != nil { if err != nil {
return err.Subjectf("[%d]", i) return err.Subjectf("[%d]", i)
} }
newSlice = reflect.Append(newSlice, tmp) newSlice.Index(i).Set(tmp)
i++ i++
} }
dst.Set(newSlice) dst.Set(newSlice)
@ -424,10 +346,11 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
return true, E.From(parser.Parse(src)) return true, E.From(parser.Parse(src))
} }
// yaml like // yaml like
isMultiline := strings.ContainsRune(src, '\n')
var tmp any var tmp any
switch dst.Kind() { switch dst.Kind() {
case reflect.Slice: case reflect.Slice:
src = strings.TrimSpace(src)
isMultiline := strings.ContainsRune(src, '\n')
// one liner is comma separated list // one liner is comma separated list
if !isMultiline { if !isMultiline {
values := strutils.CommaSeperatedList(src) values := strutils.CommaSeperatedList(src)
@ -444,16 +367,10 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
} }
return return
} }
lines := strutils.SplitLine(src) sl := make([]any, 0)
sl := make([]string, 0, len(lines)) err := yaml.Unmarshal([]byte(src), &sl)
for _, line := range lines { if err != nil {
line = strings.TrimLeftFunc(line, func(r rune) bool { return true, E.From(err)
return r == '-' || unicode.IsSpace(r)
})
if line == "" || line[0] == '#' {
continue
}
sl = append(sl, line)
} }
tmp = sl tmp = sl
case reflect.Map, reflect.Struct: case reflect.Map, reflect.Struct:

View file

@ -8,7 +8,7 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestSerializeDeserialize(t *testing.T) { func TestDeserialize(t *testing.T) {
type S struct { type S struct {
I int I int
S string S string
@ -37,12 +37,6 @@ func TestSerializeDeserialize(t *testing.T) {
} }
) )
t.Run("serialize", func(t *testing.T) {
s, err := Serialize(testStruct)
ExpectNoError(t, err)
ExpectDeepEqual(t, s, testStructSerialized)
})
t.Run("deserialize", func(t *testing.T) { t.Run("deserialize", func(t *testing.T) {
var s2 S var s2 S
err := Deserialize(testStructSerialized, &s2) err := Deserialize(testStructSerialized, &s2)
@ -174,7 +168,7 @@ func TestStringToSlice(t *testing.T) {
}) })
t.Run("multiline", func(t *testing.T) { t.Run("multiline", func(t *testing.T) {
dst := make([]string, 0) dst := make([]string, 0)
convertible, err := ConvertString(" a\n b\n c", reflect.ValueOf(&dst)) convertible, err := ConvertString("- a\n- b\n- c", reflect.ValueOf(&dst))
ExpectTrue(t, convertible) ExpectTrue(t, convertible)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectDeepEqual(t, dst, []string{"a", "b", "c"}) ExpectDeepEqual(t, dst, []string{"a", "b", "c"})

View file

@ -1,7 +1,6 @@
package strutils package strutils
import ( import (
"net/url"
"strings" "strings"
"golang.org/x/text/cases" "golang.org/x/text/cases"
@ -22,14 +21,6 @@ func Title(s string) string {
return cases.Title(language.AmericanEnglish).String(s) return cases.Title(language.AmericanEnglish).String(s)
} }
func ExtractPort(fullURL string) (int, error) {
url, err := url.Parse(fullURL)
if err != nil {
return 0, err
}
return Atoi(url.Port())
}
func ToLowerNoSnake(s string) string { func ToLowerNoSnake(s string) string {
return strings.ToLower(strings.ReplaceAll(s, "_", "")) return strings.ToLower(strings.ReplaceAll(s, "_", ""))
} }

View file

@ -12,3 +12,10 @@ var ErrValidationError = E.New("validation error")
func Validator() *validator.Validate { func Validator() *validator.Validate {
return validate return validate
} }
func MustRegisterValidation(tag string, fn validator.Func) {
err := validate.RegisterValidation(tag, fn)
if err != nil {
panic(err)
}
}

118
next-release.md Normal file
View file

@ -0,0 +1,118 @@
GoDoxy v0.8.2 expected changes
- **Thanks [polds](https://github.com/polds)**
Optionally allow a user to specify a “warm-up” endpoint to start the container, returning a 403 if the endpoint isnt hit and the container has been stopped.
This can help prevent bots from starting random containers, or allow health check systems to run some probes. Or potentially lock the start endpoints behind a different authentication mechanism, etc.
Sample service showing this:
```yaml
hello-world:
image: nginxdemos/hello
container_name: hello-world
restart: "no"
ports:
- "9100:80"
labels:
proxy.aliases: hello-world
proxy.#1.port: 9100
proxy.idle_timeout: 45s
proxy.wake_timeout: 30s
proxy.stop_method: stop
proxy.stop_timeout: 10s
proxy.stop_signal: SIGTERM
proxy.start_endpoint: "/start"
```
Hitting `/` on this service when the container is down:
```curl
$ curl -sv -X GET -H "Host: hello-world.godoxy.local" http://localhost/
* Host localhost:80 was resolved.
* IPv6: ::1
* IPv4: 127.0.0.1
* Trying [::1]:80...
* Connected to localhost (::1) port 80
> GET / HTTP/1.1
> Host: hello-world.godoxy.local
> User-Agent: curl/8.7.1
> Accept: */*
>
* Request completely sent off
< HTTP/1.1 403 Forbidden
< Content-Type: text/plain; charset=utf-8
< X-Content-Type-Options: nosniff
< Date: Wed, 08 Jan 2025 02:04:51 GMT
< Content-Length: 71
<
Forbidden: Container can only be started via configured start endpoint
* Connection #0 to host localhost left intact
```
Hitting `/start` when the container is down:
```curl
curl -sv -X GET -H "Host: hello-world.godoxy.local" -H "X-Goproxy-Check-Redirect: skip" http://localhost/start
* Host localhost:80 was resolved.
* IPv6: ::1
* IPv4: 127.0.0.1
* Trying [::1]:80...
* Connected to localhost (::1) port 80
> GET /start HTTP/1.1
> Host: hello-world.godoxy.local
> User-Agent: curl/8.7.1
> Accept: */*
> X-Goproxy-Check-Redirect: skip
>
* Request completely sent off
< HTTP/1.1 200 OK
< Date: Wed, 08 Jan 2025 02:13:39 GMT
< Content-Length: 0
<
* Connection #0 to host localhost left intact
```
- Caddyfile like rules
```yaml
proxy.goaccess.rules: |
- name: default
do: |
rewrite / /index.html
serve /var/www/goaccess
- name: ws
on: |
header Connection Upgrade
header Upgrade websocket
do: bypass # do nothing, pass to reverse proxy
proxy.app.rules: |
- name: default
do: bypass # do nothing, pass to reverse proxy
- name: block POST and PUT
on: method POST | method PUT
do: error 403 Forbidden
```
````
- config reload will now cause all servers to fully restart (i.e. proxy, api, prometheus, etc)
- multiline-string as list now treated as YAML list, which requires hyphen prefix `-`, i.e.
```yaml
proxy.app.middlewares.request.hide_headers:
- X-Header1
- X-Header2
````
- autocert now supports hot-reload
- middleware compose now supports cross-referencing, e.g.
```yaml
foo:
- use: RedirectHTTP
bar: # in the same file or different file
- use: foo@file
```
- Fixes
- bug: cert renewal failure no longer causes renew schdueler to stuck forever
- bug: access log writes to closed file after config reload