migrated from logrus to zerolog, improved error formatting, fixed concurrent map write, fixed crash on rapid page refresh for idle containers, fixed infinite recursion on gotfiy error, fixed websocket connection problem when using idlewatcher

This commit is contained in:
yusing 2024-10-29 11:34:58 +08:00
parent cfa74d69ae
commit e5bbb18414
137 changed files with 2640 additions and 2348 deletions

View file

@ -65,3 +65,6 @@ debug-list-containers:
ci-test:
mkdir -p /tmp/artifacts
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
cloc:
cloc --not-match-f '_test.go$$' cmd internal pkg

View file

@ -2,7 +2,6 @@ package main
import (
"encoding/json"
"io"
"log"
"net/http"
"os"
@ -10,15 +9,14 @@ import (
"syscall"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal"
"github.com/yusing/go-proxy/internal/api"
"github.com/yusing/go-proxy/internal/api/v1/query"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/notif"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/task"
@ -33,44 +31,26 @@ func main() {
return
}
l := logrus.WithField("module", "main")
timeFmt := "01-02 15:04:05"
fullTS := true
if common.IsTrace {
logrus.SetLevel(logrus.TraceLevel)
timeFmt = "04:05"
fullTS = false
} else if common.IsDebug {
logrus.SetLevel(logrus.DebugLevel)
}
if args.Command != common.CommandStart {
logrus.SetOutput(io.Discard)
} else {
logrus.SetFormatter(&logrus.TextFormatter{
DisableSorting: true,
FullTimestamp: fullTS,
ForceColors: true,
TimestampFormat: timeFmt,
})
logrus.Infof("go-proxy version %s", pkg.GetVersion())
logrus.AddHook(notif.GetDispatcher())
}
if args.Command == common.CommandReload {
if err := query.ReloadServer(); err != nil {
log.Fatal(err)
E.LogFatal("server reload error", err)
}
log.Print("ok")
logging.Info().Msg("ok")
return
}
// exit if only validate config
if args.Command == common.CommandStart {
logging.Info().Msgf("go-proxy version %s", pkg.GetVersion())
logging.Trace().Msg("trace enabled")
// logging.AddHook(notif.GetDispatcher())
} else {
logging.DiscardLogger()
}
if args.Command == common.CommandValidate {
data, err := os.ReadFile(common.ConfigPath)
if err == nil {
err = config.Validate(data).Error()
err = config.Validate(data)
}
if err != nil {
log.Fatal("config error: ", err)
@ -88,7 +68,7 @@ func main() {
var cfg *config.Config
var err E.Error
if cfg, err = config.Load(); err != nil {
logrus.Warn(err)
E.LogWarn("errors in config", err)
}
switch args.Command {
@ -145,10 +125,10 @@ func main() {
autocert := config.GetAutoCertProvider()
if autocert != nil {
if err := autocert.Setup(); err != nil {
l.Fatal(err)
E.LogFatal("autocert setup error", err)
}
} else {
l.Info("autocert not configured")
logging.Info().Msg("autocert not configured")
}
proxyServer := server.InitProxyServer(server.Options{
@ -174,7 +154,7 @@ func main() {
<-sig
// grafully shutdown
logrus.Info("shutting down")
logging.Info().Msg("shutting down")
task.CancelGlobalContext()
task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
}
@ -182,15 +162,15 @@ func main() {
func prepareDirectory(dir string) {
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err = os.MkdirAll(dir, 0o755); err != nil {
logrus.Fatalf("failed to create directory %s: %v", dir, err)
logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
}
}
}
func printJSON(obj any) {
j, err := E.Check(json.MarshalIndent(obj, "", " "))
j, err := json.MarshalIndent(obj, "", " ")
if err != nil {
logrus.Fatal(err)
logging.Fatal().Err(err).Send()
}
rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Printf("%s", j) // raw output for convenience using "jq"

7
go.mod
View file

@ -10,8 +10,8 @@ require (
github.com/go-acme/lego/v4 v4.19.2
github.com/gotify/server/v2 v2.5.0
github.com/puzpuzpuz/xsync/v3 v3.4.0
github.com/rs/zerolog v1.33.0
github.com/santhosh-tekuri/jsonschema v1.2.4
github.com/sirupsen/logrus v1.9.3
golang.org/x/net v0.30.0
golang.org/x/text v0.19.0
gopkg.in/yaml.v3 v3.0.1
@ -20,7 +20,7 @@ require (
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cloudflare/cloudflare-go v0.107.0 // indirect
github.com/cloudflare/cloudflare-go v0.108.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect
@ -33,6 +33,8 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/term v0.5.0 // indirect
@ -42,6 +44,7 @@ require (
github.com/ovh/go-ovh v1.6.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
go.opentelemetry.io/otel v1.31.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect

18
go.sum
View file

@ -4,12 +4,13 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cloudflare/cloudflare-go v0.107.0 h1:cMDIw2tzt6TXCJyMFVyP+BPOVkIfMvcKjhMNSNvuEPc=
github.com/cloudflare/cloudflare-go v0.107.0/go.mod h1:5cYGzVBqNTLxMYSLdVjuSs5LJL517wJDSvMPWUrzHzc=
github.com/cloudflare/cloudflare-go v0.108.0 h1:C4Skfjd8I8X3uEOGmQUT4/iGyZcWdkIU7HwvMoLkEE0=
github.com/cloudflare/cloudflare-go v0.108.0/go.mod h1:m492eNahT/9MsN7Ppnoge8AaI7QhVFtEgVm3I9HJFeU=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -40,6 +41,7 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@ -61,6 +63,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g=
github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
@ -88,6 +96,9 @@ github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPK
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
@ -140,6 +151,9 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View file

@ -6,7 +6,6 @@ import (
"net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
@ -38,7 +37,6 @@ func NewHandler() http.Handler {
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("GET", "/v1/stats", v1.Stats)
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc())
return mux
}
@ -50,7 +48,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
Logger.Warnf("blocked API request from %s", host)
LogWarn(r).Msgf("blocked API request from %s", host)
http.Error(w, "forbidden", http.StatusForbidden)
return
}

View file

@ -1,31 +0,0 @@
package errorpage
import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
)
func GetHandleFunc() http.HandlerFunc {
setup()
return serveHTTP
}
func serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if r.URL.Path == "/" {
http.Error(w, "invalid path", http.StatusNotFound)
return
}
content, ok := fileContentMap.Load(r.URL.Path)
if !ok {
http.Error(w, "404 not found", http.StatusNotFound)
return
}
if _, err := w.Write(content); err != nil {
HandleErr(w, r, err)
}
}

View file

@ -39,15 +39,16 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
return
}
var validateErr E.Error
var valErr E.Error
if filename == common.ConfigFileName {
validateErr = config.Validate(content)
valErr = config.Validate(content)
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {
validateErr = provider.Validate(content)
valErr = provider.Validate(content)
}
// no validation for include files
if validateErr != nil {
U.RespondJSON(w, r, validateErr.JSONObject(), http.StatusBadRequest)
if valErr != nil {
U.RespondJSON(w, r, valErr, http.StatusBadRequest)
return
}

View file

@ -20,16 +20,13 @@ func ReloadServer() E.Error {
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode)
b, err := io.ReadAll(resp.Body)
failure := E.Errorf("server reload status %v", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if err != nil {
return failure.Extraf("unable to read response body: %s", err)
return failure.With(err)
}
reloadErr, ok := E.FromJSON(b)
if ok {
return E.Join("reload success, but server returned error", reloadErr)
}
return failure.Extraf("unable to read response body")
reloadErr := string(body)
return failure.Withf(reloadErr)
}
return nil
}
@ -42,7 +39,7 @@ func List[T any](what string) (_ T, outErr E.Error) {
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode)
outErr = E.Errorf("list %s: failed, status %v", what, resp.StatusCode)
return
}
var res T

View file

@ -9,8 +9,8 @@ import (
func Reload(w http.ResponseWriter, r *http.Request) {
if err := config.Reload(); err != nil {
U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError)
} else {
w.WriteHeader(http.StatusOK)
U.HandleErr(w, r, err)
return
}
U.WriteBody(w, []byte("OK"))
}

View file

@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
func Stats(w http.ResponseWriter, r *http.Request) {
@ -23,7 +23,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses))
if len(originPats) == 0 {
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins")
U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
originPats = []string{"*"}
} else {
for i, domain := range config.Value().MatchDomains {
@ -38,7 +38,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
OriginPatterns: originPats,
})
if err != nil {
U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err)
U.LogError(r).Err(err).Msg("failed to upgrade websocket")
return
}
/* trunk-ignore(golangci-lint/errcheck) */
@ -53,7 +53,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
for range ticker.C {
stats := getStats()
if err := wsjson.Write(ctx, conn, stats); err != nil {
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err)
U.LogError(r).Msg("failed to write JSON")
return
}
}
@ -62,6 +62,6 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
func getStats() map[string]any {
return map[string]any{
"proxies": config.Statistics(),
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
"uptime": strutils.FormatDuration(server.GetProxyServer().Uptime()),
}
}

View file

@ -1,37 +1,31 @@
package utils
import (
"errors"
"fmt"
"net/http"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
)
var Logger = logrus.WithField("module", "api")
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
if origErr == nil {
return
}
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
Logger.Error(err)
LogError(r).Msg(origErr.Error())
if len(code) > 0 {
http.Error(w, err.String(), code[0])
http.Error(w, origErr.Error(), code[0])
return
}
http.Error(w, err.String(), http.StatusInternalServerError)
http.Error(w, origErr.Error(), http.StatusInternalServerError)
}
func ErrMissingKey(k string) error {
return errors.New("missing key '" + k + "' in query or request body")
return E.New("missing key '" + k + "' in query or request body")
}
func ErrInvalidKey(k string) error {
return errors.New("invalid key '" + k + "' in query or request body")
return E.New("invalid key '" + k + "' in query or request body")
}
func ErrNotFound(k, v string) error {
return fmt.Errorf("key %q with value %q not found", k, v)
return E.Errorf("key %q with value %q not found", k, v)
}

View file

@ -9,12 +9,11 @@ import (
)
var (
HTTPClient = &http.Client{
httpClient = &http.Client{
Timeout: common.ConnectionTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DisableKeepAlives: true,
ForceAttemptHTTP2: true,
ForceAttemptHTTP2: false,
DialContext: (&net.Dialer{
Timeout: common.DialTimeout,
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
@ -23,7 +22,7 @@ var (
},
}
Get = HTTPClient.Get
Post = HTTPClient.Post
Head = HTTPClient.Head
Get = httpClient.Get
Post = httpClient.Post
Head = httpClient.Head
)

View file

@ -0,0 +1,18 @@
package utils
import (
"net/http"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return logging.WithLevel(level).Str("module", "api").
Str("method", r.Method).
Str("path", r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }

View file

@ -3,6 +3,8 @@ package utils
import (
"encoding/json"
"net/http"
"github.com/yusing/go-proxy/internal/logging"
)
func WriteBody(w http.ResponseWriter, body []byte) {
@ -11,15 +13,25 @@ func WriteBody(w http.ResponseWriter, body []byte) {
}
}
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) bool {
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
if len(code) > 0 {
w.WriteHeader(code[0])
}
w.Header().Set("Content-Type", "application/json")
j, err := json.MarshalIndent(data, "", " ")
if err != nil {
HandleErr(w, r, err)
return false
var j []byte
var err error
switch data := data.(type) {
case string:
j = []byte(`"` + data + `"`)
case []byte:
j = data
default:
j, err = json.MarshalIndent(data, "", " ")
if err != nil {
logging.Panic().Err(err).Msg("failed to marshal json")
return false
}
}
_, err = w.Write(j)
if err != nil {

View file

@ -8,12 +8,21 @@ import (
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/lego"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/config/types"
)
type Config types.AutoCertConfig
var (
ErrMissingDomain = E.New("missing field 'domains'")
ErrMissingEmail = E.New("missing field 'email'")
ErrMissingProvider = E.New("missing field 'provider'")
ErrUnknownProvider = E.New("unknown provider")
)
func NewConfig(cfg *types.AutoCertConfig) *Config {
if cfg.CertPath == "" {
cfg.CertPath = CertFileDefault
@ -27,35 +36,36 @@ func NewConfig(cfg *types.AutoCertConfig) *Config {
return (*Config)(cfg)
}
func (cfg *Config) GetProvider() (provider *Provider, res E.Error) {
b := E.NewBuilder("unable to initialize autocert")
defer b.To(&res)
func (cfg *Config) GetProvider() (*Provider, E.Error) {
b := E.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
b.Addf("%s", "no domains specified")
b.Add(ErrMissingDomain)
}
if cfg.Provider == "" {
b.Addf("%s", "no provider specified")
b.Add(ErrMissingProvider)
}
if cfg.Email == "" {
b.Addf("%s", "no email specified")
b.Add(ErrMissingEmail)
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]
if !ok {
b.Addf("unknown provider: %q", cfg.Provider)
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
}
}
if b.HasError() {
return
return nil, b.Error()
}
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
if err.HasError() {
b.Add(E.FailWith("generate private key", err))
return
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
b.Addf("generate private key: %w", err)
return nil, b.Error()
}
user := &User{
@ -66,11 +76,9 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.Error) {
legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048
provider = &Provider{
return &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
}
return
}, nil
}

View file

@ -7,7 +7,6 @@ import (
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/duckdns"
"github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/sirupsen/logrus"
)
const (
@ -36,5 +35,3 @@ var providersGenMap = map[string]ProviderGenerator{
var (
ErrGetCertFailure = errors.New("get certificate failed")
)
var logger = logrus.WithField("module", "autocert")

View file

@ -0,0 +1,5 @@
package autocert
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "autocert").Logger()

View file

@ -17,6 +17,7 @@ import (
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
@ -57,25 +58,20 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries
}
func (p *Provider) ObtainCert() (res E.Error) {
b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res)
func (p *Provider) ObtainCert() E.Error {
if p.cfg.Provider == ProviderLocal {
return nil
}
if p.client == nil {
if err := p.initClient(); err.HasError() {
b.Add(E.FailWith("init autocert client", err))
return
if err := p.initClient(); err != nil {
return err
}
}
if p.user.Registration == nil {
if err := p.registerACME(); err.HasError() {
b.Add(E.FailWith("register ACME", err))
return
if err := p.registerACME(); err != nil {
return E.From(err)
}
}
@ -84,27 +80,23 @@ func (p *Provider) ObtainCert() (res E.Error) {
Domains: p.cfg.Domains,
Bundle: true,
}
cert, err := E.Check(client.Certificate.Obtain(req))
if err.HasError() {
b.Add(err)
return
cert, err := client.Certificate.Obtain(req)
if err != nil {
return E.From(err)
}
if err = p.saveCert(cert); err.HasError() {
b.Add(E.FailWith("save certificate", err))
return
if err = p.saveCert(cert); err != nil {
return E.From(err)
}
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
if err.HasError() {
b.Add(E.FailWith("parse obtained certificate", err))
return
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
if err != nil {
return E.From(err)
}
expiries, err := getCertExpiries(&tlsCert)
if err.HasError() {
b.Add(E.FailWith("get certificate expiry", err))
return
if err != nil {
return E.From(err)
}
p.tlsCert = &tlsCert
p.certExpiries = expiries
@ -113,21 +105,22 @@ func (p *Provider) ObtainCert() (res E.Error) {
}
func (p *Provider) LoadCert() E.Error {
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
if err.HasError() {
return err
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
if err != nil {
return E.Errorf("load SSL certificate: %w", err)
}
expiries, err := getCertExpiries(&cert)
if err.HasError() {
return err
if err != nil {
return E.Errorf("parse SSL certificate: %w", err)
}
p.tlsCert = &cert
p.certExpiries = expiries
logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn())))
logger.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded()
}
// ShouldRenewOn returns the time at which the certificate should be renewed.
func (p *Provider) ShouldRenewOn() time.Time {
for _, expiry := range p.certExpiries {
return expiry.AddDate(0, -1, 0) // 1 month before
@ -150,8 +143,8 @@ func (p *Provider) ScheduleRenewal() {
case <-task.Context().Done():
return
case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err.HasError() {
logger.Warn(err)
if err := p.renewIfNeeded(); err != nil {
E.LogWarn("cert renew failed", err, &logger)
}
}
}
@ -159,31 +152,32 @@ func (p *Provider) ScheduleRenewal() {
}
func (p *Provider) initClient() E.Error {
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() {
return E.FailWith("create lego client", err)
legoClient, err := lego.NewClient(p.legoCfg)
if err != nil {
return E.From(err)
}
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
if err.HasError() {
return E.FailWith("create lego provider", err)
generator := providersGenMap[p.cfg.Provider]
legoProvider, pErr := generator(p.cfg.Options)
if pErr != nil {
return pErr
}
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.HasError() {
return E.FailWith("set challenge provider", err)
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
if err != nil {
return E.From(err)
}
p.client = legoClient
return nil
}
func (p *Provider) registerACME() E.Error {
func (p *Provider) registerACME() error {
if p.user.Registration != nil {
return nil
}
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.HasError() {
reg, err := p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
return err
}
p.user.Registration = reg
@ -191,26 +185,27 @@ func (p *Provider) registerACME() E.Error {
return nil
}
func (p *Provider) saveCert(cert *certificate.Resource) E.Error {
func (p *Provider) saveCert(cert *certificate.Resource) error {
/* This should have been done in setup
but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath))
if err != nil {
if os.IsNotExist(err) {
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
return E.FailWith("create cert directory", err)
return err
}
} else {
return E.FailWith("stat cert directory", err)
return err
}
}
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil {
return E.FailWith("write key file", err)
return err
}
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
if err != nil {
return E.FailWith("write cert file", err)
return err
}
return nil
}
@ -232,7 +227,7 @@ func (p *Provider) certState() CertState {
sort.Strings(certDomains)
if !reflect.DeepEqual(certDomains, wantedDomains) {
logger.Infof("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
logger.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
return CertStateMismatch
}
@ -246,25 +241,25 @@ func (p *Provider) renewIfNeeded() E.Error {
switch p.certState() {
case CertStateExpired:
logger.Info("certs expired, renewing")
logger.Info().Msg("certs expired, renewing")
case CertStateMismatch:
logger.Info("cert domains mismatch with config, renewing")
logger.Info().Msg("cert domains mismatch with config, renewing")
default:
return nil
}
if err := p.ObtainCert(); err.HasError() {
return E.FailWith("renew certificate", err)
if err := p.ObtainCert(); err != nil {
return err
}
return nil
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) {
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert))
if err.HasError() {
return nil, E.FailWith("parse certificate", err)
x509Cert, err := x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
if x509Cert.IsCA {
continue
@ -284,13 +279,10 @@ func providerGenerator[CT any, PT challenge.Provider](
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
cfg := defaultCfg()
err := U.Deserialize(opt, cfg)
if err.HasError() {
if err != nil {
return nil, err
}
p, err := E.Check(newProvider(cfg))
if err.HasError() {
return nil, err
}
return p, nil
p, pErr := newProvider(cfg)
return p, E.From(pErr)
}
}

View file

@ -45,6 +45,6 @@ oauth2_config:
testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
ExpectNoError(t, U.Deserialize(opt, cfg).Error())
ExpectNoError(t, U.Deserialize(opt, cfg))
ExpectDeepEqual(t, cfg, cfgExpected)
}

View file

@ -11,7 +11,7 @@ func (p *Provider) Setup() (err E.Error) {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
return err
}
logger.Debug("obtaining cert due to error loading cert")
logger.Debug().Msg("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err
}
@ -20,7 +20,7 @@ func (p *Provider) Setup() (err E.Error) {
p.ScheduleRenewal()
for _, expiry := range p.GetExpiries() {
logger.Infof("certificate expire on %s", expiry)
logger.Info().Msg("certificate expire on " + expiry.String())
break
}

View file

@ -3,8 +3,7 @@ package common
import (
"flag"
"fmt"
"github.com/sirupsen/logrus"
"log"
)
type Args struct {
@ -44,7 +43,7 @@ func GetArgs() Args {
flag.Parse()
args.Command = flag.Arg(0)
if err := validateArg(args.Command); err != nil {
logrus.Fatal(err)
log.Fatalf("invalid command: %s", err)
}
return args
}
@ -55,5 +54,5 @@ func validateArg(arg string) error {
return nil
}
}
return fmt.Errorf("invalid command: %s", arg)
return fmt.Errorf("invalid command %q", arg)
}

View file

@ -7,8 +7,6 @@ import (
"os"
"strconv"
"strings"
"github.com/sirupsen/logrus"
)
var (
@ -40,7 +38,7 @@ func GetEnvBool(key string, defaultValue bool) bool {
}
b, err := strconv.ParseBool(value)
if err != nil {
log.Fatalf("Invalid boolean value: %s", value)
log.Fatalf("env %s: invalid boolean value: %s", key, value)
}
return b
}
@ -57,7 +55,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
addr = GetEnv(key, defaultValue)
host, port, err := net.SplitHostPort(addr)
if err != nil {
logrus.Fatalf("Invalid address: %s", addr)
log.Fatalf("env %s: invalid address: %s", key, addr)
}
if host == "" {
host = "localhost"

View file

@ -2,14 +2,15 @@ package config
import (
"os"
"strconv"
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/route"
proxy "github.com/yusing/go-proxy/internal/route/provider"
@ -31,7 +32,7 @@ type Config struct {
var (
instance *Config
cfgWatcher watcher.Watcher
logger = logrus.WithField("module", "config")
logger = logging.With().Str("module", "config").Logger()
reloadMu sync.Mutex
)
@ -80,7 +81,7 @@ func WatchChanges() {
configEventFlushInterval,
OnConfigChange,
func(err E.Error) {
logger.Error(err)
E.LogError("config reload error", err, &logger)
},
)
eventQueue.Start(cfgWatcher.Events(task.Context()))
@ -93,15 +94,16 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) {
// just reload once and check the last event
switch ev[len(ev)-1].Action {
case events.ActionFileRenamed:
logger.Warn(cfgRenameWarn)
logger.Warn().Msg(cfgRenameWarn)
return
case events.ActionFileDeleted:
logger.Warn(cfgDeleteWarn)
logger.Warn().Msg(cfgDeleteWarn)
return
}
if err := Reload(); err != nil {
logger.Error(err)
// recovered in event queue
panic(err)
}
}
@ -138,63 +140,57 @@ func (cfg *Config) Task() task.Task {
}
func (cfg *Config) StartProxyProviders() {
b := E.NewBuilder("errors starting providers")
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
b.Add(p.Start(cfg.task.Subtask(p.String())))
})
errs := cfg.providers.CollectErrorsParallel(
func(_ string, p *proxy.Provider) error {
subtask := cfg.task.Subtask(p.String())
return p.Start(subtask)
})
if b.HasError() {
logger.Error(b.Build())
if err := E.Join(errs...); err != nil {
E.LogError("route provider errors", err, &logger)
}
}
func (cfg *Config) load() (res E.Error) {
errs := E.NewBuilder("errors loading config")
defer errs.To(&res)
func (cfg *Config) load() E.Error {
const errMsg = "config load error"
logger.Debug("loading config")
defer logger.Debug("loaded config")
data, err := E.Check(os.ReadFile(common.ConfigPath))
data, err := os.ReadFile(common.ConfigPath)
if err != nil {
errs.Add(E.FailWith("read config", err))
logrus.Fatal(errs.Build())
E.LogFatal(errMsg, err, &logger)
}
if !common.NoSchemaValidation {
if err = Validate(data); err != nil {
errs.Add(E.FailWith("schema validation", err))
logrus.Fatal(errs.Build())
if err := Validate(data); err != nil {
E.LogFatal(errMsg, err, &logger)
}
}
model := types.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err != nil {
errs.Add(E.FailWith("parse config", err))
logrus.Fatal(errs.Build())
E.LogFatal(errMsg, err, &logger)
}
// errors are non fatal below
errs := E.NewBuilder(errMsg)
errs.Add(cfg.initNotification(model.Providers.Notification))
errs.Add(cfg.initAutoCert(&model.AutoCert))
errs.Add(cfg.loadRouteProviders(&model.Providers))
cfg.value = model
route.SetFindMuxDomains(model.MatchDomains)
return
return errs.Error()
}
func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) {
if len(notifCfgMap) == 0 {
return
}
errs := E.NewBuilder("errors initializing notification providers")
errs := E.NewBuilder("notification providers load errors")
for name, notifCfg := range notifCfgMap {
_, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg)
errs.Add(err)
}
return errs.Build()
return errs.Error()
}
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
@ -203,40 +199,45 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error)
}
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err != nil {
err = E.FailWith("autocert provider", err)
}
return
}
func (cfg *Config) loadRouteProviders(providers *types.Providers) (outErr E.Error) {
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
subtask := cfg.task.Subtask("load route providers")
defer subtask.Finish("done")
errs := E.NewBuilder("errors loading route providers")
results := E.NewBuilder("loaded providers")
defer errs.To(&outErr)
errs := E.NewBuilder("route provider errors")
results := E.NewBuilder("loaded route providers")
lenLongestName := 0
for _, filename := range providers.Files {
p, err := proxy.NewFileProvider(filename)
if err != nil {
errs.Add(err)
errs.Add(E.PrependSubject(filename, err))
continue
}
cfg.providers.Store(p.GetName(), p)
errs.Add(p.LoadRoutes().Subject(filename))
results.Addf("%d routes from %s", p.NumRoutes(), p.String())
if len(p.GetName()) > lenLongestName {
lenLongestName = len(p.GetName())
}
}
for name, dockerHost := range providers.Docker {
p, err := proxy.NewDockerProvider(name, dockerHost)
if err != nil {
errs.Add(err.Subjectf("%s (%s)", name, dockerHost))
errs.Add(E.PrependSubject(name, err))
continue
}
cfg.providers.Store(p.GetName(), p)
errs.Add(p.LoadRoutes().Subject(p.GetName()))
results.Addf("%d routes from %s", p.NumRoutes(), p.String())
if len(p.GetName()) > lenLongestName {
lenLongestName = len(p.GetName())
}
}
logger.Info(results.Build())
return
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
if err := p.LoadRoutes(); err != nil {
errs.Add(err.Subject(p.String()))
}
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.GetName(), p.NumRoutes())
})
logger.Info().Msg(results.String())
return errs.Error()
}

View file

@ -9,8 +9,8 @@ import (
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/route"
proxy "github.com/yusing/go-proxy/internal/route/provider"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
func DumpEntries() map[string]*entry.RawEntry {
@ -61,7 +61,7 @@ func HomepageConfig() homepage.Config {
}
if item.Name == "" {
item.Name = U.Title(
item.Name = strutils.Title(
strings.ReplaceAll(
strings.ReplaceAll(alias, "-", " "),
"_", " ",

View file

@ -1,14 +1,15 @@
package docker
import (
"errors"
"net/http"
"sync"
"github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
@ -22,7 +23,7 @@ type (
key string
refCount *U.RefCount
l logrus.FieldLogger
l zerolog.Logger
}
)
@ -70,7 +71,7 @@ func (c *SharedClient) Close() error {
// Returns:
// - Client: the Docker client connection.
// - error: an error if the connection failed.
func ConnectClient(host string) (Client, E.Error) {
func ConnectClient(host string) (Client, error) {
clientMapMu.Lock()
defer clientMapMu.Unlock()
@ -85,13 +86,13 @@ func ConnectClient(host string) (Client, E.Error) {
switch host {
case "":
return nil, E.Invalid("docker host", "empty")
return nil, errors.New("empty docker host")
case common.DockerHostFromEnv:
opt = clientOptEnvHost
default:
helper, err := E.Check(connhelper.GetConnectionHelper(host))
if err.HasError() {
return nil, E.UnexpectedError(err.Error())
helper, err := connhelper.GetConnectionHelper(host)
if err != nil {
logging.Panic().Err(err).Msg("failed to get connection helper")
}
if helper != nil {
httpClient := &http.Client{
@ -113,9 +114,9 @@ func ConnectClient(host string) (Client, E.Error) {
}
}
client, err := E.Check(client.NewClientWithOpts(opt...))
client, err := client.NewClientWithOpts(opt...)
if err.HasError() {
if err != nil {
return nil, err
}
@ -123,9 +124,9 @@ func ConnectClient(host string) (Client, E.Error) {
Client: client,
key: host,
refCount: U.NewRefCounter(),
l: logger.WithField("docker_client", client.DaemonHost()),
l: logger.With().Str("address", client.DaemonHost()).Logger(),
}
c.l.Debugf("client connected")
c.l.Trace().Msg("client connected")
clientMap.Store(host, c)
@ -135,7 +136,7 @@ func ConnectClient(host string) (Client, E.Error) {
if c.Connected() {
c.Client.Close()
c.l.Debugf("client closed")
c.l.Trace().Msg("client closed")
}
}()
return c, nil

View file

@ -6,8 +6,8 @@ import (
"strings"
"github.com/docker/docker/api/types"
"github.com/sirupsen/logrus"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
@ -59,7 +59,7 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) {
NetworkMode: c.HostConfig.NetworkMode,
Aliases: helper.getAliases(),
IsExcluded: U.ParseBool(helper.getDeleteLabel(LabelExclude)),
IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)),
IsExplicit: isExplicit,
IsDatabase: helper.isDatabase(),
IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout),
@ -120,7 +120,7 @@ func (c *Container) setPublicIP() {
}
url, err := url.Parse(c.DockerHost)
if err != nil {
logrus.Errorf("invalid docker host %q: %v\nfalling back to 127.0.0.1", c.DockerHost, err)
logger.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
c.PublicIP = "127.0.0.1"
return
}

View file

@ -4,7 +4,7 @@ import (
"strings"
"github.com/docker/docker/api/types"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type containerHelper struct {
@ -23,7 +23,7 @@ func (c containerHelper) getDeleteLabel(label string) string {
func (c containerHelper) getAliases() []string {
if l := c.getDeleteLabel(LabelAliases); l != "" {
return U.CommaSeperatedList(l)
return strutils.CommaSeperatedList(l)
}
return []string{c.getName()}
}
@ -44,7 +44,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping {
if v.PublicPort == 0 {
continue
}
res[U.PortString(v.PublicPort)] = v
res[strutils.PortString(v.PublicPort)] = v
}
return res
}
@ -52,7 +52,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping {
func (c containerHelper) getPrivatePortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
res[U.PortString(v.PrivatePort)] = v
res[strutils.PortString(v.PrivatePort)] = v
}
return res
}

View file

@ -1,6 +1,7 @@
package types
import (
"errors"
"time"
"github.com/yusing/go-proxy/internal/docker"
@ -30,7 +31,7 @@ const (
StopMethodKill StopMethod = "kill"
)
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
func ValidateConfig(cont *docker.Container) (*Config, E.Error) {
if cont == nil {
return nil, nil
}
@ -44,26 +45,16 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
}, nil
}
b := E.NewBuilder("invalid idlewatcher config")
defer b.To(&res)
errs := E.NewBuilder("invalid idlewatcher config")
idleTimeout, err := validateDurationPostitive(cont.IdleTimeout)
b.Add(err.Subjectf("%s", "idle_timeout"))
idleTimeout := E.Collect(errs, validateDurationPostitive, cont.IdleTimeout)
wakeTimeout := E.Collect(errs, validateDurationPostitive, cont.WakeTimeout)
stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout)
stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod)
signal := E.Collect(errs, validateSignal, cont.StopSignal)
wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout)
b.Add(err.Subjectf("%s", "wake_timeout"))
stopTimeout, err := validateDurationPostitive(cont.StopTimeout)
b.Add(err.Subjectf("%s", "stop_timeout"))
stopMethod, err := validateStopMethod(cont.StopMethod)
b.Add(err)
signal, err := validateSignal(cont.StopSignal)
b.Add(err)
if err := b.Build(); err != nil {
return
if errs.HasError() {
return nil, errs.Error()
}
return &Config{
@ -80,33 +71,33 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
}, nil
}
func validateDurationPostitive(value string) (time.Duration, E.Error) {
func validateDurationPostitive(value string) (time.Duration, error) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value).With(err)
return 0, err
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
return 0, errors.New("duration must be positive")
}
return d, nil
}
func validateSignal(s string) (Signal, E.Error) {
func validateSignal(s string) (Signal, error) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
return "", errors.New("invalid signal " + s)
}
func validateStopMethod(s string) (StopMethod, E.Error) {
func validateStopMethod(s string) (StopMethod, error) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
return "", errors.New("invalid stop method " + s)
}
}

View file

@ -42,7 +42,7 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr
watcher, err := registerWatcher(providerSubTask, entry, waker)
if err != nil {
return nil, err
return nil, E.Errorf("register watcher: %w", err)
}
if rp != nil {
@ -75,6 +75,9 @@ func (w *Watcher) Start(routeSubTask task.Task) E.Error {
// Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason any) {
if w.stream != nil {
w.stream.Close()
}
}
// Name implements health.HealthMonitor.

View file

@ -7,9 +7,7 @@ import (
"strconv"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/watcher/health"
)
@ -20,21 +18,26 @@ func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if !shouldNext {
return
}
w.rp.ServeHTTP(rw, r)
select {
case <-r.Context().Done():
return
default:
w.rp.ServeHTTP(rw, r)
}
}
func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer()
if r.Body != nil {
defer r.Body.Close()
}
// pass through if container is already ready
if w.ready.Load() {
return true
}
if r.Body != nil {
defer r.Body.Close()
}
accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
@ -49,23 +52,22 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
rw.Header().Add("Cache-Control", "must-revalidate")
rw.Header().Add("Connection", "close")
if _, err := rw.Write(body); err != nil {
w.l.Errorf("error writing http response: %s", err)
w.Err(err).Msg("error writing http response")
}
return
return false
}
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel()
checkCanceled := func() bool {
checkCanceled := func() (canceled bool) {
select {
case <-w.task.Context().Done():
w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context()))
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true
case <-ctx.Done():
w.l.Debugf("wake canceled: %s", context.Cause(ctx))
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
w.WakeDebug().Str("cause", context.Cause(ctx).Error()).Msg("canceled")
return true
case <-w.task.Context().Done():
w.WakeDebug().Str("cause", w.task.FinishCause().Error()).Msg("canceled")
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true
default:
return false
@ -76,12 +78,12 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
return false
}
w.l.Debug("wake signal received")
w.WakeTrace().Msg("signal received")
err := w.wakeIfStopped()
if err != nil {
w.l.Error(E.FailWith("wake", err))
w.WakeError(err).Send()
http.Error(rw, "Error waking container", http.StatusInternalServerError)
return
return false
}
for {
@ -92,11 +94,11 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
if isCheckRedirect {
logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL())
w.Debug().Msgf("redirecting to %s ...", w.hc.URL())
rw.WriteHeader(http.StatusOK)
return
return false
}
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
w.Debug().Msgf("passing through to %s ...", w.hc.URL())
return true
}

View file

@ -7,7 +7,7 @@ import (
"net"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
@ -16,25 +16,25 @@ func (w *Watcher) Addr() net.Addr {
return w.stream.Addr()
}
// Setup implements types.Stream.
func (w *Watcher) Setup() error {
return w.stream.Setup()
}
// Accept implements types.Stream.
func (w *Watcher) Accept() (conn net.Conn, err error) {
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
conn, err = w.stream.Accept()
if err != nil {
logrus.Errorf("accept failed with error: %s", err)
return
}
if err := w.wakeFromStream(); err != nil {
w.l.Error(err)
if wakeErr := w.wakeFromStream(); wakeErr != nil {
w.WakeError(wakeErr).Msg("error waking from stream")
}
return
}
// Handle implements types.Stream.
func (w *Watcher) Handle(conn net.Conn) error {
func (w *Watcher) Handle(conn types.StreamConn) error {
if err := w.wakeFromStream(); err != nil {
return err
}
@ -54,11 +54,11 @@ func (w *Watcher) wakeFromStream() error {
return nil
}
w.l.Debug("wake signal received")
w.WakeDebug().Msg("wake signal received")
wakeErr := w.wakeIfStopped()
if wakeErr != nil {
wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr)
w.l.Error(wakeErr)
wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr)
w.WakeError(wakeErr).Msg("wake failed")
return wakeErr
}
@ -69,18 +69,18 @@ func (w *Watcher) wakeFromStream() error {
select {
case <-w.task.Context().Done():
cause := w.task.FinishCause()
w.l.Debugf("wake canceled: %s", cause)
w.WakeDebug().Str("cause", cause.Error()).Msg("canceled")
return cause
case <-ctx.Done():
cause := context.Cause(ctx)
w.l.Debugf("wake canceled: %s", cause)
w.WakeDebug().Str("cause", cause.Error()).Msg("timeout")
return cause
default:
}
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
w.Debug().Msg("container is ready, passing through to " + w.hc.URL().String())
return nil
}

View file

@ -3,15 +3,15 @@ package idlewatcher
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog"
D "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
@ -25,6 +25,8 @@ type (
Watcher struct {
_ U.NoCopy
zerolog.Logger
*idlewatcher.Config
*waker
@ -32,7 +34,6 @@ type (
stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
ticker *time.Ticker
task task.Task
l *logrus.Entry
}
WakeDone <-chan error
@ -44,13 +45,12 @@ var (
watcherMap = F.NewMapOf[string, *Watcher]()
watcherMapMu sync.Mutex
logger = logrus.WithField("module", "idle_watcher")
logger = logging.With().Str("module", "idle_watcher").Logger()
)
const dockerReqTimeout = 3 * time.Second
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) {
failure := E.Failure("idle_watcher register")
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, error) {
cfg := entry.IdlewatcherConfig()
if cfg.IdleTimeout == 0 {
@ -71,17 +71,17 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
}
client, err := D.ConnectClient(cfg.DockerHost)
if err.HasError() {
return nil, failure.With(err)
if err != nil {
return nil, err
}
w := &Watcher{
Logger: logger.With().Str("name", cfg.ContainerName).Logger(),
Config: cfg,
waker: waker,
client: client,
task: providerSubtask,
ticker: time.NewTicker(cfg.IdleTimeout),
l: logger.WithField("container", cfg.ContainerName),
}
w.stopByMethod = w.getStopCallback()
watcherMap.Store(key, w)
@ -99,6 +99,23 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
return w, nil
}
// WakeDebug logs a debug message related to waking the container.
func (w *Watcher) WakeDebug() *zerolog.Event {
return w.Debug().Str("action", "wake")
}
func (w *Watcher) WakeTrace() *zerolog.Event {
return w.Trace().Str("action", "wake")
}
func (w *Watcher) WakeError(err error) *zerolog.Event {
return w.Err(err).Str("action", "wake")
}
func (w *Watcher) LogReason(action, reason string) {
w.Info().Str("reason", reason).Msg(action)
}
func (w *Watcher) containerStop(ctx context.Context) error {
return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{
Signal: string(w.StopSignal),
@ -130,7 +147,7 @@ func (w *Watcher) containerStatus() (string, error) {
defer cancel()
json, err := w.client.ContainerInspect(ctx, w.ContainerID)
if err != nil {
return "", fmt.Errorf("failed to inspect container: %w", err)
return "", err
}
return json.State.Status, nil
}
@ -181,7 +198,7 @@ func (w *Watcher) getStopCallback() StopCallback {
}
func (w *Watcher) resetIdleTimer() {
w.l.Trace("reset idle timer")
w.Trace().Msg("reset idle timer")
w.ticker.Reset(w.IdleTimeout)
}
@ -190,7 +207,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
Filters: W.NewDockerFilter(
W.DockerFilterContainer,
W.DockerrFilterContainer(w.ContainerID),
W.DockerFilterContainerNameID(w.ContainerID),
W.DockerFilterStart,
W.DockerFilterStop,
W.DockerFilterDie,
@ -214,7 +231,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
//
// it exits only if the context is canceled, the container is destroyed,
// errors occured on docker client, or route provider died (mainly caused by config reload).
func (w *Watcher) watchUntilDestroy() error {
func (w *Watcher) watchUntilDestroy() (returnCause error) {
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
defer eventTask.Finish("stopped")
@ -224,36 +241,36 @@ func (w *Watcher) watchUntilDestroy() error {
case <-w.task.Context().Done():
return w.task.FinishCause()
case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err))
return err.Error()
if !err.Is(context.Canceled) {
E.LogError("idlewatcher error", err, &w.Logger)
}
return err
case e := <-dockerEventCh:
switch {
case e.Action == events.ActionContainerDestroy:
w.ContainerRunning = false
w.ready.Store(false)
w.l.Info("watcher stopped by container destruction")
w.LogReason("watcher stopped", "container destroyed")
return errors.New("container destroyed")
// create / start / unpause
case e.Action.IsContainerWake():
w.ContainerRunning = true
w.resetIdleTimer()
w.l.Info("container awaken")
w.Info().Msg("awaken")
case e.Action.IsContainerSleep(): // stop / pause / kil
w.ContainerRunning = false
w.ready.Store(false)
w.ticker.Stop()
default:
w.l.Errorf("unexpected docker event: %s", e)
w.Error().Msg("unexpected docker event: " + e.String())
}
// container name changed should also change the container id
if w.ContainerName != e.ActorName {
w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName)
w.Debug().Msgf("renamed %s -> %s", w.ContainerName, e.ActorName)
w.ContainerName = e.ActorName
}
if w.ContainerID != e.ActorID {
w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID)
w.Debug().Msgf("id changed %s -> %s", w.ContainerID, e.ActorID)
w.ContainerID = e.ActorID
// recreate event stream
eventTask.Finish("recreate event stream")
@ -263,9 +280,9 @@ func (w *Watcher) watchUntilDestroy() error {
w.ticker.Stop()
if w.ContainerRunning {
if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) {
w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err)
w.Err(err).Msgf("container stop with method %q failed", w.StopMethod)
} else {
w.l.Info("container stopped by idle timeout")
w.LogReason("container stopped", "idle timeout")
}
}
}

View file

@ -4,28 +4,26 @@ import (
"context"
"errors"
"time"
E "github.com/yusing/go-proxy/internal/error"
)
func Inspect(dockerHost string, containerID string) (*Container, E.Error) {
func Inspect(dockerHost string, containerID string) (*Container, error) {
client, err := ConnectClient(dockerHost)
defer client.Close()
if err.HasError() {
return nil, E.FailWith("connect to docker", err)
if err != nil {
return nil, err
}
return client.Inspect(containerID)
}
func (c Client) Inspect(containerID string) (*Container, E.Error) {
func (c Client) Inspect(containerID string) (*Container, error) {
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
defer cancel()
json, err := c.ContainerInspect(ctx, containerID)
if err != nil {
return nil, E.From(err)
return nil, err
}
return FromJSON(json, c.key), nil
}

View file

@ -24,6 +24,11 @@ type (
NestedLabelMap map[string]U.SerializedObject
)
var (
ErrApplyToNil = E.New("label value is nil")
ErrFieldNotExist = E.New("field does not exist")
)
func (l *Label) String() string {
if l.Attribute == "" {
return l.Namespace + "." + l.Target
@ -41,7 +46,7 @@ func (l *Label) String() string {
// - error: an error if the field does not exist.
func ApplyLabel[T any](obj *T, l *Label) E.Error {
if obj == nil {
return E.Invalid("nil object", l)
return ErrApplyToNil.Subject(l.String())
}
switch nestedLabel := l.Value.(type) {
case *Label:
@ -54,7 +59,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
}
}
if !field.IsValid() {
return E.NotExist("field", l.Attribute)
return ErrFieldNotExist.Subject(l.Attribute).Subject(l.String())
}
dst, ok := field.Interface().(NestedLabelMap)
if !ok {
@ -65,7 +70,11 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
} else {
field = field.Addr()
}
return U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
err := U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
if err != nil {
return err.Subject(l.String())
}
return nil
}
if dst == nil {
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
@ -77,18 +86,22 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
return nil
default:
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
err := U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
if err != nil {
return err.Subject(l.String())
}
return nil
}
}
func ParseLabel(label string, value string) (*Label, E.Error) {
func ParseLabel(label string, value string) *Label {
parts := strings.Split(label, ".")
if len(parts) < 2 {
return &Label{
Namespace: label,
Value: value,
}, nil
}
}
l := &Label{
@ -104,12 +117,9 @@ func ParseLabel(label string, value string) (*Label, E.Error) {
l.Attribute = parts[2]
default:
l.Attribute = parts[2]
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
if err != nil {
return nil, err
}
nestedLabel := ParseLabel(strings.Join(parts[3:], "."), value)
l.Value = nestedLabel
}
return l, nil
return l
}

View file

@ -20,9 +20,8 @@ func makeLabel(ns, name, attr string) string {
func TestNestedLabel(t *testing.T) {
mAttr := "prop1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error())
sGot := ExpectType[*Label](t, pl.Value)
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
sGot := ExpectType[*Label](t, lbl.Value)
ExpectFalse(t, sGot == nil)
ExpectEqual(t, sGot.Namespace, mName)
ExpectEqual(t, sGot.Attribute, mAttr)
@ -32,10 +31,9 @@ func TestApplyNestedLabel(t *testing.T) {
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
err := ApplyLabel(entry, lbl)
ExpectNoError(t, err)
middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr])
@ -52,10 +50,9 @@ func TestApplyNestedLabelExisting(t *testing.T) {
entry.Middlewares[mName] = make(U.SerializedObject)
entry.Middlewares[mName][checkAttr] = checkV
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
err := ApplyLabel(entry, lbl)
ExpectNoError(t, err)
middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr])
@ -74,10 +71,9 @@ func TestApplyNestedLabelNoAttr(t *testing.T) {
entry.Middlewares = make(NestedLabelMap)
entry.Middlewares[mName] = make(U.SerializedObject)
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
lbl := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
err := ApplyLabel(entry, lbl)
ExpectNoError(t, err)
_, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
}

View file

@ -8,7 +8,6 @@ import (
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
E "github.com/yusing/go-proxy/internal/error"
)
var listOptions = container.ListOptions{
@ -23,19 +22,19 @@ var listOptions = container.ListOptions{
All: true,
}
func ListContainers(clientHost string) ([]types.Container, E.Error) {
func ListContainers(clientHost string) ([]types.Container, error) {
dockerClient, err := ConnectClient(clientHost)
if err.HasError() {
return nil, E.FailWith("connect to docker", err)
if err != nil {
return nil, err
}
defer dockerClient.Close()
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout"))
defer cancel()
containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions))
if err.HasError() {
return nil, E.FailWith("list containers", err)
containers, err := dockerClient.ContainerList(ctx, listOptions)
if err != nil {
return nil, err
}
return containers, nil
}

View file

@ -1,5 +1,7 @@
package docker
import "github.com/sirupsen/logrus"
import (
"github.com/yusing/go-proxy/internal/logging"
)
var logger = logrus.WithField("module", "docker")
var logger = logging.With().Str("module", "docker").Logger()

46
internal/error/base.go Normal file
View file

@ -0,0 +1,46 @@
package error
import (
"errors"
"fmt"
)
// baseError is an immutable wrapper around an error.
type baseError struct {
Err error `json:"err"`
}
func (err *baseError) Unwrap() error {
return err.Err
}
func (err *baseError) Is(other error) bool {
if other, ok := other.(*baseError); ok {
return errors.Is(err.Err, other.Err)
}
return errors.Is(err.Err, other)
}
func (err baseError) Subject(subject string) Error {
err.Err = PrependSubject(subject, err.Err)
return &err
}
func (err *baseError) Subjectf(format string, args ...any) Error {
if len(args) > 0 {
return err.Subject(fmt.Sprintf(format, args...))
}
return err.Subject(format)
}
func (err baseError) With(extra error) Error {
return &nestedError{&err, []error{extra}}
}
func (err baseError) Withf(format string, args ...any) Error {
return &nestedError{&err, []error{fmt.Errorf(format, args...)}}
}
func (err *baseError) Error() string {
return err.Err.Error()
}

View file

@ -1,104 +1,104 @@
package error
import (
"errors"
"fmt"
"sync"
)
type Builder struct {
*builder
}
type builder struct {
message string
errors []Error
about string
errs []error
sync.Mutex
}
func NewBuilder(format string, args ...any) Builder {
if len(args) > 0 {
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
func NewBuilder(about string) *Builder {
return &Builder{about: about}
}
func (b *Builder) About() string {
if !b.HasError() {
return ""
}
return Builder{&builder{message: format}}
return b.about
}
//go:inline
func (b *Builder) HasError() bool {
return len(b.errs) > 0
}
func (b *Builder) Error() Error {
if !b.HasError() {
return nil
}
if len(b.errs) == 1 {
return From(b.errs[0])
}
return &nestedError{Err: New(b.about), Extras: b.errs}
}
func (b *Builder) String() string {
if !b.HasError() {
return ""
}
return (&nestedError{Err: New(b.about), Extras: b.errs}).Error()
}
// Add adds an error to the Builder.
//
// adding nil is no-op,
//
// flatten is a boolean flag to flatten the NestedError.
func (b Builder) Add(err Error, flatten ...bool) {
if err != nil {
b.Lock()
if len(flatten) > 0 && flatten[0] {
for _, e := range err.extras {
b.errors = append(b.errors, &e)
}
func (b *Builder) Add(err error) *Builder {
if err == nil {
return b
}
b.Lock()
defer b.Unlock()
switch err := err.(type) {
case *baseError:
b.errs = append(b.errs, err.Err)
case *nestedError:
if err.Err == nil {
b.errs = append(b.errs, err.Extras...)
} else {
b.errors = append(b.errors, err)
b.errs = append(b.errs, err)
}
b.Unlock()
}
}
func (b Builder) AddE(err error) {
b.Add(From(err))
}
func (b Builder) Addf(format string, args ...any) {
if len(args) > 0 {
b.Add(errorf(format, args...))
} else {
b.AddE(errors.New(format))
}
}
func (b Builder) AddRange(errs ...Error) {
b.Lock()
defer b.Unlock()
for _, err := range errs {
b.errors = append(b.errors, err)
}
}
func (b Builder) AddRangeE(errs ...error) {
b.Lock()
defer b.Unlock()
for _, err := range errs {
b.errors = append(b.errors, From(err))
}
}
// Build builds a NestedError based on the errors collected in the Builder.
//
// If there are no errors in the Builder, it returns a Nil() NestedError.
// Otherwise, it returns a NestedError with the message and the errors collected.
//
// Returns:
// - NestedError: the built NestedError.
func (b Builder) Build() Error {
if len(b.errors) == 0 {
return nil
}
return Join(b.message, b.errors...)
}
func (b Builder) To(ptr *Error) {
switch {
case ptr == nil:
return
case *ptr == nil:
*ptr = b.Build()
default:
(*ptr).extras = append((*ptr).extras, *b.Build())
b.errs = append(b.errs, err)
}
return b
}
func (b Builder) String() string {
return b.Build().String()
func (b *Builder) Adds(err string) *Builder {
b.Lock()
defer b.Unlock()
b.errs = append(b.errs, newError(err))
return b
}
func (b Builder) HasError() bool {
return len(b.errors) > 0
func (b *Builder) Addf(format string, args ...any) *Builder {
if len(args) > 0 {
b.Lock()
defer b.Unlock()
b.errs = append(b.errs, fmt.Errorf(format, args...))
} else {
b.Adds(format)
}
return b
}
func (b *Builder) AddRange(errs ...error) *Builder {
b.Lock()
defer b.Unlock()
for _, err := range errs {
if err != nil {
b.errs = append(b.errs, err)
}
}
return b
}

View file

@ -1,6 +1,9 @@
package error_test
import (
"context"
"errors"
"io"
"testing"
. "github.com/yusing/go-proxy/internal/error"
@ -8,14 +11,13 @@ import (
)
func TestBuilderEmpty(t *testing.T) {
eb := NewBuilder("qwer")
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
eb := NewBuilder("foo")
ExpectTrue(t, errors.Is(eb.Error(), nil))
ExpectFalse(t, eb.HasError())
}
func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf")
eb := NewBuilder("foo")
var err Error
for range 3 {
eb.Add(nil)
@ -23,41 +25,31 @@ func TestBuilderAddNil(t *testing.T) {
for range 3 {
eb.Add(err)
}
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
eb.AddRange(nil, nil, err)
ExpectFalse(t, eb.HasError())
ExpectTrue(t, eb.Error() == nil)
}
func TestBuilderIs(t *testing.T) {
eb := NewBuilder("foo")
eb.Add(context.Canceled)
eb.Add(io.ErrShortBuffer)
ExpectTrue(t, eb.HasError())
ExpectError(t, io.ErrShortBuffer, eb.Error())
ExpectError(t, context.Canceled, eb.Error())
}
func TestBuilderNested(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2")))
eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
got := eb.Build().String()
expected1 := (`error occurred:
- Action 1 failed:
- invalid Inner: 1
- invalid Inner: 2
- Action 2 failed:
- invalid Inner: 3`)
expected2 := (`error occurred:
- Action 1 failed:
- invalid Inner: "1"
- invalid Inner: "2"
- Action 2 failed:
- invalid Inner: "3"`)
ExpectEqualAny(t, got, []string{expected1, expected2})
}
func TestBuilderTo(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Addf("abcd")
var err Error
eb.To(&err)
got := err.String()
expected := (`error occurred:
- abcd`)
eb := NewBuilder("action failed")
eb.Add(New("Action 1").Withf("Inner: 1").Withf("Inner: 2"))
eb.Add(New("Action 2").Withf("Inner: 3"))
got := eb.String()
expected := `action failed
Action 1
Inner: 1
Inner: 2
Action 2
Inner: 3`
ExpectEqual(t, got, expected)
}

View file

@ -1,317 +1,31 @@
package error
import (
"encoding/json"
"errors"
"fmt"
"strings"
)
type Error interface {
error
type (
Error = *ErrorImpl
ErrorImpl struct {
subject string
err error
extras []ErrorImpl
}
ErrorJSONMarshaller struct {
Subject string `json:"subject"`
Err string `json:"error"`
Extras []ErrorJSONMarshaller `json:"extras,omitempty"`
}
)
func From(err error) Error {
if IsNil(err) {
return nil
}
return &ErrorImpl{err: err}
// Is is a wrapper for errors.Is when there is no sub-error.
//
// When there are sub-errors, they will also be checked.
Is(other error) bool
// With appends a sub-error to the error.
With(extra error) Error
// Withf is a wrapper for With(fmt.Errorf(format, args...)).
Withf(format string, args ...any) Error
// Subject prepends the given subject with a colon and space to the error message.
//
// If there is already a subject in the error message, the subject will be
// prepended to the existing subject with " > ".
//
// Subject empty string is ignored.
Subject(subject string) Error
// Subjectf is a wrapper for Subject(fmt.Sprintf(format, args...)).
Subjectf(format string, args ...any) Error
}
func FromJSON(data []byte) (Error, bool) {
var j ErrorJSONMarshaller
if err := json.Unmarshal(data, &j); err != nil {
return nil, false
}
if j.Err == "" {
return nil, false
}
extras := make([]ErrorImpl, len(j.Extras))
for i, e := range j.Extras {
extra, ok := fromJSONObject(e)
if !ok {
return nil, false
}
extras[i] = *extra
}
return &ErrorImpl{
subject: j.Subject,
err: errors.New(j.Err),
extras: extras,
}, true
}
func TryUnwrap(err error) error {
if err == nil {
return nil
}
if unwrapped := errors.Unwrap(err); unwrapped != nil {
return unwrapped
}
return err
}
// Check is a helper function that
// convert (T, error) to (T, NestedError).
func Check[T any](obj T, err error) (T, Error) {
return obj, From(err)
}
func Join(message string, err ...Error) Error {
extras := make([]ErrorImpl, len(err))
nErr := 0
for i, e := range err {
if e == nil {
continue
}
extras[i] = *e
nErr++
}
if nErr == 0 {
return nil
}
return &ErrorImpl{
err: errors.New(message),
extras: extras,
}
}
func JoinE(message string, err ...error) Error {
b := NewBuilder("%s", message)
for _, e := range err {
b.AddE(e)
}
return b.Build()
}
func IsNil(err error) bool {
return err == nil
}
func IsNotNil(err error) bool {
return err != nil
}
func (ne Error) String() string {
var buf strings.Builder
ne.writeToSB(&buf, 0, "")
return buf.String()
}
func (ne Error) Is(err error) bool {
if ne == nil {
return err == nil
}
// return errors.Is(ne.err, err)
if errors.Is(ne.err, err) {
return true
}
for _, e := range ne.extras {
if e.Is(err) {
return true
}
}
return false
}
func (ne Error) IsNot(err error) bool {
return !ne.Is(err)
}
func (ne Error) Error() error {
if ne == nil {
return nil
}
return ne.buildError(0, "")
}
func (ne Error) With(s any) Error {
if ne == nil {
return ne
}
var msg string
switch ss := s.(type) {
case nil:
return ne
case *ErrorImpl:
if len(ss.extras) == 1 {
ne.extras = append(ne.extras, ss.extras[0])
return ne
}
return ne.withError(ss)
case error:
// unwrap only once
return ne.withError(From(TryUnwrap(ss)))
case string:
msg = ss
case fmt.Stringer:
return ne.appendMsg(ss.String())
default:
return ne.appendMsg(fmt.Sprint(s))
}
return ne.withError(From(errors.New(msg)))
}
func (ne Error) Extraf(format string, args ...any) Error {
return ne.With(errorf(format, args...))
}
func (ne Error) Subject(s any, sep ...string) Error {
if ne == nil {
return ne
}
var subject string
switch ss := s.(type) {
case string:
subject = ss
case fmt.Stringer:
subject = ss.String()
default:
subject = fmt.Sprint(s)
}
switch {
case ne.subject == "":
ne.subject = subject
case len(sep) > 0:
ne.subject = fmt.Sprintf("%s%s%s", subject, sep[0], ne.subject)
default:
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
}
return ne
}
func (ne Error) Subjectf(format string, args ...any) Error {
if ne == nil {
return ne
}
return ne.Subject(fmt.Sprintf(format, args...))
}
func (ne Error) JSONObject() ErrorJSONMarshaller {
extras := make([]ErrorJSONMarshaller, len(ne.extras))
for i, e := range ne.extras {
extras[i] = e.JSONObject()
}
return ErrorJSONMarshaller{
Subject: ne.subject,
Err: ne.err.Error(),
Extras: extras,
}
}
func (ne Error) JSON() []byte {
b, err := json.MarshalIndent(ne.JSONObject(), "", " ")
if err != nil {
panic(err)
}
return b
}
func (ne Error) NoError() bool {
return ne == nil
}
func (ne Error) HasError() bool {
return ne != nil
}
func errorf(format string, args ...any) Error {
for i, arg := range args {
if err, ok := arg.(error); ok {
if unwrapped := errors.Unwrap(err); unwrapped != nil {
args[i] = unwrapped
}
}
}
return From(fmt.Errorf(format, args...))
}
func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) {
data, err := json.Marshal(obj)
if err != nil {
return nil, false
}
return FromJSON(data)
}
func (ne Error) withError(err Error) Error {
if ne != nil && err != nil {
ne.extras = append(ne.extras, *err)
}
return ne
}
func (ne Error) appendMsg(msg string) Error {
if ne == nil {
return nil
}
ne.err = fmt.Errorf("%w %s", ne.err, msg)
return ne
}
func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) {
for range level {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return
}
if ne.subject != "" {
sb.WriteString(ne.subject)
sb.WriteRune(' ')
}
sb.WriteString(ne.err.Error())
if len(ne.extras) > 0 {
sb.WriteRune(':')
for _, extra := range ne.extras {
sb.WriteRune('\n')
extra.writeToSB(sb, level+1, "- ")
}
}
}
func (ne Error) buildError(level int, prefix string) error {
var res error
var sb strings.Builder
for range level {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return errors.New(sb.String())
}
res = fmt.Errorf("%s%w", sb.String(), ne.err)
sb.Reset()
if ne.subject != "" {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
}
if len(ne.extras) > 0 {
sb.WriteRune(':')
res = fmt.Errorf("%w%s", res, sb.String())
for _, extra := range ne.extras {
res = errors.Join(res, extra.buildError(level+1, "- "))
}
} else {
res = fmt.Errorf("%w%s", res, sb.String())
}
return res
// this makes JSON marshalling work,
// as the builtin one doesn't.
type errStr string
func (err errStr) Error() string {
return string(err)
}

View file

@ -1,107 +1,157 @@
package error_test
package error
import (
"errors"
"strings"
"testing"
. "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBaseString(t *testing.T) {
ExpectEqual(t, New("error").Error(), "error")
}
func TestBaseWithSubject(t *testing.T) {
err := New("error")
withSubject := err.Subject("foo")
withSubjectf := err.Subjectf("%s %s", "foo", "bar")
ExpectError(t, err, withSubject)
ExpectStrEqual(t, withSubject.Error(), "foo: error")
ExpectTrue(t, withSubject.Is(err))
ExpectError(t, err, withSubjectf)
ExpectStrEqual(t, withSubjectf.Error(), "foo bar: error")
ExpectTrue(t, withSubjectf.Is(err))
}
func TestBaseWithExtra(t *testing.T) {
err := New("error")
extra := New("bar").Subject("baz")
withExtra := err.With(extra)
ExpectTrue(t, withExtra.Is(extra))
ExpectTrue(t, withExtra.Is(err))
ExpectTrue(t, errors.Is(withExtra, extra))
ExpectTrue(t, errors.Is(withExtra, err))
ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error()))
ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error()))
ExpectTrue(t, strings.Contains(withExtra.Error(), "baz"))
}
func TestBaseUnwrap(t *testing.T) {
err := errors.New("err")
wrapped := From(err)
ExpectError(t, err, errors.Unwrap(wrapped))
}
func TestNestedUnwrap(t *testing.T) {
err := errors.New("err")
err2 := New("err2")
wrapped := From(err).Subject("foo").With(err2.Subject("bar"))
unwrapper, ok := wrapped.(interface{ Unwrap() []error })
ExpectTrue(t, ok)
ExpectError(t, err, wrapped)
ExpectError(t, err2, wrapped)
ExpectEqual(t, len(unwrapper.Unwrap()), 2)
}
func TestErrorIs(t *testing.T) {
ExpectTrue(t, Failure("foo").Is(ErrFailure))
ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure))
ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid))
ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid))
from := errors.New("error")
err := From(from)
ExpectError(t, from, err)
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid))
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure))
ExpectTrue(t, err.Is(from))
ExpectFalse(t, err.Is(New("error")))
ExpectFalse(t, Invalid("foo", "bar").Is(nil))
ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure))
ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists))
ExpectTrue(t, errors.Is(err.Subject("foo"), from))
ExpectTrue(t, errors.Is(err.Withf("foo"), from))
ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from))
}
func TestErrorNestedIs(t *testing.T) {
var err Error
ExpectTrue(t, err.Is(nil))
func TestErrorImmutability(t *testing.T) {
err := New("err")
err2 := New("err2")
err = Failure("some reason")
ExpectTrue(t, err.Is(ErrFailure))
ExpectFalse(t, err.Is(ErrDuplicated))
for range 3 {
// t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err)
err.Subject("foo")
ExpectFalse(t, strings.Contains(err.Error(), "foo"))
err.With(Duplicated("something", ""))
ExpectTrue(t, err.Is(ErrFailure))
ExpectTrue(t, err.Is(ErrDuplicated))
ExpectFalse(t, err.Is(ErrInvalid))
}
err.With(err2)
ExpectFalse(t, strings.Contains(err.Error(), "extra"))
ExpectFalse(t, err.Is(err2))
func TestIsNil(t *testing.T) {
var err Error
ExpectTrue(t, err.Is(nil))
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())
eb := NewBuilder("")
returnNil := func() error {
return eb.Build().Error()
err = err.Subject("bar").Withf("baz")
ExpectTrue(t, err != nil)
}
ExpectTrue(t, IsNil(returnNil()))
ExpectTrue(t, returnNil() == nil)
ExpectTrue(t, (err.
Subject("any").
With("something").
Extraf("foo %s", "bar")) == nil)
}
func TestErrorSimple(t *testing.T) {
ne := Failure("foo bar")
ExpectEqual(t, ne.String(), "foo bar failed")
ne = ne.Subject("baz")
ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
}
func TestErrorWith(t *testing.T) {
ne := Failure("foo").With("bar").With("baz")
ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz")
err1 := New("err1")
err2 := New("err2")
err3 := err1.With(err2)
ExpectTrue(t, err3.Is(err1))
ExpectTrue(t, err3.Is(err2))
err2.Subject("foo")
ExpectTrue(t, err3.Is(err1))
ExpectTrue(t, err3.Is(err2))
// check if err3 is affected by err2.Subject
ExpectFalse(t, strings.Contains(err3.Error(), "foo"))
}
func TestErrorNested(t *testing.T) {
inner := Failure("inner").
With("1").
With("1")
inner2 := Failure("inner2").
func TestErrorStringSimple(t *testing.T) {
errFailure := New("generic failure")
ne := errFailure.Subject("foo bar")
ExpectStrEqual(t, ne.Error(), "foo bar: generic failure")
ne = ne.Subject("baz")
ExpectStrEqual(t, ne.Error(), "baz > foo bar: generic failure")
}
func TestErrorStringNested(t *testing.T) {
errFailure := New("generic failure")
inner := errFailure.Subject("inner").
Withf("1").
Withf("1")
inner2 := errFailure.Subject("inner2").
Subject("action 2").
With("2").
With("2")
inner3 := Failure("inner3").
Withf("2").
Withf("2")
inner3 := errFailure.Subject("inner3").
Subject("action 3").
With("3").
With("3")
ne := Failure("foo").
With("bar").
With("baz").
Withf("3").
Withf("3")
ne := errFailure.
Subject("foo").
Withf("bar").
Withf("baz").
With(inner).
With(inner.With(inner2.With(inner3)))
want := `foo failed:
- bar
- baz
- inner failed:
- 1
- 1
- inner failed:
- 1
- 1
- inner2 failed for "action 2":
- 2
- 2
- inner3 failed for "action 3":
- 3
- 3`
ExpectEqual(t, ne.String(), want)
ExpectEqual(t, ne.Error().Error(), want)
want := `foo: generic failure
bar
baz
inner: generic failure
1
1
inner: generic failure
1
1
action 2 > inner2: generic failure
2
2
action 3 > inner3: generic failure
3
3`
ExpectStrEqual(t, ne.Error(), want)
}

View file

@ -1,83 +0,0 @@
package error
import (
stderrors "errors"
"fmt"
"reflect"
)
var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error")
ErrTypeMismatch = stderrors.New("type mismatch")
ErrPanicRecv = stderrors.New("panic recovered from")
)
const fmtSubjectWhat = "%w %v: %q"
func Failure(what string) Error {
return errorf("%s %w", what, ErrFailure)
}
func FailedWhy(what string, why string) Error {
return Failure(what).With(why)
}
func FailWith(what string, err any) Error {
return Failure(what).With(err)
}
func Invalid(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
}
func Unsupported(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
}
func Unexpected(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
}
func UnexpectedError(err error) Error {
return errorf("%w error: %w", ErrUnexpected, err)
}
func NotExist(subject, what any) Error {
return errorf("%v %w: %v", subject, ErrNotExists, what)
}
func Missing(subject any) Error {
return errorf("%w %v", ErrMissing, subject)
}
func Duplicated(subject, what any) Error {
return errorf("%w %v: %v", ErrDuplicated, subject, what)
}
func OutOfRange(subject any, value any) Error {
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
}
func TypeError(subject any, from, to reflect.Type) Error {
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
}
func TypeError2(subject any, from, to reflect.Value) Error {
return TypeError(subject, from.Type(), to.Type())
}
func TypeMismatch[Expect any](value any) Error {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
}
func PanicRecv(format string, args ...any) Error {
return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...))
}

43
internal/error/log.go Normal file
View file

@ -0,0 +1,43 @@
package error
import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func getLogger(logger ...*zerolog.Logger) *zerolog.Logger {
if len(logger) > 0 {
return logger[0]
}
return logging.GetLogger()
}
//go:inline
func LogFatal(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Fatal().Msg(err.Error())
}
//go:inline
func LogError(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Error().Msg(err.Error())
}
//go:inline
func LogWarn(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Warn().Msg(err.Error())
}
//go:inline
func LogPanic(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Panic().Msg(err.Error())
}
//go:inline
func LogInfo(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Info().Msg(err.Error())
}
//go:inline
func LogDebug(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Debug().Msg(err.Error())
}

View file

@ -0,0 +1,120 @@
package error
import (
"errors"
"fmt"
"strings"
)
type nestedError struct {
Err error `json:"err"`
Extras []error `json:"extras"`
}
func (err nestedError) Subject(subject string) Error {
if err.Err == nil {
err.Err = newError(subject)
} else {
err.Err = PrependSubject(subject, err.Err)
}
return &err
}
func (err *nestedError) Subjectf(format string, args ...any) Error {
if len(args) > 0 {
return err.Subject(fmt.Sprintf(format, args...))
}
return err.Subject(format)
}
func (err nestedError) With(extra error) Error {
if extra != nil {
err.Extras = append(err.Extras, extra)
}
return &err
}
func (err nestedError) Withf(format string, args ...any) Error {
if len(args) > 0 {
err.Extras = append(err.Extras, fmt.Errorf(format, args...))
} else {
err.Extras = append(err.Extras, newError(format))
}
return &err
}
func (err *nestedError) Unwrap() []error {
if err.Err == nil {
if len(err.Extras) == 0 {
return nil
}
return err.Extras
}
return append([]error{err.Err}, err.Extras...)
}
func (err *nestedError) Is(other error) bool {
if errors.Is(err.Err, other) {
return true
}
for _, e := range err.Extras {
if errors.Is(e, other) {
return true
}
}
return false
}
func (err *nestedError) Error() string {
return buildError(err, 0)
}
//go:inline
func makeLine(err string, level int) string {
const bulletPrefix = "• "
const spaces = " "
if level == 0 {
return err
}
return spaces[:2*level] + bulletPrefix + err
}
func makeLines(errs []error, level int) []string {
if len(errs) == 0 {
return nil
}
lines := make([]string, 0, len(errs))
for _, err := range errs {
switch err := err.(type) {
case *nestedError:
if err.Err != nil {
lines = append(lines, makeLine(err.Err.Error(), level))
}
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
lines = append(lines, extras...)
}
default:
lines = append(lines, makeLine(err.Error(), level))
}
}
return lines
}
func buildError(err error, level int) string {
switch err := err.(type) {
case nil:
return makeLine("<nil>", level)
case *nestedError:
lines := make([]string, 0, 1+len(err.Extras))
if err.Err != nil {
lines = append(lines, makeLine(err.Err.Error(), level))
}
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
lines = append(lines, extras...)
}
return strings.Join(lines, "\n")
default:
return makeLine(err.Error(), level)
}
}

50
internal/error/subject.go Normal file
View file

@ -0,0 +1,50 @@
package error
import (
"strings"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
)
type withSubject struct {
Subject string `json:"subject"`
Err error `json:"err"`
}
const subjectSep = " > "
func highlight(subject string) string {
return ansi.HighlightRed + subject + ansi.Reset
}
func PrependSubject(subject string, err error) *withSubject {
switch err := err.(type) {
case *withSubject:
return err.Prepend(subject)
case *baseError:
return PrependSubject(subject, err.Err)
default:
return &withSubject{subject, err}
}
}
func (err withSubject) Prepend(subject string) *withSubject {
if subject != "" {
err.Subject = subject + subjectSep + err.Subject
}
return &err
}
func (err *withSubject) Is(other error) bool {
return err.Err == other
}
func (err *withSubject) Unwrap() error {
return err.Err
}
func (err *withSubject) Error() string {
subjects := strings.Split(err.Subject, subjectSep)
subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1])
return strings.Join(subjects, subjectSep) + ": " + err.Err.Error()
}

71
internal/error/utils.go Normal file
View file

@ -0,0 +1,71 @@
package error
import (
"errors"
"fmt"
)
var ErrInvalidErrorJson = errors.New("invalid error json")
func newError(message string) error {
return errStr(message)
}
func New(message string) Error {
if message == "" {
return nil
}
return &baseError{newError(message)}
}
func Errorf(format string, args ...any) Error {
return &baseError{fmt.Errorf(format, args...)}
}
func From(err error) Error {
if err == nil {
return nil
}
if err, ok := err.(Error); ok {
return err
}
return &baseError{err}
}
func Must[T any](v T, err error) T {
if err != nil {
LogPanic("must failed", err)
}
return v
}
func Join(errors ...error) Error {
n := 0
for _, err := range errors {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
errs := make([]error, 0, n)
for _, err := range errors {
if err != nil {
errs = append(errs, err)
}
}
return &nestedError{Extras: errs}
}
func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn Func, arg Arg) T {
result, err := fn(arg)
eb.Add(err)
return result
}
func Collect2[T any, Err error, Arg1 any, Arg2 any, Func func(Arg1, Arg2) (T, Err)](eb *Builder, fn Func, arg1 Arg1, arg2 Arg2) T {
result, err := fn(arg1, arg2)
eb.Add(err)
return result
}

View file

@ -53,7 +53,7 @@ func ListAvailableIcons() ([]string, error) {
icons = append(icons, content.Path)
}
}
err = utils.SaveJSON(iconsCachePath, &icons, 0o644).Error()
err = utils.SaveJSON(iconsCachePath, &icons, 0o644)
if err != nil {
log.Print("error saving cache", err)
}

View file

@ -0,0 +1,69 @@
package logging
import (
"os"
"strings"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
)
var logger zerolog.Logger
func init() {
var timeFmt string
var level zerolog.Level
var exclude []string
if common.IsTrace {
timeFmt = "04:05"
level = zerolog.TraceLevel
} else if common.IsDebug {
timeFmt = "01-02 15:04"
level = zerolog.DebugLevel
} else {
timeFmt = "01-02 15:04"
level = zerolog.InfoLevel
exclude = []string{"module"}
}
prefixLength := len(timeFmt) + 5 // level takes 3 + 2 spaces
prefix := strings.Repeat(" ", prefixLength)
logger = zerolog.New(
zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: timeFmt,
FieldsExclude: exclude,
FormatMessage: func(msgI interface{}) string { // pad spaces for each line
msg := msgI.(string)
lines := strings.Split(msg, "\n")
if len(lines) == 1 {
return msg
}
for i := 1; i < len(lines); i++ {
lines[i] = prefix + lines[i]
}
return strings.Join(lines, "\n")
},
},
).Level(level).With().Timestamp().Logger()
}
func DiscardLogger() { logger = zerolog.Nop() }
func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
func GetLogger() *zerolog.Logger { return &logger }
func With() zerolog.Context { return logger.With() }
func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
func Info() *zerolog.Event { return logger.Info() }
func Warn() *zerolog.Event { return logger.Warn() }
func Error() *zerolog.Event { return logger.Error() }
func Err(err error) *zerolog.Event { return logger.Err(err) }
func Debug() *zerolog.Event { return logger.Debug() }
func Fatal() *zerolog.Event { return logger.Fatal() }
func Panic() *zerolog.Event { return logger.Panic() }
func Trace() *zerolog.Event { return logger.Trace() }

View file

@ -0,0 +1,15 @@
package http
import "net/http"
type DummyResponseWriter struct{}
func (w DummyResponseWriter) Header() http.Header {
return make(http.Header)
}
func (w DummyResponseWriter) Write([]byte) (_ int, _ error) {
return
}
func (w DummyResponseWriter) WriteHeader(int) {}

View file

@ -1,15 +0,0 @@
package loadbalancer
import "net/http"
type DummyResponseWriter struct{}
func (w *DummyResponseWriter) Header() (_ http.Header) {
return
}
func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) {
return
}
func (w *DummyResponseWriter) WriteHeader(int) {}

View file

@ -11,20 +11,22 @@ import (
)
type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware
pool servers
mu sync.Mutex
}
func (lb *LoadBalancer) newIPHash() impl {
impl := new(ipHash)
impl := &ipHash{LoadBalancer: lb}
if len(lb.Options) == 0 {
return impl
}
var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options)
if err != nil {
logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err)
E.LogError("invalid real_ip options, ignoring", err, &impl.Logger)
}
return impl
}
@ -70,7 +72,7 @@ func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err)
impl.Err(err).Msg("invalid remote address " + r.RemoteAddr)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))

View file

@ -31,14 +31,14 @@ func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.R
srv := srvs[0]
minConn, ok := impl.nConn.Load(srv)
if !ok {
logger.Errorf("[BUG] server %s not found", srv.Name)
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i])
if !ok {
logger.Errorf("[BUG] server %s not found", srv.Name)
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
if nConn.Load() < minConn.Load() {

View file

@ -6,9 +6,11 @@ import (
"sync"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
@ -29,6 +31,8 @@ type (
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
}
LoadBalancer struct {
zerolog.Logger
impl
*Config
@ -48,6 +52,7 @@ const maxWeight weightType = 100
func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{
Logger: logger.With().Str("name", cfg.Link).Logger(),
Config: new(Config),
pool: newPool(),
task: task.DummyTask(),
@ -102,7 +107,7 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
if lb.Mode == Unset && cfg.Mode != Unset {
lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() {
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode)
lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
}
lb.updateImpl()
}
@ -131,7 +136,11 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.rebalance()
lb.impl.OnAddServer(srv)
logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
lb.Debug().
Str("action", "add").
Str("server", srv.Name).
Msgf("%d servers available", lb.pool.Size())
}
func (lb *LoadBalancer) RemoveServer(srv *Server) {
@ -148,13 +157,15 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.rebalance()
lb.impl.OnRemoveServer(srv)
lb.Debug().
Str("action", "remove").
Str("server", srv.Name).
Msgf("%d servers left", lb.pool.Size())
if lb.pool.Size() == 0 {
lb.task.Finish("no server left")
logger.Infof("loadbalancer %s stopped", lb.Link)
return
}
logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
}
func (lb *LoadBalancer) rebalance() {
@ -218,7 +229,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
defer cancel()
// send dummy request to wake all servers
var dummyRW *DummyResponseWriter
var dummyRW gphttp.DummyResponseWriter
for _, srv := range srvs {
// wake only if server implements Waker
_, ok := srv.handler.(idlewatcher.Waker)

View file

@ -1,5 +1,5 @@
package loadbalancer
import "github.com/sirupsen/logrus"
import "github.com/yusing/go-proxy/internal/logging"
var logger = logrus.WithField("module", "load_balancer")
var logger = logging.With().Str("module", "load_balancer").Logger()

View file

@ -0,0 +1,5 @@
package http
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "http").Logger()

View file

@ -15,9 +15,9 @@ type cidrWhitelist struct {
}
type cidrWhitelistOpts struct {
Allow []*types.CIDR
StatusCode int
Message string
Allow []*types.CIDR `json:"allow"`
StatusCode int `json:"statusCode"`
Message string `json:"message"`
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
@ -47,7 +47,7 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
return nil, err
}
if len(wl.cidrWhitelistOpts.Allow) == 0 {
return nil, E.Missing("allow range")
return nil, E.New("no allowed CIDRs")
}
return wl.m, nil
}

View file

@ -5,6 +5,7 @@ import (
"net/http"
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -13,10 +14,9 @@ var testCIDRWhitelistCompose []byte
var deny, accept *Middleware
func TestCIDRWhitelist(t *testing.T) {
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
if err != nil {
panic(err)
}
errs := E.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
ExpectNoError(t, errs.Error())
deny = mids["deny@file"]
accept = mids["accept@file"]
if deny == nil || accept == nil {
@ -26,7 +26,7 @@ func TestCIDRWhitelist(t *testing.T) {
t.Run("deny", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
}
@ -35,7 +35,7 @@ func TestCIDRWhitelist(t *testing.T) {
t.Run("accept", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
})

View file

@ -10,10 +10,10 @@ import (
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
const (
@ -26,7 +26,7 @@ const (
var (
cfCIDRsLastUpdate time.Time
cfCIDRsMu sync.Mutex
cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP")
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
)
var CloudflareRealIP = &realIP{
@ -80,13 +80,13 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
)
if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval)
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
}
cfCIDRsLastUpdate = time.Now()
cfCIDRsLogger.Debugf("cloudflare CIDR range updated")
cfCIDRsLogger.Info().Msg("cloudflare CIDR range updated")
return
}

View file

@ -8,24 +8,31 @@ import (
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
)
var CustomErrorPage = &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
modifyResponse: func(resp *Response) error {
var CustomErrorPage *Middleware
func init() {
CustomErrorPage = customErrorPage()
}
func customErrorPage() *Middleware {
m := &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
}
m.modifyResponse = func(resp *Response) error {
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
/* trunk-ignore(golangci-lint/errcheck) */
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
@ -34,12 +41,13 @@ var CustomErrorPage = &Middleware{
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else {
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
},
}
return m
}
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
@ -51,7 +59,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
filename := path[len(gphttp.StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename)
if !ok {
errPageLogger.Errorf("unable to load resource %s", filename)
logger.Error().Msg("unable to load resource " + filename)
return false
}
ext := filepath.Ext(filename)
@ -63,15 +71,13 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
}
if _, err := w.Write(file); err != nil {
errPageLogger.WithError(err).Errorf("unable to write resource %s", filename)
logger.Err(err).Msg("unable to write resource " + filename)
http.Error(w, "Error page failure", http.StatusInternalServerError)
}
return true
}
return false
}
var errPageLogger = logrus.WithField("middleware", "error_page")

View file

@ -1,14 +1,15 @@
package errorpage
import (
"context"
"fmt"
"os"
"path"
"sync"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
@ -23,9 +24,10 @@ var (
)
var setup = sync.OnceFunc(func() {
dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath)
task := task.GlobalTask("error page")
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
loadContent()
go watchDir()
go watchDir(task)
})
func GetStaticFile(filename string) ([]byte, bool) {
@ -44,7 +46,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
Logger.Error(err)
logger.Err(err).Msg("failed to list error page resources")
return
}
for _, file := range files {
@ -53,19 +55,21 @@ func loadContent() {
}
content, err := os.ReadFile(file)
if err != nil {
Logger.Errorf("failed to read error page resource %s: %s", file, err)
logger.Warn().Err(err).Msgf("failed to read error page resource %s", file)
continue
}
file = path.Base(file)
Logger.Infof("error page resource %s loaded", file)
logging.Info().Msgf("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
func watchDir() {
eventCh, errCh := dirWatcher.Events(context.Background())
func watchDir(task task.Task) {
eventCh, errCh := dirWatcher.Events(task.Context())
for {
select {
case <-task.Context().Done():
return
case event, ok := <-eventCh:
if !ok {
return
@ -77,14 +81,14 @@ func watchDir() {
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
Logger.Infof("error page resource %s deleted", filename)
logger.Warn().Msgf("error page resource %s deleted", filename)
case events.ActionFileRenamed:
Logger.Infof("error page resource %s deleted", filename)
logger.Warn().Msgf("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
case err := <-errCh:
Logger.Errorf("error watching error page directory: %s", err)
E.LogError("error watching error page directory", err, &logger)
}
}
}

View file

@ -0,0 +1,5 @@
package errorpage
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "errorpage").Logger()

View file

@ -24,11 +24,12 @@ type (
client http.Client
}
forwardAuthOpts struct {
Address string
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
transport http.RoundTripper
Address string `json:"address"`
TrustForwardHeader bool `json:"trustForwardHeader"`
AuthResponseHeaders []string `json:"authResponseHeaders"`
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
transport http.RoundTripper
}
)
@ -39,13 +40,11 @@ var ForwardAuth = &forwardAuth{
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
fa.forwardAuthOpts = new(forwardAuthOpts)
err := Deserialize(optsRaw, fa.forwardAuthOpts)
if err != nil {
if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil {
return nil, err
}
_, err = E.Check(url.Parse(fa.Address))
if err != nil {
return nil, E.Invalid("address", fa.Address)
if _, err := url.Parse(fa.Address); err != nil {
return nil, E.From(err)
}
fa.m = &Middleware{

View file

@ -0,0 +1,5 @@
package middleware
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "middleware").Logger()

View file

@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
@ -32,6 +33,8 @@ type (
Middleware struct {
_ U.NoCopy
zerolog.Logger
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
@ -78,13 +81,19 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if len(optsRaw) != 0 && m.withOptions != nil {
return m.withOptions(optsRaw)
if m.withOptions != nil {
m, err := m.withOptions(optsRaw)
if err != nil {
return nil, err
}
m.Logger = logger.With().Str("name", m.name).Logger()
return m, nil
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
Logger: logger.With().Str("name", m.name).Logger(),
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
@ -108,24 +117,20 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) {
middlewares = make([]*Middleware, 0, len(middlewaresMap))
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
defer func() {
invalidM.Add(invalidOpts.Build())
invalidM.To(&res)
}()
errs := E.NewBuilder("middlewares compile error")
invalidOpts := E.NewBuilder("options compile error")
for name, opts := range middlewaresMap {
m, ok := Get(name)
if !ok {
invalidM.Add(E.NotExist("middleware", name))
m, err := Get(name)
if err != nil {
errs.Add(err)
continue
}
m, err := m.WithOptionsClone(opts)
m, err = m.WithOptionsClone(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
@ -133,7 +138,10 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid
middlewares = append(middlewares, m)
}
return
if invalidOpts.HasError() {
errs.Add(invalidOpts.Error())
}
return middlewares, errs.Error()
}
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {

View file

@ -4,64 +4,60 @@ import (
"fmt"
"net/http"
"os"
"path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) {
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, E.FailWith("read middleware compose file", err)
eb.Add(err)
return nil
}
return BuildMiddlewaresFromYAML(fileContent)
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
}
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("yaml unmarshal", err))
return
eb.Add(err)
return nil
}
middlewares = make(map[string]*Middleware)
middlewares := make(map[string]*Middleware)
for name, defs := range rawMap {
chainErr := E.NewBuilder("%s", name)
chainErr := E.NewBuilder("")
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"] == "" {
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
chainErr.Addf("item %d: missing field 'use'", i)
continue
}
baseName := def["use"].(string)
base, ok := Get(baseName)
if !ok {
base, ok = middlewares[baseName]
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
continue
}
base, err := Get(baseName)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
if err != nil {
chainErr.Add(err.Subjectf("item%d", i))
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
m.name = fmt.Sprintf("%s[%d]", name, i)
chain = append(chain, m)
}
if chainErr.HasError() {
b.Add(chainErr.Build())
eb.Add(chainErr.Error().Subject(source))
} else {
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
}
}
return
return middlewares
}
// TODO: check conflict or duplicates.
@ -86,11 +82,13 @@ func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware")
errs := E.NewBuilder("modify response errors")
for _, mr := range modResps {
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
if err := mr.modifyResponse(res); err != nil {
errs.Add(E.From(err).Subject(mr.name))
}
}
return b.Build().Error()
return errs.Error()
}
}

View file

@ -13,10 +13,10 @@ import (
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
ExpectNoError(t, err.Error())
_, err = E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
errs := E.NewBuilder("")
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
ExpectNoError(t, errs.Error())
E.Must(json.MarshalIndent(middlewares, "", " "))
// t.Log(string(data))
// TODO: test
}

View file

@ -6,26 +6,37 @@ import (
"path"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
var middlewares map[string]*Middleware
var allMiddlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[U.ToLowerNoSnake(name)]
return
var (
ErrUnknownMiddleware = E.New("unknown middleware")
ErrDuplicatedMiddleware = E.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[U.ToLowerNoSnake(name)]
if !ok {
return nil, ErrUnknownMiddleware.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
}
return middleware, nil
}
func All() map[string]*Middleware {
return middlewares
return allMiddlewares
}
// initialize middleware names and label parsers
func init() {
middlewares = map[string]*Middleware{
allMiddlewares = map[string]*Middleware{
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
@ -39,10 +50,10 @@ func init() {
// !experimental
"forwardauth": ForwardAuth.m,
"oauth2": OAuth2.m,
// "oauth2": OAuth2.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
for name, m := range allMiddlewares {
names[m] = append(names[m], http.CanonicalHeaderKey(name))
}
for m, names := range names {
@ -55,27 +66,30 @@ func init() {
}
func LoadComposeFiles() {
b := E.NewBuilder("failed to load middlewares")
errs := E.NewBuilder("middleware compile errors")
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
logrus.Errorf("failed to list middleware definitions: %s", err)
logger.Err(err).Msg("failed to list middleware definitions")
return
}
for _, defFile := range middlewareDefs {
mws, err := BuildMiddlewaresFromComposeFile(defFile)
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
if len(mws) == 0 {
continue
}
for name, m := range mws {
if _, ok := middlewares[name]; ok {
b.Add(E.Duplicated("middleware", name))
if _, ok := allMiddlewares[name]; ok {
errs.Add(ErrDuplicatedMiddleware.Subject(name))
continue
}
middlewares[U.ToLowerNoSnake(name)] = m
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
allMiddlewares[U.ToLowerNoSnake(name)] = m
logger.Info().
Str("name", name).
Str("src", path.Base(defFile)).
Msg("middleware loaded")
}
b.Add(err.Subject(path.Base(defFile)))
}
if b.HasError() {
logger.Error(b.Build())
if errs.HasError() {
E.LogError(errs.About(), errs.Error(), &logger)
}
}
var logger = logrus.WithField("module", "middlewares")

View file

@ -12,9 +12,9 @@ type (
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
SetHeaders map[string]string `json:"setHeaders"`
AddHeaders map[string]string `json:"addHeaders"`
HideHeaders []string `json:"hideHeaders"`
}
)

View file

@ -16,7 +16,7 @@ func TestSetModifyRequest(t *testing.T) {
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
@ -26,7 +26,7 @@ func TestSetModifyRequest(t *testing.T) {
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")

View file

@ -13,11 +13,7 @@ type (
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
modifyResponseOpts = modifyRequestOpts
)
var ModifyResponse = &modifyResponse{

View file

@ -16,7 +16,7 @@ func TestSetModifyResponse(t *testing.T) {
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
@ -26,7 +26,7 @@ func TestSetModifyResponse(t *testing.T) {
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))

View file

@ -1,129 +1,117 @@
package middleware
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"reflect"
// import (
// "encoding/json"
// "fmt"
// "net/http"
// "net/url"
E "github.com/yusing/go-proxy/internal/error"
)
// E "github.com/yusing/go-proxy/internal/error"
// )
type oAuth2 struct {
*oAuth2Opts
m *Middleware
}
// type oAuth2 struct {
// oAuth2Opts
// m *Middleware
// }
type oAuth2Opts struct {
ClientID string
ClientSecret string
AuthURL string // Authorization Endpoint
TokenURL string // Token Endpoint
}
// type oAuth2Opts struct {
// ClientID string `validate:"required"`
// ClientSecret string `validate:"required"`
// AuthURL string `validate:"required"` // Authorization Endpoint
// TokenURL string `validate:"required"` // Token Endpoint
// }
var OAuth2 = &oAuth2{
m: &Middleware{withOptions: NewAuthentikOAuth2},
}
// var OAuth2 = &oAuth2{
// m: &Middleware{withOptions: NewAuthentikOAuth2},
// }
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
oauth := new(oAuth2)
oauth.m = &Middleware{
impl: oauth,
before: oauth.handleOAuth2,
}
oauth.oAuth2Opts = &oAuth2Opts{}
err := Deserialize(opts, oauth.oAuth2Opts)
if err != nil {
return nil, err
}
b := E.NewBuilder("missing required fields")
optV := reflect.ValueOf(oauth.oAuth2Opts)
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
if optV.FieldByName(field.Name).Len() == 0 {
b.Add(E.Missing(field.Name))
}
}
if b.HasError() {
return nil, b.Build().Subject("oAuth2")
}
return oauth.m, nil
}
// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
// oauth := new(oAuth2)
// oauth.m = &Middleware{
// impl: oauth,
// before: oauth.handleOAuth2,
// }
// err := Deserialize(opts, &oauth.oAuth2Opts)
// if err != nil {
// return nil, err
// }
// return oauth.m, nil
// }
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// Check if the user is authenticated (you may use session, cookie, etc.)
if !userIsAuthenticated(r) {
// TODO: Redirect to OAuth2 auth URL
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
return
}
// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// // Check if the user is authenticated (you may use session, cookie, etc.)
// if !userIsAuthenticated(r) {
// // TODO: Redirect to OAuth2 auth URL
// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
// return
// }
// If you have a token in the query string, process it
if code := r.URL.Query().Get("code"); code != "" {
// Exchange the authorization code for a token here
// Use the TokenURL and authenticate the user
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI)
if err != nil {
// handle error
http.Error(rw, "failed to get token", http.StatusUnauthorized)
return
}
// // If you have a token in the query string, process it
// if code := r.URL.Query().Get("code"); code != "" {
// // Exchange the authorization code for a token here
// // Use the TokenURL and authenticate the user
// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
// if err != nil {
// // handle error
// http.Error(rw, "failed to get token", http.StatusUnauthorized)
// return
// }
// Save token and user info based on your requirements
saveToken(rw, token)
// // Save token and user info based on your requirements
// saveToken(rw, token)
// Redirect to the originally requested URL
http.Redirect(rw, r, "/", http.StatusFound)
return
}
// // Redirect to the originally requested URL
// http.Redirect(rw, r, "/", http.StatusFound)
// return
// }
// If user is authenticated, go to the next handler
next(rw, r)
}
// // If user is authenticated, go to the next handler
// next(rw, r)
// }
func userIsAuthenticated(r *http.Request) bool {
// Example: Check for a session or cookie
session, err := r.Cookie("session_token")
if err != nil || session.Value == "" {
return false
}
// Validate the session_token if necessary
return true
}
// func userIsAuthenticated(r *http.Request) bool {
// // Example: Check for a session or cookie
// session, err := r.Cookie("session_token")
// if err != nil || session.Value == "" {
// return false
// }
// // Validate the session_token if necessary
// return true
// }
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// Prepare the request body
data := url.Values{
"client_id": {opts.ClientID},
"client_secret": {opts.ClientSecret},
"code": {code},
"grant_type": {"authorization_code"},
"redirect_uri": {requestURI},
}
resp, err := http.PostForm(opts.TokenURL, data)
if err != nil {
return "", fmt.Errorf("failed to request token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
}
// Decode the response
var tokenResp struct {
AccessToken string `json:"access_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
return tokenResp.AccessToken, nil
}
// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// // Prepare the request body
// data := url.Values{
// "client_id": {opts.ClientID},
// "client_secret": {opts.ClientSecret},
// "code": {code},
// "grant_type": {"authorization_code"},
// "redirect_uri": {requestURI},
// }
// resp, err := http.PostForm(opts.TokenURL, data)
// if err != nil {
// return "", fmt.Errorf("failed to request token: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
// }
// // Decode the response
// var tokenResp struct {
// AccessToken string `json:"access_token"`
// }
// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
// return "", fmt.Errorf("failed to decode token response: %w", err)
// }
// return tokenResp.AccessToken, nil
// }
func saveToken(rw ResponseWriter, token string) {
// Example: Save token in cookie
http.SetCookie(rw, &http.Cookie{
Name: "auth_token",
Value: token,
// set other properties as necessary, such as Secure and HttpOnly
})
}
// func saveToken(rw ResponseWriter, token string) {
// // Example: Save token in cookie
// http.SetCookie(rw, &http.Cookie{
// Name: "auth_token",
// Value: token,
// // set other properties as necessary, such as Secure and HttpOnly
// })
// }

View file

@ -16,9 +16,9 @@ type realIP struct {
type realIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string
Header string `json:"header"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR
From []*types.CIDR `json:"from"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
@ -27,7 +27,7 @@ type realIPOpts struct {
the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field.
*/
Recursive bool
Recursive bool `json:"recursive"`
}
var RealIP = &realIP{

View file

@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
}
ri, err := NewRealIP(opts)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
for i, CIDR := range ri.impl.(*realIP).From {
@ -61,15 +61,15 @@ func TestSetRealIP(t *testing.T) {
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := NewRealIP(opts)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
mr, err := NewModifyRequest(optsMr)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)

View file

@ -12,7 +12,7 @@ func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http",
})
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
}
@ -21,6 +21,6 @@ func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https",
})
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}

View file

@ -6,7 +6,7 @@ import (
"time"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Trace struct {
@ -88,7 +88,7 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
return nil
}
return addTrace(&Trace{
Time: U.FormatTime(time.Now()),
Time: strutils.FormatTime(time.Now()),
Caller: m.Fullname(),
Message: fmt.Sprintf(msg, args...),
})

View file

@ -57,8 +57,7 @@ func (w *ModifyResponseWriter) WriteHeader(code int) {
}
if err := w.modifier(&resp); err != nil {
w.modifierErr = err
logger.Errorf("error modifying response: %s", err)
w.modifierErr = fmt.Errorf("response modifier error: %w", err)
w.w.WriteHeader(http.StatusInternalServerError)
return
}

View file

@ -23,7 +23,7 @@ import (
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts"
@ -69,6 +69,8 @@ type ProxyRequest struct {
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
zerolog.Logger
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
@ -149,7 +151,12 @@ func NewReverseProxy(name string, target types.URL, transport http.RoundTripper)
if transport == nil {
panic("nil transport")
}
rp := &ReverseProxy{Transport: transport, TargetName: name, TargetURL: target}
rp := &ReverseProxy{
Logger: logger.With().Str("name", name).Logger(),
Transport: transport,
TargetName: name,
TargetURL: target,
}
rp.ServeHTTP = rp.serveHTTP
return rp
}
@ -195,9 +202,9 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
switch {
case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF):
logger.Debugf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err)
logger.Debug().Err(err).Str("url", r.URL.String()).Msg("http proxy error")
default:
logger.Errorf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err)
logger.Err(err).Str("url", r.URL.String()).Msg("http proxy error")
}
if writeHeader {
rw.WriteHeader(http.StatusBadGateway)
@ -219,6 +226,10 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
}
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
if _, ok := rw.(DummyResponseWriter); ok {
return
}
transport := p.Transport
ctx := req.Context()
@ -453,6 +464,7 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
resUpType := UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
return
}
if !strings.EqualFold(reqUpType, resUpType) {
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true)
@ -518,5 +530,3 @@ func IsPrint(s string) bool {
}
return true
}
var logger = logrus.WithField("module", "http")

View file

@ -9,10 +9,15 @@ import (
type CIDR net.IPNet
var (
ErrInvalidCIDR = E.New("invalid CIDR")
ErrInvalidCIDRType = E.New("invalid CIDR type")
)
func (cidr *CIDR) ConvertFrom(val any) E.Error {
cidrStr, ok := val.(string)
if !ok {
return E.TypeMismatch[string](val)
return ErrInvalidCIDRType.Subjectf("%T", val)
}
if !strings.Contains(cidrStr, "/") {
@ -20,7 +25,7 @@ func (cidr *CIDR) ConvertFrom(val any) E.Error {
}
_, ipnet, err := net.ParseCIDR(cidrStr)
if err != nil {
return E.Invalid("CIDR", cidr)
return ErrInvalidCIDR.Subject(cidrStr)
}
*cidr = CIDR(*ipnet)
return nil

View file

@ -5,9 +5,28 @@ import (
"net"
)
type Stream interface {
fmt.Stringer
net.Listener
Setup() error
Handle(conn net.Conn) error
type (
Stream interface {
fmt.Stringer
StreamListener
Setup() error
Handle(conn StreamConn) error
}
StreamListener interface {
Addr() net.Addr
Accept() (StreamConn, error)
Close() error
}
StreamConn any
NetListenerWrapper struct {
net.Listener
}
)
func NetListener(l net.Listener) StreamListener {
return NetListenerWrapper{Listener: l}
}
func (l NetListenerWrapper) Accept() (StreamConn, error) {
return l.Listener.Accept()
}

View file

@ -1,22 +1,32 @@
package notif
import (
"github.com/sirupsen/logrus"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
Dispatcher struct {
task task.Task
logCh chan *logrus.Entry
logCh chan *LogMessage
providers F.Set[Provider]
}
LogMessage struct {
Level zerolog.Level
Title, Message string
}
)
var dispatcher *Dispatcher
var ErrUnknownNotifProvider = E.New("unknown notification provider")
const dispatchErr = "notification dispatch error"
func init() {
dispatcher = newNotifDispatcher()
go dispatcher.start()
@ -25,7 +35,7 @@ func init() {
func newNotifDispatcher() *Dispatcher {
return &Dispatcher{
task: task.GlobalTask("notif dispatcher"),
logCh: make(chan *logrus.Entry),
logCh: make(chan *LogMessage),
providers: F.NewSet[Provider](),
}
}
@ -34,11 +44,13 @@ func GetDispatcher() *Dispatcher {
return dispatcher
}
func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E.Error) {
func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, error) {
name := configSubTask.Name()
createFunc, ok := Providers[name]
if !ok {
return nil, E.NotExist("provider", name)
return nil, ErrUnknownNotifProvider.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, Providers)))
}
if provider, err := createFunc(cfg); err != nil {
return nil, err
@ -53,7 +65,6 @@ func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E.
func (disp *Dispatcher) start() {
defer dispatcher.task.Finish("dispatcher stopped")
defer close(dispatcher.logCh)
for {
select {
@ -65,36 +76,39 @@ func (disp *Dispatcher) start() {
}
}
func (disp *Dispatcher) dispatch(entry *logrus.Entry) {
func (disp *Dispatcher) dispatch(msg *LogMessage) {
task := disp.task.Subtask("dispatch notif")
defer task.Finish("notifs dispatched")
defer task.Finish("notif dispatched")
errs := E.NewBuilder("errors sending notif")
errs := E.NewBuilder(dispatchErr)
disp.providers.RangeAllParallel(func(p Provider) {
if err := p.Send(task.Context(), entry); err != nil {
errs.Addf("%s: %s", p.Name(), err)
if err := p.Send(task.Context(), msg); err != nil {
errs.Add(E.PrependSubject(p.Name(), err))
}
})
if err := errs.Build(); err != nil {
logrus.Error("notif dispatcher failure: ", err)
if errs.HasError() {
E.LogError(errs.About(), errs.Error())
}
}
// Levels implements logrus.Hook.
func (disp *Dispatcher) Levels() []logrus.Level {
return []logrus.Level{
logrus.WarnLevel,
logrus.ErrorLevel,
logrus.FatalLevel,
logrus.PanicLevel,
}
}
// Run implements zerolog.Hook.
// func (disp *Dispatcher) Run(e *zerolog.Event, level zerolog.Level, message string) {
// if strings.HasPrefix(message, dispatchErr) { // prevent recursion
// return
// }
// switch level {
// case zerolog.WarnLevel, zerolog.ErrorLevel, zerolog.FatalLevel, zerolog.PanicLevel:
// disp.logCh <- &LogMessage{
// Level: level,
// Message: message,
// }
// }
// }
// Fire implements logrus.Hook.
func (disp *Dispatcher) Fire(entry *logrus.Entry) error {
if disp.providers.Size() == 0 {
return nil
func Notify(title, msg string) {
dispatcher.logCh <- &LogMessage{
Level: zerolog.InfoLevel,
Title: title,
Message: msg,
}
disp.logCh <- entry
return nil
}

View file

@ -9,7 +9,7 @@ import (
"net/url"
"github.com/gotify/server/v2/model"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils"
)
@ -39,7 +39,7 @@ func newGotifyClient(cfg map[string]any) (Provider, E.Error) {
url, uErr := url.Parse(client.URL)
if uErr != nil {
return nil, E.FailWith("parse url", uErr)
return nil, E.Errorf("invalid gotify URL %s", client.URL)
}
client.url = url
@ -52,30 +52,23 @@ func (client *GotifyClient) Name() string {
}
// Send implements NotifProvider.
func (client *GotifyClient) Send(ctx context.Context, entry *logrus.Entry) error {
func (client *GotifyClient) Send(ctx context.Context, logMsg *LogMessage) error {
var priority int
var title string
switch entry.Level {
case logrus.WarnLevel:
switch logMsg.Level {
case zerolog.WarnLevel:
priority = 2
title = "Warning"
case logrus.ErrorLevel:
case zerolog.ErrorLevel:
priority = 5
title = "Error"
case logrus.FatalLevel, logrus.PanicLevel:
case zerolog.FatalLevel, zerolog.PanicLevel:
priority = 8
title = "Critical"
default:
return nil
}
if subjects := FieldsAsTitle(entry); subjects != "" {
title = subjects + " " + title
}
msg := &GotifyMessage{
Title: title,
Message: entry.Message,
Title: logMsg.Title,
Message: logMsg.Message,
Priority: priority,
}

View file

@ -1,21 +0,0 @@
package notif
import (
"fmt"
"strings"
"github.com/sirupsen/logrus"
U "github.com/yusing/go-proxy/internal/utils"
)
func FieldsAsTitle(entry *logrus.Entry) string {
if len(entry.Data) == 0 {
return ""
}
var parts []string
for k, v := range entry.Data {
parts = append(parts, fmt.Sprintf("%s: %s", k, v))
}
parts[0] = U.Title(parts[0])
return strings.Join(parts, ", ")
}

View file

@ -3,14 +3,13 @@ package notif
import (
"context"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
)
type (
Provider interface {
Name() string
Send(ctx context.Context, entry *logrus.Entry) error
Send(ctx context.Context, logMsg *LogMessage) error
}
ProviderCreateFunc func(map[string]any) (Provider, E.Error)
ProviderConfig map[string]any

View file

@ -23,18 +23,18 @@ func ValidateEntry(m *RawEntry) (Entry, E.Error) {
scheme, err := T.NewScheme(m.Scheme)
if err != nil {
return nil, err
return nil, E.From(err)
}
var entry Entry
e := E.NewBuilder("error validating entry")
errs := E.NewBuilder("entry validation failed")
if scheme.IsStream() {
entry = validateStreamEntry(m, e)
entry = validateStreamEntry(m, errs)
} else {
entry = validateRPEntry(m, scheme, e)
entry = validateRPEntry(m, scheme, errs)
}
if err := e.Build(); err != nil {
return nil, err
if errs.HasError() {
return nil, errs.Error()
}
return entry, nil
}

View file

@ -5,13 +5,14 @@ import (
"strings"
"github.com/docker/docker/api/types"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/watcher/health"
)
@ -85,20 +86,20 @@ func (e *RawEntry) FillMissingFields() {
} else if !isDocker {
pp = "80"
} else {
logrus.Debugf("no port found for %s", e.Alias)
logging.Debug().Msg("no port found for " + e.Alias)
}
}
// replace private port with public port if using public IP.
if e.Host == cont.PublicIP {
if p, ok := cont.PrivatePortMapping[pp]; ok {
pp = U.PortString(p.PublicPort)
pp = strutils.PortString(p.PublicPort)
}
}
// replace public port with private port if using private IP.
if e.Host == cont.PrivateIP {
if p, ok := cont.PublicPortMapping[pp]; ok {
pp = U.PortString(p.PrivatePort)
pp = strutils.PortString(p.PrivatePort)
}
}

View file

@ -53,7 +53,7 @@ func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config {
return rp.Idlewatcher
}
func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry {
func validateRPEntry(m *RawEntry, s fields.Scheme, errs *E.Builder) *ReverseProxyEntry {
cont := m.Container
if cont == nil {
cont = docker.DummyContainer
@ -64,35 +64,26 @@ func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEnt
lb = nil
}
host, err := fields.ValidateHost(m.Host)
b.Add(err)
host := E.Collect(errs, fields.ValidateHost, m.Host)
port := E.Collect(errs, fields.ValidatePort, m.Port)
pathPats := E.Collect(errs, fields.ValidatePathPatterns, m.PathPatterns)
url := E.Collect(errs, url.Parse, fmt.Sprintf("%s://%s:%d", s, host, port))
iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container)
port, err := fields.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
b.Add(err)
if err != nil {
if errs.HasError() {
return nil
}
return &ReverseProxyEntry{
Raw: m,
Alias: fields.NewAlias(m.Alias),
Alias: fields.Alias(m.Alias),
Scheme: s,
URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
PathPatterns: pathPats,
HealthCheck: m.HealthCheck,
LoadBalance: lb,
Middlewares: m.Middlewares,
Idlewatcher: idleWatcherCfg,
Idlewatcher: iwCfg,
}
}

View file

@ -51,34 +51,25 @@ func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config {
return s.Idlewatcher
}
func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry {
func validateStreamEntry(m *RawEntry, errs *E.Builder) *StreamEntry {
cont := m.Container
if cont == nil {
cont = docker.DummyContainer
}
host, err := fields.ValidateHost(m.Host)
b.Add(err)
host := E.Collect(errs, fields.ValidateHost, m.Host)
port := E.Collect(errs, fields.ValidateStreamPort, m.Port)
scheme := E.Collect(errs, fields.ValidateStreamScheme, m.Scheme)
url := E.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", scheme.ListeningScheme, host, port.ListeningPort))
idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container)
port, err := fields.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := fields.ValidateStreamScheme(m.Scheme)
b.Add(err)
url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort)))
b.Add(err)
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
b.Add(err)
if b.HasError() {
if errs.HasError() {
return nil
}
return &StreamEntry{
Raw: m,
Alias: fields.NewAlias(m.Alias),
Alias: fields.Alias(m.Alias),
Scheme: *scheme,
URL: url,
Host: host,

View file

@ -1,6 +1,3 @@
package fields
type (
Alias string
NewAlias = Alias
)
type Alias string

View file

@ -1,14 +1,10 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type (
Host string
Subdomain = Alias
)
func ValidateHost[String ~string](s String) (Host, E.Error) {
func ValidateHost[String ~string](s String) (Host, error) {
return Host(s), nil
}

View file

@ -1,6 +1,8 @@
package fields
import (
"errors"
"fmt"
"regexp"
E "github.com/yusing/go-proxy/internal/error"
@ -13,12 +15,17 @@ type (
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)
func ValidatePathPattern(s string) (PathPattern, E.Error) {
var (
ErrEmptyPathPattern = errors.New("path must not be empty")
ErrInvalidPathPattern = errors.New("invalid path pattern")
)
func ValidatePathPattern(s string) (PathPattern, error) {
if len(s) == 0 {
return "", E.Invalid("path", "must not be empty")
return "", ErrEmptyPathPattern
}
if !pathPattern.MatchString(s) {
return "", E.Invalid("path pattern", s)
return "", fmt.Errorf("%w %q", ErrInvalidPathPattern, s)
}
return PathPattern(s), nil
}
@ -27,13 +34,15 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.Error) {
if len(s) == 0 {
return []PathPattern{"/"}, nil
}
errs := E.NewBuilder("invalid path patterns")
pp := make(PathPatterns, len(s))
for i, v := range s {
pattern, err := ValidatePathPattern(v)
if err != nil {
return nil, err
errs.Add(err)
} else {
pp[i] = pattern
}
pp[i] = pattern
}
return pp, nil
return pp, errs.Error()
}

View file

@ -1,9 +1,9 @@
package fields
import (
"errors"
"testing"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -38,10 +38,10 @@ var invalidPatterns = []string{
func TestPathPatternRegex(t *testing.T) {
for _, pattern := range validPatterns {
_, err := ValidatePathPattern(pattern)
U.ExpectNoError(t, err.Error())
U.ExpectNoError(t, err)
}
for _, pattern := range invalidPatterns {
_, err := ValidatePathPattern(pattern)
U.ExpectError2(t, pattern, E.ErrInvalid, err.Error())
U.ExpectTrue(t, errors.Is(err, ErrInvalidPathPattern))
}
}

View file

@ -4,22 +4,25 @@ import (
"strconv"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Port int
func ValidatePort[String ~string](v String) (Port, E.Error) {
p, err := strconv.Atoi(string(v))
var ErrPortOutOfRange = E.New("port out of range")
func ValidatePort[String ~string](v String) (Port, error) {
p, err := strutils.Atoi(string(v))
if err != nil {
return ErrPort, E.Invalid("port number", v).With(err)
return ErrPort, err
}
return ValidatePortInt(p)
}
func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) {
func ValidatePortInt[Int int | uint16](v Int) (Port, error) {
p := Port(v)
if !p.inBound() {
return ErrPort, E.OutOfRange("port", p)
return ErrPort, ErrPortOutOfRange.Subject(strconv.Itoa(int(p)))
}
return p, nil
}

View file

@ -6,12 +6,14 @@ import (
type Scheme string
func NewScheme[String ~string](s String) (Scheme, E.Error) {
var ErrInvalidScheme = E.New("invalid scheme")
func NewScheme(s string) (Scheme, error) {
switch s {
case "http", "https", "tcp", "udp":
return Scheme(s), nil
}
return "", E.Invalid("scheme", s)
return "", ErrInvalidScheme.Subject(s)
}
func (s Scheme) IsHTTP() bool { return s == "http" }

View file

@ -3,7 +3,6 @@ package fields
import (
"strings"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
@ -12,7 +11,9 @@ type StreamPort struct {
ProxyPort Port `json:"proxy"`
}
func ValidateStreamPort(p string) (_ StreamPort, err E.Error) {
var ErrStreamPortTooManyColons = E.New("too many colons")
func ValidateStreamPort(p string) (StreamPort, error) {
split := strings.Split(p, ":")
switch len(split) {
@ -21,36 +22,14 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.Error) {
case 2:
break
default:
err = E.Invalid("stream port", p).With("too many colons")
return
return StreamPort{}, ErrStreamPortTooManyColons.Subject(p)
}
listeningPort, err := ValidatePort(split[0])
if err != nil {
err = err.Subject("listening port")
return
}
proxyPort, err := ValidatePort(split[1])
if err.Is(E.ErrOutOfRange) {
err = err.Subject("proxy port")
return
} else if err != nil {
proxyPort, err = parseNameToPort(split[1])
if err != nil {
err = E.Invalid("proxy port", proxyPort)
return
}
listeningPort, lErr := ValidatePort(split[0])
proxyPort, pErr := ValidatePort(split[1])
if err := E.Join(lErr, pErr); err != nil {
return StreamPort{}, err
}
return StreamPort{listeningPort, proxyPort}, nil
}
func parseNameToPort(name string) (Port, E.Error) {
port, ok := common.ServiceNamePortMapTCP[name]
if !ok {
return ErrPort, E.Invalid("service", name)
}
return Port(port), nil
}

View file

@ -1,9 +1,9 @@
package fields
import (
"strconv"
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -11,7 +11,6 @@ var validPorts = []string{
"1234:5678",
"0:2345",
"2345",
"1234:postgres",
}
var invalidPorts = []string{
@ -19,7 +18,6 @@ var invalidPorts = []string{
"123:",
"0:",
":1234",
"1234:1234:1234",
"qwerty",
"asdfgh:asdfgh",
"1234:asdfgh",
@ -32,17 +30,25 @@ var outOfRangePorts = []string{
"0:65536",
}
var tooManyColonsPorts = []string{
"1234:1234:1234",
}
func TestStreamPort(t *testing.T) {
for _, port := range validPorts {
_, err := ValidateStreamPort(port)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
}
for _, port := range invalidPorts {
_, err := ValidateStreamPort(port)
ExpectError2(t, port, E.ErrInvalid, err.Error())
ExpectError2(t, port, strconv.ErrSyntax, err)
}
for _, port := range outOfRangePorts {
_, err := ValidateStreamPort(port)
ExpectError2(t, port, E.ErrOutOfRange, err.Error())
ExpectError2(t, port, ErrPortOutOfRange, err)
}
for _, port := range tooManyColonsPorts {
_, err := ValidateStreamPort(port)
ExpectError2(t, port, ErrStreamPortTooManyColons, err)
}
}

View file

@ -12,22 +12,23 @@ type StreamScheme struct {
ProxyScheme Scheme `json:"proxy"`
}
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) {
ss = &StreamScheme{}
func ValidateStreamScheme(s string) (*StreamScheme, error) {
ss := &StreamScheme{}
parts := strings.Split(s, ":")
if len(parts) == 1 {
parts = []string{s, s}
} else if len(parts) != 2 {
return nil, E.Invalid("stream scheme", s)
return nil, ErrInvalidScheme.Subject(s)
}
ss.ListeningScheme, err = NewScheme(parts[0])
if err.HasError() {
return nil, err
}
ss.ProxyScheme, err = NewScheme(parts[1])
if err.HasError() {
var lErr, pErr error
ss.ListeningScheme, lErr = NewScheme(parts[0])
ss.ProxyScheme, pErr = NewScheme(parts[1])
if err := E.Join(lErr, pErr); err != nil {
return nil, err
}
return ss, nil
}

View file

@ -0,0 +1,37 @@
package fields
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var (
validStreamSchemes = []string{
"tcp:tcp",
"tcp:udp",
"udp:tcp",
"udp:udp",
"tcp",
"udp",
}
invalidStreamSchemes = []string{
"tcp:tcp:",
"tcp:",
":udp:",
":udp",
"top",
}
)
func TestNewStreamScheme(t *testing.T) {
for _, s := range validStreamSchemes {
_, err := ValidateStreamScheme(s)
ExpectNoError(t, err)
}
for _, s := range invalidStreamSchemes {
_, err := ValidateStreamScheme(s)
ExpectError(t, ErrInvalidScheme, err)
}
}

View file

@ -7,13 +7,13 @@ import (
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
"github.com/yusing/go-proxy/internal/proxy/entry"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/task"
@ -33,6 +33,8 @@ type (
rp *gphttp.ReverseProxy
task task.Task
l zerolog.Logger
}
SubdomainKey = PT.Alias
@ -88,6 +90,10 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
ReverseProxyEntry: entry,
rp: rp,
task: task.DummyTask(),
l: logger.With().
Str("type", string(entry.Scheme)).
Str("name", string(entry.Alias)).
Logger(),
}
return r, nil
}
@ -107,11 +113,11 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
defer httpRoutesMu.Unlock()
if !entry.UseHealthCheck(r) && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
r.l.Error().Msg("healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled")
if r.HealthCheck == nil {
r.HealthCheck = new(health.HealthCheckConfig)
}
r.HealthCheck.Disable = true
r.HealthCheck.Disable = false
}
switch {
@ -143,7 +149,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
if r.HealthMon != nil {
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
logrus.Warn(E.FailWith("health monitor", err))
E.LogWarn("health monitor error", err, &r.l)
}
}
@ -209,15 +215,13 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) {
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if err != nil {
if !middleware.ServeStaticErrorPageFile(w, r) {
logrus.Error(E.Failure("request").
Subjectf("%s %s", r.Method, r.URL.String()).
With(err))
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok {
w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if _, err := w.Write(errorPage); err != nil {
logrus.Errorf("failed to respond error page to %s: %s", r.RemoteAddr, err)
logger.Err(err).Msg("failed to write error page")
}
} else {
http.Error(w, err.Error(), http.StatusNotFound)

Some files were not shown because too many files have changed in this diff Show more