mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
migrated from logrus to zerolog, improved error formatting, fixed concurrent map write, fixed crash on rapid page refresh for idle containers, fixed infinite recursion on gotfiy error, fixed websocket connection problem when using idlewatcher
This commit is contained in:
parent
cfa74d69ae
commit
e5bbb18414
137 changed files with 2640 additions and 2348 deletions
3
Makefile
3
Makefile
|
@ -65,3 +65,6 @@ debug-list-containers:
|
|||
ci-test:
|
||||
mkdir -p /tmp/artifacts
|
||||
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
|
||||
|
||||
cloc:
|
||||
cloc --not-match-f '_test.go$$' cmd internal pkg
|
58
cmd/main.go
58
cmd/main.go
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -10,15 +9,14 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal"
|
||||
"github.com/yusing/go-proxy/internal/api"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/query"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/server"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -33,44 +31,26 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
l := logrus.WithField("module", "main")
|
||||
timeFmt := "01-02 15:04:05"
|
||||
fullTS := true
|
||||
|
||||
if common.IsTrace {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
timeFmt = "04:05"
|
||||
fullTS = false
|
||||
} else if common.IsDebug {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
|
||||
if args.Command != common.CommandStart {
|
||||
logrus.SetOutput(io.Discard)
|
||||
} else {
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
DisableSorting: true,
|
||||
FullTimestamp: fullTS,
|
||||
ForceColors: true,
|
||||
TimestampFormat: timeFmt,
|
||||
})
|
||||
logrus.Infof("go-proxy version %s", pkg.GetVersion())
|
||||
logrus.AddHook(notif.GetDispatcher())
|
||||
}
|
||||
|
||||
if args.Command == common.CommandReload {
|
||||
if err := query.ReloadServer(); err != nil {
|
||||
log.Fatal(err)
|
||||
E.LogFatal("server reload error", err)
|
||||
}
|
||||
log.Print("ok")
|
||||
logging.Info().Msg("ok")
|
||||
return
|
||||
}
|
||||
|
||||
// exit if only validate config
|
||||
if args.Command == common.CommandStart {
|
||||
logging.Info().Msgf("go-proxy version %s", pkg.GetVersion())
|
||||
logging.Trace().Msg("trace enabled")
|
||||
// logging.AddHook(notif.GetDispatcher())
|
||||
} else {
|
||||
logging.DiscardLogger()
|
||||
}
|
||||
|
||||
if args.Command == common.CommandValidate {
|
||||
data, err := os.ReadFile(common.ConfigPath)
|
||||
if err == nil {
|
||||
err = config.Validate(data).Error()
|
||||
err = config.Validate(data)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal("config error: ", err)
|
||||
|
@ -88,7 +68,7 @@ func main() {
|
|||
var cfg *config.Config
|
||||
var err E.Error
|
||||
if cfg, err = config.Load(); err != nil {
|
||||
logrus.Warn(err)
|
||||
E.LogWarn("errors in config", err)
|
||||
}
|
||||
|
||||
switch args.Command {
|
||||
|
@ -145,10 +125,10 @@ func main() {
|
|||
autocert := config.GetAutoCertProvider()
|
||||
if autocert != nil {
|
||||
if err := autocert.Setup(); err != nil {
|
||||
l.Fatal(err)
|
||||
E.LogFatal("autocert setup error", err)
|
||||
}
|
||||
} else {
|
||||
l.Info("autocert not configured")
|
||||
logging.Info().Msg("autocert not configured")
|
||||
}
|
||||
|
||||
proxyServer := server.InitProxyServer(server.Options{
|
||||
|
@ -174,7 +154,7 @@ func main() {
|
|||
<-sig
|
||||
|
||||
// grafully shutdown
|
||||
logrus.Info("shutting down")
|
||||
logging.Info().Msg("shutting down")
|
||||
task.CancelGlobalContext()
|
||||
task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
|
||||
}
|
||||
|
@ -182,15 +162,15 @@ func main() {
|
|||
func prepareDirectory(dir string) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||
logrus.Fatalf("failed to create directory %s: %v", dir, err)
|
||||
logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func printJSON(obj any) {
|
||||
j, err := E.Check(json.MarshalIndent(obj, "", " "))
|
||||
j, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
logging.Fatal().Err(err).Send()
|
||||
}
|
||||
rawLogger := log.New(os.Stdout, "", 0)
|
||||
rawLogger.Printf("%s", j) // raw output for convenience using "jq"
|
||||
|
|
7
go.mod
7
go.mod
|
@ -10,8 +10,8 @@ require (
|
|||
github.com/go-acme/lego/v4 v4.19.2
|
||||
github.com/gotify/server/v2 v2.5.0
|
||||
github.com/puzpuzpuz/xsync/v3 v3.4.0
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
golang.org/x/net v0.30.0
|
||||
golang.org/x/text v0.19.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
|
@ -20,7 +20,7 @@ require (
|
|||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cloudflare/cloudflare-go v0.107.0 // indirect
|
||||
github.com/cloudflare/cloudflare-go v0.108.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
|
@ -33,6 +33,8 @@ require (
|
|||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/miekg/dns v1.1.62 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
|
@ -42,6 +44,7 @@ require (
|
|||
github.com/ovh/go-ovh v1.6.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/rogpeppe/go-internal v1.13.1 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
|
||||
go.opentelemetry.io/otel v1.31.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect
|
||||
|
|
18
go.sum
18
go.sum
|
@ -4,12 +4,13 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
|
|||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cloudflare/cloudflare-go v0.107.0 h1:cMDIw2tzt6TXCJyMFVyP+BPOVkIfMvcKjhMNSNvuEPc=
|
||||
github.com/cloudflare/cloudflare-go v0.107.0/go.mod h1:5cYGzVBqNTLxMYSLdVjuSs5LJL517wJDSvMPWUrzHzc=
|
||||
github.com/cloudflare/cloudflare-go v0.108.0 h1:C4Skfjd8I8X3uEOGmQUT4/iGyZcWdkIU7HwvMoLkEE0=
|
||||
github.com/cloudflare/cloudflare-go v0.108.0/go.mod h1:m492eNahT/9MsN7Ppnoge8AaI7QhVFtEgVm3I9HJFeU=
|
||||
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
|
||||
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
@ -40,6 +41,7 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
|||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
|
||||
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
|
@ -61,6 +63,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
|||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g=
|
||||
github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
|
||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||
|
@ -88,6 +96,9 @@ github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPK
|
|||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
|
@ -140,6 +151,9 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
|
|||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
|
@ -38,7 +37,6 @@ func NewHandler() http.Handler {
|
|||
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
|
||||
mux.HandleFunc("GET", "/v1/stats", v1.Stats)
|
||||
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
|
||||
mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc())
|
||||
return mux
|
||||
}
|
||||
|
||||
|
@ -50,7 +48,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
|
||||
Logger.Warnf("blocked API request from %s", host)
|
||||
LogWarn(r).Msgf("blocked API request from %s", host)
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
package errorpage
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
)
|
||||
|
||||
func GetHandleFunc() http.HandlerFunc {
|
||||
setup()
|
||||
return serveHTTP
|
||||
}
|
||||
|
||||
func serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/" {
|
||||
http.Error(w, "invalid path", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
content, ok := fileContentMap.Load(r.URL.Path)
|
||||
if !ok {
|
||||
http.Error(w, "404 not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(content); err != nil {
|
||||
HandleErr(w, r, err)
|
||||
}
|
||||
}
|
|
@ -39,15 +39,16 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var validateErr E.Error
|
||||
var valErr E.Error
|
||||
if filename == common.ConfigFileName {
|
||||
validateErr = config.Validate(content)
|
||||
valErr = config.Validate(content)
|
||||
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {
|
||||
validateErr = provider.Validate(content)
|
||||
valErr = provider.Validate(content)
|
||||
}
|
||||
// no validation for include files
|
||||
|
||||
if validateErr != nil {
|
||||
U.RespondJSON(w, r, validateErr.JSONObject(), http.StatusBadRequest)
|
||||
if valErr != nil {
|
||||
U.RespondJSON(w, r, valErr, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -20,16 +20,13 @@ func ReloadServer() E.Error {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode)
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
failure := E.Errorf("server reload status %v", resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return failure.Extraf("unable to read response body: %s", err)
|
||||
return failure.With(err)
|
||||
}
|
||||
reloadErr, ok := E.FromJSON(b)
|
||||
if ok {
|
||||
return E.Join("reload success, but server returned error", reloadErr)
|
||||
}
|
||||
return failure.Extraf("unable to read response body")
|
||||
reloadErr := string(body)
|
||||
return failure.Withf(reloadErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -42,7 +39,7 @@ func List[T any](what string) (_ T, outErr E.Error) {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode)
|
||||
outErr = E.Errorf("list %s: failed, status %v", what, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
var res T
|
||||
|
|
|
@ -9,8 +9,8 @@ import (
|
|||
|
||||
func Reload(w http.ResponseWriter, r *http.Request) {
|
||||
if err := config.Reload(); err != nil {
|
||||
U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
U.HandleErr(w, r, err)
|
||||
return
|
||||
}
|
||||
U.WriteBody(w, []byte("OK"))
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/server"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
func Stats(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -23,7 +23,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
|
|||
originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses))
|
||||
|
||||
if len(originPats) == 0 {
|
||||
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins")
|
||||
U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
|
||||
originPats = []string{"*"}
|
||||
} else {
|
||||
for i, domain := range config.Value().MatchDomains {
|
||||
|
@ -38,7 +38,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
|
|||
OriginPatterns: originPats,
|
||||
})
|
||||
if err != nil {
|
||||
U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err)
|
||||
U.LogError(r).Err(err).Msg("failed to upgrade websocket")
|
||||
return
|
||||
}
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
|
@ -53,7 +53,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
|
|||
for range ticker.C {
|
||||
stats := getStats()
|
||||
if err := wsjson.Write(ctx, conn, stats); err != nil {
|
||||
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err)
|
||||
U.LogError(r).Msg("failed to write JSON")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -62,6 +62,6 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
|
|||
func getStats() map[string]any {
|
||||
return map[string]any{
|
||||
"proxies": config.Statistics(),
|
||||
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
|
||||
"uptime": strutils.FormatDuration(server.GetProxyServer().Uptime()),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,37 +1,31 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
var Logger = logrus.WithField("module", "api")
|
||||
|
||||
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
|
||||
if origErr == nil {
|
||||
return
|
||||
}
|
||||
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
|
||||
Logger.Error(err)
|
||||
LogError(r).Msg(origErr.Error())
|
||||
if len(code) > 0 {
|
||||
http.Error(w, err.String(), code[0])
|
||||
http.Error(w, origErr.Error(), code[0])
|
||||
return
|
||||
}
|
||||
http.Error(w, err.String(), http.StatusInternalServerError)
|
||||
http.Error(w, origErr.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func ErrMissingKey(k string) error {
|
||||
return errors.New("missing key '" + k + "' in query or request body")
|
||||
return E.New("missing key '" + k + "' in query or request body")
|
||||
}
|
||||
|
||||
func ErrInvalidKey(k string) error {
|
||||
return errors.New("invalid key '" + k + "' in query or request body")
|
||||
return E.New("invalid key '" + k + "' in query or request body")
|
||||
}
|
||||
|
||||
func ErrNotFound(k, v string) error {
|
||||
return fmt.Errorf("key %q with value %q not found", k, v)
|
||||
return E.Errorf("key %q with value %q not found", k, v)
|
||||
}
|
||||
|
|
|
@ -9,12 +9,11 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
HTTPClient = &http.Client{
|
||||
httpClient = &http.Client{
|
||||
Timeout: common.ConnectionTimeout,
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DisableKeepAlives: true,
|
||||
ForceAttemptHTTP2: true,
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: common.DialTimeout,
|
||||
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
|
||||
|
@ -23,7 +22,7 @@ var (
|
|||
},
|
||||
}
|
||||
|
||||
Get = HTTPClient.Get
|
||||
Post = HTTPClient.Post
|
||||
Head = HTTPClient.Head
|
||||
Get = httpClient.Get
|
||||
Post = httpClient.Post
|
||||
Head = httpClient.Head
|
||||
)
|
||||
|
|
18
internal/api/v1/utils/logging.go
Normal file
18
internal/api/v1/utils/logging.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
||||
return logging.WithLevel(level).Str("module", "api").
|
||||
Str("method", r.Method).
|
||||
Str("path", r.RequestURI)
|
||||
}
|
||||
|
||||
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
|
||||
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
|
||||
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }
|
|
@ -3,6 +3,8 @@ package utils
|
|||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func WriteBody(w http.ResponseWriter, body []byte) {
|
||||
|
@ -11,15 +13,25 @@ func WriteBody(w http.ResponseWriter, body []byte) {
|
|||
}
|
||||
}
|
||||
|
||||
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) bool {
|
||||
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
|
||||
if len(code) > 0 {
|
||||
w.WriteHeader(code[0])
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
j, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
HandleErr(w, r, err)
|
||||
return false
|
||||
var j []byte
|
||||
var err error
|
||||
|
||||
switch data := data.(type) {
|
||||
case string:
|
||||
j = []byte(`"` + data + `"`)
|
||||
case []byte:
|
||||
j = data
|
||||
default:
|
||||
j, err = json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
logging.Panic().Err(err).Msg("failed to marshal json")
|
||||
return false
|
||||
}
|
||||
}
|
||||
_, err = w.Write(j)
|
||||
if err != nil {
|
||||
|
|
|
@ -8,12 +8,21 @@ import (
|
|||
"github.com/go-acme/lego/v4/certcrypto"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
)
|
||||
|
||||
type Config types.AutoCertConfig
|
||||
|
||||
var (
|
||||
ErrMissingDomain = E.New("missing field 'domains'")
|
||||
ErrMissingEmail = E.New("missing field 'email'")
|
||||
ErrMissingProvider = E.New("missing field 'provider'")
|
||||
ErrUnknownProvider = E.New("unknown provider")
|
||||
)
|
||||
|
||||
func NewConfig(cfg *types.AutoCertConfig) *Config {
|
||||
if cfg.CertPath == "" {
|
||||
cfg.CertPath = CertFileDefault
|
||||
|
@ -27,35 +36,36 @@ func NewConfig(cfg *types.AutoCertConfig) *Config {
|
|||
return (*Config)(cfg)
|
||||
}
|
||||
|
||||
func (cfg *Config) GetProvider() (provider *Provider, res E.Error) {
|
||||
b := E.NewBuilder("unable to initialize autocert")
|
||||
defer b.To(&res)
|
||||
func (cfg *Config) GetProvider() (*Provider, E.Error) {
|
||||
b := E.NewBuilder("autocert errors")
|
||||
|
||||
if cfg.Provider != ProviderLocal {
|
||||
if len(cfg.Domains) == 0 {
|
||||
b.Addf("%s", "no domains specified")
|
||||
b.Add(ErrMissingDomain)
|
||||
}
|
||||
if cfg.Provider == "" {
|
||||
b.Addf("%s", "no provider specified")
|
||||
b.Add(ErrMissingProvider)
|
||||
}
|
||||
if cfg.Email == "" {
|
||||
b.Addf("%s", "no email specified")
|
||||
b.Add(ErrMissingEmail)
|
||||
}
|
||||
// check if provider is implemented
|
||||
_, ok := providersGenMap[cfg.Provider]
|
||||
if !ok {
|
||||
b.Addf("unknown provider: %q", cfg.Provider)
|
||||
b.Add(ErrUnknownProvider.
|
||||
Subject(cfg.Provider).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
|
||||
}
|
||||
}
|
||||
|
||||
if b.HasError() {
|
||||
return
|
||||
return nil, b.Error()
|
||||
}
|
||||
|
||||
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
|
||||
if err.HasError() {
|
||||
b.Add(E.FailWith("generate private key", err))
|
||||
return
|
||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
b.Addf("generate private key: %w", err)
|
||||
return nil, b.Error()
|
||||
}
|
||||
|
||||
user := &User{
|
||||
|
@ -66,11 +76,9 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.Error) {
|
|||
legoCfg := lego.NewConfig(user)
|
||||
legoCfg.Certificate.KeyType = certcrypto.RSA2048
|
||||
|
||||
provider = &Provider{
|
||||
return &Provider{
|
||||
cfg: cfg,
|
||||
user: user,
|
||||
legoCfg: legoCfg,
|
||||
}
|
||||
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
|
||||
"github.com/go-acme/lego/v4/providers/dns/duckdns"
|
||||
"github.com/go-acme/lego/v4/providers/dns/ovh"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -36,5 +35,3 @@ var providersGenMap = map[string]ProviderGenerator{
|
|||
var (
|
||||
ErrGetCertFailure = errors.New("get certificate failed")
|
||||
)
|
||||
|
||||
var logger = logrus.WithField("module", "autocert")
|
||||
|
|
5
internal/autocert/logger.go
Normal file
5
internal/autocert/logger.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package autocert
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "autocert").Logger()
|
|
@ -17,6 +17,7 @@ import (
|
|||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -57,25 +58,20 @@ func (p *Provider) GetExpiries() CertExpiries {
|
|||
return p.certExpiries
|
||||
}
|
||||
|
||||
func (p *Provider) ObtainCert() (res E.Error) {
|
||||
b := E.NewBuilder("failed to obtain certificate")
|
||||
defer b.To(&res)
|
||||
|
||||
func (p *Provider) ObtainCert() E.Error {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.client == nil {
|
||||
if err := p.initClient(); err.HasError() {
|
||||
b.Add(E.FailWith("init autocert client", err))
|
||||
return
|
||||
if err := p.initClient(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if p.user.Registration == nil {
|
||||
if err := p.registerACME(); err.HasError() {
|
||||
b.Add(E.FailWith("register ACME", err))
|
||||
return
|
||||
if err := p.registerACME(); err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,27 +80,23 @@ func (p *Provider) ObtainCert() (res E.Error) {
|
|||
Domains: p.cfg.Domains,
|
||||
Bundle: true,
|
||||
}
|
||||
cert, err := E.Check(client.Certificate.Obtain(req))
|
||||
if err.HasError() {
|
||||
b.Add(err)
|
||||
return
|
||||
cert, err := client.Certificate.Obtain(req)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
if err = p.saveCert(cert); err.HasError() {
|
||||
b.Add(E.FailWith("save certificate", err))
|
||||
return
|
||||
if err = p.saveCert(cert); err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
|
||||
if err.HasError() {
|
||||
b.Add(E.FailWith("parse obtained certificate", err))
|
||||
return
|
||||
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
expiries, err := getCertExpiries(&tlsCert)
|
||||
if err.HasError() {
|
||||
b.Add(E.FailWith("get certificate expiry", err))
|
||||
return
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
p.tlsCert = &tlsCert
|
||||
p.certExpiries = expiries
|
||||
|
@ -113,21 +105,22 @@ func (p *Provider) ObtainCert() (res E.Error) {
|
|||
}
|
||||
|
||||
func (p *Provider) LoadCert() E.Error {
|
||||
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
|
||||
if err.HasError() {
|
||||
return err
|
||||
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
||||
if err != nil {
|
||||
return E.Errorf("load SSL certificate: %w", err)
|
||||
}
|
||||
expiries, err := getCertExpiries(&cert)
|
||||
if err.HasError() {
|
||||
return err
|
||||
if err != nil {
|
||||
return E.Errorf("parse SSL certificate: %w", err)
|
||||
}
|
||||
p.tlsCert = &cert
|
||||
p.certExpiries = expiries
|
||||
|
||||
logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn())))
|
||||
logger.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
|
||||
return p.renewIfNeeded()
|
||||
}
|
||||
|
||||
// ShouldRenewOn returns the time at which the certificate should be renewed.
|
||||
func (p *Provider) ShouldRenewOn() time.Time {
|
||||
for _, expiry := range p.certExpiries {
|
||||
return expiry.AddDate(0, -1, 0) // 1 month before
|
||||
|
@ -150,8 +143,8 @@ func (p *Provider) ScheduleRenewal() {
|
|||
case <-task.Context().Done():
|
||||
return
|
||||
case <-ticker.C: // check every 5 seconds
|
||||
if err := p.renewIfNeeded(); err.HasError() {
|
||||
logger.Warn(err)
|
||||
if err := p.renewIfNeeded(); err != nil {
|
||||
E.LogWarn("cert renew failed", err, &logger)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -159,31 +152,32 @@ func (p *Provider) ScheduleRenewal() {
|
|||
}
|
||||
|
||||
func (p *Provider) initClient() E.Error {
|
||||
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
|
||||
if err.HasError() {
|
||||
return E.FailWith("create lego client", err)
|
||||
legoClient, err := lego.NewClient(p.legoCfg)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
|
||||
if err.HasError() {
|
||||
return E.FailWith("create lego provider", err)
|
||||
generator := providersGenMap[p.cfg.Provider]
|
||||
legoProvider, pErr := generator(p.cfg.Options)
|
||||
if pErr != nil {
|
||||
return pErr
|
||||
}
|
||||
|
||||
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
|
||||
if err.HasError() {
|
||||
return E.FailWith("set challenge provider", err)
|
||||
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
p.client = legoClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) registerACME() E.Error {
|
||||
func (p *Provider) registerACME() error {
|
||||
if p.user.Registration != nil {
|
||||
return nil
|
||||
}
|
||||
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
|
||||
if err.HasError() {
|
||||
reg, err := p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.user.Registration = reg
|
||||
|
@ -191,26 +185,27 @@ func (p *Provider) registerACME() E.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) saveCert(cert *certificate.Resource) E.Error {
|
||||
func (p *Provider) saveCert(cert *certificate.Resource) error {
|
||||
/* This should have been done in setup
|
||||
but double check is always a good choice.*/
|
||||
_, err := os.Stat(path.Dir(p.cfg.CertPath))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
|
||||
return E.FailWith("create cert directory", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return E.FailWith("stat cert directory", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
|
||||
if err != nil {
|
||||
return E.FailWith("write key file", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
|
||||
if err != nil {
|
||||
return E.FailWith("write cert file", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -232,7 +227,7 @@ func (p *Provider) certState() CertState {
|
|||
sort.Strings(certDomains)
|
||||
|
||||
if !reflect.DeepEqual(certDomains, wantedDomains) {
|
||||
logger.Infof("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
|
||||
logger.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
|
||||
return CertStateMismatch
|
||||
}
|
||||
|
||||
|
@ -246,25 +241,25 @@ func (p *Provider) renewIfNeeded() E.Error {
|
|||
|
||||
switch p.certState() {
|
||||
case CertStateExpired:
|
||||
logger.Info("certs expired, renewing")
|
||||
logger.Info().Msg("certs expired, renewing")
|
||||
case CertStateMismatch:
|
||||
logger.Info("cert domains mismatch with config, renewing")
|
||||
logger.Info().Msg("cert domains mismatch with config, renewing")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.ObtainCert(); err.HasError() {
|
||||
return E.FailWith("renew certificate", err)
|
||||
if err := p.ObtainCert(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) {
|
||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
||||
r := make(CertExpiries, len(cert.Certificate))
|
||||
for _, cert := range cert.Certificate {
|
||||
x509Cert, err := E.Check(x509.ParseCertificate(cert))
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("parse certificate", err)
|
||||
x509Cert, err := x509.ParseCertificate(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if x509Cert.IsCA {
|
||||
continue
|
||||
|
@ -284,13 +279,10 @@ func providerGenerator[CT any, PT challenge.Provider](
|
|||
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
|
||||
cfg := defaultCfg()
|
||||
err := U.Deserialize(opt, cfg)
|
||||
if err.HasError() {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p, err := E.Check(newProvider(cfg))
|
||||
if err.HasError() {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
p, pErr := newProvider(cfg)
|
||||
return p, E.From(pErr)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,6 +45,6 @@ oauth2_config:
|
|||
testYaml = testYaml[1:] // remove first \n
|
||||
opt := make(map[string]any)
|
||||
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
|
||||
ExpectNoError(t, U.Deserialize(opt, cfg).Error())
|
||||
ExpectNoError(t, U.Deserialize(opt, cfg))
|
||||
ExpectDeepEqual(t, cfg, cfgExpected)
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ func (p *Provider) Setup() (err E.Error) {
|
|||
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
|
||||
return err
|
||||
}
|
||||
logger.Debug("obtaining cert due to error loading cert")
|
||||
logger.Debug().Msg("obtaining cert due to error loading cert")
|
||||
if err = p.ObtainCert(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ func (p *Provider) Setup() (err E.Error) {
|
|||
p.ScheduleRenewal()
|
||||
|
||||
for _, expiry := range p.GetExpiries() {
|
||||
logger.Infof("certificate expire on %s", expiry)
|
||||
logger.Info().Msg("certificate expire on " + expiry.String())
|
||||
break
|
||||
}
|
||||
|
||||
|
|
|
@ -3,8 +3,7 @@ package common
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"log"
|
||||
)
|
||||
|
||||
type Args struct {
|
||||
|
@ -44,7 +43,7 @@ func GetArgs() Args {
|
|||
flag.Parse()
|
||||
args.Command = flag.Arg(0)
|
||||
if err := validateArg(args.Command); err != nil {
|
||||
logrus.Fatal(err)
|
||||
log.Fatalf("invalid command: %s", err)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
@ -55,5 +54,5 @@ func validateArg(arg string) error {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid command: %s", arg)
|
||||
return fmt.Errorf("invalid command %q", arg)
|
||||
}
|
||||
|
|
|
@ -7,8 +7,6 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -40,7 +38,7 @@ func GetEnvBool(key string, defaultValue bool) bool {
|
|||
}
|
||||
b, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid boolean value: %s", value)
|
||||
log.Fatalf("env %s: invalid boolean value: %s", key, value)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
@ -57,7 +55,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
|
|||
addr = GetEnv(key, defaultValue)
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Invalid address: %s", addr)
|
||||
log.Fatalf("env %s: invalid address: %s", key, addr)
|
||||
}
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
|
|
|
@ -2,14 +2,15 @@ package config
|
|||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
||||
|
@ -31,7 +32,7 @@ type Config struct {
|
|||
var (
|
||||
instance *Config
|
||||
cfgWatcher watcher.Watcher
|
||||
logger = logrus.WithField("module", "config")
|
||||
logger = logging.With().Str("module", "config").Logger()
|
||||
reloadMu sync.Mutex
|
||||
)
|
||||
|
||||
|
@ -80,7 +81,7 @@ func WatchChanges() {
|
|||
configEventFlushInterval,
|
||||
OnConfigChange,
|
||||
func(err E.Error) {
|
||||
logger.Error(err)
|
||||
E.LogError("config reload error", err, &logger)
|
||||
},
|
||||
)
|
||||
eventQueue.Start(cfgWatcher.Events(task.Context()))
|
||||
|
@ -93,15 +94,16 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) {
|
|||
// just reload once and check the last event
|
||||
switch ev[len(ev)-1].Action {
|
||||
case events.ActionFileRenamed:
|
||||
logger.Warn(cfgRenameWarn)
|
||||
logger.Warn().Msg(cfgRenameWarn)
|
||||
return
|
||||
case events.ActionFileDeleted:
|
||||
logger.Warn(cfgDeleteWarn)
|
||||
logger.Warn().Msg(cfgDeleteWarn)
|
||||
return
|
||||
}
|
||||
|
||||
if err := Reload(); err != nil {
|
||||
logger.Error(err)
|
||||
// recovered in event queue
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,63 +140,57 @@ func (cfg *Config) Task() task.Task {
|
|||
}
|
||||
|
||||
func (cfg *Config) StartProxyProviders() {
|
||||
b := E.NewBuilder("errors starting providers")
|
||||
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
|
||||
b.Add(p.Start(cfg.task.Subtask(p.String())))
|
||||
})
|
||||
errs := cfg.providers.CollectErrorsParallel(
|
||||
func(_ string, p *proxy.Provider) error {
|
||||
subtask := cfg.task.Subtask(p.String())
|
||||
return p.Start(subtask)
|
||||
})
|
||||
|
||||
if b.HasError() {
|
||||
logger.Error(b.Build())
|
||||
if err := E.Join(errs...); err != nil {
|
||||
E.LogError("route provider errors", err, &logger)
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) load() (res E.Error) {
|
||||
errs := E.NewBuilder("errors loading config")
|
||||
defer errs.To(&res)
|
||||
func (cfg *Config) load() E.Error {
|
||||
const errMsg = "config load error"
|
||||
|
||||
logger.Debug("loading config")
|
||||
defer logger.Debug("loaded config")
|
||||
|
||||
data, err := E.Check(os.ReadFile(common.ConfigPath))
|
||||
data, err := os.ReadFile(common.ConfigPath)
|
||||
if err != nil {
|
||||
errs.Add(E.FailWith("read config", err))
|
||||
logrus.Fatal(errs.Build())
|
||||
E.LogFatal(errMsg, err, &logger)
|
||||
}
|
||||
|
||||
if !common.NoSchemaValidation {
|
||||
if err = Validate(data); err != nil {
|
||||
errs.Add(E.FailWith("schema validation", err))
|
||||
logrus.Fatal(errs.Build())
|
||||
if err := Validate(data); err != nil {
|
||||
E.LogFatal(errMsg, err, &logger)
|
||||
}
|
||||
}
|
||||
|
||||
model := types.DefaultConfig()
|
||||
if err := E.From(yaml.Unmarshal(data, model)); err != nil {
|
||||
errs.Add(E.FailWith("parse config", err))
|
||||
logrus.Fatal(errs.Build())
|
||||
E.LogFatal(errMsg, err, &logger)
|
||||
}
|
||||
|
||||
// errors are non fatal below
|
||||
errs := E.NewBuilder(errMsg)
|
||||
errs.Add(cfg.initNotification(model.Providers.Notification))
|
||||
errs.Add(cfg.initAutoCert(&model.AutoCert))
|
||||
errs.Add(cfg.loadRouteProviders(&model.Providers))
|
||||
|
||||
cfg.value = model
|
||||
route.SetFindMuxDomains(model.MatchDomains)
|
||||
return
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) {
|
||||
if len(notifCfgMap) == 0 {
|
||||
return
|
||||
}
|
||||
errs := E.NewBuilder("errors initializing notification providers")
|
||||
|
||||
errs := E.NewBuilder("notification providers load errors")
|
||||
for name, notifCfg := range notifCfgMap {
|
||||
_, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg)
|
||||
errs.Add(err)
|
||||
}
|
||||
return errs.Build()
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
|
||||
|
@ -203,40 +199,45 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error)
|
|||
}
|
||||
|
||||
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
|
||||
if err != nil {
|
||||
err = E.FailWith("autocert provider", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *Config) loadRouteProviders(providers *types.Providers) (outErr E.Error) {
|
||||
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
|
||||
subtask := cfg.task.Subtask("load route providers")
|
||||
defer subtask.Finish("done")
|
||||
|
||||
errs := E.NewBuilder("errors loading route providers")
|
||||
results := E.NewBuilder("loaded providers")
|
||||
defer errs.To(&outErr)
|
||||
errs := E.NewBuilder("route provider errors")
|
||||
results := E.NewBuilder("loaded route providers")
|
||||
|
||||
lenLongestName := 0
|
||||
for _, filename := range providers.Files {
|
||||
p, err := proxy.NewFileProvider(filename)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
errs.Add(E.PrependSubject(filename, err))
|
||||
continue
|
||||
}
|
||||
cfg.providers.Store(p.GetName(), p)
|
||||
errs.Add(p.LoadRoutes().Subject(filename))
|
||||
results.Addf("%d routes from %s", p.NumRoutes(), p.String())
|
||||
if len(p.GetName()) > lenLongestName {
|
||||
lenLongestName = len(p.GetName())
|
||||
}
|
||||
}
|
||||
for name, dockerHost := range providers.Docker {
|
||||
p, err := proxy.NewDockerProvider(name, dockerHost)
|
||||
if err != nil {
|
||||
errs.Add(err.Subjectf("%s (%s)", name, dockerHost))
|
||||
errs.Add(E.PrependSubject(name, err))
|
||||
continue
|
||||
}
|
||||
cfg.providers.Store(p.GetName(), p)
|
||||
errs.Add(p.LoadRoutes().Subject(p.GetName()))
|
||||
results.Addf("%d routes from %s", p.NumRoutes(), p.String())
|
||||
if len(p.GetName()) > lenLongestName {
|
||||
lenLongestName = len(p.GetName())
|
||||
}
|
||||
}
|
||||
logger.Info(results.Build())
|
||||
return
|
||||
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
|
||||
if err := p.LoadRoutes(); err != nil {
|
||||
errs.Add(err.Subject(p.String()))
|
||||
}
|
||||
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.GetName(), p.NumRoutes())
|
||||
})
|
||||
logger.Info().Msg(results.String())
|
||||
return errs.Error()
|
||||
}
|
||||
|
|
|
@ -9,8 +9,8 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
func DumpEntries() map[string]*entry.RawEntry {
|
||||
|
@ -61,7 +61,7 @@ func HomepageConfig() homepage.Config {
|
|||
}
|
||||
|
||||
if item.Name == "" {
|
||||
item.Name = U.Title(
|
||||
item.Name = strutils.Title(
|
||||
strings.ReplaceAll(
|
||||
strings.ReplaceAll(alias, "-", " "),
|
||||
"_", " ",
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package docker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/docker/cli/cli/connhelper"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
|
@ -22,7 +23,7 @@ type (
|
|||
key string
|
||||
refCount *U.RefCount
|
||||
|
||||
l logrus.FieldLogger
|
||||
l zerolog.Logger
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -70,7 +71,7 @@ func (c *SharedClient) Close() error {
|
|||
// Returns:
|
||||
// - Client: the Docker client connection.
|
||||
// - error: an error if the connection failed.
|
||||
func ConnectClient(host string) (Client, E.Error) {
|
||||
func ConnectClient(host string) (Client, error) {
|
||||
clientMapMu.Lock()
|
||||
defer clientMapMu.Unlock()
|
||||
|
||||
|
@ -85,13 +86,13 @@ func ConnectClient(host string) (Client, E.Error) {
|
|||
|
||||
switch host {
|
||||
case "":
|
||||
return nil, E.Invalid("docker host", "empty")
|
||||
return nil, errors.New("empty docker host")
|
||||
case common.DockerHostFromEnv:
|
||||
opt = clientOptEnvHost
|
||||
default:
|
||||
helper, err := E.Check(connhelper.GetConnectionHelper(host))
|
||||
if err.HasError() {
|
||||
return nil, E.UnexpectedError(err.Error())
|
||||
helper, err := connhelper.GetConnectionHelper(host)
|
||||
if err != nil {
|
||||
logging.Panic().Err(err).Msg("failed to get connection helper")
|
||||
}
|
||||
if helper != nil {
|
||||
httpClient := &http.Client{
|
||||
|
@ -113,9 +114,9 @@ func ConnectClient(host string) (Client, E.Error) {
|
|||
}
|
||||
}
|
||||
|
||||
client, err := E.Check(client.NewClientWithOpts(opt...))
|
||||
client, err := client.NewClientWithOpts(opt...)
|
||||
|
||||
if err.HasError() {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -123,9 +124,9 @@ func ConnectClient(host string) (Client, E.Error) {
|
|||
Client: client,
|
||||
key: host,
|
||||
refCount: U.NewRefCounter(),
|
||||
l: logger.WithField("docker_client", client.DaemonHost()),
|
||||
l: logger.With().Str("address", client.DaemonHost()).Logger(),
|
||||
}
|
||||
c.l.Debugf("client connected")
|
||||
c.l.Trace().Msg("client connected")
|
||||
|
||||
clientMap.Store(host, c)
|
||||
|
||||
|
@ -135,7 +136,7 @@ func ConnectClient(host string) (Client, E.Error) {
|
|||
|
||||
if c.Connected() {
|
||||
c.Client.Close()
|
||||
c.l.Debugf("client closed")
|
||||
c.l.Trace().Msg("client closed")
|
||||
}
|
||||
}()
|
||||
return c, nil
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -59,7 +59,7 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) {
|
|||
NetworkMode: c.HostConfig.NetworkMode,
|
||||
|
||||
Aliases: helper.getAliases(),
|
||||
IsExcluded: U.ParseBool(helper.getDeleteLabel(LabelExclude)),
|
||||
IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)),
|
||||
IsExplicit: isExplicit,
|
||||
IsDatabase: helper.isDatabase(),
|
||||
IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout),
|
||||
|
@ -120,7 +120,7 @@ func (c *Container) setPublicIP() {
|
|||
}
|
||||
url, err := url.Parse(c.DockerHost)
|
||||
if err != nil {
|
||||
logrus.Errorf("invalid docker host %q: %v\nfalling back to 127.0.0.1", c.DockerHost, err)
|
||||
logger.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
|
||||
c.PublicIP = "127.0.0.1"
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type containerHelper struct {
|
||||
|
@ -23,7 +23,7 @@ func (c containerHelper) getDeleteLabel(label string) string {
|
|||
|
||||
func (c containerHelper) getAliases() []string {
|
||||
if l := c.getDeleteLabel(LabelAliases); l != "" {
|
||||
return U.CommaSeperatedList(l)
|
||||
return strutils.CommaSeperatedList(l)
|
||||
}
|
||||
return []string{c.getName()}
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping {
|
|||
if v.PublicPort == 0 {
|
||||
continue
|
||||
}
|
||||
res[U.PortString(v.PublicPort)] = v
|
||||
res[strutils.PortString(v.PublicPort)] = v
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping {
|
|||
func (c containerHelper) getPrivatePortMapping() PortMapping {
|
||||
res := make(PortMapping)
|
||||
for _, v := range c.Ports {
|
||||
res[U.PortString(v.PrivatePort)] = v
|
||||
res[strutils.PortString(v.PrivatePort)] = v
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
|
@ -30,7 +31,7 @@ const (
|
|||
StopMethodKill StopMethod = "kill"
|
||||
)
|
||||
|
||||
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
|
||||
func ValidateConfig(cont *docker.Container) (*Config, E.Error) {
|
||||
if cont == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -44,26 +45,16 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
b := E.NewBuilder("invalid idlewatcher config")
|
||||
defer b.To(&res)
|
||||
errs := E.NewBuilder("invalid idlewatcher config")
|
||||
|
||||
idleTimeout, err := validateDurationPostitive(cont.IdleTimeout)
|
||||
b.Add(err.Subjectf("%s", "idle_timeout"))
|
||||
idleTimeout := E.Collect(errs, validateDurationPostitive, cont.IdleTimeout)
|
||||
wakeTimeout := E.Collect(errs, validateDurationPostitive, cont.WakeTimeout)
|
||||
stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout)
|
||||
stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod)
|
||||
signal := E.Collect(errs, validateSignal, cont.StopSignal)
|
||||
|
||||
wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout)
|
||||
b.Add(err.Subjectf("%s", "wake_timeout"))
|
||||
|
||||
stopTimeout, err := validateDurationPostitive(cont.StopTimeout)
|
||||
b.Add(err.Subjectf("%s", "stop_timeout"))
|
||||
|
||||
stopMethod, err := validateStopMethod(cont.StopMethod)
|
||||
b.Add(err)
|
||||
|
||||
signal, err := validateSignal(cont.StopSignal)
|
||||
b.Add(err)
|
||||
|
||||
if err := b.Build(); err != nil {
|
||||
return
|
||||
if errs.HasError() {
|
||||
return nil, errs.Error()
|
||||
}
|
||||
|
||||
return &Config{
|
||||
|
@ -80,33 +71,33 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func validateDurationPostitive(value string) (time.Duration, E.Error) {
|
||||
func validateDurationPostitive(value string) (time.Duration, error) {
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, E.Invalid("duration", value).With(err)
|
||||
return 0, err
|
||||
}
|
||||
if d < 0 {
|
||||
return 0, E.Invalid("duration", "negative value")
|
||||
return 0, errors.New("duration must be positive")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func validateSignal(s string) (Signal, E.Error) {
|
||||
func validateSignal(s string) (Signal, error) {
|
||||
switch s {
|
||||
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
|
||||
"INT", "TERM", "HUP", "QUIT":
|
||||
return Signal(s), nil
|
||||
}
|
||||
|
||||
return "", E.Invalid("signal", s)
|
||||
return "", errors.New("invalid signal " + s)
|
||||
}
|
||||
|
||||
func validateStopMethod(s string) (StopMethod, E.Error) {
|
||||
func validateStopMethod(s string) (StopMethod, error) {
|
||||
sm := StopMethod(s)
|
||||
switch sm {
|
||||
case StopMethodPause, StopMethodStop, StopMethodKill:
|
||||
return sm, nil
|
||||
default:
|
||||
return "", E.Invalid("stop_method", sm)
|
||||
return "", errors.New("invalid stop method " + s)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr
|
|||
|
||||
watcher, err := registerWatcher(providerSubTask, entry, waker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, E.Errorf("register watcher: %w", err)
|
||||
}
|
||||
|
||||
if rp != nil {
|
||||
|
@ -75,6 +75,9 @@ func (w *Watcher) Start(routeSubTask task.Task) E.Error {
|
|||
|
||||
// Finish implements health.HealthMonitor.
|
||||
func (w *Watcher) Finish(reason any) {
|
||||
if w.stream != nil {
|
||||
w.stream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Name implements health.HealthMonitor.
|
||||
|
|
|
@ -7,9 +7,7 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
@ -20,21 +18,26 @@ func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
if !shouldNext {
|
||||
return
|
||||
}
|
||||
w.rp.ServeHTTP(rw, r)
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
default:
|
||||
w.rp.ServeHTTP(rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
|
||||
w.resetIdleTimer()
|
||||
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
}
|
||||
|
||||
// pass through if container is already ready
|
||||
if w.ready.Load() {
|
||||
return true
|
||||
}
|
||||
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
}
|
||||
|
||||
accept := gphttp.GetAccept(r.Header)
|
||||
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
|
||||
|
||||
|
@ -49,23 +52,22 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
rw.Header().Add("Cache-Control", "must-revalidate")
|
||||
rw.Header().Add("Connection", "close")
|
||||
if _, err := rw.Write(body); err != nil {
|
||||
w.l.Errorf("error writing http response: %s", err)
|
||||
w.Err(err).Msg("error writing http response")
|
||||
}
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
|
||||
defer cancel()
|
||||
|
||||
checkCanceled := func() bool {
|
||||
checkCanceled := func() (canceled bool) {
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context()))
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
w.l.Debugf("wake canceled: %s", context.Cause(ctx))
|
||||
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
|
||||
w.WakeDebug().Str("cause", context.Cause(ctx).Error()).Msg("canceled")
|
||||
return true
|
||||
case <-w.task.Context().Done():
|
||||
w.WakeDebug().Str("cause", w.task.FinishCause().Error()).Msg("canceled")
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
@ -76,12 +78,12 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
return false
|
||||
}
|
||||
|
||||
w.l.Debug("wake signal received")
|
||||
w.WakeTrace().Msg("signal received")
|
||||
err := w.wakeIfStopped()
|
||||
if err != nil {
|
||||
w.l.Error(E.FailWith("wake", err))
|
||||
w.WakeError(err).Send()
|
||||
http.Error(rw, "Error waking container", http.StatusInternalServerError)
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
for {
|
||||
|
@ -92,11 +94,11 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
if w.Status() == health.StatusHealthy {
|
||||
w.resetIdleTimer()
|
||||
if isCheckRedirect {
|
||||
logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL())
|
||||
w.Debug().Msgf("redirecting to %s ...", w.hc.URL())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
return false
|
||||
}
|
||||
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
|
||||
w.Debug().Msgf("passing through to %s ...", w.hc.URL())
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
|
@ -16,25 +16,25 @@ func (w *Watcher) Addr() net.Addr {
|
|||
return w.stream.Addr()
|
||||
}
|
||||
|
||||
// Setup implements types.Stream.
|
||||
func (w *Watcher) Setup() error {
|
||||
return w.stream.Setup()
|
||||
}
|
||||
|
||||
// Accept implements types.Stream.
|
||||
func (w *Watcher) Accept() (conn net.Conn, err error) {
|
||||
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
|
||||
conn, err = w.stream.Accept()
|
||||
if err != nil {
|
||||
logrus.Errorf("accept failed with error: %s", err)
|
||||
return
|
||||
}
|
||||
if err := w.wakeFromStream(); err != nil {
|
||||
w.l.Error(err)
|
||||
if wakeErr := w.wakeFromStream(); wakeErr != nil {
|
||||
w.WakeError(wakeErr).Msg("error waking from stream")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle implements types.Stream.
|
||||
func (w *Watcher) Handle(conn net.Conn) error {
|
||||
func (w *Watcher) Handle(conn types.StreamConn) error {
|
||||
if err := w.wakeFromStream(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -54,11 +54,11 @@ func (w *Watcher) wakeFromStream() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
w.l.Debug("wake signal received")
|
||||
w.WakeDebug().Msg("wake signal received")
|
||||
wakeErr := w.wakeIfStopped()
|
||||
if wakeErr != nil {
|
||||
wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr)
|
||||
w.l.Error(wakeErr)
|
||||
wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr)
|
||||
w.WakeError(wakeErr).Msg("wake failed")
|
||||
return wakeErr
|
||||
}
|
||||
|
||||
|
@ -69,18 +69,18 @@ func (w *Watcher) wakeFromStream() error {
|
|||
select {
|
||||
case <-w.task.Context().Done():
|
||||
cause := w.task.FinishCause()
|
||||
w.l.Debugf("wake canceled: %s", cause)
|
||||
w.WakeDebug().Str("cause", cause.Error()).Msg("canceled")
|
||||
return cause
|
||||
case <-ctx.Done():
|
||||
cause := context.Cause(ctx)
|
||||
w.l.Debugf("wake canceled: %s", cause)
|
||||
w.WakeDebug().Str("cause", cause.Error()).Msg("timeout")
|
||||
return cause
|
||||
default:
|
||||
}
|
||||
|
||||
if w.Status() == health.StatusHealthy {
|
||||
w.resetIdleTimer()
|
||||
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
|
||||
w.Debug().Msg("container is ready, passing through to " + w.hc.URL().String())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,15 +3,15 @@ package idlewatcher
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
|
@ -25,6 +25,8 @@ type (
|
|||
Watcher struct {
|
||||
_ U.NoCopy
|
||||
|
||||
zerolog.Logger
|
||||
|
||||
*idlewatcher.Config
|
||||
*waker
|
||||
|
||||
|
@ -32,7 +34,6 @@ type (
|
|||
stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
|
||||
ticker *time.Ticker
|
||||
task task.Task
|
||||
l *logrus.Entry
|
||||
}
|
||||
|
||||
WakeDone <-chan error
|
||||
|
@ -44,13 +45,12 @@ var (
|
|||
watcherMap = F.NewMapOf[string, *Watcher]()
|
||||
watcherMapMu sync.Mutex
|
||||
|
||||
logger = logrus.WithField("module", "idle_watcher")
|
||||
logger = logging.With().Str("module", "idle_watcher").Logger()
|
||||
)
|
||||
|
||||
const dockerReqTimeout = 3 * time.Second
|
||||
|
||||
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) {
|
||||
failure := E.Failure("idle_watcher register")
|
||||
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, error) {
|
||||
cfg := entry.IdlewatcherConfig()
|
||||
|
||||
if cfg.IdleTimeout == 0 {
|
||||
|
@ -71,17 +71,17 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
|
|||
}
|
||||
|
||||
client, err := D.ConnectClient(cfg.DockerHost)
|
||||
if err.HasError() {
|
||||
return nil, failure.With(err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w := &Watcher{
|
||||
Logger: logger.With().Str("name", cfg.ContainerName).Logger(),
|
||||
Config: cfg,
|
||||
waker: waker,
|
||||
client: client,
|
||||
task: providerSubtask,
|
||||
ticker: time.NewTicker(cfg.IdleTimeout),
|
||||
l: logger.WithField("container", cfg.ContainerName),
|
||||
}
|
||||
w.stopByMethod = w.getStopCallback()
|
||||
watcherMap.Store(key, w)
|
||||
|
@ -99,6 +99,23 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
|
|||
return w, nil
|
||||
}
|
||||
|
||||
// WakeDebug logs a debug message related to waking the container.
|
||||
func (w *Watcher) WakeDebug() *zerolog.Event {
|
||||
return w.Debug().Str("action", "wake")
|
||||
}
|
||||
|
||||
func (w *Watcher) WakeTrace() *zerolog.Event {
|
||||
return w.Trace().Str("action", "wake")
|
||||
}
|
||||
|
||||
func (w *Watcher) WakeError(err error) *zerolog.Event {
|
||||
return w.Err(err).Str("action", "wake")
|
||||
}
|
||||
|
||||
func (w *Watcher) LogReason(action, reason string) {
|
||||
w.Info().Str("reason", reason).Msg(action)
|
||||
}
|
||||
|
||||
func (w *Watcher) containerStop(ctx context.Context) error {
|
||||
return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{
|
||||
Signal: string(w.StopSignal),
|
||||
|
@ -130,7 +147,7 @@ func (w *Watcher) containerStatus() (string, error) {
|
|||
defer cancel()
|
||||
json, err := w.client.ContainerInspect(ctx, w.ContainerID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to inspect container: %w", err)
|
||||
return "", err
|
||||
}
|
||||
return json.State.Status, nil
|
||||
}
|
||||
|
@ -181,7 +198,7 @@ func (w *Watcher) getStopCallback() StopCallback {
|
|||
}
|
||||
|
||||
func (w *Watcher) resetIdleTimer() {
|
||||
w.l.Trace("reset idle timer")
|
||||
w.Trace().Msg("reset idle timer")
|
||||
w.ticker.Reset(w.IdleTimeout)
|
||||
}
|
||||
|
||||
|
@ -190,7 +207,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
|
|||
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
|
||||
Filters: W.NewDockerFilter(
|
||||
W.DockerFilterContainer,
|
||||
W.DockerrFilterContainer(w.ContainerID),
|
||||
W.DockerFilterContainerNameID(w.ContainerID),
|
||||
W.DockerFilterStart,
|
||||
W.DockerFilterStop,
|
||||
W.DockerFilterDie,
|
||||
|
@ -214,7 +231,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
|
|||
//
|
||||
// it exits only if the context is canceled, the container is destroyed,
|
||||
// errors occured on docker client, or route provider died (mainly caused by config reload).
|
||||
func (w *Watcher) watchUntilDestroy() error {
|
||||
func (w *Watcher) watchUntilDestroy() (returnCause error) {
|
||||
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
|
||||
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
|
||||
defer eventTask.Finish("stopped")
|
||||
|
@ -224,36 +241,36 @@ func (w *Watcher) watchUntilDestroy() error {
|
|||
case <-w.task.Context().Done():
|
||||
return w.task.FinishCause()
|
||||
case err := <-dockerEventErrCh:
|
||||
if err != nil && err.IsNot(context.Canceled) {
|
||||
w.l.Error(E.FailWith("docker watcher", err))
|
||||
return err.Error()
|
||||
if !err.Is(context.Canceled) {
|
||||
E.LogError("idlewatcher error", err, &w.Logger)
|
||||
}
|
||||
return err
|
||||
case e := <-dockerEventCh:
|
||||
switch {
|
||||
case e.Action == events.ActionContainerDestroy:
|
||||
w.ContainerRunning = false
|
||||
w.ready.Store(false)
|
||||
w.l.Info("watcher stopped by container destruction")
|
||||
w.LogReason("watcher stopped", "container destroyed")
|
||||
return errors.New("container destroyed")
|
||||
// create / start / unpause
|
||||
case e.Action.IsContainerWake():
|
||||
w.ContainerRunning = true
|
||||
w.resetIdleTimer()
|
||||
w.l.Info("container awaken")
|
||||
w.Info().Msg("awaken")
|
||||
case e.Action.IsContainerSleep(): // stop / pause / kil
|
||||
w.ContainerRunning = false
|
||||
w.ready.Store(false)
|
||||
w.ticker.Stop()
|
||||
default:
|
||||
w.l.Errorf("unexpected docker event: %s", e)
|
||||
w.Error().Msg("unexpected docker event: " + e.String())
|
||||
}
|
||||
// container name changed should also change the container id
|
||||
if w.ContainerName != e.ActorName {
|
||||
w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName)
|
||||
w.Debug().Msgf("renamed %s -> %s", w.ContainerName, e.ActorName)
|
||||
w.ContainerName = e.ActorName
|
||||
}
|
||||
if w.ContainerID != e.ActorID {
|
||||
w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID)
|
||||
w.Debug().Msgf("id changed %s -> %s", w.ContainerID, e.ActorID)
|
||||
w.ContainerID = e.ActorID
|
||||
// recreate event stream
|
||||
eventTask.Finish("recreate event stream")
|
||||
|
@ -263,9 +280,9 @@ func (w *Watcher) watchUntilDestroy() error {
|
|||
w.ticker.Stop()
|
||||
if w.ContainerRunning {
|
||||
if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err)
|
||||
w.Err(err).Msgf("container stop with method %q failed", w.StopMethod)
|
||||
} else {
|
||||
w.l.Info("container stopped by idle timeout")
|
||||
w.LogReason("container stopped", "idle timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,28 +4,26 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
func Inspect(dockerHost string, containerID string) (*Container, E.Error) {
|
||||
func Inspect(dockerHost string, containerID string) (*Container, error) {
|
||||
client, err := ConnectClient(dockerHost)
|
||||
defer client.Close()
|
||||
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("connect to docker", err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Inspect(containerID)
|
||||
}
|
||||
|
||||
func (c Client) Inspect(containerID string) (*Container, E.Error) {
|
||||
func (c Client) Inspect(containerID string) (*Container, error) {
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
|
||||
defer cancel()
|
||||
|
||||
json, err := c.ContainerInspect(ctx, containerID)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
return nil, err
|
||||
}
|
||||
return FromJSON(json, c.key), nil
|
||||
}
|
||||
|
|
|
@ -24,6 +24,11 @@ type (
|
|||
NestedLabelMap map[string]U.SerializedObject
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApplyToNil = E.New("label value is nil")
|
||||
ErrFieldNotExist = E.New("field does not exist")
|
||||
)
|
||||
|
||||
func (l *Label) String() string {
|
||||
if l.Attribute == "" {
|
||||
return l.Namespace + "." + l.Target
|
||||
|
@ -41,7 +46,7 @@ func (l *Label) String() string {
|
|||
// - error: an error if the field does not exist.
|
||||
func ApplyLabel[T any](obj *T, l *Label) E.Error {
|
||||
if obj == nil {
|
||||
return E.Invalid("nil object", l)
|
||||
return ErrApplyToNil.Subject(l.String())
|
||||
}
|
||||
switch nestedLabel := l.Value.(type) {
|
||||
case *Label:
|
||||
|
@ -54,7 +59,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
|
|||
}
|
||||
}
|
||||
if !field.IsValid() {
|
||||
return E.NotExist("field", l.Attribute)
|
||||
return ErrFieldNotExist.Subject(l.Attribute).Subject(l.String())
|
||||
}
|
||||
dst, ok := field.Interface().(NestedLabelMap)
|
||||
if !ok {
|
||||
|
@ -65,7 +70,11 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
|
|||
} else {
|
||||
field = field.Addr()
|
||||
}
|
||||
return U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
|
||||
err := U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
|
||||
if err != nil {
|
||||
return err.Subject(l.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if dst == nil {
|
||||
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
|
||||
|
@ -77,18 +86,22 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
|
|||
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
|
||||
return nil
|
||||
default:
|
||||
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
|
||||
err := U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
|
||||
if err != nil {
|
||||
return err.Subject(l.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func ParseLabel(label string, value string) (*Label, E.Error) {
|
||||
func ParseLabel(label string, value string) *Label {
|
||||
parts := strings.Split(label, ".")
|
||||
|
||||
if len(parts) < 2 {
|
||||
return &Label{
|
||||
Namespace: label,
|
||||
Value: value,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
l := &Label{
|
||||
|
@ -104,12 +117,9 @@ func ParseLabel(label string, value string) (*Label, E.Error) {
|
|||
l.Attribute = parts[2]
|
||||
default:
|
||||
l.Attribute = parts[2]
|
||||
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nestedLabel := ParseLabel(strings.Join(parts[3:], "."), value)
|
||||
l.Value = nestedLabel
|
||||
}
|
||||
|
||||
return l, nil
|
||||
return l
|
||||
}
|
||||
|
|
|
@ -20,9 +20,8 @@ func makeLabel(ns, name, attr string) string {
|
|||
|
||||
func TestNestedLabel(t *testing.T) {
|
||||
mAttr := "prop1"
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
sGot := ExpectType[*Label](t, pl.Value)
|
||||
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
sGot := ExpectType[*Label](t, lbl.Value)
|
||||
ExpectFalse(t, sGot == nil)
|
||||
ExpectEqual(t, sGot.Namespace, mName)
|
||||
ExpectEqual(t, sGot.Attribute, mAttr)
|
||||
|
@ -32,10 +31,9 @@ func TestApplyNestedLabel(t *testing.T) {
|
|||
entry := new(struct {
|
||||
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||
})
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
err := ApplyLabel(entry, lbl)
|
||||
ExpectNoError(t, err)
|
||||
middleware1, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
got := ExpectType[string](t, middleware1[mAttr])
|
||||
|
@ -52,10 +50,9 @@ func TestApplyNestedLabelExisting(t *testing.T) {
|
|||
entry.Middlewares[mName] = make(U.SerializedObject)
|
||||
entry.Middlewares[mName][checkAttr] = checkV
|
||||
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
err := ApplyLabel(entry, lbl)
|
||||
ExpectNoError(t, err)
|
||||
middleware1, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
got := ExpectType[string](t, middleware1[mAttr])
|
||||
|
@ -74,10 +71,9 @@ func TestApplyNestedLabelNoAttr(t *testing.T) {
|
|||
entry.Middlewares = make(NestedLabelMap)
|
||||
entry.Middlewares[mName] = make(U.SerializedObject)
|
||||
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
lbl := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
|
||||
err := ApplyLabel(entry, lbl)
|
||||
ExpectNoError(t, err)
|
||||
_, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
var listOptions = container.ListOptions{
|
||||
|
@ -23,19 +22,19 @@ var listOptions = container.ListOptions{
|
|||
All: true,
|
||||
}
|
||||
|
||||
func ListContainers(clientHost string) ([]types.Container, E.Error) {
|
||||
func ListContainers(clientHost string) ([]types.Container, error) {
|
||||
dockerClient, err := ConnectClient(clientHost)
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("connect to docker", err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer dockerClient.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout"))
|
||||
defer cancel()
|
||||
|
||||
containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions))
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("list containers", err)
|
||||
containers, err := dockerClient.ContainerList(ctx, listOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return containers, nil
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package docker
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
var logger = logrus.WithField("module", "docker")
|
||||
var logger = logging.With().Str("module", "docker").Logger()
|
||||
|
|
46
internal/error/base.go
Normal file
46
internal/error/base.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// baseError is an immutable wrapper around an error.
|
||||
type baseError struct {
|
||||
Err error `json:"err"`
|
||||
}
|
||||
|
||||
func (err *baseError) Unwrap() error {
|
||||
return err.Err
|
||||
}
|
||||
|
||||
func (err *baseError) Is(other error) bool {
|
||||
if other, ok := other.(*baseError); ok {
|
||||
return errors.Is(err.Err, other.Err)
|
||||
}
|
||||
return errors.Is(err.Err, other)
|
||||
}
|
||||
|
||||
func (err baseError) Subject(subject string) Error {
|
||||
err.Err = PrependSubject(subject, err.Err)
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err *baseError) Subjectf(format string, args ...any) Error {
|
||||
if len(args) > 0 {
|
||||
return err.Subject(fmt.Sprintf(format, args...))
|
||||
}
|
||||
return err.Subject(format)
|
||||
}
|
||||
|
||||
func (err baseError) With(extra error) Error {
|
||||
return &nestedError{&err, []error{extra}}
|
||||
}
|
||||
|
||||
func (err baseError) Withf(format string, args ...any) Error {
|
||||
return &nestedError{&err, []error{fmt.Errorf(format, args...)}}
|
||||
}
|
||||
|
||||
func (err *baseError) Error() string {
|
||||
return err.Err.Error()
|
||||
}
|
|
@ -1,104 +1,104 @@
|
|||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Builder struct {
|
||||
*builder
|
||||
}
|
||||
|
||||
type builder struct {
|
||||
message string
|
||||
errors []Error
|
||||
about string
|
||||
errs []error
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewBuilder(format string, args ...any) Builder {
|
||||
if len(args) > 0 {
|
||||
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
|
||||
func NewBuilder(about string) *Builder {
|
||||
return &Builder{about: about}
|
||||
}
|
||||
|
||||
func (b *Builder) About() string {
|
||||
if !b.HasError() {
|
||||
return ""
|
||||
}
|
||||
return Builder{&builder{message: format}}
|
||||
return b.about
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func (b *Builder) HasError() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
func (b *Builder) Error() Error {
|
||||
if !b.HasError() {
|
||||
return nil
|
||||
}
|
||||
if len(b.errs) == 1 {
|
||||
return From(b.errs[0])
|
||||
}
|
||||
return &nestedError{Err: New(b.about), Extras: b.errs}
|
||||
}
|
||||
|
||||
func (b *Builder) String() string {
|
||||
if !b.HasError() {
|
||||
return ""
|
||||
}
|
||||
return (&nestedError{Err: New(b.about), Extras: b.errs}).Error()
|
||||
}
|
||||
|
||||
// Add adds an error to the Builder.
|
||||
//
|
||||
// adding nil is no-op,
|
||||
//
|
||||
// flatten is a boolean flag to flatten the NestedError.
|
||||
func (b Builder) Add(err Error, flatten ...bool) {
|
||||
if err != nil {
|
||||
b.Lock()
|
||||
if len(flatten) > 0 && flatten[0] {
|
||||
for _, e := range err.extras {
|
||||
b.errors = append(b.errors, &e)
|
||||
}
|
||||
func (b *Builder) Add(err error) *Builder {
|
||||
if err == nil {
|
||||
return b
|
||||
}
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
switch err := err.(type) {
|
||||
case *baseError:
|
||||
b.errs = append(b.errs, err.Err)
|
||||
case *nestedError:
|
||||
if err.Err == nil {
|
||||
b.errs = append(b.errs, err.Extras...)
|
||||
} else {
|
||||
b.errors = append(b.errors, err)
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
b.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (b Builder) AddE(err error) {
|
||||
b.Add(From(err))
|
||||
}
|
||||
|
||||
func (b Builder) Addf(format string, args ...any) {
|
||||
if len(args) > 0 {
|
||||
b.Add(errorf(format, args...))
|
||||
} else {
|
||||
b.AddE(errors.New(format))
|
||||
}
|
||||
}
|
||||
|
||||
func (b Builder) AddRange(errs ...Error) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
for _, err := range errs {
|
||||
b.errors = append(b.errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b Builder) AddRangeE(errs ...error) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
for _, err := range errs {
|
||||
b.errors = append(b.errors, From(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Build builds a NestedError based on the errors collected in the Builder.
|
||||
//
|
||||
// If there are no errors in the Builder, it returns a Nil() NestedError.
|
||||
// Otherwise, it returns a NestedError with the message and the errors collected.
|
||||
//
|
||||
// Returns:
|
||||
// - NestedError: the built NestedError.
|
||||
func (b Builder) Build() Error {
|
||||
if len(b.errors) == 0 {
|
||||
return nil
|
||||
}
|
||||
return Join(b.message, b.errors...)
|
||||
}
|
||||
|
||||
func (b Builder) To(ptr *Error) {
|
||||
switch {
|
||||
case ptr == nil:
|
||||
return
|
||||
case *ptr == nil:
|
||||
*ptr = b.Build()
|
||||
default:
|
||||
(*ptr).extras = append((*ptr).extras, *b.Build())
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b Builder) String() string {
|
||||
return b.Build().String()
|
||||
func (b *Builder) Adds(err string) *Builder {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
b.errs = append(b.errs, newError(err))
|
||||
return b
|
||||
}
|
||||
|
||||
func (b Builder) HasError() bool {
|
||||
return len(b.errors) > 0
|
||||
func (b *Builder) Addf(format string, args ...any) *Builder {
|
||||
if len(args) > 0 {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
b.errs = append(b.errs, fmt.Errorf(format, args...))
|
||||
} else {
|
||||
b.Adds(format)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) AddRange(errs ...error) *Builder {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package error_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/error"
|
||||
|
@ -8,14 +11,13 @@ import (
|
|||
)
|
||||
|
||||
func TestBuilderEmpty(t *testing.T) {
|
||||
eb := NewBuilder("qwer")
|
||||
ExpectTrue(t, eb.Build() == nil)
|
||||
ExpectTrue(t, eb.Build().NoError())
|
||||
eb := NewBuilder("foo")
|
||||
ExpectTrue(t, errors.Is(eb.Error(), nil))
|
||||
ExpectFalse(t, eb.HasError())
|
||||
}
|
||||
|
||||
func TestBuilderAddNil(t *testing.T) {
|
||||
eb := NewBuilder("asdf")
|
||||
eb := NewBuilder("foo")
|
||||
var err Error
|
||||
for range 3 {
|
||||
eb.Add(nil)
|
||||
|
@ -23,41 +25,31 @@ func TestBuilderAddNil(t *testing.T) {
|
|||
for range 3 {
|
||||
eb.Add(err)
|
||||
}
|
||||
ExpectTrue(t, eb.Build() == nil)
|
||||
ExpectTrue(t, eb.Build().NoError())
|
||||
eb.AddRange(nil, nil, err)
|
||||
ExpectFalse(t, eb.HasError())
|
||||
ExpectTrue(t, eb.Error() == nil)
|
||||
}
|
||||
|
||||
func TestBuilderIs(t *testing.T) {
|
||||
eb := NewBuilder("foo")
|
||||
eb.Add(context.Canceled)
|
||||
eb.Add(io.ErrShortBuffer)
|
||||
ExpectTrue(t, eb.HasError())
|
||||
ExpectError(t, io.ErrShortBuffer, eb.Error())
|
||||
ExpectError(t, context.Canceled, eb.Error())
|
||||
}
|
||||
|
||||
func TestBuilderNested(t *testing.T) {
|
||||
eb := NewBuilder("error occurred")
|
||||
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2")))
|
||||
eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
|
||||
|
||||
got := eb.Build().String()
|
||||
expected1 := (`error occurred:
|
||||
- Action 1 failed:
|
||||
- invalid Inner: 1
|
||||
- invalid Inner: 2
|
||||
- Action 2 failed:
|
||||
- invalid Inner: 3`)
|
||||
expected2 := (`error occurred:
|
||||
- Action 1 failed:
|
||||
- invalid Inner: "1"
|
||||
- invalid Inner: "2"
|
||||
- Action 2 failed:
|
||||
- invalid Inner: "3"`)
|
||||
ExpectEqualAny(t, got, []string{expected1, expected2})
|
||||
}
|
||||
|
||||
func TestBuilderTo(t *testing.T) {
|
||||
eb := NewBuilder("error occurred")
|
||||
eb.Addf("abcd")
|
||||
|
||||
var err Error
|
||||
eb.To(&err)
|
||||
got := err.String()
|
||||
expected := (`error occurred:
|
||||
- abcd`)
|
||||
eb := NewBuilder("action failed")
|
||||
eb.Add(New("Action 1").Withf("Inner: 1").Withf("Inner: 2"))
|
||||
eb.Add(New("Action 2").Withf("Inner: 3"))
|
||||
|
||||
got := eb.String()
|
||||
expected := `action failed
|
||||
• Action 1
|
||||
• Inner: 1
|
||||
• Inner: 2
|
||||
• Action 2
|
||||
• Inner: 3`
|
||||
ExpectEqual(t, got, expected)
|
||||
}
|
||||
|
|
|
@ -1,317 +1,31 @@
|
|||
package error
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
type Error interface {
|
||||
error
|
||||
|
||||
type (
|
||||
Error = *ErrorImpl
|
||||
ErrorImpl struct {
|
||||
subject string
|
||||
err error
|
||||
extras []ErrorImpl
|
||||
}
|
||||
ErrorJSONMarshaller struct {
|
||||
Subject string `json:"subject"`
|
||||
Err string `json:"error"`
|
||||
Extras []ErrorJSONMarshaller `json:"extras,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
func From(err error) Error {
|
||||
if IsNil(err) {
|
||||
return nil
|
||||
}
|
||||
return &ErrorImpl{err: err}
|
||||
// Is is a wrapper for errors.Is when there is no sub-error.
|
||||
//
|
||||
// When there are sub-errors, they will also be checked.
|
||||
Is(other error) bool
|
||||
// With appends a sub-error to the error.
|
||||
With(extra error) Error
|
||||
// Withf is a wrapper for With(fmt.Errorf(format, args...)).
|
||||
Withf(format string, args ...any) Error
|
||||
// Subject prepends the given subject with a colon and space to the error message.
|
||||
//
|
||||
// If there is already a subject in the error message, the subject will be
|
||||
// prepended to the existing subject with " > ".
|
||||
//
|
||||
// Subject empty string is ignored.
|
||||
Subject(subject string) Error
|
||||
// Subjectf is a wrapper for Subject(fmt.Sprintf(format, args...)).
|
||||
Subjectf(format string, args ...any) Error
|
||||
}
|
||||
|
||||
func FromJSON(data []byte) (Error, bool) {
|
||||
var j ErrorJSONMarshaller
|
||||
if err := json.Unmarshal(data, &j); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if j.Err == "" {
|
||||
return nil, false
|
||||
}
|
||||
extras := make([]ErrorImpl, len(j.Extras))
|
||||
for i, e := range j.Extras {
|
||||
extra, ok := fromJSONObject(e)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
extras[i] = *extra
|
||||
}
|
||||
return &ErrorImpl{
|
||||
subject: j.Subject,
|
||||
err: errors.New(j.Err),
|
||||
extras: extras,
|
||||
}, true
|
||||
}
|
||||
|
||||
func TryUnwrap(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if unwrapped := errors.Unwrap(err); unwrapped != nil {
|
||||
return unwrapped
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Check is a helper function that
|
||||
// convert (T, error) to (T, NestedError).
|
||||
func Check[T any](obj T, err error) (T, Error) {
|
||||
return obj, From(err)
|
||||
}
|
||||
|
||||
func Join(message string, err ...Error) Error {
|
||||
extras := make([]ErrorImpl, len(err))
|
||||
nErr := 0
|
||||
for i, e := range err {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
extras[i] = *e
|
||||
nErr++
|
||||
}
|
||||
if nErr == 0 {
|
||||
return nil
|
||||
}
|
||||
return &ErrorImpl{
|
||||
err: errors.New(message),
|
||||
extras: extras,
|
||||
}
|
||||
}
|
||||
|
||||
func JoinE(message string, err ...error) Error {
|
||||
b := NewBuilder("%s", message)
|
||||
for _, e := range err {
|
||||
b.AddE(e)
|
||||
}
|
||||
return b.Build()
|
||||
}
|
||||
|
||||
func IsNil(err error) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func IsNotNil(err error) bool {
|
||||
return err != nil
|
||||
}
|
||||
|
||||
func (ne Error) String() string {
|
||||
var buf strings.Builder
|
||||
ne.writeToSB(&buf, 0, "")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (ne Error) Is(err error) bool {
|
||||
if ne == nil {
|
||||
return err == nil
|
||||
}
|
||||
// return errors.Is(ne.err, err)
|
||||
if errors.Is(ne.err, err) {
|
||||
return true
|
||||
}
|
||||
for _, e := range ne.extras {
|
||||
if e.Is(err) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ne Error) IsNot(err error) bool {
|
||||
return !ne.Is(err)
|
||||
}
|
||||
|
||||
func (ne Error) Error() error {
|
||||
if ne == nil {
|
||||
return nil
|
||||
}
|
||||
return ne.buildError(0, "")
|
||||
}
|
||||
|
||||
func (ne Error) With(s any) Error {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
var msg string
|
||||
switch ss := s.(type) {
|
||||
case nil:
|
||||
return ne
|
||||
case *ErrorImpl:
|
||||
if len(ss.extras) == 1 {
|
||||
ne.extras = append(ne.extras, ss.extras[0])
|
||||
return ne
|
||||
}
|
||||
return ne.withError(ss)
|
||||
case error:
|
||||
// unwrap only once
|
||||
return ne.withError(From(TryUnwrap(ss)))
|
||||
case string:
|
||||
msg = ss
|
||||
case fmt.Stringer:
|
||||
return ne.appendMsg(ss.String())
|
||||
default:
|
||||
return ne.appendMsg(fmt.Sprint(s))
|
||||
}
|
||||
return ne.withError(From(errors.New(msg)))
|
||||
}
|
||||
|
||||
func (ne Error) Extraf(format string, args ...any) Error {
|
||||
return ne.With(errorf(format, args...))
|
||||
}
|
||||
|
||||
func (ne Error) Subject(s any, sep ...string) Error {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
var subject string
|
||||
switch ss := s.(type) {
|
||||
case string:
|
||||
subject = ss
|
||||
case fmt.Stringer:
|
||||
subject = ss.String()
|
||||
default:
|
||||
subject = fmt.Sprint(s)
|
||||
}
|
||||
switch {
|
||||
case ne.subject == "":
|
||||
ne.subject = subject
|
||||
case len(sep) > 0:
|
||||
ne.subject = fmt.Sprintf("%s%s%s", subject, sep[0], ne.subject)
|
||||
default:
|
||||
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
|
||||
}
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne Error) Subjectf(format string, args ...any) Error {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
return ne.Subject(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (ne Error) JSONObject() ErrorJSONMarshaller {
|
||||
extras := make([]ErrorJSONMarshaller, len(ne.extras))
|
||||
for i, e := range ne.extras {
|
||||
extras[i] = e.JSONObject()
|
||||
}
|
||||
return ErrorJSONMarshaller{
|
||||
Subject: ne.subject,
|
||||
Err: ne.err.Error(),
|
||||
Extras: extras,
|
||||
}
|
||||
}
|
||||
|
||||
func (ne Error) JSON() []byte {
|
||||
b, err := json.MarshalIndent(ne.JSONObject(), "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (ne Error) NoError() bool {
|
||||
return ne == nil
|
||||
}
|
||||
|
||||
func (ne Error) HasError() bool {
|
||||
return ne != nil
|
||||
}
|
||||
|
||||
func errorf(format string, args ...any) Error {
|
||||
for i, arg := range args {
|
||||
if err, ok := arg.(error); ok {
|
||||
if unwrapped := errors.Unwrap(err); unwrapped != nil {
|
||||
args[i] = unwrapped
|
||||
}
|
||||
}
|
||||
}
|
||||
return From(fmt.Errorf(format, args...))
|
||||
}
|
||||
|
||||
func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return FromJSON(data)
|
||||
}
|
||||
|
||||
func (ne Error) withError(err Error) Error {
|
||||
if ne != nil && err != nil {
|
||||
ne.extras = append(ne.extras, *err)
|
||||
}
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne Error) appendMsg(msg string) Error {
|
||||
if ne == nil {
|
||||
return nil
|
||||
}
|
||||
ne.err = fmt.Errorf("%w %s", ne.err, msg)
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) {
|
||||
for range level {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
sb.WriteString(prefix)
|
||||
|
||||
if ne.NoError() {
|
||||
sb.WriteString("nil")
|
||||
return
|
||||
}
|
||||
|
||||
if ne.subject != "" {
|
||||
sb.WriteString(ne.subject)
|
||||
sb.WriteRune(' ')
|
||||
}
|
||||
sb.WriteString(ne.err.Error())
|
||||
if len(ne.extras) > 0 {
|
||||
sb.WriteRune(':')
|
||||
for _, extra := range ne.extras {
|
||||
sb.WriteRune('\n')
|
||||
extra.writeToSB(sb, level+1, "- ")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ne Error) buildError(level int, prefix string) error {
|
||||
var res error
|
||||
var sb strings.Builder
|
||||
|
||||
for range level {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
sb.WriteString(prefix)
|
||||
|
||||
if ne.NoError() {
|
||||
sb.WriteString("nil")
|
||||
return errors.New(sb.String())
|
||||
}
|
||||
|
||||
res = fmt.Errorf("%s%w", sb.String(), ne.err)
|
||||
sb.Reset()
|
||||
|
||||
if ne.subject != "" {
|
||||
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
|
||||
}
|
||||
if len(ne.extras) > 0 {
|
||||
sb.WriteRune(':')
|
||||
res = fmt.Errorf("%w%s", res, sb.String())
|
||||
for _, extra := range ne.extras {
|
||||
res = errors.Join(res, extra.buildError(level+1, "- "))
|
||||
}
|
||||
} else {
|
||||
res = fmt.Errorf("%w%s", res, sb.String())
|
||||
}
|
||||
return res
|
||||
// this makes JSON marshalling work,
|
||||
// as the builtin one doesn't.
|
||||
type errStr string
|
||||
|
||||
func (err errStr) Error() string {
|
||||
return string(err)
|
||||
}
|
||||
|
|
|
@ -1,107 +1,157 @@
|
|||
package error_test
|
||||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestBaseString(t *testing.T) {
|
||||
ExpectEqual(t, New("error").Error(), "error")
|
||||
}
|
||||
|
||||
func TestBaseWithSubject(t *testing.T) {
|
||||
err := New("error")
|
||||
withSubject := err.Subject("foo")
|
||||
withSubjectf := err.Subjectf("%s %s", "foo", "bar")
|
||||
|
||||
ExpectError(t, err, withSubject)
|
||||
ExpectStrEqual(t, withSubject.Error(), "foo: error")
|
||||
ExpectTrue(t, withSubject.Is(err))
|
||||
|
||||
ExpectError(t, err, withSubjectf)
|
||||
ExpectStrEqual(t, withSubjectf.Error(), "foo bar: error")
|
||||
ExpectTrue(t, withSubjectf.Is(err))
|
||||
}
|
||||
|
||||
func TestBaseWithExtra(t *testing.T) {
|
||||
err := New("error")
|
||||
extra := New("bar").Subject("baz")
|
||||
withExtra := err.With(extra)
|
||||
|
||||
ExpectTrue(t, withExtra.Is(extra))
|
||||
ExpectTrue(t, withExtra.Is(err))
|
||||
|
||||
ExpectTrue(t, errors.Is(withExtra, extra))
|
||||
ExpectTrue(t, errors.Is(withExtra, err))
|
||||
|
||||
ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error()))
|
||||
ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error()))
|
||||
ExpectTrue(t, strings.Contains(withExtra.Error(), "baz"))
|
||||
}
|
||||
|
||||
func TestBaseUnwrap(t *testing.T) {
|
||||
err := errors.New("err")
|
||||
wrapped := From(err)
|
||||
|
||||
ExpectError(t, err, errors.Unwrap(wrapped))
|
||||
}
|
||||
|
||||
func TestNestedUnwrap(t *testing.T) {
|
||||
err := errors.New("err")
|
||||
err2 := New("err2")
|
||||
wrapped := From(err).Subject("foo").With(err2.Subject("bar"))
|
||||
|
||||
unwrapper, ok := wrapped.(interface{ Unwrap() []error })
|
||||
ExpectTrue(t, ok)
|
||||
|
||||
ExpectError(t, err, wrapped)
|
||||
ExpectError(t, err2, wrapped)
|
||||
ExpectEqual(t, len(unwrapper.Unwrap()), 2)
|
||||
}
|
||||
|
||||
func TestErrorIs(t *testing.T) {
|
||||
ExpectTrue(t, Failure("foo").Is(ErrFailure))
|
||||
ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure))
|
||||
ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid))
|
||||
ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid))
|
||||
from := errors.New("error")
|
||||
err := From(from)
|
||||
ExpectError(t, from, err)
|
||||
|
||||
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid))
|
||||
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure))
|
||||
ExpectTrue(t, err.Is(from))
|
||||
ExpectFalse(t, err.Is(New("error")))
|
||||
|
||||
ExpectFalse(t, Invalid("foo", "bar").Is(nil))
|
||||
|
||||
ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure))
|
||||
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid))
|
||||
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure))
|
||||
ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists))
|
||||
ExpectTrue(t, errors.Is(err.Subject("foo"), from))
|
||||
ExpectTrue(t, errors.Is(err.Withf("foo"), from))
|
||||
ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from))
|
||||
}
|
||||
|
||||
func TestErrorNestedIs(t *testing.T) {
|
||||
var err Error
|
||||
ExpectTrue(t, err.Is(nil))
|
||||
func TestErrorImmutability(t *testing.T) {
|
||||
err := New("err")
|
||||
err2 := New("err2")
|
||||
|
||||
err = Failure("some reason")
|
||||
ExpectTrue(t, err.Is(ErrFailure))
|
||||
ExpectFalse(t, err.Is(ErrDuplicated))
|
||||
for range 3 {
|
||||
// t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err)
|
||||
err.Subject("foo")
|
||||
ExpectFalse(t, strings.Contains(err.Error(), "foo"))
|
||||
|
||||
err.With(Duplicated("something", ""))
|
||||
ExpectTrue(t, err.Is(ErrFailure))
|
||||
ExpectTrue(t, err.Is(ErrDuplicated))
|
||||
ExpectFalse(t, err.Is(ErrInvalid))
|
||||
}
|
||||
err.With(err2)
|
||||
ExpectFalse(t, strings.Contains(err.Error(), "extra"))
|
||||
ExpectFalse(t, err.Is(err2))
|
||||
|
||||
func TestIsNil(t *testing.T) {
|
||||
var err Error
|
||||
ExpectTrue(t, err.Is(nil))
|
||||
ExpectTrue(t, err == nil)
|
||||
ExpectTrue(t, err.NoError())
|
||||
|
||||
eb := NewBuilder("")
|
||||
returnNil := func() error {
|
||||
return eb.Build().Error()
|
||||
err = err.Subject("bar").Withf("baz")
|
||||
ExpectTrue(t, err != nil)
|
||||
}
|
||||
ExpectTrue(t, IsNil(returnNil()))
|
||||
ExpectTrue(t, returnNil() == nil)
|
||||
|
||||
ExpectTrue(t, (err.
|
||||
Subject("any").
|
||||
With("something").
|
||||
Extraf("foo %s", "bar")) == nil)
|
||||
}
|
||||
|
||||
func TestErrorSimple(t *testing.T) {
|
||||
ne := Failure("foo bar")
|
||||
ExpectEqual(t, ne.String(), "foo bar failed")
|
||||
ne = ne.Subject("baz")
|
||||
ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
|
||||
}
|
||||
|
||||
func TestErrorWith(t *testing.T) {
|
||||
ne := Failure("foo").With("bar").With("baz")
|
||||
ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz")
|
||||
err1 := New("err1")
|
||||
err2 := New("err2")
|
||||
|
||||
err3 := err1.With(err2)
|
||||
|
||||
ExpectTrue(t, err3.Is(err1))
|
||||
ExpectTrue(t, err3.Is(err2))
|
||||
|
||||
err2.Subject("foo")
|
||||
|
||||
ExpectTrue(t, err3.Is(err1))
|
||||
ExpectTrue(t, err3.Is(err2))
|
||||
|
||||
// check if err3 is affected by err2.Subject
|
||||
ExpectFalse(t, strings.Contains(err3.Error(), "foo"))
|
||||
}
|
||||
|
||||
func TestErrorNested(t *testing.T) {
|
||||
inner := Failure("inner").
|
||||
With("1").
|
||||
With("1")
|
||||
inner2 := Failure("inner2").
|
||||
func TestErrorStringSimple(t *testing.T) {
|
||||
errFailure := New("generic failure")
|
||||
ne := errFailure.Subject("foo bar")
|
||||
ExpectStrEqual(t, ne.Error(), "foo bar: generic failure")
|
||||
ne = ne.Subject("baz")
|
||||
ExpectStrEqual(t, ne.Error(), "baz > foo bar: generic failure")
|
||||
}
|
||||
|
||||
func TestErrorStringNested(t *testing.T) {
|
||||
errFailure := New("generic failure")
|
||||
inner := errFailure.Subject("inner").
|
||||
Withf("1").
|
||||
Withf("1")
|
||||
inner2 := errFailure.Subject("inner2").
|
||||
Subject("action 2").
|
||||
With("2").
|
||||
With("2")
|
||||
inner3 := Failure("inner3").
|
||||
Withf("2").
|
||||
Withf("2")
|
||||
inner3 := errFailure.Subject("inner3").
|
||||
Subject("action 3").
|
||||
With("3").
|
||||
With("3")
|
||||
ne := Failure("foo").
|
||||
With("bar").
|
||||
With("baz").
|
||||
Withf("3").
|
||||
Withf("3")
|
||||
ne := errFailure.
|
||||
Subject("foo").
|
||||
Withf("bar").
|
||||
Withf("baz").
|
||||
With(inner).
|
||||
With(inner.With(inner2.With(inner3)))
|
||||
want := `foo failed:
|
||||
- bar
|
||||
- baz
|
||||
- inner failed:
|
||||
- 1
|
||||
- 1
|
||||
- inner failed:
|
||||
- 1
|
||||
- 1
|
||||
- inner2 failed for "action 2":
|
||||
- 2
|
||||
- 2
|
||||
- inner3 failed for "action 3":
|
||||
- 3
|
||||
- 3`
|
||||
ExpectEqual(t, ne.String(), want)
|
||||
ExpectEqual(t, ne.Error().Error(), want)
|
||||
want := `foo: generic failure
|
||||
• bar
|
||||
• baz
|
||||
• inner: generic failure
|
||||
• 1
|
||||
• 1
|
||||
• inner: generic failure
|
||||
• 1
|
||||
• 1
|
||||
• action 2 > inner2: generic failure
|
||||
• 2
|
||||
• 2
|
||||
• action 3 > inner3: generic failure
|
||||
• 3
|
||||
• 3`
|
||||
ExpectStrEqual(t, ne.Error(), want)
|
||||
}
|
||||
|
|
|
@ -1,83 +0,0 @@
|
|||
package error
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFailure = stderrors.New("failed")
|
||||
ErrInvalid = stderrors.New("invalid")
|
||||
ErrUnsupported = stderrors.New("unsupported")
|
||||
ErrUnexpected = stderrors.New("unexpected")
|
||||
ErrNotExists = stderrors.New("does not exist")
|
||||
ErrMissing = stderrors.New("missing")
|
||||
ErrDuplicated = stderrors.New("duplicated")
|
||||
ErrOutOfRange = stderrors.New("out of range")
|
||||
ErrTypeError = stderrors.New("type error")
|
||||
ErrTypeMismatch = stderrors.New("type mismatch")
|
||||
ErrPanicRecv = stderrors.New("panic recovered from")
|
||||
)
|
||||
|
||||
const fmtSubjectWhat = "%w %v: %q"
|
||||
|
||||
func Failure(what string) Error {
|
||||
return errorf("%s %w", what, ErrFailure)
|
||||
}
|
||||
|
||||
func FailedWhy(what string, why string) Error {
|
||||
return Failure(what).With(why)
|
||||
}
|
||||
|
||||
func FailWith(what string, err any) Error {
|
||||
return Failure(what).With(err)
|
||||
}
|
||||
|
||||
func Invalid(subject, what any) Error {
|
||||
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
|
||||
}
|
||||
|
||||
func Unsupported(subject, what any) Error {
|
||||
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
|
||||
}
|
||||
|
||||
func Unexpected(subject, what any) Error {
|
||||
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
|
||||
}
|
||||
|
||||
func UnexpectedError(err error) Error {
|
||||
return errorf("%w error: %w", ErrUnexpected, err)
|
||||
}
|
||||
|
||||
func NotExist(subject, what any) Error {
|
||||
return errorf("%v %w: %v", subject, ErrNotExists, what)
|
||||
}
|
||||
|
||||
func Missing(subject any) Error {
|
||||
return errorf("%w %v", ErrMissing, subject)
|
||||
}
|
||||
|
||||
func Duplicated(subject, what any) Error {
|
||||
return errorf("%w %v: %v", ErrDuplicated, subject, what)
|
||||
}
|
||||
|
||||
func OutOfRange(subject any, value any) Error {
|
||||
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
|
||||
}
|
||||
|
||||
func TypeError(subject any, from, to reflect.Type) Error {
|
||||
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
|
||||
}
|
||||
|
||||
func TypeError2(subject any, from, to reflect.Value) Error {
|
||||
return TypeError(subject, from.Type(), to.Type())
|
||||
}
|
||||
|
||||
func TypeMismatch[Expect any](value any) Error {
|
||||
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
|
||||
}
|
||||
|
||||
func PanicRecv(format string, args ...any) Error {
|
||||
return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...))
|
||||
}
|
43
internal/error/log.go
Normal file
43
internal/error/log.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package error
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func getLogger(logger ...*zerolog.Logger) *zerolog.Logger {
|
||||
if len(logger) > 0 {
|
||||
return logger[0]
|
||||
}
|
||||
return logging.GetLogger()
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogFatal(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Fatal().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogError(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Error().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogWarn(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Warn().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogPanic(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Panic().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogInfo(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Info().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogDebug(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Debug().Msg(err.Error())
|
||||
}
|
120
internal/error/nested_error.go
Normal file
120
internal/error/nested_error.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type nestedError struct {
|
||||
Err error `json:"err"`
|
||||
Extras []error `json:"extras"`
|
||||
}
|
||||
|
||||
func (err nestedError) Subject(subject string) Error {
|
||||
if err.Err == nil {
|
||||
err.Err = newError(subject)
|
||||
} else {
|
||||
err.Err = PrependSubject(subject, err.Err)
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err *nestedError) Subjectf(format string, args ...any) Error {
|
||||
if len(args) > 0 {
|
||||
return err.Subject(fmt.Sprintf(format, args...))
|
||||
}
|
||||
return err.Subject(format)
|
||||
}
|
||||
|
||||
func (err nestedError) With(extra error) Error {
|
||||
if extra != nil {
|
||||
err.Extras = append(err.Extras, extra)
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err nestedError) Withf(format string, args ...any) Error {
|
||||
if len(args) > 0 {
|
||||
err.Extras = append(err.Extras, fmt.Errorf(format, args...))
|
||||
} else {
|
||||
err.Extras = append(err.Extras, newError(format))
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err *nestedError) Unwrap() []error {
|
||||
if err.Err == nil {
|
||||
if len(err.Extras) == 0 {
|
||||
return nil
|
||||
}
|
||||
return err.Extras
|
||||
}
|
||||
return append([]error{err.Err}, err.Extras...)
|
||||
}
|
||||
|
||||
func (err *nestedError) Is(other error) bool {
|
||||
if errors.Is(err.Err, other) {
|
||||
return true
|
||||
}
|
||||
for _, e := range err.Extras {
|
||||
if errors.Is(e, other) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (err *nestedError) Error() string {
|
||||
return buildError(err, 0)
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func makeLine(err string, level int) string {
|
||||
const bulletPrefix = "• "
|
||||
const spaces = " "
|
||||
|
||||
if level == 0 {
|
||||
return err
|
||||
}
|
||||
return spaces[:2*level] + bulletPrefix + err
|
||||
}
|
||||
|
||||
func makeLines(errs []error, level int) []string {
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
lines := make([]string, 0, len(errs))
|
||||
for _, err := range errs {
|
||||
switch err := err.(type) {
|
||||
case *nestedError:
|
||||
if err.Err != nil {
|
||||
lines = append(lines, makeLine(err.Err.Error(), level))
|
||||
}
|
||||
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
|
||||
lines = append(lines, extras...)
|
||||
}
|
||||
default:
|
||||
lines = append(lines, makeLine(err.Error(), level))
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func buildError(err error, level int) string {
|
||||
switch err := err.(type) {
|
||||
case nil:
|
||||
return makeLine("<nil>", level)
|
||||
case *nestedError:
|
||||
lines := make([]string, 0, 1+len(err.Extras))
|
||||
if err.Err != nil {
|
||||
lines = append(lines, makeLine(err.Err.Error(), level))
|
||||
}
|
||||
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
|
||||
lines = append(lines, extras...)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
default:
|
||||
return makeLine(err.Error(), level)
|
||||
}
|
||||
}
|
50
internal/error/subject.go
Normal file
50
internal/error/subject.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package error
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
|
||||
)
|
||||
|
||||
type withSubject struct {
|
||||
Subject string `json:"subject"`
|
||||
Err error `json:"err"`
|
||||
}
|
||||
|
||||
const subjectSep = " > "
|
||||
|
||||
func highlight(subject string) string {
|
||||
return ansi.HighlightRed + subject + ansi.Reset
|
||||
}
|
||||
|
||||
func PrependSubject(subject string, err error) *withSubject {
|
||||
switch err := err.(type) {
|
||||
case *withSubject:
|
||||
return err.Prepend(subject)
|
||||
case *baseError:
|
||||
return PrependSubject(subject, err.Err)
|
||||
default:
|
||||
return &withSubject{subject, err}
|
||||
}
|
||||
}
|
||||
|
||||
func (err withSubject) Prepend(subject string) *withSubject {
|
||||
if subject != "" {
|
||||
err.Subject = subject + subjectSep + err.Subject
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err *withSubject) Is(other error) bool {
|
||||
return err.Err == other
|
||||
}
|
||||
|
||||
func (err *withSubject) Unwrap() error {
|
||||
return err.Err
|
||||
}
|
||||
|
||||
func (err *withSubject) Error() string {
|
||||
subjects := strings.Split(err.Subject, subjectSep)
|
||||
subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1])
|
||||
return strings.Join(subjects, subjectSep) + ": " + err.Err.Error()
|
||||
}
|
71
internal/error/utils.go
Normal file
71
internal/error/utils.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var ErrInvalidErrorJson = errors.New("invalid error json")
|
||||
|
||||
func newError(message string) error {
|
||||
return errStr(message)
|
||||
}
|
||||
|
||||
func New(message string) Error {
|
||||
if message == "" {
|
||||
return nil
|
||||
}
|
||||
return &baseError{newError(message)}
|
||||
}
|
||||
|
||||
func Errorf(format string, args ...any) Error {
|
||||
return &baseError{fmt.Errorf(format, args...)}
|
||||
}
|
||||
|
||||
func From(err error) Error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if err, ok := err.(Error); ok {
|
||||
return err
|
||||
}
|
||||
return &baseError{err}
|
||||
}
|
||||
|
||||
func Must[T any](v T, err error) T {
|
||||
if err != nil {
|
||||
LogPanic("must failed", err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func Join(errors ...error) Error {
|
||||
n := 0
|
||||
for _, err := range errors {
|
||||
if err != nil {
|
||||
n++
|
||||
}
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
errs := make([]error, 0, n)
|
||||
for _, err := range errors {
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return &nestedError{Extras: errs}
|
||||
}
|
||||
|
||||
func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn Func, arg Arg) T {
|
||||
result, err := fn(arg)
|
||||
eb.Add(err)
|
||||
return result
|
||||
}
|
||||
|
||||
func Collect2[T any, Err error, Arg1 any, Arg2 any, Func func(Arg1, Arg2) (T, Err)](eb *Builder, fn Func, arg1 Arg1, arg2 Arg2) T {
|
||||
result, err := fn(arg1, arg2)
|
||||
eb.Add(err)
|
||||
return result
|
||||
}
|
|
@ -53,7 +53,7 @@ func ListAvailableIcons() ([]string, error) {
|
|||
icons = append(icons, content.Path)
|
||||
}
|
||||
}
|
||||
err = utils.SaveJSON(iconsCachePath, &icons, 0o644).Error()
|
||||
err = utils.SaveJSON(iconsCachePath, &icons, 0o644)
|
||||
if err != nil {
|
||||
log.Print("error saving cache", err)
|
||||
}
|
||||
|
|
69
internal/logging/logging.go
Normal file
69
internal/logging/logging.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
var logger zerolog.Logger
|
||||
|
||||
func init() {
|
||||
var timeFmt string
|
||||
var level zerolog.Level
|
||||
var exclude []string
|
||||
|
||||
if common.IsTrace {
|
||||
timeFmt = "04:05"
|
||||
level = zerolog.TraceLevel
|
||||
} else if common.IsDebug {
|
||||
timeFmt = "01-02 15:04"
|
||||
level = zerolog.DebugLevel
|
||||
} else {
|
||||
timeFmt = "01-02 15:04"
|
||||
level = zerolog.InfoLevel
|
||||
exclude = []string{"module"}
|
||||
}
|
||||
|
||||
prefixLength := len(timeFmt) + 5 // level takes 3 + 2 spaces
|
||||
prefix := strings.Repeat(" ", prefixLength)
|
||||
|
||||
logger = zerolog.New(
|
||||
zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: timeFmt,
|
||||
FieldsExclude: exclude,
|
||||
FormatMessage: func(msgI interface{}) string { // pad spaces for each line
|
||||
msg := msgI.(string)
|
||||
lines := strings.Split(msg, "\n")
|
||||
if len(lines) == 1 {
|
||||
return msg
|
||||
}
|
||||
for i := 1; i < len(lines); i++ {
|
||||
lines[i] = prefix + lines[i]
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
},
|
||||
},
|
||||
).Level(level).With().Timestamp().Logger()
|
||||
}
|
||||
|
||||
func DiscardLogger() { logger = zerolog.Nop() }
|
||||
|
||||
func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
|
||||
|
||||
func GetLogger() *zerolog.Logger { return &logger }
|
||||
func With() zerolog.Context { return logger.With() }
|
||||
|
||||
func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
|
||||
|
||||
func Info() *zerolog.Event { return logger.Info() }
|
||||
func Warn() *zerolog.Event { return logger.Warn() }
|
||||
func Error() *zerolog.Event { return logger.Error() }
|
||||
func Err(err error) *zerolog.Event { return logger.Err(err) }
|
||||
func Debug() *zerolog.Event { return logger.Debug() }
|
||||
func Fatal() *zerolog.Event { return logger.Fatal() }
|
||||
func Panic() *zerolog.Event { return logger.Panic() }
|
||||
func Trace() *zerolog.Event { return logger.Trace() }
|
15
internal/net/http/dummy_response_writer.go
Normal file
15
internal/net/http/dummy_response_writer.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
type DummyResponseWriter struct{}
|
||||
|
||||
func (w DummyResponseWriter) Header() http.Header {
|
||||
return make(http.Header)
|
||||
}
|
||||
|
||||
func (w DummyResponseWriter) Write([]byte) (_ int, _ error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (w DummyResponseWriter) WriteHeader(int) {}
|
|
@ -1,15 +0,0 @@
|
|||
package loadbalancer
|
||||
|
||||
import "net/http"
|
||||
|
||||
type DummyResponseWriter struct{}
|
||||
|
||||
func (w *DummyResponseWriter) Header() (_ http.Header) {
|
||||
return
|
||||
}
|
||||
|
||||
func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (w *DummyResponseWriter) WriteHeader(int) {}
|
|
@ -11,20 +11,22 @@ import (
|
|||
)
|
||||
|
||||
type ipHash struct {
|
||||
*LoadBalancer
|
||||
|
||||
realIP *middleware.Middleware
|
||||
pool servers
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newIPHash() impl {
|
||||
impl := new(ipHash)
|
||||
impl := &ipHash{LoadBalancer: lb}
|
||||
if len(lb.Options) == 0 {
|
||||
return impl
|
||||
}
|
||||
var err E.Error
|
||||
impl.realIP, err = middleware.NewRealIP(lb.Options)
|
||||
if err != nil {
|
||||
logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err)
|
||||
E.LogError("invalid real_ip options, ignoring", err, &impl.Logger)
|
||||
}
|
||||
return impl
|
||||
}
|
||||
|
@ -70,7 +72,7 @@ func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err)
|
||||
impl.Err(err).Msg("invalid remote address " + r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
idx := hashIP(ip) % uint32(len(impl.pool))
|
||||
|
|
|
@ -31,14 +31,14 @@ func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.R
|
|||
srv := srvs[0]
|
||||
minConn, ok := impl.nConn.Load(srv)
|
||||
if !ok {
|
||||
logger.Errorf("[BUG] server %s not found", srv.Name)
|
||||
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for i := 1; i < len(srvs); i++ {
|
||||
nConn, ok := impl.nConn.Load(srvs[i])
|
||||
if !ok {
|
||||
logger.Errorf("[BUG] server %s not found", srv.Name)
|
||||
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
if nConn.Load() < minConn.Load() {
|
||||
|
|
|
@ -6,9 +6,11 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
|
@ -29,6 +31,8 @@ type (
|
|||
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
|
||||
}
|
||||
LoadBalancer struct {
|
||||
zerolog.Logger
|
||||
|
||||
impl
|
||||
*Config
|
||||
|
||||
|
@ -48,6 +52,7 @@ const maxWeight weightType = 100
|
|||
|
||||
func New(cfg *Config) *LoadBalancer {
|
||||
lb := &LoadBalancer{
|
||||
Logger: logger.With().Str("name", cfg.Link).Logger(),
|
||||
Config: new(Config),
|
||||
pool: newPool(),
|
||||
task: task.DummyTask(),
|
||||
|
@ -102,7 +107,7 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
|
|||
if lb.Mode == Unset && cfg.Mode != Unset {
|
||||
lb.Mode = cfg.Mode
|
||||
if !lb.Mode.ValidateUpdate() {
|
||||
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode)
|
||||
lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
|
||||
}
|
||||
lb.updateImpl()
|
||||
}
|
||||
|
@ -131,7 +136,11 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
|
|||
|
||||
lb.rebalance()
|
||||
lb.impl.OnAddServer(srv)
|
||||
logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
|
||||
|
||||
lb.Debug().
|
||||
Str("action", "add").
|
||||
Str("server", srv.Name).
|
||||
Msgf("%d servers available", lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
||||
|
@ -148,13 +157,15 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
|||
lb.rebalance()
|
||||
lb.impl.OnRemoveServer(srv)
|
||||
|
||||
lb.Debug().
|
||||
Str("action", "remove").
|
||||
Str("server", srv.Name).
|
||||
Msgf("%d servers left", lb.pool.Size())
|
||||
|
||||
if lb.pool.Size() == 0 {
|
||||
lb.task.Finish("no server left")
|
||||
logger.Infof("loadbalancer %s stopped", lb.Link)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) rebalance() {
|
||||
|
@ -218,7 +229,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
|
||||
defer cancel()
|
||||
// send dummy request to wake all servers
|
||||
var dummyRW *DummyResponseWriter
|
||||
var dummyRW gphttp.DummyResponseWriter
|
||||
for _, srv := range srvs {
|
||||
// wake only if server implements Waker
|
||||
_, ok := srv.handler.(idlewatcher.Waker)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
package loadbalancer
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logrus.WithField("module", "load_balancer")
|
||||
var logger = logging.With().Str("module", "load_balancer").Logger()
|
||||
|
|
5
internal/net/http/logger.go
Normal file
5
internal/net/http/logger.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package http
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "http").Logger()
|
|
@ -15,9 +15,9 @@ type cidrWhitelist struct {
|
|||
}
|
||||
|
||||
type cidrWhitelistOpts struct {
|
||||
Allow []*types.CIDR
|
||||
StatusCode int
|
||||
Message string
|
||||
Allow []*types.CIDR `json:"allow"`
|
||||
StatusCode int `json:"statusCode"`
|
||||
Message string `json:"message"`
|
||||
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
|||
return nil, err
|
||||
}
|
||||
if len(wl.cidrWhitelistOpts.Allow) == 0 {
|
||||
return nil, E.Missing("allow range")
|
||||
return nil, E.New("no allowed CIDRs")
|
||||
}
|
||||
return wl.m, nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -13,10 +14,9 @@ var testCIDRWhitelistCompose []byte
|
|||
var deny, accept *Middleware
|
||||
|
||||
func TestCIDRWhitelist(t *testing.T) {
|
||||
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
errs := E.NewBuilder("")
|
||||
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
deny = mids["deny@file"]
|
||||
accept = mids["accept@file"]
|
||||
if deny == nil || accept == nil {
|
||||
|
@ -26,7 +26,7 @@ func TestCIDRWhitelist(t *testing.T) {
|
|||
t.Run("deny", func(t *testing.T) {
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(deny, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
|
||||
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ func TestCIDRWhitelist(t *testing.T) {
|
|||
t.Run("accept", func(t *testing.T) {
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(accept, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -10,10 +10,10 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -26,7 +26,7 @@ const (
|
|||
var (
|
||||
cfCIDRsLastUpdate time.Time
|
||||
cfCIDRsMu sync.Mutex
|
||||
cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP")
|
||||
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
|
||||
)
|
||||
|
||||
var CloudflareRealIP = &realIP{
|
||||
|
@ -80,13 +80,13 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
|||
)
|
||||
if err != nil {
|
||||
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
||||
cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval)
|
||||
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
cfCIDRsLastUpdate = time.Now()
|
||||
cfCIDRsLogger.Debugf("cloudflare CIDR range updated")
|
||||
cfCIDRsLogger.Info().Msg("cloudflare CIDR range updated")
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -8,24 +8,31 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
|
||||
)
|
||||
|
||||
var CustomErrorPage = &Middleware{
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if !ServeStaticErrorPageFile(w, r) {
|
||||
next(w, r)
|
||||
}
|
||||
},
|
||||
modifyResponse: func(resp *Response) error {
|
||||
var CustomErrorPage *Middleware
|
||||
|
||||
func init() {
|
||||
CustomErrorPage = customErrorPage()
|
||||
}
|
||||
|
||||
func customErrorPage() *Middleware {
|
||||
m := &Middleware{
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if !ServeStaticErrorPageFile(w, r) {
|
||||
next(w, r)
|
||||
}
|
||||
},
|
||||
}
|
||||
m.modifyResponse = func(resp *Response) error {
|
||||
// only handles non-success status code and html/plain content type
|
||||
contentType := gphttp.GetContentType(resp.Header)
|
||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||
if ok {
|
||||
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
|
||||
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
io.Copy(io.Discard, resp.Body) // drain the original body
|
||||
resp.Body.Close()
|
||||
|
@ -34,12 +41,13 @@ var CustomErrorPage = &Middleware{
|
|||
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
|
||||
} else {
|
||||
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
|
||||
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
||||
|
@ -51,7 +59,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
|||
filename := path[len(gphttp.StaticFilePathPrefix):]
|
||||
file, ok := errorpage.GetStaticFile(filename)
|
||||
if !ok {
|
||||
errPageLogger.Errorf("unable to load resource %s", filename)
|
||||
logger.Error().Msg("unable to load resource " + filename)
|
||||
return false
|
||||
}
|
||||
ext := filepath.Ext(filename)
|
||||
|
@ -63,15 +71,13 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
|||
case ".css":
|
||||
w.Header().Set("Content-Type", "text/css; charset=utf-8")
|
||||
default:
|
||||
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
|
||||
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
if _, err := w.Write(file); err != nil {
|
||||
errPageLogger.WithError(err).Errorf("unable to write resource %s", filename)
|
||||
logger.Err(err).Msg("unable to write resource " + filename)
|
||||
http.Error(w, "Error page failure", http.StatusInternalServerError)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var errPageLogger = logrus.WithField("middleware", "error_page")
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package errorpage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
|
@ -23,9 +24,10 @@ var (
|
|||
)
|
||||
|
||||
var setup = sync.OnceFunc(func() {
|
||||
dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath)
|
||||
task := task.GlobalTask("error page")
|
||||
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
|
||||
loadContent()
|
||||
go watchDir()
|
||||
go watchDir(task)
|
||||
})
|
||||
|
||||
func GetStaticFile(filename string) ([]byte, bool) {
|
||||
|
@ -44,7 +46,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
|
|||
func loadContent() {
|
||||
files, err := U.ListFiles(errPagesBasePath, 0)
|
||||
if err != nil {
|
||||
Logger.Error(err)
|
||||
logger.Err(err).Msg("failed to list error page resources")
|
||||
return
|
||||
}
|
||||
for _, file := range files {
|
||||
|
@ -53,19 +55,21 @@ func loadContent() {
|
|||
}
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
Logger.Errorf("failed to read error page resource %s: %s", file, err)
|
||||
logger.Warn().Err(err).Msgf("failed to read error page resource %s", file)
|
||||
continue
|
||||
}
|
||||
file = path.Base(file)
|
||||
Logger.Infof("error page resource %s loaded", file)
|
||||
logging.Info().Msgf("error page resource %s loaded", file)
|
||||
fileContentMap.Store(file, content)
|
||||
}
|
||||
}
|
||||
|
||||
func watchDir() {
|
||||
eventCh, errCh := dirWatcher.Events(context.Background())
|
||||
func watchDir(task task.Task) {
|
||||
eventCh, errCh := dirWatcher.Events(task.Context())
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case event, ok := <-eventCh:
|
||||
if !ok {
|
||||
return
|
||||
|
@ -77,14 +81,14 @@ func watchDir() {
|
|||
loadContent()
|
||||
case events.ActionFileDeleted:
|
||||
fileContentMap.Delete(filename)
|
||||
Logger.Infof("error page resource %s deleted", filename)
|
||||
logger.Warn().Msgf("error page resource %s deleted", filename)
|
||||
case events.ActionFileRenamed:
|
||||
Logger.Infof("error page resource %s deleted", filename)
|
||||
logger.Warn().Msgf("error page resource %s deleted", filename)
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
}
|
||||
case err := <-errCh:
|
||||
Logger.Errorf("error watching error page directory: %s", err)
|
||||
E.LogError("error watching error page directory", err, &logger)
|
||||
}
|
||||
}
|
||||
}
|
5
internal/net/http/middleware/errorpage/logger.go
Normal file
5
internal/net/http/middleware/errorpage/logger.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package errorpage
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "errorpage").Logger()
|
|
@ -24,11 +24,12 @@ type (
|
|||
client http.Client
|
||||
}
|
||||
forwardAuthOpts struct {
|
||||
Address string
|
||||
TrustForwardHeader bool
|
||||
AuthResponseHeaders []string
|
||||
AddAuthCookiesToResponse []string
|
||||
transport http.RoundTripper
|
||||
Address string `json:"address"`
|
||||
TrustForwardHeader bool `json:"trustForwardHeader"`
|
||||
AuthResponseHeaders []string `json:"authResponseHeaders"`
|
||||
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
|
||||
|
||||
transport http.RoundTripper
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -39,13 +40,11 @@ var ForwardAuth = &forwardAuth{
|
|||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
fa := new(forwardAuth)
|
||||
fa.forwardAuthOpts = new(forwardAuthOpts)
|
||||
err := Deserialize(optsRaw, fa.forwardAuthOpts)
|
||||
if err != nil {
|
||||
if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = E.Check(url.Parse(fa.Address))
|
||||
if err != nil {
|
||||
return nil, E.Invalid("address", fa.Address)
|
||||
if _, err := url.Parse(fa.Address); err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
|
||||
fa.m = &Middleware{
|
||||
|
|
5
internal/net/http/middleware/logger.go
Normal file
5
internal/net/http/middleware/logger.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package middleware
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "middleware").Logger()
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
|
@ -32,6 +33,8 @@ type (
|
|||
Middleware struct {
|
||||
_ U.NoCopy
|
||||
|
||||
zerolog.Logger
|
||||
|
||||
name string
|
||||
|
||||
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
||||
|
@ -78,13 +81,19 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
if len(optsRaw) != 0 && m.withOptions != nil {
|
||||
return m.withOptions(optsRaw)
|
||||
if m.withOptions != nil {
|
||||
m, err := m.withOptions(optsRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Logger = logger.With().Str("name", m.name).Logger()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// WithOptionsClone is called only once
|
||||
// set withOptions and labelParser will not be used after that
|
||||
return &Middleware{
|
||||
Logger: logger.With().Str("name", m.name).Logger(),
|
||||
name: m.name,
|
||||
before: m.before,
|
||||
modifyResponse: m.modifyResponse,
|
||||
|
@ -108,24 +117,20 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
|
|||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) {
|
||||
middlewares = make([]*Middleware, 0, len(middlewaresMap))
|
||||
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
invalidM := E.NewBuilder("invalid middlewares")
|
||||
invalidOpts := E.NewBuilder("invalid options")
|
||||
defer func() {
|
||||
invalidM.Add(invalidOpts.Build())
|
||||
invalidM.To(&res)
|
||||
}()
|
||||
errs := E.NewBuilder("middlewares compile error")
|
||||
invalidOpts := E.NewBuilder("options compile error")
|
||||
|
||||
for name, opts := range middlewaresMap {
|
||||
m, ok := Get(name)
|
||||
if !ok {
|
||||
invalidM.Add(E.NotExist("middleware", name))
|
||||
m, err := Get(name)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := m.WithOptionsClone(opts)
|
||||
m, err = m.WithOptionsClone(opts)
|
||||
if err != nil {
|
||||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
|
@ -133,7 +138,10 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid
|
|||
middlewares = append(middlewares, m)
|
||||
}
|
||||
|
||||
return
|
||||
if invalidOpts.HasError() {
|
||||
errs.Add(invalidOpts.Error())
|
||||
}
|
||||
return middlewares, errs.Error()
|
||||
}
|
||||
|
||||
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
||||
|
|
|
@ -4,64 +4,60 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) {
|
||||
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, E.FailWith("read middleware compose file", err)
|
||||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
return BuildMiddlewaresFromYAML(fileContent)
|
||||
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
||||
}
|
||||
|
||||
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) {
|
||||
b := E.NewBuilder("middlewares compile errors")
|
||||
defer b.To(&outErr)
|
||||
|
||||
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
|
||||
var rawMap map[string][]map[string]any
|
||||
err := yaml.Unmarshal(data, &rawMap)
|
||||
if err != nil {
|
||||
b.Add(E.FailWith("yaml unmarshal", err))
|
||||
return
|
||||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
middlewares = make(map[string]*Middleware)
|
||||
middlewares := make(map[string]*Middleware)
|
||||
for name, defs := range rawMap {
|
||||
chainErr := E.NewBuilder("%s", name)
|
||||
chainErr := E.NewBuilder("")
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"] == "" {
|
||||
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
|
||||
chainErr.Addf("item %d: missing field 'use'", i)
|
||||
continue
|
||||
}
|
||||
baseName := def["use"].(string)
|
||||
base, ok := Get(baseName)
|
||||
if !ok {
|
||||
base, ok = middlewares[baseName]
|
||||
if !ok {
|
||||
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
|
||||
continue
|
||||
}
|
||||
base, err := Get(baseName)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
delete(def, "use")
|
||||
m, err := base.WithOptionsClone(def)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("item%d", i))
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
chain = append(chain, m)
|
||||
}
|
||||
if chainErr.HasError() {
|
||||
b.Add(chainErr.Build())
|
||||
eb.Add(chainErr.Error().Subject(source))
|
||||
} else {
|
||||
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
|
||||
}
|
||||
}
|
||||
return
|
||||
return middlewares
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
|
@ -86,11 +82,13 @@ func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
|||
}
|
||||
if len(modResps) > 0 {
|
||||
m.modifyResponse = func(res *Response) error {
|
||||
b := E.NewBuilder("errors in middleware")
|
||||
errs := E.NewBuilder("modify response errors")
|
||||
for _, mr := range modResps {
|
||||
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
|
||||
if err := mr.modifyResponse(res); err != nil {
|
||||
errs.Add(E.From(err).Subject(mr.name))
|
||||
}
|
||||
}
|
||||
return b.Build().Error()
|
||||
return errs.Error()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -13,10 +13,10 @@ import (
|
|||
var testMiddlewareCompose []byte
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
|
||||
ExpectNoError(t, err.Error())
|
||||
_, err = E.Check(json.MarshalIndent(middlewares, "", " "))
|
||||
ExpectNoError(t, err.Error())
|
||||
errs := E.NewBuilder("")
|
||||
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
E.Must(json.MarshalIndent(middlewares, "", " "))
|
||||
// t.Log(string(data))
|
||||
// TODO: test
|
||||
}
|
||||
|
|
|
@ -6,26 +6,37 @@ import (
|
|||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var middlewares map[string]*Middleware
|
||||
var allMiddlewares map[string]*Middleware
|
||||
|
||||
func Get(name string) (middleware *Middleware, ok bool) {
|
||||
middleware, ok = middlewares[U.ToLowerNoSnake(name)]
|
||||
return
|
||||
var (
|
||||
ErrUnknownMiddleware = E.New("unknown middleware")
|
||||
ErrDuplicatedMiddleware = E.New("duplicated middleware")
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
middleware, ok := allMiddlewares[U.ToLowerNoSnake(name)]
|
||||
if !ok {
|
||||
return nil, ErrUnknownMiddleware.
|
||||
Subject(name).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
|
||||
}
|
||||
return middleware, nil
|
||||
}
|
||||
|
||||
func All() map[string]*Middleware {
|
||||
return middlewares
|
||||
return allMiddlewares
|
||||
}
|
||||
|
||||
// initialize middleware names and label parsers
|
||||
func init() {
|
||||
middlewares = map[string]*Middleware{
|
||||
allMiddlewares = map[string]*Middleware{
|
||||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
"redirecthttp": RedirectHTTP,
|
||||
|
@ -39,10 +50,10 @@ func init() {
|
|||
|
||||
// !experimental
|
||||
"forwardauth": ForwardAuth.m,
|
||||
"oauth2": OAuth2.m,
|
||||
// "oauth2": OAuth2.m,
|
||||
}
|
||||
names := make(map[*Middleware][]string)
|
||||
for name, m := range middlewares {
|
||||
for name, m := range allMiddlewares {
|
||||
names[m] = append(names[m], http.CanonicalHeaderKey(name))
|
||||
}
|
||||
for m, names := range names {
|
||||
|
@ -55,27 +66,30 @@ func init() {
|
|||
}
|
||||
|
||||
func LoadComposeFiles() {
|
||||
b := E.NewBuilder("failed to load middlewares")
|
||||
errs := E.NewBuilder("middleware compile errors")
|
||||
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to list middleware definitions: %s", err)
|
||||
logger.Err(err).Msg("failed to list middleware definitions")
|
||||
return
|
||||
}
|
||||
for _, defFile := range middlewareDefs {
|
||||
mws, err := BuildMiddlewaresFromComposeFile(defFile)
|
||||
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
|
||||
if len(mws) == 0 {
|
||||
continue
|
||||
}
|
||||
for name, m := range mws {
|
||||
if _, ok := middlewares[name]; ok {
|
||||
b.Add(E.Duplicated("middleware", name))
|
||||
if _, ok := allMiddlewares[name]; ok {
|
||||
errs.Add(ErrDuplicatedMiddleware.Subject(name))
|
||||
continue
|
||||
}
|
||||
middlewares[U.ToLowerNoSnake(name)] = m
|
||||
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
|
||||
allMiddlewares[U.ToLowerNoSnake(name)] = m
|
||||
logger.Info().
|
||||
Str("name", name).
|
||||
Str("src", path.Base(defFile)).
|
||||
Msg("middleware loaded")
|
||||
}
|
||||
b.Add(err.Subject(path.Base(defFile)))
|
||||
}
|
||||
if b.HasError() {
|
||||
logger.Error(b.Build())
|
||||
if errs.HasError() {
|
||||
E.LogError(errs.About(), errs.Error(), &logger)
|
||||
}
|
||||
}
|
||||
|
||||
var logger = logrus.WithField("module", "middlewares")
|
||||
|
|
|
@ -12,9 +12,9 @@ type (
|
|||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyRequestOpts struct {
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
SetHeaders map[string]string `json:"setHeaders"`
|
||||
AddHeaders map[string]string `json:"addHeaders"`
|
||||
HideHeaders []string `json:"hideHeaders"`
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ func TestSetModifyRequest(t *testing.T) {
|
|||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyRequest.m.WithOptionsClone(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
|
||||
|
@ -26,7 +26,7 @@ func TestSetModifyRequest(t *testing.T) {
|
|||
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
})
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
|
||||
|
|
|
@ -13,11 +13,7 @@ type (
|
|||
m *Middleware
|
||||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyResponseOpts struct {
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
}
|
||||
modifyResponseOpts = modifyRequestOpts
|
||||
)
|
||||
|
||||
var ModifyResponse = &modifyResponse{
|
||||
|
|
|
@ -16,7 +16,7 @@ func TestSetModifyResponse(t *testing.T) {
|
|||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyResponse.m.WithOptionsClone(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
|
||||
|
@ -26,7 +26,7 @@ func TestSetModifyResponse(t *testing.T) {
|
|||
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
})
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
|
||||
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
|
|
|
@ -1,129 +1,117 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
// import (
|
||||
// "encoding/json"
|
||||
// "fmt"
|
||||
// "net/http"
|
||||
// "net/url"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
// E "github.com/yusing/go-proxy/internal/error"
|
||||
// )
|
||||
|
||||
type oAuth2 struct {
|
||||
*oAuth2Opts
|
||||
m *Middleware
|
||||
}
|
||||
// type oAuth2 struct {
|
||||
// oAuth2Opts
|
||||
// m *Middleware
|
||||
// }
|
||||
|
||||
type oAuth2Opts struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthURL string // Authorization Endpoint
|
||||
TokenURL string // Token Endpoint
|
||||
}
|
||||
// type oAuth2Opts struct {
|
||||
// ClientID string `validate:"required"`
|
||||
// ClientSecret string `validate:"required"`
|
||||
// AuthURL string `validate:"required"` // Authorization Endpoint
|
||||
// TokenURL string `validate:"required"` // Token Endpoint
|
||||
// }
|
||||
|
||||
var OAuth2 = &oAuth2{
|
||||
m: &Middleware{withOptions: NewAuthentikOAuth2},
|
||||
}
|
||||
// var OAuth2 = &oAuth2{
|
||||
// m: &Middleware{withOptions: NewAuthentikOAuth2},
|
||||
// }
|
||||
|
||||
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
oauth := new(oAuth2)
|
||||
oauth.m = &Middleware{
|
||||
impl: oauth,
|
||||
before: oauth.handleOAuth2,
|
||||
}
|
||||
oauth.oAuth2Opts = &oAuth2Opts{}
|
||||
err := Deserialize(opts, oauth.oAuth2Opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := E.NewBuilder("missing required fields")
|
||||
optV := reflect.ValueOf(oauth.oAuth2Opts)
|
||||
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
|
||||
if optV.FieldByName(field.Name).Len() == 0 {
|
||||
b.Add(E.Missing(field.Name))
|
||||
}
|
||||
}
|
||||
if b.HasError() {
|
||||
return nil, b.Build().Subject("oAuth2")
|
||||
}
|
||||
return oauth.m, nil
|
||||
}
|
||||
// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
// oauth := new(oAuth2)
|
||||
// oauth.m = &Middleware{
|
||||
// impl: oauth,
|
||||
// before: oauth.handleOAuth2,
|
||||
// }
|
||||
// err := Deserialize(opts, &oauth.oAuth2Opts)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return oauth.m, nil
|
||||
// }
|
||||
|
||||
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
|
||||
// Check if the user is authenticated (you may use session, cookie, etc.)
|
||||
if !userIsAuthenticated(r) {
|
||||
// TODO: Redirect to OAuth2 auth URL
|
||||
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
|
||||
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
|
||||
return
|
||||
}
|
||||
// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
|
||||
// // Check if the user is authenticated (you may use session, cookie, etc.)
|
||||
// if !userIsAuthenticated(r) {
|
||||
// // TODO: Redirect to OAuth2 auth URL
|
||||
// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
|
||||
// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
|
||||
// return
|
||||
// }
|
||||
|
||||
// If you have a token in the query string, process it
|
||||
if code := r.URL.Query().Get("code"); code != "" {
|
||||
// Exchange the authorization code for a token here
|
||||
// Use the TokenURL and authenticate the user
|
||||
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI)
|
||||
if err != nil {
|
||||
// handle error
|
||||
http.Error(rw, "failed to get token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
// // If you have a token in the query string, process it
|
||||
// if code := r.URL.Query().Get("code"); code != "" {
|
||||
// // Exchange the authorization code for a token here
|
||||
// // Use the TokenURL and authenticate the user
|
||||
// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// http.Error(rw, "failed to get token", http.StatusUnauthorized)
|
||||
// return
|
||||
// }
|
||||
|
||||
// Save token and user info based on your requirements
|
||||
saveToken(rw, token)
|
||||
// // Save token and user info based on your requirements
|
||||
// saveToken(rw, token)
|
||||
|
||||
// Redirect to the originally requested URL
|
||||
http.Redirect(rw, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
// // Redirect to the originally requested URL
|
||||
// http.Redirect(rw, r, "/", http.StatusFound)
|
||||
// return
|
||||
// }
|
||||
|
||||
// If user is authenticated, go to the next handler
|
||||
next(rw, r)
|
||||
}
|
||||
// // If user is authenticated, go to the next handler
|
||||
// next(rw, r)
|
||||
// }
|
||||
|
||||
func userIsAuthenticated(r *http.Request) bool {
|
||||
// Example: Check for a session or cookie
|
||||
session, err := r.Cookie("session_token")
|
||||
if err != nil || session.Value == "" {
|
||||
return false
|
||||
}
|
||||
// Validate the session_token if necessary
|
||||
return true
|
||||
}
|
||||
// func userIsAuthenticated(r *http.Request) bool {
|
||||
// // Example: Check for a session or cookie
|
||||
// session, err := r.Cookie("session_token")
|
||||
// if err != nil || session.Value == "" {
|
||||
// return false
|
||||
// }
|
||||
// // Validate the session_token if necessary
|
||||
// return true
|
||||
// }
|
||||
|
||||
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
|
||||
// Prepare the request body
|
||||
data := url.Values{
|
||||
"client_id": {opts.ClientID},
|
||||
"client_secret": {opts.ClientSecret},
|
||||
"code": {code},
|
||||
"grant_type": {"authorization_code"},
|
||||
"redirect_uri": {requestURI},
|
||||
}
|
||||
resp, err := http.PostForm(opts.TokenURL, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to request token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
|
||||
}
|
||||
// Decode the response
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
|
||||
// // Prepare the request body
|
||||
// data := url.Values{
|
||||
// "client_id": {opts.ClientID},
|
||||
// "client_secret": {opts.ClientSecret},
|
||||
// "code": {code},
|
||||
// "grant_type": {"authorization_code"},
|
||||
// "redirect_uri": {requestURI},
|
||||
// }
|
||||
// resp, err := http.PostForm(opts.TokenURL, data)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("failed to request token: %w", err)
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
// if resp.StatusCode != http.StatusOK {
|
||||
// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
|
||||
// }
|
||||
// // Decode the response
|
||||
// var tokenResp struct {
|
||||
// AccessToken string `json:"access_token"`
|
||||
// }
|
||||
// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
// return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
// }
|
||||
// return tokenResp.AccessToken, nil
|
||||
// }
|
||||
|
||||
func saveToken(rw ResponseWriter, token string) {
|
||||
// Example: Save token in cookie
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: "auth_token",
|
||||
Value: token,
|
||||
// set other properties as necessary, such as Secure and HttpOnly
|
||||
})
|
||||
}
|
||||
// func saveToken(rw ResponseWriter, token string) {
|
||||
// // Example: Save token in cookie
|
||||
// http.SetCookie(rw, &http.Cookie{
|
||||
// Name: "auth_token",
|
||||
// Value: token,
|
||||
// // set other properties as necessary, such as Secure and HttpOnly
|
||||
// })
|
||||
// }
|
||||
|
|
|
@ -16,9 +16,9 @@ type realIP struct {
|
|||
|
||||
type realIPOpts struct {
|
||||
// Header is the name of the header to use for the real client IP
|
||||
Header string
|
||||
Header string `json:"header"`
|
||||
// From is a list of Address / CIDRs to trust
|
||||
From []*types.CIDR
|
||||
From []*types.CIDR `json:"from"`
|
||||
/*
|
||||
If recursive search is disabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
|
@ -27,7 +27,7 @@ type realIPOpts struct {
|
|||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last non-trusted address sent in the request header field.
|
||||
*/
|
||||
Recursive bool
|
||||
Recursive bool `json:"recursive"`
|
||||
}
|
||||
|
||||
var RealIP = &realIP{
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
|||
}
|
||||
|
||||
ri, err := NewRealIP(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
for i, CIDR := range ri.impl.(*realIP).From {
|
||||
|
@ -61,15 +61,15 @@ func TestSetRealIP(t *testing.T) {
|
|||
"set_headers": map[string]string{testHeader: testRealIP},
|
||||
}
|
||||
realip, err := NewRealIP(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
|
||||
mr, err := NewModifyRequest(optsMr)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
|
||||
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
|
||||
|
||||
result, err := newMiddlewareTest(mid, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
t.Log(traces)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
||||
|
|
|
@ -12,7 +12,7 @@ func TestRedirectToHTTPs(t *testing.T) {
|
|||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
scheme: "http",
|
||||
})
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
|
||||
}
|
||||
|
@ -21,6 +21,6 @@ func TestNoRedirect(t *testing.T) {
|
|||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
scheme: "https",
|
||||
})
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Trace struct {
|
||||
|
@ -88,7 +88,7 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
|
|||
return nil
|
||||
}
|
||||
return addTrace(&Trace{
|
||||
Time: U.FormatTime(time.Now()),
|
||||
Time: strutils.FormatTime(time.Now()),
|
||||
Caller: m.Fullname(),
|
||||
Message: fmt.Sprintf(msg, args...),
|
||||
})
|
||||
|
|
|
@ -57,8 +57,7 @@ func (w *ModifyResponseWriter) WriteHeader(code int) {
|
|||
}
|
||||
|
||||
if err := w.modifier(&resp); err != nil {
|
||||
w.modifierErr = err
|
||||
logger.Errorf("error modifying response: %s", err)
|
||||
w.modifierErr = fmt.Errorf("response modifier error: %w", err)
|
||||
w.w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
@ -69,6 +69,8 @@ type ProxyRequest struct {
|
|||
// 1xx responses are forwarded to the client if the underlying
|
||||
// transport supports ClientTrace.Got1xxResponse.
|
||||
type ReverseProxy struct {
|
||||
zerolog.Logger
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
@ -149,7 +151,12 @@ func NewReverseProxy(name string, target types.URL, transport http.RoundTripper)
|
|||
if transport == nil {
|
||||
panic("nil transport")
|
||||
}
|
||||
rp := &ReverseProxy{Transport: transport, TargetName: name, TargetURL: target}
|
||||
rp := &ReverseProxy{
|
||||
Logger: logger.With().Str("name", name).Logger(),
|
||||
Transport: transport,
|
||||
TargetName: name,
|
||||
TargetURL: target,
|
||||
}
|
||||
rp.ServeHTTP = rp.serveHTTP
|
||||
return rp
|
||||
}
|
||||
|
@ -195,9 +202,9 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
|
|||
switch {
|
||||
case errors.Is(err, context.Canceled),
|
||||
errors.Is(err, io.EOF):
|
||||
logger.Debugf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err)
|
||||
logger.Debug().Err(err).Str("url", r.URL.String()).Msg("http proxy error")
|
||||
default:
|
||||
logger.Errorf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err)
|
||||
logger.Err(err).Str("url", r.URL.String()).Msg("http proxy error")
|
||||
}
|
||||
if writeHeader {
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
|
@ -219,6 +226,10 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
|
|||
}
|
||||
|
||||
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := rw.(DummyResponseWriter); ok {
|
||||
return
|
||||
}
|
||||
|
||||
transport := p.Transport
|
||||
|
||||
ctx := req.Context()
|
||||
|
@ -453,6 +464,7 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
|
|||
resUpType := UpgradeType(res.Header)
|
||||
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
||||
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(reqUpType, resUpType) {
|
||||
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true)
|
||||
|
@ -518,5 +530,3 @@ func IsPrint(s string) bool {
|
|||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var logger = logrus.WithField("module", "http")
|
||||
|
|
|
@ -9,10 +9,15 @@ import (
|
|||
|
||||
type CIDR net.IPNet
|
||||
|
||||
var (
|
||||
ErrInvalidCIDR = E.New("invalid CIDR")
|
||||
ErrInvalidCIDRType = E.New("invalid CIDR type")
|
||||
)
|
||||
|
||||
func (cidr *CIDR) ConvertFrom(val any) E.Error {
|
||||
cidrStr, ok := val.(string)
|
||||
if !ok {
|
||||
return E.TypeMismatch[string](val)
|
||||
return ErrInvalidCIDRType.Subjectf("%T", val)
|
||||
}
|
||||
|
||||
if !strings.Contains(cidrStr, "/") {
|
||||
|
@ -20,7 +25,7 @@ func (cidr *CIDR) ConvertFrom(val any) E.Error {
|
|||
}
|
||||
_, ipnet, err := net.ParseCIDR(cidrStr)
|
||||
if err != nil {
|
||||
return E.Invalid("CIDR", cidr)
|
||||
return ErrInvalidCIDR.Subject(cidrStr)
|
||||
}
|
||||
*cidr = CIDR(*ipnet)
|
||||
return nil
|
||||
|
|
|
@ -5,9 +5,28 @@ import (
|
|||
"net"
|
||||
)
|
||||
|
||||
type Stream interface {
|
||||
fmt.Stringer
|
||||
net.Listener
|
||||
Setup() error
|
||||
Handle(conn net.Conn) error
|
||||
type (
|
||||
Stream interface {
|
||||
fmt.Stringer
|
||||
StreamListener
|
||||
Setup() error
|
||||
Handle(conn StreamConn) error
|
||||
}
|
||||
StreamListener interface {
|
||||
Addr() net.Addr
|
||||
Accept() (StreamConn, error)
|
||||
Close() error
|
||||
}
|
||||
StreamConn any
|
||||
NetListenerWrapper struct {
|
||||
net.Listener
|
||||
}
|
||||
)
|
||||
|
||||
func NetListener(l net.Listener) StreamListener {
|
||||
return NetListenerWrapper{Listener: l}
|
||||
}
|
||||
|
||||
func (l NetListenerWrapper) Accept() (StreamConn, error) {
|
||||
return l.Listener.Accept()
|
||||
}
|
||||
|
|
|
@ -1,22 +1,32 @@
|
|||
package notif
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type (
|
||||
Dispatcher struct {
|
||||
task task.Task
|
||||
logCh chan *logrus.Entry
|
||||
logCh chan *LogMessage
|
||||
providers F.Set[Provider]
|
||||
}
|
||||
LogMessage struct {
|
||||
Level zerolog.Level
|
||||
Title, Message string
|
||||
}
|
||||
)
|
||||
|
||||
var dispatcher *Dispatcher
|
||||
|
||||
var ErrUnknownNotifProvider = E.New("unknown notification provider")
|
||||
|
||||
const dispatchErr = "notification dispatch error"
|
||||
|
||||
func init() {
|
||||
dispatcher = newNotifDispatcher()
|
||||
go dispatcher.start()
|
||||
|
@ -25,7 +35,7 @@ func init() {
|
|||
func newNotifDispatcher() *Dispatcher {
|
||||
return &Dispatcher{
|
||||
task: task.GlobalTask("notif dispatcher"),
|
||||
logCh: make(chan *logrus.Entry),
|
||||
logCh: make(chan *LogMessage),
|
||||
providers: F.NewSet[Provider](),
|
||||
}
|
||||
}
|
||||
|
@ -34,11 +44,13 @@ func GetDispatcher() *Dispatcher {
|
|||
return dispatcher
|
||||
}
|
||||
|
||||
func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E.Error) {
|
||||
func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, error) {
|
||||
name := configSubTask.Name()
|
||||
createFunc, ok := Providers[name]
|
||||
if !ok {
|
||||
return nil, E.NotExist("provider", name)
|
||||
return nil, ErrUnknownNotifProvider.
|
||||
Subject(name).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(name, Providers)))
|
||||
}
|
||||
if provider, err := createFunc(cfg); err != nil {
|
||||
return nil, err
|
||||
|
@ -53,7 +65,6 @@ func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E.
|
|||
|
||||
func (disp *Dispatcher) start() {
|
||||
defer dispatcher.task.Finish("dispatcher stopped")
|
||||
defer close(dispatcher.logCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
@ -65,36 +76,39 @@ func (disp *Dispatcher) start() {
|
|||
}
|
||||
}
|
||||
|
||||
func (disp *Dispatcher) dispatch(entry *logrus.Entry) {
|
||||
func (disp *Dispatcher) dispatch(msg *LogMessage) {
|
||||
task := disp.task.Subtask("dispatch notif")
|
||||
defer task.Finish("notifs dispatched")
|
||||
defer task.Finish("notif dispatched")
|
||||
|
||||
errs := E.NewBuilder("errors sending notif")
|
||||
errs := E.NewBuilder(dispatchErr)
|
||||
disp.providers.RangeAllParallel(func(p Provider) {
|
||||
if err := p.Send(task.Context(), entry); err != nil {
|
||||
errs.Addf("%s: %s", p.Name(), err)
|
||||
if err := p.Send(task.Context(), msg); err != nil {
|
||||
errs.Add(E.PrependSubject(p.Name(), err))
|
||||
}
|
||||
})
|
||||
if err := errs.Build(); err != nil {
|
||||
logrus.Error("notif dispatcher failure: ", err)
|
||||
if errs.HasError() {
|
||||
E.LogError(errs.About(), errs.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Levels implements logrus.Hook.
|
||||
func (disp *Dispatcher) Levels() []logrus.Level {
|
||||
return []logrus.Level{
|
||||
logrus.WarnLevel,
|
||||
logrus.ErrorLevel,
|
||||
logrus.FatalLevel,
|
||||
logrus.PanicLevel,
|
||||
}
|
||||
}
|
||||
// Run implements zerolog.Hook.
|
||||
// func (disp *Dispatcher) Run(e *zerolog.Event, level zerolog.Level, message string) {
|
||||
// if strings.HasPrefix(message, dispatchErr) { // prevent recursion
|
||||
// return
|
||||
// }
|
||||
// switch level {
|
||||
// case zerolog.WarnLevel, zerolog.ErrorLevel, zerolog.FatalLevel, zerolog.PanicLevel:
|
||||
// disp.logCh <- &LogMessage{
|
||||
// Level: level,
|
||||
// Message: message,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Fire implements logrus.Hook.
|
||||
func (disp *Dispatcher) Fire(entry *logrus.Entry) error {
|
||||
if disp.providers.Size() == 0 {
|
||||
return nil
|
||||
func Notify(title, msg string) {
|
||||
dispatcher.logCh <- &LogMessage{
|
||||
Level: zerolog.InfoLevel,
|
||||
Title: title,
|
||||
Message: msg,
|
||||
}
|
||||
disp.logCh <- entry
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"net/url"
|
||||
|
||||
"github.com/gotify/server/v2/model"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
@ -39,7 +39,7 @@ func newGotifyClient(cfg map[string]any) (Provider, E.Error) {
|
|||
|
||||
url, uErr := url.Parse(client.URL)
|
||||
if uErr != nil {
|
||||
return nil, E.FailWith("parse url", uErr)
|
||||
return nil, E.Errorf("invalid gotify URL %s", client.URL)
|
||||
}
|
||||
|
||||
client.url = url
|
||||
|
@ -52,30 +52,23 @@ func (client *GotifyClient) Name() string {
|
|||
}
|
||||
|
||||
// Send implements NotifProvider.
|
||||
func (client *GotifyClient) Send(ctx context.Context, entry *logrus.Entry) error {
|
||||
func (client *GotifyClient) Send(ctx context.Context, logMsg *LogMessage) error {
|
||||
var priority int
|
||||
var title string
|
||||
|
||||
switch entry.Level {
|
||||
case logrus.WarnLevel:
|
||||
switch logMsg.Level {
|
||||
case zerolog.WarnLevel:
|
||||
priority = 2
|
||||
title = "Warning"
|
||||
case logrus.ErrorLevel:
|
||||
case zerolog.ErrorLevel:
|
||||
priority = 5
|
||||
title = "Error"
|
||||
case logrus.FatalLevel, logrus.PanicLevel:
|
||||
case zerolog.FatalLevel, zerolog.PanicLevel:
|
||||
priority = 8
|
||||
title = "Critical"
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
if subjects := FieldsAsTitle(entry); subjects != "" {
|
||||
title = subjects + " " + title
|
||||
}
|
||||
|
||||
msg := &GotifyMessage{
|
||||
Title: title,
|
||||
Message: entry.Message,
|
||||
Title: logMsg.Title,
|
||||
Message: logMsg.Message,
|
||||
Priority: priority,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
package notif
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
func FieldsAsTitle(entry *logrus.Entry) string {
|
||||
if len(entry.Data) == 0 {
|
||||
return ""
|
||||
}
|
||||
var parts []string
|
||||
for k, v := range entry.Data {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", k, v))
|
||||
}
|
||||
parts[0] = U.Title(parts[0])
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
|
@ -3,14 +3,13 @@ package notif
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type (
|
||||
Provider interface {
|
||||
Name() string
|
||||
Send(ctx context.Context, entry *logrus.Entry) error
|
||||
Send(ctx context.Context, logMsg *LogMessage) error
|
||||
}
|
||||
ProviderCreateFunc func(map[string]any) (Provider, E.Error)
|
||||
ProviderConfig map[string]any
|
||||
|
|
|
@ -23,18 +23,18 @@ func ValidateEntry(m *RawEntry) (Entry, E.Error) {
|
|||
|
||||
scheme, err := T.NewScheme(m.Scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, E.From(err)
|
||||
}
|
||||
|
||||
var entry Entry
|
||||
e := E.NewBuilder("error validating entry")
|
||||
errs := E.NewBuilder("entry validation failed")
|
||||
if scheme.IsStream() {
|
||||
entry = validateStreamEntry(m, e)
|
||||
entry = validateStreamEntry(m, errs)
|
||||
} else {
|
||||
entry = validateRPEntry(m, scheme, e)
|
||||
entry = validateRPEntry(m, scheme, errs)
|
||||
}
|
||||
if err := e.Build(); err != nil {
|
||||
return nil, err
|
||||
if errs.HasError() {
|
||||
return nil, errs.Error()
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
|
|
@ -5,13 +5,14 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
"github.com/yusing/go-proxy/internal/homepage"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
|
@ -85,20 +86,20 @@ func (e *RawEntry) FillMissingFields() {
|
|||
} else if !isDocker {
|
||||
pp = "80"
|
||||
} else {
|
||||
logrus.Debugf("no port found for %s", e.Alias)
|
||||
logging.Debug().Msg("no port found for " + e.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
// replace private port with public port if using public IP.
|
||||
if e.Host == cont.PublicIP {
|
||||
if p, ok := cont.PrivatePortMapping[pp]; ok {
|
||||
pp = U.PortString(p.PublicPort)
|
||||
pp = strutils.PortString(p.PublicPort)
|
||||
}
|
||||
}
|
||||
// replace public port with private port if using private IP.
|
||||
if e.Host == cont.PrivateIP {
|
||||
if p, ok := cont.PublicPortMapping[pp]; ok {
|
||||
pp = U.PortString(p.PrivatePort)
|
||||
pp = strutils.PortString(p.PrivatePort)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config {
|
|||
return rp.Idlewatcher
|
||||
}
|
||||
|
||||
func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry {
|
||||
func validateRPEntry(m *RawEntry, s fields.Scheme, errs *E.Builder) *ReverseProxyEntry {
|
||||
cont := m.Container
|
||||
if cont == nil {
|
||||
cont = docker.DummyContainer
|
||||
|
@ -64,35 +64,26 @@ func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEnt
|
|||
lb = nil
|
||||
}
|
||||
|
||||
host, err := fields.ValidateHost(m.Host)
|
||||
b.Add(err)
|
||||
host := E.Collect(errs, fields.ValidateHost, m.Host)
|
||||
port := E.Collect(errs, fields.ValidatePort, m.Port)
|
||||
pathPats := E.Collect(errs, fields.ValidatePathPatterns, m.PathPatterns)
|
||||
url := E.Collect(errs, url.Parse, fmt.Sprintf("%s://%s:%d", s, host, port))
|
||||
iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container)
|
||||
|
||||
port, err := fields.ValidatePort(m.Port)
|
||||
b.Add(err)
|
||||
|
||||
pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns)
|
||||
b.Add(err)
|
||||
|
||||
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
|
||||
b.Add(err)
|
||||
|
||||
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
|
||||
b.Add(err)
|
||||
|
||||
if err != nil {
|
||||
if errs.HasError() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &ReverseProxyEntry{
|
||||
Raw: m,
|
||||
Alias: fields.NewAlias(m.Alias),
|
||||
Alias: fields.Alias(m.Alias),
|
||||
Scheme: s,
|
||||
URL: net.NewURL(url),
|
||||
NoTLSVerify: m.NoTLSVerify,
|
||||
PathPatterns: pathPatterns,
|
||||
PathPatterns: pathPats,
|
||||
HealthCheck: m.HealthCheck,
|
||||
LoadBalance: lb,
|
||||
Middlewares: m.Middlewares,
|
||||
Idlewatcher: idleWatcherCfg,
|
||||
Idlewatcher: iwCfg,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,34 +51,25 @@ func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config {
|
|||
return s.Idlewatcher
|
||||
}
|
||||
|
||||
func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry {
|
||||
func validateStreamEntry(m *RawEntry, errs *E.Builder) *StreamEntry {
|
||||
cont := m.Container
|
||||
if cont == nil {
|
||||
cont = docker.DummyContainer
|
||||
}
|
||||
|
||||
host, err := fields.ValidateHost(m.Host)
|
||||
b.Add(err)
|
||||
host := E.Collect(errs, fields.ValidateHost, m.Host)
|
||||
port := E.Collect(errs, fields.ValidateStreamPort, m.Port)
|
||||
scheme := E.Collect(errs, fields.ValidateStreamScheme, m.Scheme)
|
||||
url := E.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", scheme.ListeningScheme, host, port.ListeningPort))
|
||||
idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container)
|
||||
|
||||
port, err := fields.ValidateStreamPort(m.Port)
|
||||
b.Add(err)
|
||||
|
||||
scheme, err := fields.ValidateStreamScheme(m.Scheme)
|
||||
b.Add(err)
|
||||
|
||||
url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort)))
|
||||
b.Add(err)
|
||||
|
||||
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
|
||||
b.Add(err)
|
||||
|
||||
if b.HasError() {
|
||||
if errs.HasError() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &StreamEntry{
|
||||
Raw: m,
|
||||
Alias: fields.NewAlias(m.Alias),
|
||||
Alias: fields.Alias(m.Alias),
|
||||
Scheme: *scheme,
|
||||
URL: url,
|
||||
Host: host,
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
package fields
|
||||
|
||||
type (
|
||||
Alias string
|
||||
NewAlias = Alias
|
||||
)
|
||||
type Alias string
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type (
|
||||
Host string
|
||||
Subdomain = Alias
|
||||
)
|
||||
|
||||
func ValidateHost[String ~string](s String) (Host, E.Error) {
|
||||
func ValidateHost[String ~string](s String) (Host, error) {
|
||||
return Host(s), nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
@ -13,12 +15,17 @@ type (
|
|||
|
||||
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)
|
||||
|
||||
func ValidatePathPattern(s string) (PathPattern, E.Error) {
|
||||
var (
|
||||
ErrEmptyPathPattern = errors.New("path must not be empty")
|
||||
ErrInvalidPathPattern = errors.New("invalid path pattern")
|
||||
)
|
||||
|
||||
func ValidatePathPattern(s string) (PathPattern, error) {
|
||||
if len(s) == 0 {
|
||||
return "", E.Invalid("path", "must not be empty")
|
||||
return "", ErrEmptyPathPattern
|
||||
}
|
||||
if !pathPattern.MatchString(s) {
|
||||
return "", E.Invalid("path pattern", s)
|
||||
return "", fmt.Errorf("%w %q", ErrInvalidPathPattern, s)
|
||||
}
|
||||
return PathPattern(s), nil
|
||||
}
|
||||
|
@ -27,13 +34,15 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.Error) {
|
|||
if len(s) == 0 {
|
||||
return []PathPattern{"/"}, nil
|
||||
}
|
||||
errs := E.NewBuilder("invalid path patterns")
|
||||
pp := make(PathPatterns, len(s))
|
||||
for i, v := range s {
|
||||
pattern, err := ValidatePathPattern(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
errs.Add(err)
|
||||
} else {
|
||||
pp[i] = pattern
|
||||
}
|
||||
pp[i] = pattern
|
||||
}
|
||||
return pp, nil
|
||||
return pp, errs.Error()
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
U "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -38,10 +38,10 @@ var invalidPatterns = []string{
|
|||
func TestPathPatternRegex(t *testing.T) {
|
||||
for _, pattern := range validPatterns {
|
||||
_, err := ValidatePathPattern(pattern)
|
||||
U.ExpectNoError(t, err.Error())
|
||||
U.ExpectNoError(t, err)
|
||||
}
|
||||
for _, pattern := range invalidPatterns {
|
||||
_, err := ValidatePathPattern(pattern)
|
||||
U.ExpectError2(t, pattern, E.ErrInvalid, err.Error())
|
||||
U.ExpectTrue(t, errors.Is(err, ErrInvalidPathPattern))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,22 +4,25 @@ import (
|
|||
"strconv"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Port int
|
||||
|
||||
func ValidatePort[String ~string](v String) (Port, E.Error) {
|
||||
p, err := strconv.Atoi(string(v))
|
||||
var ErrPortOutOfRange = E.New("port out of range")
|
||||
|
||||
func ValidatePort[String ~string](v String) (Port, error) {
|
||||
p, err := strutils.Atoi(string(v))
|
||||
if err != nil {
|
||||
return ErrPort, E.Invalid("port number", v).With(err)
|
||||
return ErrPort, err
|
||||
}
|
||||
return ValidatePortInt(p)
|
||||
}
|
||||
|
||||
func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) {
|
||||
func ValidatePortInt[Int int | uint16](v Int) (Port, error) {
|
||||
p := Port(v)
|
||||
if !p.inBound() {
|
||||
return ErrPort, E.OutOfRange("port", p)
|
||||
return ErrPort, ErrPortOutOfRange.Subject(strconv.Itoa(int(p)))
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
|
|
@ -6,12 +6,14 @@ import (
|
|||
|
||||
type Scheme string
|
||||
|
||||
func NewScheme[String ~string](s String) (Scheme, E.Error) {
|
||||
var ErrInvalidScheme = E.New("invalid scheme")
|
||||
|
||||
func NewScheme(s string) (Scheme, error) {
|
||||
switch s {
|
||||
case "http", "https", "tcp", "udp":
|
||||
return Scheme(s), nil
|
||||
}
|
||||
return "", E.Invalid("scheme", s)
|
||||
return "", ErrInvalidScheme.Subject(s)
|
||||
}
|
||||
|
||||
func (s Scheme) IsHTTP() bool { return s == "http" }
|
||||
|
|
|
@ -3,7 +3,6 @@ package fields
|
|||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
|
@ -12,7 +11,9 @@ type StreamPort struct {
|
|||
ProxyPort Port `json:"proxy"`
|
||||
}
|
||||
|
||||
func ValidateStreamPort(p string) (_ StreamPort, err E.Error) {
|
||||
var ErrStreamPortTooManyColons = E.New("too many colons")
|
||||
|
||||
func ValidateStreamPort(p string) (StreamPort, error) {
|
||||
split := strings.Split(p, ":")
|
||||
|
||||
switch len(split) {
|
||||
|
@ -21,36 +22,14 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.Error) {
|
|||
case 2:
|
||||
break
|
||||
default:
|
||||
err = E.Invalid("stream port", p).With("too many colons")
|
||||
return
|
||||
return StreamPort{}, ErrStreamPortTooManyColons.Subject(p)
|
||||
}
|
||||
|
||||
listeningPort, err := ValidatePort(split[0])
|
||||
if err != nil {
|
||||
err = err.Subject("listening port")
|
||||
return
|
||||
}
|
||||
|
||||
proxyPort, err := ValidatePort(split[1])
|
||||
|
||||
if err.Is(E.ErrOutOfRange) {
|
||||
err = err.Subject("proxy port")
|
||||
return
|
||||
} else if err != nil {
|
||||
proxyPort, err = parseNameToPort(split[1])
|
||||
if err != nil {
|
||||
err = E.Invalid("proxy port", proxyPort)
|
||||
return
|
||||
}
|
||||
listeningPort, lErr := ValidatePort(split[0])
|
||||
proxyPort, pErr := ValidatePort(split[1])
|
||||
if err := E.Join(lErr, pErr); err != nil {
|
||||
return StreamPort{}, err
|
||||
}
|
||||
|
||||
return StreamPort{listeningPort, proxyPort}, nil
|
||||
}
|
||||
|
||||
func parseNameToPort(name string) (Port, E.Error) {
|
||||
port, ok := common.ServiceNamePortMapTCP[name]
|
||||
if !ok {
|
||||
return ErrPort, E.Invalid("service", name)
|
||||
}
|
||||
return Port(port), nil
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -11,7 +11,6 @@ var validPorts = []string{
|
|||
"1234:5678",
|
||||
"0:2345",
|
||||
"2345",
|
||||
"1234:postgres",
|
||||
}
|
||||
|
||||
var invalidPorts = []string{
|
||||
|
@ -19,7 +18,6 @@ var invalidPorts = []string{
|
|||
"123:",
|
||||
"0:",
|
||||
":1234",
|
||||
"1234:1234:1234",
|
||||
"qwerty",
|
||||
"asdfgh:asdfgh",
|
||||
"1234:asdfgh",
|
||||
|
@ -32,17 +30,25 @@ var outOfRangePorts = []string{
|
|||
"0:65536",
|
||||
}
|
||||
|
||||
var tooManyColonsPorts = []string{
|
||||
"1234:1234:1234",
|
||||
}
|
||||
|
||||
func TestStreamPort(t *testing.T) {
|
||||
for _, port := range validPorts {
|
||||
_, err := ValidateStreamPort(port)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
}
|
||||
for _, port := range invalidPorts {
|
||||
_, err := ValidateStreamPort(port)
|
||||
ExpectError2(t, port, E.ErrInvalid, err.Error())
|
||||
ExpectError2(t, port, strconv.ErrSyntax, err)
|
||||
}
|
||||
for _, port := range outOfRangePorts {
|
||||
_, err := ValidateStreamPort(port)
|
||||
ExpectError2(t, port, E.ErrOutOfRange, err.Error())
|
||||
ExpectError2(t, port, ErrPortOutOfRange, err)
|
||||
}
|
||||
for _, port := range tooManyColonsPorts {
|
||||
_, err := ValidateStreamPort(port)
|
||||
ExpectError2(t, port, ErrStreamPortTooManyColons, err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,22 +12,23 @@ type StreamScheme struct {
|
|||
ProxyScheme Scheme `json:"proxy"`
|
||||
}
|
||||
|
||||
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) {
|
||||
ss = &StreamScheme{}
|
||||
func ValidateStreamScheme(s string) (*StreamScheme, error) {
|
||||
ss := &StreamScheme{}
|
||||
parts := strings.Split(s, ":")
|
||||
if len(parts) == 1 {
|
||||
parts = []string{s, s}
|
||||
} else if len(parts) != 2 {
|
||||
return nil, E.Invalid("stream scheme", s)
|
||||
return nil, ErrInvalidScheme.Subject(s)
|
||||
}
|
||||
ss.ListeningScheme, err = NewScheme(parts[0])
|
||||
if err.HasError() {
|
||||
return nil, err
|
||||
}
|
||||
ss.ProxyScheme, err = NewScheme(parts[1])
|
||||
if err.HasError() {
|
||||
|
||||
var lErr, pErr error
|
||||
ss.ListeningScheme, lErr = NewScheme(parts[0])
|
||||
ss.ProxyScheme, pErr = NewScheme(parts[1])
|
||||
|
||||
if err := E.Join(lErr, pErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
|
|
37
internal/proxy/fields/stream_scheme_test.go
Normal file
37
internal/proxy/fields/stream_scheme_test.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
var (
|
||||
validStreamSchemes = []string{
|
||||
"tcp:tcp",
|
||||
"tcp:udp",
|
||||
"udp:tcp",
|
||||
"udp:udp",
|
||||
"tcp",
|
||||
"udp",
|
||||
}
|
||||
|
||||
invalidStreamSchemes = []string{
|
||||
"tcp:tcp:",
|
||||
"tcp:",
|
||||
":udp:",
|
||||
":udp",
|
||||
"top",
|
||||
}
|
||||
)
|
||||
|
||||
func TestNewStreamScheme(t *testing.T) {
|
||||
for _, s := range validStreamSchemes {
|
||||
_, err := ValidateStreamScheme(s)
|
||||
ExpectNoError(t, err)
|
||||
}
|
||||
for _, s := range invalidStreamSchemes {
|
||||
_, err := ValidateStreamScheme(s)
|
||||
ExpectError(t, ErrInvalidScheme, err)
|
||||
}
|
||||
}
|
|
@ -7,13 +7,13 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
PT "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -33,6 +33,8 @@ type (
|
|||
rp *gphttp.ReverseProxy
|
||||
|
||||
task task.Task
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
|
||||
SubdomainKey = PT.Alias
|
||||
|
@ -88,6 +90,10 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
|
|||
ReverseProxyEntry: entry,
|
||||
rp: rp,
|
||||
task: task.DummyTask(),
|
||||
l: logger.With().
|
||||
Str("type", string(entry.Scheme)).
|
||||
Str("name", string(entry.Alias)).
|
||||
Logger(),
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
@ -107,11 +113,11 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
|
|||
defer httpRoutesMu.Unlock()
|
||||
|
||||
if !entry.UseHealthCheck(r) && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
|
||||
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
|
||||
r.l.Error().Msg("healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled")
|
||||
if r.HealthCheck == nil {
|
||||
r.HealthCheck = new(health.HealthCheckConfig)
|
||||
}
|
||||
r.HealthCheck.Disable = true
|
||||
r.HealthCheck.Disable = false
|
||||
}
|
||||
|
||||
switch {
|
||||
|
@ -143,7 +149,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
|
|||
|
||||
if r.HealthMon != nil {
|
||||
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
|
||||
logrus.Warn(E.FailWith("health monitor", err))
|
||||
E.LogWarn("health monitor error", err, &r.l)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -209,15 +215,13 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
|||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||
if err != nil {
|
||||
if !middleware.ServeStaticErrorPageFile(w, r) {
|
||||
logrus.Error(E.Failure("request").
|
||||
Subjectf("%s %s", r.Method, r.URL.String()).
|
||||
With(err))
|
||||
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
||||
if ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if _, err := w.Write(errorPage); err != nil {
|
||||
logrus.Errorf("failed to respond error page to %s: %s", r.RemoteAddr, err)
|
||||
logger.Err(err).Msg("failed to write error page")
|
||||
}
|
||||
} else {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue