refactor: remove forward auth, move module net/http to net/gphttp

This commit is contained in:
yusing 2025-03-28 07:03:35 +08:00
parent c0c6e21a16
commit 5d2df3550b
69 changed files with 321 additions and 745 deletions

View file

@ -1,13 +1,14 @@
package accesslog package accesslog
import ( import (
"bufio"
"bytes" "bytes"
"io" "io"
"net/http" "net/http"
"sync" "sync"
"time" "time"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
) )
@ -17,13 +18,9 @@ type (
task *task.Task task *task.Task
cfg *Config cfg *Config
io AccessLogIO io AccessLogIO
buffered *bufio.Writer
buf bytes.Buffer // buffer for non-flushed log lineBufPool sync.Pool // buffer pool for formatting a single log line
bufMu sync.RWMutex
bufPool sync.Pool // buffer pool for formatting a single log line
flushThreshold int
Formatter Formatter
} }
@ -44,13 +41,17 @@ type (
) )
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger { func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
if cfg.BufferSize == 0 {
cfg.BufferSize = DefaultBufferSize
}
if cfg.BufferSize < 4096 {
cfg.BufferSize = 4096
}
l := &AccessLogger{ l := &AccessLogger{
task: parent.Subtask("accesslog"), task: parent.Subtask("accesslog"),
cfg: cfg, cfg: cfg,
io: io, io: io,
} buffered: bufio.NewWriterSize(io, cfg.BufferSize),
if cfg.BufferSize < 1024 {
cfg.BufferSize = DefaultBufferSize
} }
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now} fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
@ -65,10 +66,8 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog
panic("invalid access log format") panic("invalid access log format")
} }
l.flushThreshold = int(cfg.BufferSize * 4 / 5) // 80% l.lineBufPool.New = func() any {
l.buf.Grow(int(cfg.BufferSize)) return bytes.NewBuffer(make([]byte, 0, 1024))
l.bufPool.New = func() any {
return new(bytes.Buffer)
} }
go l.start() go l.start()
return l return l
@ -89,15 +88,12 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
return return
} }
line := l.bufPool.Get().(*bytes.Buffer) line := l.lineBufPool.Get().(*bytes.Buffer)
l.Format(line, req, res)
line.WriteRune('\n')
l.bufMu.Lock()
l.buf.Write(line.Bytes())
line.Reset() line.Reset()
l.bufPool.Put(line) defer l.lineBufPool.Put(line)
l.bufMu.Unlock() l.Formatter.Format(line, req, res)
line.WriteRune('\n')
l.write(line.Bytes())
} }
func (l *AccessLogger) LogError(req *http.Request, err error) { func (l *AccessLogger) LogError(req *http.Request, err error) {
@ -115,55 +111,53 @@ func (l *AccessLogger) Rotate() error {
l.io.Lock() l.io.Lock()
defer l.io.Unlock() defer l.io.Unlock()
return l.cfg.Retention.rotateLogFile(l.io) return l.rotate()
}
func (l *AccessLogger) Flush(force bool) {
if l.buf.Len() == 0 {
return
}
if force || l.buf.Len() >= l.flushThreshold {
l.bufMu.RLock()
l.write(l.buf.Bytes())
l.buf.Reset()
l.bufMu.RUnlock()
}
} }
func (l *AccessLogger) handleErr(err error) { func (l *AccessLogger) handleErr(err error) {
E.LogError("failed to write access log", err) gperr.LogError("failed to write access log", err)
} }
func (l *AccessLogger) start() { func (l *AccessLogger) start() {
defer func() { defer func() {
if l.buf.Len() > 0 { // flush last if err := l.Flush(); err != nil {
l.write(l.buf.Bytes()) l.handleErr(err)
} }
l.io.Close() l.close()
l.task.Finish(nil) l.task.Finish(nil)
}() }()
// periodic flush + threshold flush // flushes the buffer every 30 seconds
periodic := time.NewTicker(5 * time.Second) flushTicker := time.NewTicker(30 * time.Second)
threshold := time.NewTicker(time.Second) defer flushTicker.Stop()
defer periodic.Stop()
defer threshold.Stop()
for { for {
select { select {
case <-l.task.Context().Done(): case <-l.task.Context().Done():
return return
case <-periodic.C: case <-flushTicker.C:
l.Flush(true) if err := l.Flush(); err != nil {
case <-threshold.C: l.handleErr(err)
l.Flush(false) }
} }
} }
} }
func (l *AccessLogger) Flush() error {
l.io.Lock()
defer l.io.Unlock()
return l.buffered.Flush()
}
func (l *AccessLogger) close() {
l.io.Lock()
defer l.io.Unlock()
l.io.Close()
}
func (l *AccessLogger) write(data []byte) { func (l *AccessLogger) write(data []byte) {
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
_, err := l.io.Write(data) _, err := l.buffered.Write(data)
l.io.Unlock() l.io.Unlock()
if err != nil { if err != nil {
l.handleErr(err) l.handleErr(err)

View file

@ -9,7 +9,7 @@ import (
"testing" "testing"
"time" "time"
. "github.com/yusing/go-proxy/internal/net/http/accesslog" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )

View file

@ -17,7 +17,7 @@ type (
Cookies FieldConfig `json:"cookies"` Cookies FieldConfig `json:"cookies"`
} }
Config struct { Config struct {
BufferSize uint `json:"buffer_size" validate:"gte=1"` BufferSize int `json:"buffer_size"`
Format Format `json:"format" validate:"oneof=common combined json"` Format Format `json:"format" validate:"oneof=common combined json"`
Path string `json:"path" validate:"required"` Path string `json:"path" validate:"required"`
Filters Filters `json:"filters"` Filters Filters `json:"filters"`

View file

@ -4,7 +4,7 @@ import (
"testing" "testing"
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/net/http/accesslog" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )

View file

@ -3,7 +3,7 @@ package accesslog_test
import ( import (
"testing" "testing"
. "github.com/yusing/go-proxy/internal/net/http/accesslog" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )

View file

@ -71,14 +71,14 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
go func(l *AccessLogger) { go func(l *AccessLogger) {
defer wg.Done() defer wg.Done()
parallelLog(l, req, resp, logCountPerLogger) parallelLog(l, req, resp, logCountPerLogger)
l.Flush(true) l.Flush()
}(logger) }(logger)
} }
wg.Wait() wg.Wait()
expected := loggerCount * logCountPerLogger expected := loggerCount * logCountPerLogger
actual := file.Count() actual := file.LineCount()
ExpectEqual(t, actual, expected) ExpectEqual(t, actual, expected)
} }

View file

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"strings" "strings"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -27,7 +27,7 @@ type (
CIDR struct{ types.CIDR } CIDR struct{ types.CIDR }
) )
var ErrInvalidHTTPHeaderFilter = E.New("invalid http header filter") var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool { func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool {
if len(f.Values) == 0 { if len(f.Values) == 0 {

View file

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"testing" "testing"
. "github.com/yusing/go-proxy/internal/net/http/accesslog" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )

View file

@ -49,7 +49,6 @@ func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) {
return 0, io.EOF return 0, io.EOF
} }
n = copy(p, m.data[off:]) n = copy(p, m.data[off:])
m.position += int64(n)
return n, nil return n, nil
} }
@ -63,7 +62,7 @@ func (m *MockFile) Truncate(size int64) error {
return nil return nil
} }
func (m *MockFile) Count() int { func (m *MockFile) LineCount() int {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
return bytes.Count(m.data[:m.position], []byte("\n")) return bytes.Count(m.data[:m.position], []byte("\n"))
@ -72,3 +71,7 @@ func (m *MockFile) Count() int {
func (m *MockFile) Len() int64 { func (m *MockFile) Len() int64 {
return m.position return m.position
} }
func (m *MockFile) Content() []byte {
return m.data[:m.position]
}

View file

@ -0,0 +1,56 @@
package accesslog
import (
"strconv"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Retention struct {
Days uint64 `json:"days"`
Last uint64 `json:"last"`
}
var (
ErrInvalidSyntax = gperr.New("invalid syntax")
ErrZeroValue = gperr.New("zero value")
)
var defaultChunkSize = 64 * 1024 // 64KB
// Syntax:
//
// <N> days|weeks|months
//
// last <N>
//
// Parse implements strutils.Parser.
func (r *Retention) Parse(v string) (err error) {
split := strutils.SplitSpace(v)
if len(split) != 2 {
return ErrInvalidSyntax.Subject(v)
}
switch split[0] {
case "last":
r.Last, err = strconv.ParseUint(split[1], 10, 64)
default: // <N> days|weeks|months
r.Days, err = strconv.ParseUint(split[0], 10, 64)
if err != nil {
return
}
switch split[1] {
case "days":
case "weeks":
r.Days *= 7
case "months":
r.Days *= 30
default:
return ErrInvalidSyntax.Subject("unit " + split[1])
}
}
if r.Days == 0 && r.Last == 0 {
return ErrZeroValue
}
return
}

View file

@ -0,0 +1,33 @@
package accesslog_test
import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRetention(t *testing.T) {
tests := []struct {
input string
expected *Retention
shouldErr bool
}{
{"30 days", &Retention{Days: 30}, false},
{"2 weeks", &Retention{Days: 14}, false},
{"last 5", &Retention{Last: 5}, false},
{"invalid input", &Retention{}, true},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
r := &Retention{}
err := r.Parse(test.input)
if !test.shouldErr {
ExpectNoError(t, err)
} else {
ExpectDeepEqual(t, r, test.expected)
}
})
}
}

View file

@ -3,7 +3,7 @@ package accesslog
import ( import (
"strconv" "strconv"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -12,7 +12,7 @@ type StatusCodeRange struct {
End int End int
} }
var ErrInvalidStatusCodeRange = E.New("invalid status code range") var ErrInvalidStatusCodeRange = gperr.New("invalid status code range")
func (r *StatusCodeRange) Includes(code int) bool { func (r *StatusCodeRange) Includes(code int) bool {
return r.Start <= code && code <= r.End return r.Start <= code && code <= r.End
@ -25,7 +25,7 @@ func (r *StatusCodeRange) Parse(v string) error {
case 1: case 1:
start, err := strconv.Atoi(split[0]) start, err := strconv.Atoi(split[0])
if err != nil { if err != nil {
return E.From(err) return gperr.Wrap(err)
} }
r.Start = start r.Start = start
r.End = start r.End = start
@ -33,7 +33,7 @@ func (r *StatusCodeRange) Parse(v string) error {
case 2: case 2:
start, errStart := strconv.Atoi(split[0]) start, errStart := strconv.Atoi(split[0])
end, errEnd := strconv.Atoi(split[1]) end, errEnd := strconv.Atoi(split[1])
if err := E.Join(errStart, errEnd); err != nil { if err := gperr.Join(errStart, errEnd); err != nil {
return err return err
} }
r.Start = start r.Start = start

View file

@ -1,4 +1,4 @@
package http package gphttp
import ( import (
"mime" "mime"

View file

@ -1,4 +1,4 @@
package http package gphttp
import ( import (
"net/http" "net/http"

View file

@ -6,8 +6,8 @@ import (
"net/http" "net/http"
"sync" "sync"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
) )
type ipHash struct { type ipHash struct {
@ -23,10 +23,10 @@ func (lb *LoadBalancer) newIPHash() impl {
if len(lb.Options) == 0 { if len(lb.Options) == 0 {
return impl return impl
} }
var err E.Error var err gperr.Error
impl.realIP, err = middleware.RealIP.New(lb.Options) impl.realIP, err = middleware.RealIP.New(lb.Options)
if err != nil { if err != nil {
E.LogError("invalid real_ip options, ignoring", err, &impl.l) gperr.LogError("invalid real_ip options, ignoring", err, &impl.l)
} }
return impl return impl
} }

View file

@ -6,10 +6,10 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
@ -54,7 +54,7 @@ func New(cfg *Config) *LoadBalancer {
} }
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(parent task.Parent) E.Error { func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
lb.startTime = time.Now() lb.startTime = time.Now()
lb.task = parent.Subtask("loadbalancer."+lb.Link, false) lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
parent.OnCancel("lb_remove_route", func() { parent.OnCancel("lb_remove_route", func() {
@ -125,12 +125,12 @@ func (lb *LoadBalancer) AddServer(srv Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
if lb.pool.Has(srv.Name()) { if lb.pool.Has(srv.Key()) { // FIXME: this should be a warning
old, _ := lb.pool.Load(srv.Name()) old, _ := lb.pool.Load(srv.Key())
lb.sumWeight -= old.Weight() lb.sumWeight -= old.Weight()
lb.impl.OnRemoveServer(old) lb.impl.OnRemoveServer(old)
} }
lb.pool.Store(srv.Name(), srv) lb.pool.Store(srv.Key(), srv)
lb.sumWeight += srv.Weight() lb.sumWeight += srv.Weight()
lb.rebalance() lb.rebalance()
@ -146,11 +146,11 @@ func (lb *LoadBalancer) RemoveServer(srv Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
if !lb.pool.Has(srv.Name()) { if !lb.pool.Has(srv.Key()) {
return return
} }
lb.pool.Delete(srv.Name()) lb.pool.Delete(srv.Key())
lb.sumWeight -= srv.Weight() lb.sumWeight -= srv.Weight()
lb.rebalance() lb.rebalance()
@ -227,7 +227,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return return
} }
if r.Header.Get(common.HeaderCheckRedirect) != "" { if r.Header.Get(httpheaders.HeaderGoDoxyCheckRedirect) != "" {
// wake all servers // wake all servers
for _, srv := range srvs { for _, srv := range srvs {
if err := srv.TryWake(); err != nil { if err := srv.TryWake(); err != nil {
@ -244,7 +244,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) { func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any) extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v Server) { lb.pool.RangeAll(func(k string, v Server) {
extra[v.Name()] = v extra[v.Key()] = v
}) })
return (&monitor.JSONRepresentation{ return (&monitor.JSONRepresentation{

View file

@ -3,7 +3,7 @@ package loadbalancer
import ( import (
"testing" "testing"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )

View file

@ -1,7 +1,7 @@
package loadbalancer package loadbalancer
import ( import (
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
) )
type ( type (

View file

@ -26,6 +26,7 @@ type (
http.Handler http.Handler
health.HealthMonitor health.HealthMonitor
Name() string Name() string
Key() string
URL() *net.URL URL() *net.URL
Weight() Weight Weight() Weight
SetWeight(weight Weight) SetWeight(weight Weight)
@ -51,6 +52,7 @@ func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, h
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server { func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
srv := &server{ srv := &server{
weight: Weight(weight), weight: Weight(weight),
url: net.MustParseURL("http://localhost"),
} }
return srv return srv
} }
@ -63,6 +65,10 @@ func (srv *server) URL() *net.URL {
return srv.url return srv.url
} }
func (srv *server) Key() string {
return srv.url.Host
}
func (srv *server) Weight() Weight { func (srv *server) Weight() Weight {
return srv.weight return srv.weight
} }
@ -78,9 +84,7 @@ func (srv *server) String() string {
func (srv *server) TryWake() error { func (srv *server) TryWake() error {
waker, ok := srv.Handler.(idlewatcher.Waker) waker, ok := srv.Handler.(idlewatcher.Waker)
if ok { if ok {
if err := waker.Wake(); err != nil { return waker.Wake()
return err
}
} }
return nil return nil
} }

View file

@ -1,4 +1,4 @@
package http package gphttp
import "net/http" import "net/http"

View file

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"

View file

@ -7,7 +7,7 @@ import (
"strings" "strings"
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -61,7 +61,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
} }
func TestCIDRWhitelist(t *testing.T) { func TestCIDRWhitelist(t *testing.T) {
errs := E.NewBuilder("") errs := gperr.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs) mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
ExpectNoError(t, errs.Error()) ExpectNoError(t, errs.Error())
deny = mids["deny@file"] deny = mids["deny@file"]

View file

@ -12,6 +12,7 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -28,7 +29,7 @@ const (
) )
var ( var (
cfCIDRsLastUpdate time.Time cfCIDRsLastUpdate atomic.Value[time.Time]
cfCIDRsMu sync.Mutex cfCIDRsMu sync.Mutex
// RFC 1918. // RFC 1918.
@ -68,14 +69,14 @@ func (cri *cloudflareRealIP) getTracer() *Tracer {
} }
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval { if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return return
} }
cfCIDRsMu.Lock() cfCIDRsMu.Lock()
defer cfCIDRsMu.Unlock() defer cfCIDRsMu.Unlock()
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval { if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return return
} }
@ -88,7 +89,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs), fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
) )
if err != nil { if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval) cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval)) logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil return nil
} }
@ -97,7 +98,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
} }
} }
cfCIDRsLastUpdate = time.Now() cfCIDRsLastUpdate.Store(time.Now())
logging.Info().Msg("cloudflare CIDR range updated") logging.Info().Msg("cloudflare CIDR range updated")
return return
} }

View file

@ -9,14 +9,17 @@ import (
"strings" "strings"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
) )
type customErrorPage struct{} type customErrorPage struct{}
var CustomErrorPage = NewMiddleware[customErrorPage]() var CustomErrorPage = NewMiddleware[customErrorPage]()
const StaticFilePathPrefix = "/$gperrorpage/"
// before implements RequestModifier. // before implements RequestModifier.
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) { func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
return !ServeStaticErrorPageFile(w, r) return !ServeStaticErrorPageFile(w, r)
@ -34,8 +37,8 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
resp.Body.Close() resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage)) resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage)) resp.ContentLength = int64(len(errorPage))
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage))) resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8") resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
} else { } else {
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode) logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
} }
@ -49,8 +52,8 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
if path != "" && path[0] != '/' { if path != "" && path[0] != '/' {
path = "/" + path path = "/" + path
} }
if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) { if strings.HasPrefix(path, StaticFilePathPrefix) {
filename := path[len(gphttp.StaticFilePathPrefix):] filename := path[len(StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename) file, ok := errorpage.GetStaticFile(filename)
if !ok { if !ok {
logging.Error().Msg("unable to load resource " + filename) logging.Error().Msg("unable to load resource " + filename)
@ -59,11 +62,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
ext := filepath.Ext(filename) ext := filepath.Ext(filename)
switch ext { switch ext {
case ".html": case ".html":
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8") w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
case ".js": case ".js":
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8") w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
case ".css": case ".css":
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8") w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
default: default:
logging.Error().Msgf("unexpected file type %q for %s", ext, filename) logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
} }

View file

@ -7,7 +7,7 @@ import (
"sync" "sync"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
@ -90,7 +90,7 @@ func watchDir() {
loadContent() loadContent()
} }
case err := <-errCh: case err := <-errCh:
E.LogError("error watching error page directory", err) gperr.LogError("error watching error page directory", err)
} }
} }
} }

