mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
refactor: remove forward auth, move module net/http to net/gphttp
This commit is contained in:
parent
c0c6e21a16
commit
5d2df3550b
69 changed files with 321 additions and 745 deletions
|
@ -1,29 +1,26 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type (
|
||||
AccessLogger struct {
|
||||
task *task.Task
|
||||
cfg *Config
|
||||
io AccessLogIO
|
||||
|
||||
buf bytes.Buffer // buffer for non-flushed log
|
||||
bufMu sync.RWMutex
|
||||
bufPool sync.Pool // buffer pool for formatting a single log line
|
||||
|
||||
flushThreshold int
|
||||
task *task.Task
|
||||
cfg *Config
|
||||
io AccessLogIO
|
||||
buffered *bufio.Writer
|
||||
|
||||
lineBufPool sync.Pool // buffer pool for formatting a single log line
|
||||
Formatter
|
||||
}
|
||||
|
||||
|
@ -44,14 +41,18 @@ type (
|
|||
)
|
||||
|
||||
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
|
||||
l := &AccessLogger{
|
||||
task: parent.Subtask("accesslog"),
|
||||
cfg: cfg,
|
||||
io: io,
|
||||
}
|
||||
if cfg.BufferSize < 1024 {
|
||||
if cfg.BufferSize == 0 {
|
||||
cfg.BufferSize = DefaultBufferSize
|
||||
}
|
||||
if cfg.BufferSize < 4096 {
|
||||
cfg.BufferSize = 4096
|
||||
}
|
||||
l := &AccessLogger{
|
||||
task: parent.Subtask("accesslog"),
|
||||
cfg: cfg,
|
||||
io: io,
|
||||
buffered: bufio.NewWriterSize(io, cfg.BufferSize),
|
||||
}
|
||||
|
||||
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
|
||||
switch l.cfg.Format {
|
||||
|
@ -65,10 +66,8 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog
|
|||
panic("invalid access log format")
|
||||
}
|
||||
|
||||
l.flushThreshold = int(cfg.BufferSize * 4 / 5) // 80%
|
||||
l.buf.Grow(int(cfg.BufferSize))
|
||||
l.bufPool.New = func() any {
|
||||
return new(bytes.Buffer)
|
||||
l.lineBufPool.New = func() any {
|
||||
return bytes.NewBuffer(make([]byte, 0, 1024))
|
||||
}
|
||||
go l.start()
|
||||
return l
|
||||
|
@ -89,15 +88,12 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
|
|||
return
|
||||
}
|
||||
|
||||
line := l.bufPool.Get().(*bytes.Buffer)
|
||||
l.Format(line, req, res)
|
||||
line.WriteRune('\n')
|
||||
|
||||
l.bufMu.Lock()
|
||||
l.buf.Write(line.Bytes())
|
||||
line := l.lineBufPool.Get().(*bytes.Buffer)
|
||||
line.Reset()
|
||||
l.bufPool.Put(line)
|
||||
l.bufMu.Unlock()
|
||||
defer l.lineBufPool.Put(line)
|
||||
l.Formatter.Format(line, req, res)
|
||||
line.WriteRune('\n')
|
||||
l.write(line.Bytes())
|
||||
}
|
||||
|
||||
func (l *AccessLogger) LogError(req *http.Request, err error) {
|
||||
|
@ -115,55 +111,53 @@ func (l *AccessLogger) Rotate() error {
|
|||
l.io.Lock()
|
||||
defer l.io.Unlock()
|
||||
|
||||
return l.cfg.Retention.rotateLogFile(l.io)
|
||||
}
|
||||
|
||||
func (l *AccessLogger) Flush(force bool) {
|
||||
if l.buf.Len() == 0 {
|
||||
return
|
||||
}
|
||||
if force || l.buf.Len() >= l.flushThreshold {
|
||||
l.bufMu.RLock()
|
||||
l.write(l.buf.Bytes())
|
||||
l.buf.Reset()
|
||||
l.bufMu.RUnlock()
|
||||
}
|
||||
return l.rotate()
|
||||
}
|
||||
|
||||
func (l *AccessLogger) handleErr(err error) {
|
||||
E.LogError("failed to write access log", err)
|
||||
gperr.LogError("failed to write access log", err)
|
||||
}
|
||||
|
||||
func (l *AccessLogger) start() {
|
||||
defer func() {
|
||||
if l.buf.Len() > 0 { // flush last
|
||||
l.write(l.buf.Bytes())
|
||||
if err := l.Flush(); err != nil {
|
||||
l.handleErr(err)
|
||||
}
|
||||
l.io.Close()
|
||||
l.close()
|
||||
l.task.Finish(nil)
|
||||
}()
|
||||
|
||||
// periodic flush + threshold flush
|
||||
periodic := time.NewTicker(5 * time.Second)
|
||||
threshold := time.NewTicker(time.Second)
|
||||
defer periodic.Stop()
|
||||
defer threshold.Stop()
|
||||
// flushes the buffer every 30 seconds
|
||||
flushTicker := time.NewTicker(30 * time.Second)
|
||||
defer flushTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.task.Context().Done():
|
||||
return
|
||||
case <-periodic.C:
|
||||
l.Flush(true)
|
||||
case <-threshold.C:
|
||||
l.Flush(false)
|
||||
case <-flushTicker.C:
|
||||
if err := l.Flush(); err != nil {
|
||||
l.handleErr(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *AccessLogger) Flush() error {
|
||||
l.io.Lock()
|
||||
defer l.io.Unlock()
|
||||
return l.buffered.Flush()
|
||||
}
|
||||
|
||||
func (l *AccessLogger) close() {
|
||||
l.io.Lock()
|
||||
defer l.io.Unlock()
|
||||
l.io.Close()
|
||||
}
|
||||
|
||||
func (l *AccessLogger) write(data []byte) {
|
||||
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
|
||||
_, err := l.io.Write(data)
|
||||
_, err := l.buffered.Write(data)
|
||||
l.io.Unlock()
|
||||
if err != nil {
|
||||
l.handleErr(err)
|
|
@ -9,7 +9,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
|
@ -17,7 +17,7 @@ type (
|
|||
Cookies FieldConfig `json:"cookies"`
|
||||
}
|
||||
Config struct {
|
||||
BufferSize uint `json:"buffer_size" validate:"gte=1"`
|
||||
BufferSize int `json:"buffer_size"`
|
||||
Format Format `json:"format" validate:"oneof=common combined json"`
|
||||
Path string `json:"path" validate:"required"`
|
||||
Filters Filters `json:"filters"`
|
|
@ -4,7 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
|
@ -3,7 +3,7 @@ package accesslog_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
|
@ -71,14 +71,14 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
|
|||
go func(l *AccessLogger) {
|
||||
defer wg.Done()
|
||||
parallelLog(l, req, resp, logCountPerLogger)
|
||||
l.Flush(true)
|
||||
l.Flush()
|
||||
}(logger)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
expected := loggerCount * logCountPerLogger
|
||||
actual := file.Count()
|
||||
actual := file.LineCount()
|
||||
ExpectEqual(t, actual, expected)
|
||||
}
|
||||
|
|
@ -5,7 +5,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ type (
|
|||
CIDR struct{ types.CIDR }
|
||||
)
|
||||
|
||||
var ErrInvalidHTTPHeaderFilter = E.New("invalid http header filter")
|
||||
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
|
||||
|
||||
func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool {
|
||||
if len(f.Values) == 0 {
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
|
@ -49,7 +49,6 @@ func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) {
|
|||
return 0, io.EOF
|
||||
}
|
||||
n = copy(p, m.data[off:])
|
||||
m.position += int64(n)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
@ -63,7 +62,7 @@ func (m *MockFile) Truncate(size int64) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *MockFile) Count() int {
|
||||
func (m *MockFile) LineCount() int {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return bytes.Count(m.data[:m.position], []byte("\n"))
|
||||
|
@ -72,3 +71,7 @@ func (m *MockFile) Count() int {
|
|||
func (m *MockFile) Len() int64 {
|
||||
return m.position
|
||||
}
|
||||
|
||||
func (m *MockFile) Content() []byte {
|
||||
return m.data[:m.position]
|
||||
}
|
56
internal/net/gphttp/accesslog/retention.go
Normal file
56
internal/net/gphttp/accesslog/retention.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Retention struct {
|
||||
Days uint64 `json:"days"`
|
||||
Last uint64 `json:"last"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidSyntax = gperr.New("invalid syntax")
|
||||
ErrZeroValue = gperr.New("zero value")
|
||||
)
|
||||
|
||||
var defaultChunkSize = 64 * 1024 // 64KB
|
||||
|
||||
// Syntax:
|
||||
//
|
||||
// <N> days|weeks|months
|
||||
//
|
||||
// last <N>
|
||||
//
|
||||
// Parse implements strutils.Parser.
|
||||
func (r *Retention) Parse(v string) (err error) {
|
||||
split := strutils.SplitSpace(v)
|
||||
if len(split) != 2 {
|
||||
return ErrInvalidSyntax.Subject(v)
|
||||
}
|
||||
switch split[0] {
|
||||
case "last":
|
||||
r.Last, err = strconv.ParseUint(split[1], 10, 64)
|
||||
default: // <N> days|weeks|months
|
||||
r.Days, err = strconv.ParseUint(split[0], 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch split[1] {
|
||||
case "days":
|
||||
case "weeks":
|
||||
r.Days *= 7
|
||||
case "months":
|
||||
r.Days *= 30
|
||||
default:
|
||||
return ErrInvalidSyntax.Subject("unit " + split[1])
|
||||
}
|
||||
}
|
||||
if r.Days == 0 && r.Last == 0 {
|
||||
return ErrZeroValue
|
||||
}
|
||||
return
|
||||
}
|
33
internal/net/gphttp/accesslog/retention_test.go
Normal file
33
internal/net/gphttp/accesslog/retention_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package accesslog_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestParseRetention(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected *Retention
|
||||
shouldErr bool
|
||||
}{
|
||||
{"30 days", &Retention{Days: 30}, false},
|
||||
{"2 weeks", &Retention{Days: 14}, false},
|
||||
{"last 5", &Retention{Last: 5}, false},
|
||||
{"invalid input", &Retention{}, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
r := &Retention{}
|
||||
err := r.Parse(test.input)
|
||||
if !test.shouldErr {
|
||||
ExpectNoError(t, err)
|
||||
} else {
|
||||
ExpectDeepEqual(t, r, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,7 +3,7 @@ package accesslog
|
|||
import (
|
||||
"strconv"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@ type StatusCodeRange struct {
|
|||
End int
|
||||
}
|
||||
|
||||
var ErrInvalidStatusCodeRange = E.New("invalid status code range")
|
||||
var ErrInvalidStatusCodeRange = gperr.New("invalid status code range")
|
||||
|
||||
func (r *StatusCodeRange) Includes(code int) bool {
|
||||
return r.Start <= code && code <= r.End
|
||||
|
@ -25,7 +25,7 @@ func (r *StatusCodeRange) Parse(v string) error {
|
|||
case 1:
|
||||
start, err := strconv.Atoi(split[0])
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
return gperr.Wrap(err)
|
||||
}
|
||||
r.Start = start
|
||||
r.End = start
|
||||
|
@ -33,7 +33,7 @@ func (r *StatusCodeRange) Parse(v string) error {
|
|||
case 2:
|
||||
start, errStart := strconv.Atoi(split[0])
|
||||
end, errEnd := strconv.Atoi(split[1])
|
||||
if err := E.Join(errStart, errEnd); err != nil {
|
||||
if err := gperr.Join(errStart, errEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
r.Start = start
|
|
@ -1,4 +1,4 @@
|
|||
package http
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"mime"
|
|
@ -1,4 +1,4 @@
|
|||
package http
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"net/http"
|
|
@ -6,8 +6,8 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
||||
)
|
||||
|
||||
type ipHash struct {
|
||||
|
@ -23,10 +23,10 @@ func (lb *LoadBalancer) newIPHash() impl {
|
|||
if len(lb.Options) == 0 {
|
||||
return impl
|
||||
}
|
||||
var err E.Error
|
||||
var err gperr.Error
|
||||
impl.realIP, err = middleware.RealIP.New(lb.Options)
|
||||
if err != nil {
|
||||
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
||||
gperr.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
||||
}
|
||||
return impl
|
||||
}
|
|
@ -6,10 +6,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
|
@ -54,7 +54,7 @@ func New(cfg *Config) *LoadBalancer {
|
|||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Start(parent task.Parent) E.Error {
|
||||
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
|
||||
lb.startTime = time.Now()
|
||||
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
|
||||
parent.OnCancel("lb_remove_route", func() {
|
||||
|
@ -125,12 +125,12 @@ func (lb *LoadBalancer) AddServer(srv Server) {
|
|||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
if lb.pool.Has(srv.Name()) {
|
||||
old, _ := lb.pool.Load(srv.Name())
|
||||
if lb.pool.Has(srv.Key()) { // FIXME: this should be a warning
|
||||
old, _ := lb.pool.Load(srv.Key())
|
||||
lb.sumWeight -= old.Weight()
|
||||
lb.impl.OnRemoveServer(old)
|
||||
}
|
||||
lb.pool.Store(srv.Name(), srv)
|
||||
lb.pool.Store(srv.Key(), srv)
|
||||
lb.sumWeight += srv.Weight()
|
||||
|
||||
lb.rebalance()
|
||||
|
@ -146,11 +146,11 @@ func (lb *LoadBalancer) RemoveServer(srv Server) {
|
|||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
if !lb.pool.Has(srv.Name()) {
|
||||
if !lb.pool.Has(srv.Key()) {
|
||||
return
|
||||
}
|
||||
|
||||
lb.pool.Delete(srv.Name())
|
||||
lb.pool.Delete(srv.Key())
|
||||
|
||||
lb.sumWeight -= srv.Weight()
|
||||
lb.rebalance()
|
||||
|
@ -227,7 +227,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
if r.Header.Get(common.HeaderCheckRedirect) != "" {
|
||||
if r.Header.Get(httpheaders.HeaderGoDoxyCheckRedirect) != "" {
|
||||
// wake all servers
|
||||
for _, srv := range srvs {
|
||||
if err := srv.TryWake(); err != nil {
|
||||
|
@ -244,7 +244,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||
extra := make(map[string]any)
|
||||
lb.pool.RangeAll(func(k string, v Server) {
|
||||
extra[v.Name()] = v
|
||||
extra[v.Key()] = v
|
||||
})
|
||||
|
||||
return (&monitor.JSONRepresentation{
|
|
@ -3,7 +3,7 @@ package loadbalancer
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
)
|
||||
|
||||
type (
|
|
@ -26,6 +26,7 @@ type (
|
|||
http.Handler
|
||||
health.HealthMonitor
|
||||
Name() string
|
||||
Key() string
|
||||
URL() *net.URL
|
||||
Weight() Weight
|
||||
SetWeight(weight Weight)
|
||||
|
@ -51,6 +52,7 @@ func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, h
|
|||
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
|
||||
srv := &server{
|
||||
weight: Weight(weight),
|
||||
url: net.MustParseURL("http://localhost"),
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
@ -63,6 +65,10 @@ func (srv *server) URL() *net.URL {
|
|||
return srv.url
|
||||
}
|
||||
|
||||
func (srv *server) Key() string {
|
||||
return srv.url.Host
|
||||
}
|
||||
|
||||
func (srv *server) Weight() Weight {
|
||||
return srv.weight
|
||||
}
|
||||
|
@ -78,9 +84,7 @@ func (srv *server) String() string {
|
|||
func (srv *server) TryWake() error {
|
||||
waker, ok := srv.Handler.(idlewatcher.Waker)
|
||||
if ok {
|
||||
if err := waker.Wake(); err != nil {
|
||||
return err
|
||||
}
|
||||
return waker.Wake()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package http
|
||||
package gphttp
|
||||
|
||||
import "net/http"
|
||||
|
|
@ -5,7 +5,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
|
@ -7,7 +7,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
@ -61,7 +61,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCIDRWhitelist(t *testing.T) {
|
||||
errs := E.NewBuilder("")
|
||||
errs := gperr.NewBuilder("")
|
||||
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
deny = mids["deny@file"]
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
|
@ -28,7 +29,7 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
cfCIDRsLastUpdate time.Time
|
||||
cfCIDRsLastUpdate atomic.Value[time.Time]
|
||||
cfCIDRsMu sync.Mutex
|
||||
|
||||
// RFC 1918.
|
||||
|
@ -68,14 +69,14 @@ func (cri *cloudflareRealIP) getTracer() *Tracer {
|
|||
}
|
||||
|
||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
}
|
||||
|
||||
cfCIDRsMu.Lock()
|
||||
defer cfCIDRsMu.Unlock()
|
||||
|
||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -88,7 +89,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
|||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
|
||||
)
|
||||
if err != nil {
|
||||
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
||||
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
|
||||
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||
return nil
|
||||
}
|
||||
|
@ -97,7 +98,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
|||
}
|
||||
}
|
||||
|
||||
cfCIDRsLastUpdate = time.Now()
|
||||
cfCIDRsLastUpdate.Store(time.Now())
|
||||
logging.Info().Msg("cloudflare CIDR range updated")
|
||||
return
|
||||
}
|
|
@ -9,14 +9,17 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
|
||||
)
|
||||
|
||||
type customErrorPage struct{}
|
||||
|
||||
var CustomErrorPage = NewMiddleware[customErrorPage]()
|
||||
|
||||
const StaticFilePathPrefix = "/$gperrorpage/"
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
return !ServeStaticErrorPageFile(w, r)
|
||||
|
@ -34,8 +37,8 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
|
|||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||
resp.ContentLength = int64(len(errorPage))
|
||||
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
|
||||
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||
} else {
|
||||
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
}
|
||||
|
@ -49,8 +52,8 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
|
|||
if path != "" && path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) {
|
||||
filename := path[len(gphttp.StaticFilePathPrefix):]
|
||||
if strings.HasPrefix(path, StaticFilePathPrefix) {
|
||||
filename := path[len(StaticFilePathPrefix):]
|
||||
file, ok := errorpage.GetStaticFile(filename)
|
||||
if !ok {
|
||||
logging.Error().Msg("unable to load resource " + filename)
|
||||
|
@ -59,11 +62,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
|
|||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".html":
|
||||
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
|
||||
w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||
case ".js":
|
||||
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
|
||||
w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
|
||||
case ".css":
|
||||
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
|
||||
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
|
||||
default:
|
||||
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
|
@ -90,7 +90,7 @@ func watchDir() {
|
|||
loadContent()
|
||||
}
|
||||
case err := <-errCh:
|
||||
E.LogError("error watching error page directory", err)
|
||||
gperr.LogError("error watching error page directory", err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,15 +7,15 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
Error = E.Error
|
||||
Error = gperr.Error
|
||||
|
||||
ReverseProxy = reverseproxy.ReverseProxy
|
||||
ProxyRequest = reverseproxy.ProxyRequest
|
||||
|
@ -80,7 +80,7 @@ func NewMiddleware[ImplType any]() *Middleware {
|
|||
func (m *Middleware) enableTrace() {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
tracer.enableTrace()
|
||||
logging.Debug().Msgf("middleware %s enabled trace", m.name)
|
||||
logging.Trace().Msgf("middleware %s enabled trace", m.name)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,7 +103,7 @@ func (m *Middleware) setup() {
|
|||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
|
||||
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
||||
if len(optsRaw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -132,10 +132,10 @@ func (m *Middleware) finalize() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
|
||||
if m.construct == nil { // likely a middleware from compose
|
||||
if len(optsRaw) != 0 {
|
||||
return nil, E.New("additional options not allowed for middleware ").Subject(m.name)
|
||||
return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
|||
return nil, err
|
||||
}
|
||||
if err := mid.finalize(); err != nil {
|
||||
return nil, E.From(err)
|
||||
return nil, gperr.Wrap(err)
|
||||
}
|
||||
return mid, nil
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
|
|||
next(w, r)
|
||||
}
|
||||
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
|
||||
var middlewares []*Middleware
|
||||
middlewares, err = compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
|
@ -6,13 +6,13 @@ import (
|
|||
"path"
|
||||
"sort"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var ErrMissingMiddlewareUse = E.New("missing middleware 'use' field")
|
||||
var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
|
||||
|
||||
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
|
||||
func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
eb.Add(err)
|
||||
|
@ -21,7 +21,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]
|
|||
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
||||
}
|
||||
|
||||
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
|
||||
func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware {
|
||||
var rawMap map[string][]map[string]any
|
||||
err := yaml.Unmarshal(data, &rawMap)
|
||||
if err != nil {
|
||||
|
@ -40,11 +40,11 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str
|
|||
return middlewares
|
||||
}
|
||||
|
||||
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
||||
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
errs := E.NewBuilder("middlewares compile error")
|
||||
invalidOpts := E.NewBuilder("options compile error")
|
||||
errs := gperr.NewBuilder("middlewares compile error")
|
||||
invalidOpts := gperr.NewBuilder("options compile error")
|
||||
|
||||
for name, opts := range middlewaresMap {
|
||||
m, err := Get(name)
|
||||
|
@ -68,7 +68,7 @@ func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.
|
|||
return middlewares, errs.Error()
|
||||
}
|
||||
|
||||
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, E.Error) {
|
||||
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
|
||||
compiled, err := compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -77,8 +77,8 @@ func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (
|
|||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, E.Error) {
|
||||
chainErr := E.NewBuilder("")
|
||||
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
|
||||
chainErr := gperr.NewBuilder("")
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"] == "" {
|
|
@ -5,7 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -13,7 +13,7 @@ import (
|
|||
var testMiddlewareCompose []byte
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
errs := E.NewBuilder("")
|
||||
errs := gperr.NewBuilder("")
|
||||
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
Must(json.MarshalIndent(middlewares, "", " "))
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
)
|
||||
|
||||
type middlewareChain struct {
|
||||
|
@ -51,10 +51,10 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
|||
if len(m.modResps) == 0 {
|
||||
return nil
|
||||
}
|
||||
errs := E.NewBuilder("modify response errors")
|
||||
errs := gperr.NewBuilder("modify response errors")
|
||||
for i, mr := range m.modResps {
|
||||
if err := mr.modifyResponse(resp); err != nil {
|
||||
errs.Add(E.From(err).Subjectf("%d", i))
|
||||
errs.Add(gperr.Wrap(err).Subjectf("%d", i))
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
|
@ -4,7 +4,7 @@ import (
|
|||
"path"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
|
@ -32,15 +32,11 @@ var allMiddlewares = map[string]*Middleware{
|
|||
|
||||
"cidrwhitelist": CIDRWhiteList,
|
||||
"ratelimit": RateLimiter,
|
||||
|
||||
// !experimental
|
||||
"forwardauth": ForwardAuth,
|
||||
// "oauth2": OAuth2.m,
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUnknownMiddleware = E.New("unknown middleware")
|
||||
ErrDuplicatedMiddleware = E.New("duplicated middleware")
|
||||
ErrUnknownMiddleware = gperr.New("unknown middleware")
|
||||
ErrDuplicatedMiddleware = gperr.New("duplicated middleware")
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
|
@ -58,14 +54,14 @@ func All() map[string]*Middleware {
|
|||
}
|
||||
|
||||
func LoadComposeFiles() {
|
||||
errs := E.NewBuilder("middleware compile errors")
|
||||
errs := gperr.NewBuilder("middleware compile errors")
|
||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to list middleware definitions")
|
||||
return
|
||||
}
|
||||
for _, defFile := range middlewareDefs {
|
||||
voidErrs := E.NewBuilder("") // ignore these errors, will be added in next step
|
||||
voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step
|
||||
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
|
||||
if len(mws) == 0 {
|
||||
continue
|
||||
|
@ -103,6 +99,6 @@ func LoadComposeFiles() {
|
|||
}
|
||||
}
|
||||
if errs.HasError() {
|
||||
E.LogError(errs.About(), errs.Error())
|
||||
gperr.LogError(errs.About(), errs.Error())
|
||||
}
|
||||
}
|
|
@ -1,12 +1,13 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
)
|
||||
|
||||
type oidcMiddleware struct {
|
||||
|
@ -24,7 +25,7 @@ var OIDC = NewMiddleware[oidcMiddleware]()
|
|||
|
||||
func (amw *oidcMiddleware) finalize() error {
|
||||
if !auth.IsOIDCEnabled() {
|
||||
return E.New("OIDC not enabled but OIDC middleware is used")
|
||||
return gperr.New("OIDC not enabled but OIDC middleware is used")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -64,9 +65,6 @@ func (amw *oidcMiddleware) initSlow() error {
|
|||
|
||||
amw.authMux = http.NewServeMux()
|
||||
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
})
|
||||
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
||||
amw.auth = authProvider
|
||||
return nil
|
||||
|
@ -79,13 +77,17 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
|
|||
return false
|
||||
}
|
||||
|
||||
if err := amw.auth.CheckToken(r); err != nil {
|
||||
amw.authMux.ServeHTTP(w, r)
|
||||
return false
|
||||
}
|
||||
if r.URL.Path == auth.OIDCLogoutPath {
|
||||
amw.auth.LogoutCallbackHandler(w, r)
|
||||
return false
|
||||
}
|
||||
if err := amw.auth.CheckToken(r); err != nil {
|
||||
if errors.Is(err, auth.ErrMissingToken) {
|
||||
amw.authMux.ServeHTTP(w, r)
|
||||
} else {
|
||||
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
|
@ -111,6 +111,6 @@ func (ri *realIP) setRealIP(req *http.Request) {
|
|||
|
||||
req.RemoteAddr = lastNonTrustedIP
|
||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
|
||||
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
|
||||
ri.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||
}
|
|
@ -6,14 +6,14 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetRealIPOpts(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"header": gphttp.HeaderXRealIP,
|
||||
"header": httpheaders.HeaderXRealIP,
|
||||
"from": []string{
|
||||
"127.0.0.0/8",
|
||||
"192.168.0.0/16",
|
||||
|
@ -22,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
|||
"recursive": true,
|
||||
}
|
||||
optExpected := &RealIPOpts{
|
||||
Header: gphttp.HeaderXRealIP,
|
||||
Header: httpheaders.HeaderXRealIP,
|
||||
From: []*types.CIDR{
|
||||
{
|
||||
IP: net.ParseIP("127.0.0.0"),
|
||||
|
@ -51,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
|||
|
||||
func TestSetRealIP(t *testing.T) {
|
||||
const (
|
||||
testHeader = gphttp.HeaderXRealIP
|
||||
testHeader = httpheaders.HeaderXRealIP
|
||||
testRealIP = "192.168.1.1"
|
||||
)
|
||||
opts := OptionsRaw{
|
|
@ -44,7 +44,7 @@ func (m *redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed b
|
|||
r.URL.Host = host
|
||||
}
|
||||
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently)
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusPermanentRedirect)
|
||||
|
||||
logging.Debug().Str("url", r.URL.String()).Str("user_agent", r.UserAgent()).Msg("redirect to https")
|
||||
return false
|
|
@ -13,7 +13,7 @@ func TestRedirectToHTTPs(t *testing.T) {
|
|||
reqURL: types.MustParseURL("http://example.com"),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusMovedPermanently)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com")
|
||||
}
|
||||
|
|
@ -3,8 +3,8 @@ package middleware
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
)
|
||||
|
||||
// internal use only.
|
||||
|
@ -29,9 +29,9 @@ func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
|
|||
|
||||
// before implements RequestModifier.
|
||||
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Set(gphttp.HeaderUpstreamName, s.Name)
|
||||
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
|
||||
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
|
||||
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
|
||||
return true
|
||||
}
|
|
@ -8,19 +8,6 @@ theGreatPretender:
|
|||
- X-Test3
|
||||
- X-Test4
|
||||
|
||||
notAuthenticAuthentik:
|
||||
- use: RedirectHTTP
|
||||
- use: ForwardAuth
|
||||
address: https://authentik.company
|
||||
trustForwardHeader: true
|
||||
addAuthCookiesToResponse:
|
||||
- session_id
|
||||
- user_id
|
||||
authResponseHeaders:
|
||||
- X-Auth-SessionID
|
||||
- X-Auth-UserID
|
||||
- use: CustomErrorPage
|
||||
|
||||
realIPAuthentik:
|
||||
- use: RedirectHTTP
|
||||
- use: RealIP
|
||||
|
@ -30,9 +17,6 @@ realIPAuthentik:
|
|||
- "192.168.0.0/16"
|
||||
- "172.16.0.0/12"
|
||||
recursive: true
|
||||
- use: ForwardAuth
|
||||
address: https://authentik.company
|
||||
trustForwardHeader: true
|
||||
|
||||
testFakeRealIP:
|
||||
- use: ModifyRequest
|
|
@ -9,8 +9,8 @@ import (
|
|||
"net/http/httptest"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
@ -122,7 +122,7 @@ func (args *testArgs) bodyReader() io.Reader {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
|
@ -136,7 +136,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
|||
return newMiddlewaresTest([]*Middleware{mid}, args)
|
||||
}
|
||||
|
||||
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult,
|
|||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
return nil, gperr.Wrap(err)
|
||||
}
|
||||
|
||||
return &TestResult{
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -37,7 +37,7 @@ func (tr *Trace) WithRequest(req *http.Request) *Trace {
|
|||
return nil
|
||||
}
|
||||
tr.URL = req.RequestURI
|
||||
tr.ReqHeaders = gphttp.HeaderToMap(req.Header)
|
||||
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
|
||||
return tr
|
||||
}
|
||||
|
||||
|
@ -46,8 +46,8 @@ func (tr *Trace) WithResponse(resp *http.Response) *Trace {
|
|||
return nil
|
||||
}
|
||||
tr.URL = resp.Request.RequestURI
|
||||
tr.ReqHeaders = gphttp.HeaderToMap(resp.Request.Header)
|
||||
tr.RespHeaders = gphttp.HeaderToMap(resp.Header)
|
||||
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
|
||||
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
|
||||
tr.RespStatus = resp.StatusCode
|
||||
return tr
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -91,25 +91,25 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
|||
return ""
|
||||
},
|
||||
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamName) },
|
||||
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
||||
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
||||
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
||||
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
|
||||
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
|
||||
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
|
||||
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
|
||||
VarUpstreamAddr: func(req *http.Request) string {
|
||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||
if upPort != "" {
|
||||
return upHost + ":" + upPort
|
||||
}
|
||||
return upHost
|
||||
},
|
||||
VarUpstreamURL: func(req *http.Request) string {
|
||||
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
||||
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
|
||||
if upScheme == "" {
|
||||
return ""
|
||||
}
|
||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||
upAddr := upHost
|
||||
if upPort != "" {
|
||||
upAddr += ":" + upPort
|
|
@ -5,7 +5,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -20,10 +20,10 @@ var (
|
|||
|
||||
// before implements RequestModifier.
|
||||
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Del(gphttp.HeaderXForwardedFor)
|
||||
r.Header.Del(httpheaders.HeaderXForwardedFor)
|
||||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil {
|
||||
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
|
||||
r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
|
||||
// Copyright (c) 2020-2024 Traefik Labs
|
||||
|
||||
package http
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"bufio"
|
|
@ -1,4 +1,4 @@
|
|||
package http
|
||||
package gphttp
|
||||
|
||||
import "net/http"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package http
|
||||
package gphttp
|
||||
|
||||
import "net/http"
|
||||
|
34
internal/net/gphttp/transport.go
Normal file
34
internal/net/gphttp/transport.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package gphttp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DefaultDialer = net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
func NewTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: DefaultDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
// DisableCompression: true, // Prevent double compression
|
||||
ResponseHeaderTimeout: 60 * time.Second,
|
||||
WriteBufferSize: 16 * 1024, // 16KB
|
||||
ReadBufferSize: 16 * 1024, // 16KB
|
||||
}
|
||||
}
|
||||
|
||||
func NewTransportWithTLSConfig(tlsConfig *tls.Config) *http.Transport {
|
||||
tr := NewTransport()
|
||||
tr.TLSClientConfig = tlsConfig
|
||||
return tr
|
||||
}
|
|
@ -1,198 +0,0 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Retention struct {
|
||||
Days uint64 `json:"days"`
|
||||
Last uint64 `json:"last"`
|
||||
}
|
||||
|
||||
const chunkSizeMax int64 = 128 * 1024 // 128KB
|
||||
|
||||
var (
|
||||
ErrInvalidSyntax = E.New("invalid syntax")
|
||||
ErrZeroValue = E.New("zero value")
|
||||
)
|
||||
|
||||
// Syntax:
|
||||
//
|
||||
// <N> days|weeks|months
|
||||
//
|
||||
// last <N>
|
||||
//
|
||||
// Parse implements strutils.Parser.
|
||||
func (r *Retention) Parse(v string) (err error) {
|
||||
split := strutils.SplitSpace(v)
|
||||
if len(split) != 2 {
|
||||
return ErrInvalidSyntax.Subject(v)
|
||||
}
|
||||
switch split[0] {
|
||||
case "last":
|
||||
r.Last, err = strconv.ParseUint(split[1], 10, 64)
|
||||
default: // <N> days|weeks|months
|
||||
r.Days, err = strconv.ParseUint(split[0], 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch split[1] {
|
||||
case "days":
|
||||
case "weeks":
|
||||
r.Days *= 7
|
||||
case "months":
|
||||
r.Days *= 30
|
||||
default:
|
||||
return ErrInvalidSyntax.Subject("unit " + split[1])
|
||||
}
|
||||
}
|
||||
if r.Days == 0 && r.Last == 0 {
|
||||
return ErrZeroValue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Retention) rotateLogFile(file AccessLogIO) (err error) {
|
||||
lastN := int(r.Last)
|
||||
days := int(r.Days)
|
||||
|
||||
// Seek to end to get file size
|
||||
size, err := file.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize ring buffer for last N lines
|
||||
lines := make([][]byte, 0, lastN|(days*1000))
|
||||
pos := size
|
||||
unprocessed := 0
|
||||
|
||||
var chunk [chunkSizeMax]byte
|
||||
var lastLine []byte
|
||||
|
||||
var shouldStop func() bool
|
||||
if days > 0 {
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
shouldStop = func() bool {
|
||||
return len(lastLine) > 0 && !parseLogTime(lastLine).After(cutoff)
|
||||
}
|
||||
} else {
|
||||
shouldStop = func() bool {
|
||||
return len(lines) == lastN
|
||||
}
|
||||
}
|
||||
|
||||
// Read backwards until we have enough lines or reach start of file
|
||||
for pos > 0 {
|
||||
if pos > chunkSizeMax {
|
||||
pos -= chunkSizeMax
|
||||
} else {
|
||||
pos = 0
|
||||
}
|
||||
|
||||
// Seek to the current chunk
|
||||
if _, err = file.Seek(pos, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var nRead int
|
||||
// Read the chunk
|
||||
if nRead, err = file.Read(chunk[unprocessed:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// last unprocessed bytes + read bytes
|
||||
curChunk := chunk[:unprocessed+nRead]
|
||||
unprocessed = len(curChunk)
|
||||
|
||||
// Split into lines
|
||||
scanner := bufio.NewScanner(bytes.NewReader(curChunk))
|
||||
for !shouldStop() && scanner.Scan() {
|
||||
lastLine = scanner.Bytes()
|
||||
lines = append(lines, lastLine)
|
||||
unprocessed -= len(lastLine)
|
||||
}
|
||||
if shouldStop() {
|
||||
break
|
||||
}
|
||||
|
||||
// move unprocessed bytes to the beginning for next iteration
|
||||
copy(chunk[:], curChunk[unprocessed:])
|
||||
}
|
||||
|
||||
if days > 0 {
|
||||
// truncate to the end of the log within last N days
|
||||
return file.Truncate(pos)
|
||||
}
|
||||
|
||||
// write lines to buffer in reverse order
|
||||
// since we read them backwards
|
||||
var buf bytes.Buffer
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
buf.Write(lines[i])
|
||||
buf.WriteRune('\n')
|
||||
}
|
||||
|
||||
return writeTruncate(file, &buf)
|
||||
}
|
||||
|
||||
func writeTruncate(file AccessLogIO, buf *bytes.Buffer) (err error) {
|
||||
// Seek to beginning and truncate
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buffered := bufio.NewWriter(file)
|
||||
// Write buffer back to file
|
||||
nWritten, err := buffered.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = buffered.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Truncate file
|
||||
if err = file.Truncate(int64(nWritten)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check bytes written == buffer size
|
||||
if nWritten != buf.Len() {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseLogTime(line []byte) (t time.Time) {
|
||||
if len(line) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var start, end int
|
||||
const jsonStart = len(`{"time":"`)
|
||||
const jsonEnd = jsonStart + len(LogTimeFormat)
|
||||
|
||||
if len(line) == '{' { // possibly json log
|
||||
start = jsonStart
|
||||
end = jsonEnd
|
||||
} else { // possibly common or combined format
|
||||
// Format: <virtual host> <host ip> - - [02/Jan/2006:15:04:05 -0700] ...
|
||||
start = bytes.IndexRune(line, '[')
|
||||
end = bytes.IndexRune(line[start+1:], ']')
|
||||
if start == -1 || end == -1 || start >= end {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
timeStr := line[start+1 : end]
|
||||
t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error
|
||||
return
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package accesslog_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestParseRetention(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected *Retention
|
||||
shouldErr bool
|
||||
}{
|
||||
{"30 days", &Retention{Days: 30}, false},
|
||||
{"2 weeks", &Retention{Days: 14}, false},
|
||||
{"last 5", &Retention{Last: 5}, false},
|
||||
{"invalid input", &Retention{}, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
r := &Retention{}
|
||||
err := r.Parse(test.input)
|
||||
if !test.shouldErr {
|
||||
ExpectNoError(t, err)
|
||||
} else {
|
||||
ExpectDeepEqual(t, r, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetentionCommonFormat(t *testing.T) {
|
||||
var file MockFile
|
||||
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{
|
||||
Format: FormatCommon,
|
||||
BufferSize: 1024,
|
||||
})
|
||||
for range 10 {
|
||||
logger.Log(req, resp)
|
||||
}
|
||||
logger.Flush(true)
|
||||
// test.Finish(nil)
|
||||
|
||||
ExpectEqual(t, logger.Config().Retention, nil)
|
||||
ExpectTrue(t, file.Len() > 0)
|
||||
ExpectEqual(t, file.Count(), 10)
|
||||
|
||||
t.Run("keep last", func(t *testing.T) {
|
||||
logger.Config().Retention = strutils.MustParse[*Retention]("last 5")
|
||||
ExpectEqual(t, logger.Config().Retention.Days, 0)
|
||||
ExpectEqual(t, logger.Config().Retention.Last, 5)
|
||||
ExpectNoError(t, logger.Rotate())
|
||||
ExpectEqual(t, file.Count(), 5)
|
||||
})
|
||||
|
||||
_ = file.Truncate(0)
|
||||
|
||||
timeNow := time.Now()
|
||||
for i := range 10 {
|
||||
logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time {
|
||||
return timeNow.AddDate(0, 0, -i)
|
||||
}
|
||||
logger.Log(req, resp)
|
||||
}
|
||||
logger.Flush(true)
|
||||
|
||||
// FIXME: keep days does not work
|
||||
t.Run("keep days", func(t *testing.T) {
|
||||
logger.Config().Retention = strutils.MustParse[*Retention]("3 days")
|
||||
ExpectEqual(t, logger.Config().Retention.Days, 3)
|
||||
ExpectEqual(t, logger.Config().Retention.Last, 0)
|
||||
ExpectNoError(t, logger.Rotate())
|
||||
ExpectEqual(t, file.Count(), 3)
|
||||
})
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDialer = net.Dialer{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
DefaultTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: defaultDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DisableCompression: true, // Prevent double compression
|
||||
ResponseHeaderTimeout: 60 * time.Second,
|
||||
WriteBufferSize: 16 * 1024, // 16KB
|
||||
ReadBufferSize: 16 * 1024, // 16KB
|
||||
}
|
||||
DefaultTransportNoTLS = func() *http.Transport {
|
||||
clone := DefaultTransport.Clone()
|
||||
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
return clone
|
||||
}()
|
||||
)
|
||||
|
||||
const StaticFilePathPrefix = "/$gperrorpage/"
|
|
@ -1,221 +0,0 @@
|
|||
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
|
||||
// Copyright (c) 2020-2024 Traefik Labs
|
||||
// Copyright (c) 2024 yusing
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
forwardAuth struct {
|
||||
ForwardAuthOpts
|
||||
Tracer
|
||||
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
|
||||
}
|
||||
ForwardAuthOpts struct {
|
||||
Address string `validate:"url,required"`
|
||||
TrustForwardHeader bool
|
||||
AuthResponseHeaders []string
|
||||
AddAuthCookiesToResponse []string
|
||||
}
|
||||
)
|
||||
|
||||
var ForwardAuth = NewMiddleware[forwardAuth]()
|
||||
|
||||
var faHTTPClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (fa *forwardAuth) setup() {
|
||||
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
|
||||
gphttp.RemoveHop(req.Header)
|
||||
|
||||
// Construct original URL for the redirect
|
||||
scheme := "http"
|
||||
if req.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
originalURL := scheme + "://" + req.Host + req.RequestURI
|
||||
|
||||
url := fa.Address
|
||||
faReq, err := http.NewRequestWithContext(
|
||||
req.Context(),
|
||||
http.MethodGet,
|
||||
url,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
fa.AddTracef("new request err to %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
gphttp.CopyHeader(faReq.Header, req.Header)
|
||||
gphttp.RemoveHop(faReq.Header)
|
||||
|
||||
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
|
||||
fa.setAuthHeaders(req, faReq)
|
||||
// Set headers needed by Authentik
|
||||
faReq.Header.Set("X-Original-Url", originalURL)
|
||||
fa.AddTraceRequest("forward auth request", faReq)
|
||||
|
||||
faResp, err := faHTTPClient.Do(faReq)
|
||||
if err != nil {
|
||||
fa.AddTracef("failed to call %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer faResp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(faResp.Body)
|
||||
if err != nil {
|
||||
fa.AddTracef("failed to read response body from %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
|
||||
fa.AddTraceResponse("forward auth response", faResp)
|
||||
gphttp.CopyHeader(w.Header(), faResp.Header)
|
||||
gphttp.RemoveHop(w.Header())
|
||||
|
||||
redirectURL, err := faResp.Location()
|
||||
if err != nil {
|
||||
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
} else if redirectURL.String() != "" {
|
||||
w.Header().Set("Location", redirectURL.String())
|
||||
fa.AddTracef("%s", "redirect to "+redirectURL.String())
|
||||
}
|
||||
|
||||
w.WriteHeader(faResp.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range fa.AuthResponseHeaders {
|
||||
key := http.CanonicalHeaderKey(key)
|
||||
req.Header.Del(key)
|
||||
if len(faResp.Header[key]) > 0 {
|
||||
req.Header[key] = append([]string(nil), faResp.Header[key]...)
|
||||
}
|
||||
}
|
||||
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
|
||||
authCookies := faResp.Cookies()
|
||||
|
||||
if len(authCookies) > 0 {
|
||||
fa.reqCookiesMap.Store(req, authCookies)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
|
||||
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
|
||||
fa.setAuthCookies(resp, cookies)
|
||||
fa.reqCookiesMap.Delete(resp.Request)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
|
||||
if len(fa.AddAuthCookiesToResponse) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cookies := resp.Cookies()
|
||||
resp.Header.Del("Set-Cookie")
|
||||
|
||||
for _, cookie := range cookies {
|
||||
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
|
||||
// this cookie is not an auth cookie, so add it back
|
||||
resp.Header.Add("Set-Cookie", cookie.String())
|
||||
}
|
||||
}
|
||||
|
||||
for _, cookie := range authCookies {
|
||||
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
|
||||
// this cookie is an auth cookie, so add to resp
|
||||
resp.Header.Add("Set-Cookie", cookie.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
if fa.TrustForwardHeader {
|
||||
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
}
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
|
||||
}
|
||||
|
||||
xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod)
|
||||
switch {
|
||||
case xMethod != "" && fa.TrustForwardHeader:
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod)
|
||||
case req.Method != "":
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
|
||||
default:
|
||||
faReq.Header.Del(gphttp.HeaderXForwardedMethod)
|
||||
}
|
||||
|
||||
xfp := req.Header.Get(gphttp.HeaderXForwardedProto)
|
||||
switch {
|
||||
case xfp != "" && fa.TrustForwardHeader:
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp)
|
||||
case req.TLS != nil:
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedProto, "https")
|
||||
default:
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedProto, "http")
|
||||
}
|
||||
|
||||
if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader {
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp)
|
||||
}
|
||||
|
||||
xfh := req.Header.Get(gphttp.HeaderXForwardedHost)
|
||||
switch {
|
||||
case xfh != "" && fa.TrustForwardHeader:
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh)
|
||||
case req.Host != "":
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
|
||||
default:
|
||||
faReq.Header.Del(gphttp.HeaderXForwardedHost)
|
||||
}
|
||||
|
||||
xfURI := req.Header.Get(gphttp.HeaderXForwardedURI)
|
||||
switch {
|
||||
case xfURI != "" && fa.TrustForwardHeader:
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI)
|
||||
case req.URL.RequestURI() != "":
|
||||
faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI())
|
||||
default:
|
||||
faReq.Header.Del(gphttp.HeaderXForwardedURI)
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue