GoDoxy/internal/net/http/middleware/middleware.go

216 lines
4.9 KiB
Go

package middleware
import (
"encoding/json"
"net/http"
"reflect"
"strings"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Error = E.Error
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
ImplNewFunc = func() any
OptionsRaw = map[string]any
Middleware struct {
name string
construct ImplNewFunc
impl any
}
RequestModifier interface {
before(w http.ResponseWriter, r *http.Request) (proceed bool)
}
ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() }
MiddlewareWithTracer *struct{ *Tracer }
)
func NewMiddleware[ImplType any]() *Middleware {
// type check
switch any(new(ImplType)).(type) {
case RequestModifier:
case ResponseModifier:
default:
panic("must implement RequestModifier or ResponseModifier")
}
return &Middleware{
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
construct: func() any { return new(ImplType) },
}
}
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.Tracer = &Tracer{name: m.name}
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.Tracer
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.parent = parent.getTracer()
}
}
func (m *Middleware) setup() {
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
setup.setup()
}
}
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
if len(optsRaw) == 0 {
return nil
}
return utils.Deserialize(optsRaw, m.impl)
}
func (m *Middleware) finalize() {
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize()
}
}
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.construct == nil {
if optsRaw != nil {
panic("bug: middleware already constructed")
}
return m, nil
}
mid := &Middleware{name: m.name, impl: m.construct()}
mid.setup()
if err := mid.apply(optsRaw); err != nil {
return nil, err
}
mid.finalize()
return mid, nil
}
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(map[string]any{
"name": m.name,
"options": m.impl,
}, "", " ")
}
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
func (m *Middleware) ModifyResponse(resp *http.Response) error {
if exec, ok := m.impl.(ResponseModifier); ok {
return exec.modifyResponse(resp)
}
return nil
}
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(ResponseModifier); ok {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
return exec.modifyResponse(resp)
})
}
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
// TODO: check conflict or duplicates.
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := E.NewBuilder("middlewares compile error")
invalidOpts := E.NewBuilder("options compile error")
for name, opts := range middlewaresMap {
m, err := Get(name)
if err != nil {
errs.Add(err)
continue
}
m, err = m.New(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
}
middlewares = append(middlewares, m)
}
if invalidOpts.HasError() {
errs.Add(invalidOpts.Error())
}
return middlewares, errs.Error()
}
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
var middlewares []*Middleware
middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil {
return
}
patchReverseProxy(rp, middlewares)
return
}
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
mid := NewMiddlewareChain(rp.TargetName, middlewares)
if before, ok := mid.impl.(RequestModifier); ok {
next := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
if proceed := before.before(w, r); proceed {
next(w, r)
}
}
}
if mr, ok := mid.impl.(ResponseModifier); ok {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *http.Response) error {
if err := mr.modifyResponse(res); err != nil {
return err
}
return ori(res)
}
} else {
rp.ModifyResponse = mr.modifyResponse
}
}
}