View file

@ -7,15 +7,15 @@ import (
"sort" "sort"
"strings" "strings"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
type ( type (
Error = E.Error Error = gperr.Error
ReverseProxy = reverseproxy.ReverseProxy ReverseProxy = reverseproxy.ReverseProxy
ProxyRequest = reverseproxy.ProxyRequest ProxyRequest = reverseproxy.ProxyRequest
@ -80,7 +80,7 @@ func NewMiddleware[ImplType any]() *Middleware {
func (m *Middleware) enableTrace() { func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok { if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.enableTrace() tracer.enableTrace()
logging.Debug().Msgf("middleware %s enabled trace", m.name) logging.Trace().Msgf("middleware %s enabled trace", m.name)
} }
} }
@ -103,7 +103,7 @@ func (m *Middleware) setup() {
} }
} }
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error { func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
if len(optsRaw) == 0 { if len(optsRaw) == 0 {
return nil return nil
} }
@ -132,10 +132,10 @@ func (m *Middleware) finalize() error {
return nil return nil
} }
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
if m.construct == nil { // likely a middleware from compose if m.construct == nil { // likely a middleware from compose
if len(optsRaw) != 0 { if len(optsRaw) != 0 {
return nil, E.New("additional options not allowed for middleware ").Subject(m.name) return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
} }
return m, nil return m, nil
} }
@ -145,7 +145,7 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
return nil, err return nil, err
} }
if err := mid.finalize(); err != nil { if err := mid.finalize(); err != nil {
return nil, E.From(err) return nil, gperr.Wrap(err)
} }
return mid, nil return mid, nil
} }
@ -196,7 +196,7 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
next(w, r) next(w, r)
} }
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) { func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
var middlewares []*Middleware var middlewares []*Middleware
middlewares, err = compileMiddlewares(middlewaresMap) middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil { if err != nil {

View file

@ -6,13 +6,13 @@ import (
"path" "path"
"sort" "sort"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
var ErrMissingMiddlewareUse = E.New("missing middleware 'use' field") var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware { func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
fileContent, err := os.ReadFile(filePath) fileContent, err := os.ReadFile(filePath)
if err != nil { if err != nil {
eb.Add(err) eb.Add(err)
@ -21,7 +21,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb) return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
} }
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware { func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware {
var rawMap map[string][]map[string]any var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap) err := yaml.Unmarshal(data, &rawMap)
if err != nil { if err != nil {
@ -40,11 +40,11 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str
return middlewares return middlewares
} }
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) { func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap)) middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := E.NewBuilder("middlewares compile error") errs := gperr.NewBuilder("middlewares compile error")
invalidOpts := E.NewBuilder("options compile error") invalidOpts := gperr.NewBuilder("options compile error")
for name, opts := range middlewaresMap { for name, opts := range middlewaresMap {
m, err := Get(name) m, err := Get(name)
@ -68,7 +68,7 @@ func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.
return middlewares, errs.Error() return middlewares, errs.Error()
} }
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, E.Error) { func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
compiled, err := compileMiddlewares(middlewaresMap) compiled, err := compileMiddlewares(middlewaresMap)
if err != nil { if err != nil {
return nil, err return nil, err
@ -77,8 +77,8 @@ func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (
} }
// TODO: check conflict or duplicates. // TODO: check conflict or duplicates.
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, E.Error) { func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
chainErr := E.NewBuilder("") chainErr := gperr.NewBuilder("")
chain := make([]*Middleware, 0, len(defs)) chain := make([]*Middleware, 0, len(defs))
for i, def := range defs { for i, def := range defs {
if def["use"] == nil || def["use"] == "" { if def["use"] == nil || def["use"] == "" {

View file

@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -13,7 +13,7 @@ import (
var testMiddlewareCompose []byte var testMiddlewareCompose []byte
func TestBuild(t *testing.T) { func TestBuild(t *testing.T) {
errs := E.NewBuilder("") errs := gperr.NewBuilder("")
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs) middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
ExpectNoError(t, errs.Error()) ExpectNoError(t, errs.Error())
Must(json.MarshalIndent(middlewares, "", " ")) Must(json.MarshalIndent(middlewares, "", " "))

View file

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
) )
type middlewareChain struct { type middlewareChain struct {
@ -51,10 +51,10 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error {
if len(m.modResps) == 0 { if len(m.modResps) == 0 {
return nil return nil
} }
errs := E.NewBuilder("modify response errors") errs := gperr.NewBuilder("modify response errors")
for i, mr := range m.modResps { for i, mr := range m.modResps {
if err := mr.modifyResponse(resp); err != nil { if err := mr.modifyResponse(resp); err != nil {
errs.Add(E.From(err).Subjectf("%d", i)) errs.Add(gperr.Wrap(err).Subjectf("%d", i))
} }
} }
return errs.Error() return errs.Error()

View file

@ -4,7 +4,7 @@ import (
"path" "path"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
@ -32,15 +32,11 @@ var allMiddlewares = map[string]*Middleware{
"cidrwhitelist": CIDRWhiteList, "cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter, "ratelimit": RateLimiter,
// !experimental
"forwardauth": ForwardAuth,
// "oauth2": OAuth2.m,
} }
var ( var (
ErrUnknownMiddleware = E.New("unknown middleware") ErrUnknownMiddleware = gperr.New("unknown middleware")
ErrDuplicatedMiddleware = E.New("duplicated middleware") ErrDuplicatedMiddleware = gperr.New("duplicated middleware")
) )
func Get(name string) (*Middleware, Error) { func Get(name string) (*Middleware, Error) {
@ -58,14 +54,14 @@ func All() map[string]*Middleware {
} }
func LoadComposeFiles() { func LoadComposeFiles() {
errs := E.NewBuilder("middleware compile errors") errs := gperr.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0) middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil { if err != nil {
logging.Err(err).Msg("failed to list middleware definitions") logging.Err(err).Msg("failed to list middleware definitions")
return return
} }
for _, defFile := range middlewareDefs { for _, defFile := range middlewareDefs {
voidErrs := E.NewBuilder("") // ignore these errors, will be added in next step voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs) mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
if len(mws) == 0 { if len(mws) == 0 {
continue continue
@ -103,6 +99,6 @@ func LoadComposeFiles() {
} }
} }
if errs.HasError() { if errs.HasError() {
E.LogError(errs.About(), errs.Error()) gperr.LogError(errs.About(), errs.Error())
} }
} }

View file

@ -1,12 +1,13 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/auth"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
) )
type oidcMiddleware struct { type oidcMiddleware struct {
@ -24,7 +25,7 @@ var OIDC = NewMiddleware[oidcMiddleware]()
func (amw *oidcMiddleware) finalize() error { func (amw *oidcMiddleware) finalize() error {
if !auth.IsOIDCEnabled() { if !auth.IsOIDCEnabled() {
return E.New("OIDC not enabled but OIDC middleware is used") return gperr.New("OIDC not enabled but OIDC middleware is used")
} }
return nil return nil
} }
@ -64,9 +65,6 @@ func (amw *oidcMiddleware) initSlow() error {
amw.authMux = http.NewServeMux() amw.authMux = http.NewServeMux()
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler) amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage) amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
amw.auth = authProvider amw.auth = authProvider
return nil return nil
@ -79,13 +77,17 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
return false return false
} }
if err := amw.auth.CheckToken(r); err != nil {
amw.authMux.ServeHTTP(w, r)
return false
}
if r.URL.Path == auth.OIDCLogoutPath { if r.URL.Path == auth.OIDCLogoutPath {
amw.auth.LogoutCallbackHandler(w, r) amw.auth.LogoutCallbackHandler(w, r)
return false return false
} }
if err := amw.auth.CheckToken(r); err != nil {
if errors.Is(err, auth.ErrMissingToken) {
amw.authMux.ServeHTTP(w, r)
} else {
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
}
return false
}
return true return true
} }

View file

@ -4,7 +4,7 @@ import (
"net" "net"
"net/http" "net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
) )
@ -111,6 +111,6 @@ func (ri *realIP) setRealIP(req *http.Request) {
req.RemoteAddr = lastNonTrustedIP req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP) req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP) req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP) ri.AddTracef("set real ip %s", lastNonTrustedIP)
} }

View file

@ -6,14 +6,14 @@ import (
"strings" "strings"
"testing" "testing"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestSetRealIPOpts(t *testing.T) { func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{ opts := OptionsRaw{
"header": gphttp.HeaderXRealIP, "header": httpheaders.HeaderXRealIP,
"from": []string{ "from": []string{
"127.0.0.0/8", "127.0.0.0/8",
"192.168.0.0/16", "192.168.0.0/16",
@ -22,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
"recursive": true, "recursive": true,
} }
optExpected := &RealIPOpts{ optExpected := &RealIPOpts{
Header: gphttp.HeaderXRealIP, Header: httpheaders.HeaderXRealIP,
From: []*types.CIDR{ From: []*types.CIDR{
{ {
IP: net.ParseIP("127.0.0.0"), IP: net.ParseIP("127.0.0.0"),
@ -51,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) {
func TestSetRealIP(t *testing.T) { func TestSetRealIP(t *testing.T) {
const ( const (
testHeader = gphttp.HeaderXRealIP testHeader = httpheaders.HeaderXRealIP
testRealIP = "192.168.1.1" testRealIP = "192.168.1.1"
) )
opts := OptionsRaw{ opts := OptionsRaw{

View file

@ -44,7 +44,7 @@ func (m *redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed b
r.URL.Host = host r.URL.Host = host
} }
http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently) http.Redirect(w, r, r.URL.String(), http.StatusPermanentRedirect)
logging.Debug().Str("url", r.URL.String()).Str("user_agent", r.UserAgent()).Msg("redirect to https") logging.Debug().Str("url", r.URL.String()).Str("user_agent", r.UserAgent()).Msg("redirect to https")
return false return false

View file

@ -13,7 +13,7 @@ func TestRedirectToHTTPs(t *testing.T) {
reqURL: types.MustParseURL("http://example.com"), reqURL: types.MustParseURL("http://example.com"),
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusMovedPermanently) ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com") ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com")
} }

View file

@ -3,8 +3,8 @@ package middleware
import ( import (
"net/http" "net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
) )
// internal use only. // internal use only.
@ -29,9 +29,9 @@ func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
// before implements RequestModifier. // before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) { func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(gphttp.HeaderUpstreamName, s.Name) r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme) r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host) r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port) r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
return true return true
} }

View file

@ -8,19 +8,6 @@ theGreatPretender:
- X-Test3 - X-Test3
- X-Test4 - X-Test4
notAuthenticAuthentik:
- use: RedirectHTTP
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
addAuthCookiesToResponse:
- session_id
- user_id
authResponseHeaders:
- X-Auth-SessionID
- X-Auth-UserID
- use: CustomErrorPage
realIPAuthentik: realIPAuthentik:
- use: RedirectHTTP - use: RedirectHTTP
- use: RealIP - use: RealIP
@ -30,9 +17,6 @@ realIPAuthentik:
- "192.168.0.0/16" - "192.168.0.0/16"
- "172.16.0.0/12" - "172.16.0.0/12"
recursive: true recursive: true
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
testFakeRealIP: testFakeRealIP:
- use: ModifyRequest - use: ModifyRequest

View file

@ -9,8 +9,8 @@ import (
"net/http/httptest" "net/http/httptest"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -122,7 +122,7 @@ func (args *testArgs) bodyReader() io.Reader {
return nil return nil
} }
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) { func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
if args == nil { if args == nil {
args = new(testArgs) args = new(testArgs)
} }
@ -136,7 +136,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
return newMiddlewaresTest([]*Middleware{mid}, args) return newMiddlewaresTest([]*Middleware{mid}, args)
} }
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) { func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
if args == nil { if args == nil {
args = new(testArgs) args = new(testArgs)
} }
@ -163,7 +163,7 @@ func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult,
data, err := io.ReadAll(resp.Body) data, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, E.From(err) return nil, gperr.Wrap(err)
} }
return &TestResult{ return &TestResult{

View file

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"sync" "sync"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
) )
type ( type (
@ -37,7 +37,7 @@ func (tr *Trace) WithRequest(req *http.Request) *Trace {
return nil return nil
} }
tr.URL = req.RequestURI tr.URL = req.RequestURI
tr.ReqHeaders = gphttp.HeaderToMap(req.Header) tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
return tr return tr
} }
@ -46,8 +46,8 @@ func (tr *Trace) WithResponse(resp *http.Response) *Trace {
return nil return nil
} }
tr.URL = resp.Request.RequestURI tr.URL = resp.Request.RequestURI
tr.ReqHeaders = gphttp.HeaderToMap(resp.Request.Header) tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
tr.RespHeaders = gphttp.HeaderToMap(resp.Header) tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
tr.RespStatus = resp.StatusCode tr.RespStatus = resp.StatusCode
return tr return tr
} }

View file

@ -7,7 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
) )
type ( type (
@ -91,25 +91,25 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
return "" return ""
}, },
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamName) }, VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string { VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort) upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
if upPort != "" { if upPort != "" {
return upHost + ":" + upPort return upHost + ":" + upPort
} }
return upHost return upHost
}, },
VarUpstreamURL: func(req *http.Request) string { VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
if upScheme == "" { if upScheme == "" {
return "" return ""
} }
upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort) upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
upAddr := upHost upAddr := upHost
if upPort != "" { if upPort != "" {
upAddr += ":" + upPort upAddr += ":" + upPort

View file

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"strings" "strings"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
) )
type ( type (
@ -20,10 +20,10 @@ var (
// before implements RequestModifier. // before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) { func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(gphttp.HeaderXForwardedFor) r.Header.Del(httpheaders.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr) clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil { if err == nil {
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP) r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
} }
return true return true
} }

