diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 53442fe..84e3c8d 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -1,7 +1,8 @@ name: Docker Image CI on: - push: {} + push: + tags: ["*"] env: REGISTRY: ghcr.io diff --git a/src/autocert/config.go b/src/autocert/config.go index e61792c..4143e8b 100644 --- a/src/autocert/config.go +++ b/src/autocert/config.go @@ -32,13 +32,13 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) { if cfg.Provider != ProviderLocal { if len(cfg.Domains) == 0 { - b.Addf("no domains specified") + b.Addf("%s", "no domains specified") } if cfg.Provider == "" { - b.Addf("no provider specified") + b.Addf("%s", "no provider specified") } if cfg.Email == "" { - b.Addf("no email specified") + b.Addf("%s", "no email specified") } // check if provider is implemented _, ok := providersGenMap[cfg.Provider] diff --git a/src/common/env.go b/src/common/env.go index 9885be5..ed9ac59 100644 --- a/src/common/env.go +++ b/src/common/env.go @@ -1,8 +1,10 @@ package common import ( + "net" "os" + "github.com/sirupsen/logrus" U "github.com/yusing/go-proxy/utils" ) @@ -12,6 +14,10 @@ var ( ProxyHTTPAddr = GetEnv("GOPROXY_HTTP_ADDR", ":80") ProxyHTTPSAddr = GetEnv("GOPROXY_HTTPS_ADDR", ":443") APIHTTPAddr = GetEnv("GOPROXY_API_ADDR", "127.0.0.1:8888") + + ProxyHTTPPort = getPort(ProxyHTTPAddr) + ProxyHTTPSPort = getPort(ProxyHTTPSAddr) + ProxyAPIPort = getPort(APIHTTPAddr) ) func GetEnvBool(key string) bool { @@ -25,3 +31,11 @@ func GetEnv(key string, defaultValue string) string { } return value } + +func getPort(addr string) string { + _, port, err := net.SplitHostPort(addr) + if err != nil { + logrus.Fatalf("Invalid address: %s", addr) + } + return port +} diff --git a/src/error/error.go b/src/error/error.go index 46f4bc8..1aab67b 100644 --- a/src/error/error.go +++ b/src/error/error.go @@ -118,9 +118,9 @@ func (ne NestedError) With(s any) NestedError { case string: msg = ss case fmt.Stringer: - return ne.append(ss.String()) + return ne.appendMsg(ss.String()) default: - return ne.append(fmt.Sprint(s)) + return ne.appendMsg(fmt.Sprint(s)) } return ne.withError(From(errors.New(msg))) } @@ -207,7 +207,7 @@ func (ne NestedError) withError(err NestedError) NestedError { return ne } -func (ne NestedError) append(msg string) NestedError { +func (ne NestedError) appendMsg(msg string) NestedError { if ne == nil { return nil } diff --git a/src/proxy/reverse_proxy_mod.go b/src/proxy/reverse_proxy_mod.go index 518d732..6e75e57 100644 --- a/src/proxy/reverse_proxy_mod.go +++ b/src/proxy/reverse_proxy_mod.go @@ -143,6 +143,8 @@ type ReverseProxy struct { // If nil, the default is to log the provided error and return // a 502 Status Bad Gateway response. ErrorHandler func(http.ResponseWriter, *http.Request, error) + + ServeHTTP http.HandlerFunc } // A BufferPool is an interface for getting and returning temporary @@ -230,12 +232,16 @@ func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *Revers } } } - return &ReverseProxy{Rewrite: func(pr *ProxyRequest) { - rewriteRequestURL(pr.Out, target) - // pr.SetXForwarded() - setHeaders(pr.Out) - hideHeaders(pr.Out) - }, Transport: transport} + rp := &ReverseProxy{ + Rewrite: func(pr *ProxyRequest) { + rewriteRequestURL(pr.Out, target) + // pr.SetXForwarded() + setHeaders(pr.Out) + hideHeaders(pr.Out) + }, Transport: transport, + } + rp.ServeHTTP = rp.serveHTTP + return rp } func rewriteRequestURL(req *http.Request, target *url.URL) { @@ -277,7 +283,7 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response return true } -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { +func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport ctx := req.Context() @@ -348,9 +354,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } outreq.Header.Del("Forwarded") - outreq.Header.Del("X-Forwarded-For") - outreq.Header.Del("X-Forwarded-Host") - outreq.Header.Del("X-Forwarded-Proto") + // outreq.Header.Del("X-Forwarded-For") + // outreq.Header.Del("X-Forwarded-Host") + // outreq.Header.Del("X-Forwarded-Proto") pr := &ProxyRequest{ In: req, diff --git a/src/route/middleware/add_x_forwarded.go b/src/route/middleware/add_x_forwarded.go new file mode 100644 index 0000000..bc8a25d --- /dev/null +++ b/src/route/middleware/add_x_forwarded.go @@ -0,0 +1,7 @@ +package middleware + +var AddXForwarded = &Middleware{ + rewrite: func(r *ProxyRequest) { + r.SetXForwarded() + }, +} diff --git a/src/route/middleware/middleware.go b/src/route/middleware/middleware.go new file mode 100644 index 0000000..5aedd92 --- /dev/null +++ b/src/route/middleware/middleware.go @@ -0,0 +1,135 @@ +package middleware + +import ( + "net/http" + + E "github.com/yusing/go-proxy/error" + P "github.com/yusing/go-proxy/proxy" +) + +type ( + ReverseProxy = P.ReverseProxy + ProxyRequest = P.ProxyRequest + Request = http.Request + Response = http.Response + ResponseWriter = http.ResponseWriter + + BeforeFunc func(w ResponseWriter, r *Request) (continue_ bool) + RewriteFunc func(req *ProxyRequest) + ModifyResponseFunc func(res *Response) error + + MiddlewareOptionsRaw map[string]string + MiddlewareOptions map[string]interface{} + + Middleware struct { + name string + + before BeforeFunc + rewrite RewriteFunc + modifyResponse ModifyResponseFunc + + options MiddlewareOptions + validateOptions func(opts MiddlewareOptionsRaw) (MiddlewareOptions, E.NestedError) + } +) + +func (m *Middleware) Name() string { + return m.name +} + +func (m *Middleware) String() string { + return m.name +} + +func (m *Middleware) WithOptions(optsRaw MiddlewareOptionsRaw) (*Middleware, E.NestedError) { + if len(optsRaw) == 0 { + return m, nil + } + + var opts MiddlewareOptions + var err E.NestedError + + if m.validateOptions != nil { + if opts, err = m.validateOptions(optsRaw); err != nil { + return nil, err + } + } + + return &Middleware{ + name: m.name, + before: m.before, + rewrite: m.rewrite, + modifyResponse: m.modifyResponse, + options: opts, + }, nil +} + +// TODO: check conflict +func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptionsRaw) (out ReverseProxy, err E.NestedError) { + out = rp + + befores := make([]BeforeFunc, 0, len(middlewares)) + rewrites := make([]RewriteFunc, 0, len(middlewares)) + modifyResponses := make([]ModifyResponseFunc, 0, len(middlewares)) + + invalidM := E.NewBuilder("invalid middlewares") + invalidOpts := E.NewBuilder("invalid options") + defer invalidM.Add(invalidOpts.Build()) + defer invalidM.To(&err) + + for name, opts := range middlewares { + m, ok := Get(name) + if !ok { + invalidM.Addf("%s", name) + continue + } + m, err = m.WithOptions(opts) + if err != nil { + invalidOpts.Add(err.Subject(name)) + continue + } + if m.before != nil { + befores = append(befores, m.before) + } + if m.rewrite != nil { + rewrites = append(rewrites, m.rewrite) + } + if m.modifyResponse != nil { + modifyResponses = append(modifyResponses, m.modifyResponse) + } + } + + if invalidM.HasError() { + return + } + + if len(befores) > 0 { + rp.ServeHTTP = func(w ResponseWriter, r *Request) { + for _, before := range befores { + if !before(w, r) { + return + } + } + rp.ServeHTTP(w, r) + } + } + if len(rewrites) > 0 { + rp.Rewrite = func(req *ProxyRequest) { + for _, rewrite := range rewrites { + rewrite(req) + } + } + } + if len(modifyResponses) > 0 { + rp.ModifyResponse = func(res *Response) error { + for _, modifyResponse := range modifyResponses { + if err := modifyResponse(res); err != nil { + return err + } + } + return nil + } + } + + return +} diff --git a/src/route/middleware/middlewares.go b/src/route/middleware/middlewares.go new file mode 100644 index 0000000..3c32ffd --- /dev/null +++ b/src/route/middleware/middlewares.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "fmt" + "strings" +) + +var middlewares = map[string]*Middleware{ + "set_x_forwarded": SetXForwarded, // nginx + "add_x_forwarded": AddXForwarded, // nginx + "trust_forward_header": AddXForwarded, // traefik alias + "redirect_http": RedirectHTTP, +} + +func Get(name string) (middleware *Middleware, ok bool) { + middleware, ok = middlewares[name] + return +} + +// initialize middleware names +var _ = func() (_ bool) { + names := make(map[*Middleware][]string) + for name, m := range middlewares { + names[m] = append(names[m], name) + } + for m, names := range names { + if len(names) > 1 { + m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", ")) + } else { + m.name = names[0] + } + } + return +}() diff --git a/src/route/middleware/redirect_http.go b/src/route/middleware/redirect_http.go new file mode 100644 index 0000000..613a2b0 --- /dev/null +++ b/src/route/middleware/redirect_http.go @@ -0,0 +1,20 @@ +package middleware + +import ( + "net/http" + + "github.com/yusing/go-proxy/common" +) + +var RedirectHTTP = &Middleware{ + before: func(w ResponseWriter, r *Request) (continue_ bool) { + if r.TLS == nil { + r.URL.Scheme = "https" + r.URL.Host = r.URL.Hostname() + common.ProxyHTTPSPort + http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) + } else { + continue_ = true + } + return + }, +} diff --git a/src/route/middleware/set_x_forwarded.go b/src/route/middleware/set_x_forwarded.go new file mode 100644 index 0000000..2cec4aa --- /dev/null +++ b/src/route/middleware/set_x_forwarded.go @@ -0,0 +1,10 @@ +package middleware + +var SetXForwarded = &Middleware{ + rewrite: func(r *ProxyRequest) { + r.Out.Header.Del("X-Forwarded-For") + r.Out.Header.Del("X-Forwarded-Host") + r.Out.Header.Del("X-Forwarded-Proto") + r.SetXForwarded() + }, +}