mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 20:52:33 +02:00
249 lines
6 KiB
Go
249 lines
6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"maps"
|
|
"net/http"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/yusing/go-proxy/internal/gperr"
|
|
"github.com/yusing/go-proxy/internal/logging"
|
|
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
|
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
|
"github.com/yusing/go-proxy/internal/utils"
|
|
)
|
|
|
|
type (
|
|
Error = gperr.Error
|
|
|
|
ReverseProxy = reverseproxy.ReverseProxy
|
|
ProxyRequest = reverseproxy.ProxyRequest
|
|
|
|
ImplNewFunc = func() any
|
|
OptionsRaw = map[string]any
|
|
|
|
commonOptions = struct {
|
|
// priority is only applied for ReverseProxy.
|
|
//
|
|
// Middleware compose follows the order of the slice
|
|
//
|
|
// Default is 10, 0 is the highest
|
|
Priority int `json:"priority"`
|
|
Bypass Bypass `json:"bypass"`
|
|
}
|
|
|
|
Middleware struct {
|
|
name string
|
|
construct ImplNewFunc
|
|
impl any
|
|
|
|
commonOptions
|
|
}
|
|
ByPriority []*Middleware
|
|
|
|
RequestModifier interface {
|
|
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
|
}
|
|
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
|
MiddlewareWithSetup interface{ setup() }
|
|
MiddlewareFinalizer interface{ finalize() }
|
|
MiddlewareFinalizerWithError interface {
|
|
finalize() error
|
|
}
|
|
MiddlewareWithTracer interface {
|
|
enableTrace()
|
|
getTracer() *Tracer
|
|
}
|
|
)
|
|
|
|
const DefaultPriority = 10
|
|
|
|
func (m ByPriority) Len() int { return len(m) }
|
|
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
|
|
func (m ByPriority) Less(i, j int) bool { return m[i].Priority < m[j].Priority }
|
|
|
|
func NewMiddleware[ImplType any]() *Middleware {
|
|
// type check
|
|
t := any(new(ImplType))
|
|
switch t.(type) {
|
|
case RequestModifier:
|
|
case ResponseModifier:
|
|
default:
|
|
panic("must implement RequestModifier or ResponseModifier")
|
|
}
|
|
_, hasFinializer := t.(MiddlewareFinalizer)
|
|
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
|
|
if hasFinializer && hasFinializerWithError {
|
|
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
|
|
}
|
|
return &Middleware{
|
|
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
|
construct: func() any { return new(ImplType) },
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) enableTrace() {
|
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
|
tracer.enableTrace()
|
|
logging.Trace().Msgf("middleware %s enabled trace", m.name)
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) getTracer() *Tracer {
|
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
|
return tracer.getTracer()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Middleware) setParent(parent *Middleware) {
|
|
if tracer := m.getTracer(); tracer != nil {
|
|
tracer.SetParent(parent.getTracer())
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) setup() {
|
|
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
|
|
setup.setup()
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
|
if len(optsRaw) == 0 {
|
|
return nil
|
|
}
|
|
commonOpts := map[string]any{
|
|
"priority": optsRaw["priority"],
|
|
"bypass": optsRaw["bypass"],
|
|
}
|
|
if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
|
|
return err
|
|
}
|
|
optsRaw = maps.Clone(optsRaw)
|
|
for k := range commonOpts {
|
|
delete(optsRaw, k)
|
|
}
|
|
return utils.MapUnmarshalValidate(optsRaw, m.impl)
|
|
}
|
|
|
|
func (m *Middleware) finalize() error {
|
|
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
|
finalizer.finalize()
|
|
}
|
|
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
|
|
return finalizer.finalize()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
|
|
if m.construct == nil { // likely a middleware from compose
|
|
if len(optsRaw) != 0 {
|
|
return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
|
|
}
|
|
return m, nil
|
|
}
|
|
mid := &Middleware{name: m.name, impl: m.construct()}
|
|
mid.setup()
|
|
if err := mid.apply(optsRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := mid.finalize(); err != nil {
|
|
return nil, gperr.Wrap(err)
|
|
}
|
|
mid.impl = mid.withCheckBypass()
|
|
return mid, nil
|
|
}
|
|
|
|
func (m *Middleware) Name() string {
|
|
return m.name
|
|
}
|
|
|
|
func (m *Middleware) String() string {
|
|
return m.name
|
|
}
|
|
|
|
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
|
type allOptions struct {
|
|
commonOptions
|
|
any
|
|
}
|
|
return json.MarshalIndent(map[string]any{
|
|
"name": m.name,
|
|
"options": allOptions{
|
|
commonOptions: m.commonOptions,
|
|
any: m.impl,
|
|
},
|
|
}, "", " ")
|
|
}
|
|
|
|
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
|
if exec, ok := m.impl.(RequestModifier); ok {
|
|
if proceed := exec.before(w, r); !proceed {
|
|
return
|
|
}
|
|
}
|
|
next(w, r)
|
|
}
|
|
|
|
func (m *Middleware) ModifyResponse(resp *http.Response) error {
|
|
if exec, ok := m.impl.(ResponseModifier); ok {
|
|
return exec.modifyResponse(resp)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
|
if exec, ok := m.impl.(ResponseModifier); ok {
|
|
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
|
return exec.modifyResponse(resp)
|
|
})
|
|
}
|
|
if exec, ok := m.impl.(RequestModifier); ok {
|
|
if proceed := exec.before(w, r); !proceed {
|
|
return
|
|
}
|
|
}
|
|
next(w, r)
|
|
}
|
|
|
|
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
|
|
var middlewares []*Middleware
|
|
middlewares, err = compileMiddlewares(middlewaresMap)
|
|
if err != nil {
|
|
return
|
|
}
|
|
patchReverseProxy(rp, middlewares)
|
|
return
|
|
}
|
|
|
|
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
|
sort.Sort(ByPriority(middlewares))
|
|
|
|
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
|
|
|
if before, ok := mid.impl.(RequestModifier); ok {
|
|
next := rp.HandlerFunc
|
|
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
|
if proceed := before.before(w, r); proceed {
|
|
next(w, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
if mr, ok := mid.impl.(ResponseModifier); ok {
|
|
if rp.ModifyResponse != nil {
|
|
ori := rp.ModifyResponse
|
|
rp.ModifyResponse = func(res *http.Response) error {
|
|
if err := mr.modifyResponse(res); err != nil {
|
|
return err
|
|
}
|
|
return ori(res)
|
|
}
|
|
} else {
|
|
rp.ModifyResponse = mr.modifyResponse
|
|
}
|
|
}
|
|
}
|