View file

@ -1,7 +1,7 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go) // Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
// Copyright (c) 2020-2024 Traefik Labs // Copyright (c) 2020-2024 Traefik Labs
package http package gphttp
import ( import (
"bufio" "bufio"

View file

@ -1,4 +1,4 @@
package http package gphttp
import "net/http" import "net/http"

View file

@ -1,4 +1,4 @@
package http package gphttp
import "net/http" import "net/http"

View file

@ -0,0 +1,34 @@
package gphttp
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var DefaultDialer = net.Dialer{
Timeout: 5 * time.Second,
}
func NewTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: DefaultDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// DisableCompression: true, // Prevent double compression
ResponseHeaderTimeout: 60 * time.Second,
WriteBufferSize: 16 * 1024, // 16KB
ReadBufferSize: 16 * 1024, // 16KB
}
}
func NewTransportWithTLSConfig(tlsConfig *tls.Config) *http.Transport {
tr := NewTransport()
tr.TLSClientConfig = tlsConfig
return tr
}

View file

@ -1,198 +0,0 @@
package accesslog
import (
"bufio"
"bytes"
"io"
"strconv"
"time"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Retention struct {
Days uint64 `json:"days"`
Last uint64 `json:"last"`
}
const chunkSizeMax int64 = 128 * 1024 // 128KB
var (
ErrInvalidSyntax = E.New("invalid syntax")
ErrZeroValue = E.New("zero value")
)
// Syntax:
//
// <N> days|weeks|months
//
// last <N>
//
// Parse implements strutils.Parser.
func (r *Retention) Parse(v string) (err error) {
split := strutils.SplitSpace(v)
if len(split) != 2 {
return ErrInvalidSyntax.Subject(v)
}
switch split[0] {
case "last":
r.Last, err = strconv.ParseUint(split[1], 10, 64)
default: // <N> days|weeks|months
r.Days, err = strconv.ParseUint(split[0], 10, 64)
if err != nil {
return
}
switch split[1] {
case "days":
case "weeks":
r.Days *= 7
case "months":
r.Days *= 30
default:
return ErrInvalidSyntax.Subject("unit " + split[1])
}
}
if r.Days == 0 && r.Last == 0 {
return ErrZeroValue
}
return
}
func (r *Retention) rotateLogFile(file AccessLogIO) (err error) {
lastN := int(r.Last)
days := int(r.Days)
// Seek to end to get file size
size, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
// Initialize ring buffer for last N lines
lines := make([][]byte, 0, lastN|(days*1000))
pos := size
unprocessed := 0
var chunk [chunkSizeMax]byte
var lastLine []byte
var shouldStop func() bool
if days > 0 {
cutoff := time.Now().AddDate(0, 0, -days)
shouldStop = func() bool {
return len(lastLine) > 0 && !parseLogTime(lastLine).After(cutoff)
}
} else {
shouldStop = func() bool {
return len(lines) == lastN
}
}
// Read backwards until we have enough lines or reach start of file
for pos > 0 {
if pos > chunkSizeMax {
pos -= chunkSizeMax
} else {
pos = 0
}
// Seek to the current chunk
if _, err = file.Seek(pos, io.SeekStart); err != nil {
return err
}
var nRead int
// Read the chunk
if nRead, err = file.Read(chunk[unprocessed:]); err != nil {
return err
}
// last unprocessed bytes + read bytes
curChunk := chunk[:unprocessed+nRead]
unprocessed = len(curChunk)
// Split into lines
scanner := bufio.NewScanner(bytes.NewReader(curChunk))
for !shouldStop() && scanner.Scan() {
lastLine = scanner.Bytes()
lines = append(lines, lastLine)
unprocessed -= len(lastLine)
}
if shouldStop() {
break
}
// move unprocessed bytes to the beginning for next iteration
copy(chunk[:], curChunk[unprocessed:])
}
if days > 0 {
// truncate to the end of the log within last N days
return file.Truncate(pos)
}
// write lines to buffer in reverse order
// since we read them backwards
var buf bytes.Buffer
for i := len(lines) - 1; i >= 0; i-- {
buf.Write(lines[i])
buf.WriteRune('\n')
}
return writeTruncate(file, &buf)
}
func writeTruncate(file AccessLogIO, buf *bytes.Buffer) (err error) {
// Seek to beginning and truncate
if _, err := file.Seek(0, 0); err != nil {
return err
}
buffered := bufio.NewWriter(file)
// Write buffer back to file
nWritten, err := buffered.Write(buf.Bytes())
if err != nil {
return err
}
if err = buffered.Flush(); err != nil {
return err
}
// Truncate file
if err = file.Truncate(int64(nWritten)); err != nil {
return err
}
// check bytes written == buffer size
if nWritten != buf.Len() {
return io.ErrShortWrite
}
return
}
func parseLogTime(line []byte) (t time.Time) {
if len(line) == 0 {
return
}
var start, end int
const jsonStart = len(`{"time":"`)
const jsonEnd = jsonStart + len(LogTimeFormat)
if len(line) == '{' { // possibly json log
start = jsonStart
end = jsonEnd
} else { // possibly common or combined format
// Format: <virtual host> <host ip> - - [02/Jan/2006:15:04:05 -0700] ...
start = bytes.IndexRune(line, '[')
end = bytes.IndexRune(line[start+1:], ']')
if start == -1 || end == -1 || start >= end {
return
}
}
timeStr := line[start+1 : end]
t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error
return
}

View file

@ -1,81 +0,0 @@
package accesslog_test
import (
"testing"
"time"
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRetention(t *testing.T) {
tests := []struct {
input string
expected *Retention
shouldErr bool
}{
{"30 days", &Retention{Days: 30}, false},
{"2 weeks", &Retention{Days: 14}, false},
{"last 5", &Retention{Last: 5}, false},
{"invalid input", &Retention{}, true},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
r := &Retention{}
err := r.Parse(test.input)
if !test.shouldErr {
ExpectNoError(t, err)
} else {
ExpectDeepEqual(t, r, test.expected)
}
})
}
}
func TestRetentionCommonFormat(t *testing.T) {
var file MockFile
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{
Format: FormatCommon,
BufferSize: 1024,
})
for range 10 {
logger.Log(req, resp)
}
logger.Flush(true)
// test.Finish(nil)
ExpectEqual(t, logger.Config().Retention, nil)
ExpectTrue(t, file.Len() > 0)
ExpectEqual(t, file.Count(), 10)
t.Run("keep last", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("last 5")
ExpectEqual(t, logger.Config().Retention.Days, 0)
ExpectEqual(t, logger.Config().Retention.Last, 5)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.Count(), 5)
})
_ = file.Truncate(0)
timeNow := time.Now()
for i := range 10 {
logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time {
return timeNow.AddDate(0, 0, -i)
}
logger.Log(req, resp)
}
logger.Flush(true)
// FIXME: keep days does not work
t.Run("keep days", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("3 days")
ExpectEqual(t, logger.Config().Retention.Days, 3)
ExpectEqual(t, logger.Config().Retention.Last, 0)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.Count(), 3)
})
}

