mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 20:52:33 +02:00
216 lines
4.9 KiB
Go
216 lines
4.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
|
|
E "github.com/yusing/go-proxy/internal/error"
|
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
|
"github.com/yusing/go-proxy/internal/utils"
|
|
)
|
|
|
|
type (
|
|
Error = E.Error
|
|
|
|
ReverseProxy = gphttp.ReverseProxy
|
|
ProxyRequest = gphttp.ProxyRequest
|
|
|
|
ImplNewFunc = func() any
|
|
OptionsRaw = map[string]any
|
|
|
|
Middleware struct {
|
|
name string
|
|
construct ImplNewFunc
|
|
impl any
|
|
}
|
|
|
|
RequestModifier interface {
|
|
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
|
}
|
|
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
|
MiddlewareWithSetup interface{ setup() }
|
|
MiddlewareFinalizer interface{ finalize() }
|
|
MiddlewareWithTracer *struct{ *Tracer }
|
|
)
|
|
|
|
func NewMiddleware[ImplType any]() *Middleware {
|
|
// type check
|
|
switch any(new(ImplType)).(type) {
|
|
case RequestModifier:
|
|
case ResponseModifier:
|
|
default:
|
|
panic("must implement RequestModifier or ResponseModifier")
|
|
}
|
|
return &Middleware{
|
|
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
|
construct: func() any { return new(ImplType) },
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) enableTrace() {
|
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
|
tracer.Tracer = &Tracer{name: m.name}
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) getTracer() *Tracer {
|
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
|
return tracer.Tracer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Middleware) setParent(parent *Middleware) {
|
|
if tracer := m.getTracer(); tracer != nil {
|
|
tracer.parent = parent.getTracer()
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) setup() {
|
|
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
|
|
setup.setup()
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
|
|
if len(optsRaw) == 0 {
|
|
return nil
|
|
}
|
|
return utils.Deserialize(optsRaw, m.impl)
|
|
}
|
|
|
|
func (m *Middleware) finalize() {
|
|
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
|
finalizer.finalize()
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
|
if m.construct == nil {
|
|
if optsRaw != nil {
|
|
panic("bug: middleware already constructed")
|
|
}
|
|
return m, nil
|
|
}
|
|
mid := &Middleware{name: m.name, impl: m.construct()}
|
|
mid.setup()
|
|
if err := mid.apply(optsRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
mid.finalize()
|
|
return mid, nil
|
|
}
|
|
|
|
func (m *Middleware) Name() string {
|
|
return m.name
|
|
}
|
|
|
|
func (m *Middleware) String() string {
|
|
return m.name
|
|
}
|
|
|
|
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
|
return json.MarshalIndent(map[string]any{
|
|
"name": m.name,
|
|
"options": m.impl,
|
|
}, "", " ")
|
|
}
|
|
|
|
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
|
if exec, ok := m.impl.(RequestModifier); ok {
|
|
if proceed := exec.before(w, r); !proceed {
|
|
return
|
|
}
|
|
}
|
|
next(w, r)
|
|
}
|
|
|
|
func (m *Middleware) ModifyResponse(resp *http.Response) error {
|
|
if exec, ok := m.impl.(ResponseModifier); ok {
|
|
return exec.modifyResponse(resp)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
|
if exec, ok := m.impl.(ResponseModifier); ok {
|
|
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
|
return exec.modifyResponse(resp)
|
|
})
|
|
}
|
|
if exec, ok := m.impl.(RequestModifier); ok {
|
|
if proceed := exec.before(w, r); !proceed {
|
|
return
|
|
}
|
|
}
|
|
next(w, r)
|
|
}
|
|
|
|
// TODO: check conflict or duplicates.
|
|
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
|
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
|
|
|
errs := E.NewBuilder("middlewares compile error")
|
|
invalidOpts := E.NewBuilder("options compile error")
|
|
|
|
for name, opts := range middlewaresMap {
|
|
m, err := Get(name)
|
|
if err != nil {
|
|
errs.Add(err)
|
|
continue
|
|
}
|
|
|
|
m, err = m.New(opts)
|
|
if err != nil {
|
|
invalidOpts.Add(err.Subject(name))
|
|
continue
|
|
}
|
|
middlewares = append(middlewares, m)
|
|
}
|
|
|
|
if invalidOpts.HasError() {
|
|
errs.Add(invalidOpts.Error())
|
|
}
|
|
return middlewares, errs.Error()
|
|
}
|
|
|
|
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
|
var middlewares []*Middleware
|
|
middlewares, err = compileMiddlewares(middlewaresMap)
|
|
if err != nil {
|
|
return
|
|
}
|
|
patchReverseProxy(rp, middlewares)
|
|
return
|
|
}
|
|
|
|
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
|
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
|
|
|
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
|
|
|
if before, ok := mid.impl.(RequestModifier); ok {
|
|
next := rp.HandlerFunc
|
|
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
|
if proceed := before.before(w, r); proceed {
|
|
next(w, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
if mr, ok := mid.impl.(ResponseModifier); ok {
|
|
if rp.ModifyResponse != nil {
|
|
ori := rp.ModifyResponse
|
|
rp.ModifyResponse = func(res *http.Response) error {
|
|
if err := mr.modifyResponse(res); err != nil {
|
|
return err
|
|
}
|
|
return ori(res)
|
|
}
|
|
} else {
|
|
rp.ModifyResponse = mr.modifyResponse
|
|
}
|
|
}
|
|
}
|