Initial abstract implementation of middlewares

This commit is contained in:
yusing 2024-09-25 14:12:40 +08:00
parent 48bf31fd0e
commit 72df9ff3e4
10 changed files with 244 additions and 17 deletions

View file

@ -1,7 +1,8 @@
name: Docker Image CI
on:
push: {}
push:
tags: ["*"]
env:
REGISTRY: ghcr.io

View file

@ -32,13 +32,13 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
b.Addf("no domains specified")
b.Addf("%s", "no domains specified")
}
if cfg.Provider == "" {
b.Addf("no provider specified")
b.Addf("%s", "no provider specified")
}
if cfg.Email == "" {
b.Addf("no email specified")
b.Addf("%s", "no email specified")
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]

View file

@ -1,8 +1,10 @@
package common
import (
"net"
"os"
"github.com/sirupsen/logrus"
U "github.com/yusing/go-proxy/utils"
)
@ -12,6 +14,10 @@ var (
ProxyHTTPAddr = GetEnv("GOPROXY_HTTP_ADDR", ":80")
ProxyHTTPSAddr = GetEnv("GOPROXY_HTTPS_ADDR", ":443")
APIHTTPAddr = GetEnv("GOPROXY_API_ADDR", "127.0.0.1:8888")
ProxyHTTPPort = getPort(ProxyHTTPAddr)
ProxyHTTPSPort = getPort(ProxyHTTPSAddr)
ProxyAPIPort = getPort(APIHTTPAddr)
)
func GetEnvBool(key string) bool {
@ -25,3 +31,11 @@ func GetEnv(key string, defaultValue string) string {
}
return value
}
func getPort(addr string) string {
_, port, err := net.SplitHostPort(addr)
if err != nil {
logrus.Fatalf("Invalid address: %s", addr)
}
return port
}

View file