View file

@ -1,34 +0,0 @@
package http
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var (
defaultDialer = net.Dialer{
Timeout: 60 * time.Second,
}
DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: defaultDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: true, // Prevent double compression
ResponseHeaderTimeout: 60 * time.Second,
WriteBufferSize: 16 * 1024, // 16KB
ReadBufferSize: 16 * 1024, // 16KB
}
DefaultTransportNoTLS = func() *http.Transport {
clone := DefaultTransport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone
}()
)
const StaticFilePathPrefix = "/$gperrorpage/"

View file

@ -1,221 +0,0 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
// Copyright (c) 2020-2024 Traefik Labs
// Copyright (c) 2024 yusing
package middleware
import (
"io"
"net"
"net/http"
"slices"
"strings"
"time"
gphttp "github.com/yusing/go-proxy/internal/net/http"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
forwardAuth struct {
ForwardAuthOpts
Tracer
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
}
ForwardAuthOpts struct {
Address string `validate:"url,required"`
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
}
)
var ForwardAuth = NewMiddleware[forwardAuth]()
var faHTTPClient = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
// setup implements MiddlewareWithSetup.
func (fa *forwardAuth) setup() {
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
}
// before implements RequestModifier.
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect
scheme := "http"
if req.TLS != nil {
scheme = "https"
}
originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address
faReq, err := http.NewRequestWithContext(
req.Context(),
http.MethodGet,
url,
nil,
)
if err != nil {
fa.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
gphttp.CopyHeader(faReq.Header, req.Header)
gphttp.RemoveHop(faReq.Header)
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik
faReq.Header.Set("X-Original-Url", originalURL)
fa.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq)
if err != nil {
fa.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer faResp.Body.Close()
body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.AddTraceResponse("forward auth response", faResp)
gphttp.CopyHeader(w.Header(), faResp.Header)
gphttp.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.AddTracef("%s", "redirect to "+redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
}
return
}
for _, key := range fa.AuthResponseHeaders {
key := http.CanonicalHeaderKey(key)
req.Header.Del(key)
if len(faResp.Header[key]) > 0 {
req.Header[key] = append([]string(nil), faResp.Header[key]...)
}
}
req.RequestURI = req.URL.RequestURI()
authCookies := faResp.Cookies()
if len(authCookies) > 0 {
fa.reqCookiesMap.Store(req, authCookies)
}
return true
}
// modifyResponse implements ResponseModifier.
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
fa.setAuthCookies(resp, cookies)
fa.reqCookiesMap.Delete(resp.Request)
}
return nil
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
cookies := resp.Cookies()
resp.Header.Del("Set-Cookie")
for _, cookie := range cookies {
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is not an auth cookie, so add it back
resp.Header.Add("Set-Cookie", cookie.String())
}
}
for _, cookie := range authCookies {
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is an auth cookie, so add to resp
resp.Header.Add("Set-Cookie", cookie.String())
}
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
}
xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod)
switch {
case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod)
case req.Method != "":
faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
default:
faReq.Header.Del(gphttp.HeaderXForwardedMethod)
}
xfp := req.Header.Get(gphttp.HeaderXForwardedProto)
switch {
case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp)
case req.TLS != nil:
faReq.Header.Set(gphttp.HeaderXForwardedProto, "https")
default:
faReq.Header.Set(gphttp.HeaderXForwardedProto, "http")
}
if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp)
}
xfh := req.Header.Get(gphttp.HeaderXForwardedHost)
switch {
case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh)
case req.Host != "":
faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
default:
faReq.Header.Del(gphttp.HeaderXForwardedHost)
}
xfURI := req.Header.Get(gphttp.HeaderXForwardedURI)
switch {
case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI)
case req.URL.RequestURI() != "":
faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI())
default:
faReq.Header.Del(gphttp.HeaderXForwardedURI)
}
}