@ -118,9 +118,9 @@ func (ne NestedError) With(s any) NestedError {
case string:
msg = ss
case fmt.Stringer:
return ne.append(ss.String())
return ne.appendMsg(ss.String())
default:
return ne.append(fmt.Sprint(s))
return ne.appendMsg(fmt.Sprint(s))
}
return ne.withError(From(errors.New(msg)))
}
@ -207,7 +207,7 @@ func (ne NestedError) withError(err NestedError) NestedError {
return ne
}
func (ne NestedError) append(msg string) NestedError {
func (ne NestedError) appendMsg(msg string) NestedError {
if ne == nil {
return nil
}

View file

@ -143,6 +143,8 @@ type ReverseProxy struct {
// If nil, the default is to log the provided error and return
// a 502 Status Bad Gateway response.
ErrorHandler func(http.ResponseWriter, *http.Request, error)
ServeHTTP http.HandlerFunc
}
// A BufferPool is an interface for getting and returning temporary
@ -230,12 +232,16 @@ func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *Revers
}
}
}
return &ReverseProxy{Rewrite: func(pr *ProxyRequest) {
rewriteRequestURL(pr.Out, target)
// pr.SetXForwarded()
setHeaders(pr.Out)
hideHeaders(pr.Out)
}, Transport: transport}
rp := &ReverseProxy{
Rewrite: func(pr *ProxyRequest) {
rewriteRequestURL(pr.Out, target)
// pr.SetXForwarded()
setHeaders(pr.Out)
hideHeaders(pr.Out)
}, Transport: transport,
}
rp.ServeHTTP = rp.serveHTTP
return rp
}
func rewriteRequestURL(req *http.Request, target *url.URL) {
@ -277,7 +283,7 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
return true
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
@ -348,9 +354,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
outreq.Header.Del("Forwarded")
outreq.Header.Del("X-Forwarded-For")
outreq.Header.Del("X-Forwarded-Host")
outreq.Header.Del("X-Forwarded-Proto")
// outreq.Header.Del("X-Forwarded-For")
// outreq.Header.Del("X-Forwarded-Host")
// outreq.Header.Del("X-Forwarded-Proto")
pr := &ProxyRequest{
In: req,

View file

@ -0,0 +1,7 @@
package middleware
var AddXForwarded = &Middleware{
rewrite: func(r *ProxyRequest) {
r.SetXForwarded()
},
}

View file

@ -0,0 +1,135 @@
package middleware
import (
"net/http"
E "github.com/yusing/go-proxy/error"
P "github.com/yusing/go-proxy/proxy"
)
type (
ReverseProxy = P.ReverseProxy
ProxyRequest = P.ProxyRequest
Request = http.Request
Response = http.Response
ResponseWriter = http.ResponseWriter
BeforeFunc func(w ResponseWriter, r *Request) (continue_ bool)
RewriteFunc func(req *ProxyRequest)
ModifyResponseFunc func(res *Response) error
MiddlewareOptionsRaw map[string]string
MiddlewareOptions map[string]interface{}
Middleware struct {
name string
before BeforeFunc
rewrite RewriteFunc
modifyResponse ModifyResponseFunc
options MiddlewareOptions
validateOptions func(opts MiddlewareOptionsRaw) (MiddlewareOptions, E.NestedError)
}
)
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) WithOptions(optsRaw MiddlewareOptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) == 0 {
return m, nil
}
var opts MiddlewareOptions
var err E.NestedError
if m.validateOptions != nil {
if opts, err = m.validateOptions(optsRaw); err != nil {
return nil, err
}
}
return &Middleware{
name: m.name,
before: m.before,
rewrite: m.rewrite,
modifyResponse: m.modifyResponse,
options: opts,
}, nil
}
// TODO: check conflict
func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptionsRaw) (out ReverseProxy, err E.NestedError) {
out = rp
befores := make([]BeforeFunc, 0, len(middlewares))
rewrites := make([]RewriteFunc, 0, len(middlewares))
modifyResponses := make([]ModifyResponseFunc, 0, len(middlewares))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
defer invalidM.Add(invalidOpts.Build())
defer invalidM.To(&err)
for name, opts := range middlewares {
m, ok := Get(name)
if !ok {
invalidM.Addf("%s", name)
continue
}
m, err = m.WithOptions(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
}
if m.before != nil {
befores = append(befores, m.before)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modifyResponses = append(modifyResponses, m.modifyResponse)
}
}
if invalidM.HasError() {
return
}
if len(befores) > 0 {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
for _, before := range befores {
if !before(w, r) {
return
}
}
rp.ServeHTTP(w, r)
}
}
if len(rewrites) > 0 {
rp.Rewrite = func(req *ProxyRequest) {
for _, rewrite := range rewrites {
rewrite(req)
}
}
}
if len(modifyResponses) > 0 {
rp.ModifyResponse = func(res *Response) error {
for _, modifyResponse := range modifyResponses {
if err := modifyResponse(res); err != nil {
return err
}
}
return nil
}
}
return
}

View file

@ -0,0 +1,34 @@
package middleware
import (
"fmt"
"strings"
)
var middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded, // nginx
"add_x_forwarded": AddXForwarded, // nginx
"trust_forward_header": AddXForwarded, // traefik alias
"redirect_http": RedirectHTTP,
}
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name]
return
}
// initialize middleware names
var _ = func() (_ bool) {
names := make(map[*Middleware][]string)
for name, m := range middlewares {
names[m] = append(names[m], name)
}
for m, names := range names {
if len(names) > 1 {
m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", "))
} else {
m.name = names[0]
}
}
return
}()

View file

@ -0,0 +1,20 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/common"
)
var RedirectHTTP = &Middleware{
before: func(w ResponseWriter, r *Request) (continue_ bool) {
if r.TLS == nil {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + common.ProxyHTTPSPort
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
} else {
continue_ = true
}
return
},
}

View file

@ -0,0 +1,10 @@
package middleware
var SetXForwarded = &Middleware{
rewrite: func(r *ProxyRequest) {
r.Out.Header.Del("X-Forwarded-For")
r.Out.Header.Del("X-Forwarded-Host")
r.Out.Header.Del("X-Forwarded-Proto")
r.SetXForwarded()
},
}