mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-26 21:53:16 +02:00
refactor: remove forward auth, move module net/http to net/gphttp
This commit is contained in:
parent
c0c6e21a16
commit
5d2df3550b
69 changed files with 321 additions and 745 deletions
|
@ -1,29 +1,26 @@
|
||||||
package accesslog
|
package accesslog
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
AccessLogger struct {
|
AccessLogger struct {
|
||||||
task *task.Task
|
task *task.Task
|
||||||
cfg *Config
|
cfg *Config
|
||||||
io AccessLogIO
|
io AccessLogIO
|
||||||
|
buffered *bufio.Writer
|
||||||
buf bytes.Buffer // buffer for non-flushed log
|
|
||||||
bufMu sync.RWMutex
|
|
||||||
bufPool sync.Pool // buffer pool for formatting a single log line
|
|
||||||
|
|
||||||
flushThreshold int
|
|
||||||
|
|
||||||
|
lineBufPool sync.Pool // buffer pool for formatting a single log line
|
||||||
Formatter
|
Formatter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,14 +41,18 @@ type (
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
|
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
|
||||||
l := &AccessLogger{
|
if cfg.BufferSize == 0 {
|
||||||
task: parent.Subtask("accesslog"),
|
|
||||||
cfg: cfg,
|
|
||||||
io: io,
|
|
||||||
}
|
|
||||||
if cfg.BufferSize < 1024 {
|
|
||||||
cfg.BufferSize = DefaultBufferSize
|
cfg.BufferSize = DefaultBufferSize
|
||||||
}
|
}
|
||||||
|
if cfg.BufferSize < 4096 {
|
||||||
|
cfg.BufferSize = 4096
|
||||||
|
}
|
||||||
|
l := &AccessLogger{
|
||||||
|
task: parent.Subtask("accesslog"),
|
||||||
|
cfg: cfg,
|
||||||
|
io: io,
|
||||||
|
buffered: bufio.NewWriterSize(io, cfg.BufferSize),
|
||||||
|
}
|
||||||
|
|
||||||
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
|
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
|
||||||
switch l.cfg.Format {
|
switch l.cfg.Format {
|
||||||
|
@ -65,10 +66,8 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog
|
||||||
panic("invalid access log format")
|
panic("invalid access log format")
|
||||||
}
|
}
|
||||||
|
|
||||||
l.flushThreshold = int(cfg.BufferSize * 4 / 5) // 80%
|
l.lineBufPool.New = func() any {
|
||||||
l.buf.Grow(int(cfg.BufferSize))
|
return bytes.NewBuffer(make([]byte, 0, 1024))
|
||||||
l.bufPool.New = func() any {
|
|
||||||
return new(bytes.Buffer)
|
|
||||||
}
|
}
|
||||||
go l.start()
|
go l.start()
|
||||||
return l
|
return l
|
||||||
|
@ -89,15 +88,12 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
line := l.bufPool.Get().(*bytes.Buffer)
|
line := l.lineBufPool.Get().(*bytes.Buffer)
|
||||||
l.Format(line, req, res)
|
|
||||||
line.WriteRune('\n')
|
|
||||||
|
|
||||||
l.bufMu.Lock()
|
|
||||||
l.buf.Write(line.Bytes())
|
|
||||||
line.Reset()
|
line.Reset()
|
||||||
l.bufPool.Put(line)
|
defer l.lineBufPool.Put(line)
|
||||||
l.bufMu.Unlock()
|
l.Formatter.Format(line, req, res)
|
||||||
|
line.WriteRune('\n')
|
||||||
|
l.write(line.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) LogError(req *http.Request, err error) {
|
func (l *AccessLogger) LogError(req *http.Request, err error) {
|
||||||
|
@ -115,55 +111,53 @@ func (l *AccessLogger) Rotate() error {
|
||||||
l.io.Lock()
|
l.io.Lock()
|
||||||
defer l.io.Unlock()
|
defer l.io.Unlock()
|
||||||
|
|
||||||
return l.cfg.Retention.rotateLogFile(l.io)
|
return l.rotate()
|
||||||
}
|
|
||||||
|
|
||||||
func (l *AccessLogger) Flush(force bool) {
|
|
||||||
if l.buf.Len() == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if force || l.buf.Len() >= l.flushThreshold {
|
|
||||||
l.bufMu.RLock()
|
|
||||||
l.write(l.buf.Bytes())
|
|
||||||
l.buf.Reset()
|
|
||||||
l.bufMu.RUnlock()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) handleErr(err error) {
|
func (l *AccessLogger) handleErr(err error) {
|
||||||
E.LogError("failed to write access log", err)
|
gperr.LogError("failed to write access log", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) start() {
|
func (l *AccessLogger) start() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if l.buf.Len() > 0 { // flush last
|
if err := l.Flush(); err != nil {
|
||||||
l.write(l.buf.Bytes())
|
l.handleErr(err)
|
||||||
}
|
}
|
||||||
l.io.Close()
|
l.close()
|
||||||
l.task.Finish(nil)
|
l.task.Finish(nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// periodic flush + threshold flush
|
// flushes the buffer every 30 seconds
|
||||||
periodic := time.NewTicker(5 * time.Second)
|
flushTicker := time.NewTicker(30 * time.Second)
|
||||||
threshold := time.NewTicker(time.Second)
|
defer flushTicker.Stop()
|
||||||
defer periodic.Stop()
|
|
||||||
defer threshold.Stop()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-l.task.Context().Done():
|
case <-l.task.Context().Done():
|
||||||
return
|
return
|
||||||
case <-periodic.C:
|
case <-flushTicker.C:
|
||||||
l.Flush(true)
|
if err := l.Flush(); err != nil {
|
||||||
case <-threshold.C:
|
l.handleErr(err)
|
||||||
l.Flush(false)
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *AccessLogger) Flush() error {
|
||||||
|
l.io.Lock()
|
||||||
|
defer l.io.Unlock()
|
||||||
|
return l.buffered.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *AccessLogger) close() {
|
||||||
|
l.io.Lock()
|
||||||
|
defer l.io.Unlock()
|
||||||
|
l.io.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) write(data []byte) {
|
func (l *AccessLogger) write(data []byte) {
|
||||||
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
|
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
|
||||||
_, err := l.io.Write(data)
|
_, err := l.buffered.Write(data)
|
||||||
l.io.Unlock()
|
l.io.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.handleErr(err)
|
l.handleErr(err)
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
|
@ -17,7 +17,7 @@ type (
|
||||||
Cookies FieldConfig `json:"cookies"`
|
Cookies FieldConfig `json:"cookies"`
|
||||||
}
|
}
|
||||||
Config struct {
|
Config struct {
|
||||||
BufferSize uint `json:"buffer_size" validate:"gte=1"`
|
BufferSize int `json:"buffer_size"`
|
||||||
Format Format `json:"format" validate:"oneof=common combined json"`
|
Format Format `json:"format" validate:"oneof=common combined json"`
|
||||||
Path string `json:"path" validate:"required"`
|
Path string `json:"path" validate:"required"`
|
||||||
Filters Filters `json:"filters"`
|
Filters Filters `json:"filters"`
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/docker"
|
"github.com/yusing/go-proxy/internal/docker"
|
||||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
|
@ -3,7 +3,7 @@ package accesslog_test
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
|
@ -71,14 +71,14 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
|
||||||
go func(l *AccessLogger) {
|
go func(l *AccessLogger) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
parallelLog(l, req, resp, logCountPerLogger)
|
parallelLog(l, req, resp, logCountPerLogger)
|
||||||
l.Flush(true)
|
l.Flush()
|
||||||
}(logger)
|
}(logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
expected := loggerCount * logCountPerLogger
|
expected := loggerCount * logCountPerLogger
|
||||||
actual := file.Count()
|
actual := file.LineCount()
|
||||||
ExpectEqual(t, actual, expected)
|
ExpectEqual(t, actual, expected)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
@ -27,7 +27,7 @@ type (
|
||||||
CIDR struct{ types.CIDR }
|
CIDR struct{ types.CIDR }
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrInvalidHTTPHeaderFilter = E.New("invalid http header filter")
|
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
|
||||||
|
|
||||||
func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool {
|
func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool {
|
||||||
if len(f.Values) == 0 {
|
if len(f.Values) == 0 {
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
|
@ -49,7 +49,6 @@ func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
n = copy(p, m.data[off:])
|
n = copy(p, m.data[off:])
|
||||||
m.position += int64(n)
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,7 +62,7 @@ func (m *MockFile) Truncate(size int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockFile) Count() int {
|
func (m *MockFile) LineCount() int {
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
return bytes.Count(m.data[:m.position], []byte("\n"))
|
return bytes.Count(m.data[:m.position], []byte("\n"))
|
||||||
|
@ -72,3 +71,7 @@ func (m *MockFile) Count() int {
|
||||||
func (m *MockFile) Len() int64 {
|
func (m *MockFile) Len() int64 {
|
||||||
return m.position
|
return m.position
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockFile) Content() []byte {
|
||||||
|
return m.data[:m.position]
|
||||||
|
}
|
56
internal/net/gphttp/accesslog/retention.go
Normal file
56
internal/net/gphttp/accesslog/retention.go
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
package accesslog
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Retention struct {
|
||||||
|
Days uint64 `json:"days"`
|
||||||
|
Last uint64 `json:"last"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidSyntax = gperr.New("invalid syntax")
|
||||||
|
ErrZeroValue = gperr.New("zero value")
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultChunkSize = 64 * 1024 // 64KB
|
||||||
|
|
||||||
|
// Syntax:
|
||||||
|
//
|
||||||
|
// <N> days|weeks|months
|
||||||
|
//
|
||||||
|
// last <N>
|
||||||
|
//
|
||||||
|
// Parse implements strutils.Parser.
|
||||||
|
func (r *Retention) Parse(v string) (err error) {
|
||||||
|
split := strutils.SplitSpace(v)
|
||||||
|
if len(split) != 2 {
|
||||||
|
return ErrInvalidSyntax.Subject(v)
|
||||||
|
}
|
||||||
|
switch split[0] {
|
||||||
|
case "last":
|
||||||
|
r.Last, err = strconv.ParseUint(split[1], 10, 64)
|
||||||
|
default: // <N> days|weeks|months
|
||||||
|
r.Days, err = strconv.ParseUint(split[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch split[1] {
|
||||||
|
case "days":
|
||||||
|
case "weeks":
|
||||||
|
r.Days *= 7
|
||||||
|
case "months":
|
||||||
|
r.Days *= 30
|
||||||
|
default:
|
||||||
|
return ErrInvalidSyntax.Subject("unit " + split[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Days == 0 && r.Last == 0 {
|
||||||
|
return ErrZeroValue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
33
internal/net/gphttp/accesslog/retention_test.go
Normal file
33
internal/net/gphttp/accesslog/retention_test.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package accesslog_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseRetention(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected *Retention
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{"30 days", &Retention{Days: 30}, false},
|
||||||
|
{"2 weeks", &Retention{Days: 14}, false},
|
||||||
|
{"last 5", &Retention{Last: 5}, false},
|
||||||
|
{"invalid input", &Retention{}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.input, func(t *testing.T) {
|
||||||
|
r := &Retention{}
|
||||||
|
err := r.Parse(test.input)
|
||||||
|
if !test.shouldErr {
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
} else {
|
||||||
|
ExpectDeepEqual(t, r, test.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,7 +3,7 @@ package accesslog
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ type StatusCodeRange struct {
|
||||||
End int
|
End int
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrInvalidStatusCodeRange = E.New("invalid status code range")
|
var ErrInvalidStatusCodeRange = gperr.New("invalid status code range")
|
||||||
|
|
||||||
func (r *StatusCodeRange) Includes(code int) bool {
|
func (r *StatusCodeRange) Includes(code int) bool {
|
||||||
return r.Start <= code && code <= r.End
|
return r.Start <= code && code <= r.End
|
||||||
|
@ -25,7 +25,7 @@ func (r *StatusCodeRange) Parse(v string) error {
|
||||||
case 1:
|
case 1:
|
||||||
start, err := strconv.Atoi(split[0])
|
start, err := strconv.Atoi(split[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.From(err)
|
return gperr.Wrap(err)
|
||||||
}
|
}
|
||||||
r.Start = start
|
r.Start = start
|
||||||
r.End = start
|
r.End = start
|
||||||
|
@ -33,7 +33,7 @@ func (r *StatusCodeRange) Parse(v string) error {
|
||||||
case 2:
|
case 2:
|
||||||
start, errStart := strconv.Atoi(split[0])
|
start, errStart := strconv.Atoi(split[0])
|
||||||
end, errEnd := strconv.Atoi(split[1])
|
end, errEnd := strconv.Atoi(split[1])
|
||||||
if err := E.Join(errStart, errEnd); err != nil {
|
if err := gperr.Join(errStart, errEnd); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.Start = start
|
r.Start = start
|
|
@ -1,4 +1,4 @@
|
||||||
package http
|
package gphttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"mime"
|
"mime"
|
|
@ -1,4 +1,4 @@
|
||||||
package http
|
package gphttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
|
@ -6,8 +6,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ipHash struct {
|
type ipHash struct {
|
||||||
|
@ -23,10 +23,10 @@ func (lb *LoadBalancer) newIPHash() impl {
|
||||||
if len(lb.Options) == 0 {
|
if len(lb.Options) == 0 {
|
||||||
return impl
|
return impl
|
||||||
}
|
}
|
||||||
var err E.Error
|
var err gperr.Error
|
||||||
impl.realIP, err = middleware.RealIP.New(lb.Options)
|
impl.realIP, err = middleware.RealIP.New(lb.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
gperr.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
||||||
}
|
}
|
||||||
return impl
|
return impl
|
||||||
}
|
}
|
|
@ -6,10 +6,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||||
"github.com/yusing/go-proxy/internal/route/routes"
|
"github.com/yusing/go-proxy/internal/route/routes"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||||
|
@ -54,7 +54,7 @@ func New(cfg *Config) *LoadBalancer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start implements task.TaskStarter.
|
// Start implements task.TaskStarter.
|
||||||
func (lb *LoadBalancer) Start(parent task.Parent) E.Error {
|
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
|
||||||
lb.startTime = time.Now()
|
lb.startTime = time.Now()
|
||||||
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
|
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
|
||||||
parent.OnCancel("lb_remove_route", func() {
|
parent.OnCancel("lb_remove_route", func() {
|
||||||
|
@ -125,12 +125,12 @@ func (lb *LoadBalancer) AddServer(srv Server) {
|
||||||
lb.poolMu.Lock()
|
lb.poolMu.Lock()
|
||||||
defer lb.poolMu.Unlock()
|
defer lb.poolMu.Unlock()
|
||||||
|
|
||||||
if lb.pool.Has(srv.Name()) {
|
if lb.pool.Has(srv.Key()) { // FIXME: this should be a warning
|
||||||
old, _ := lb.pool.Load(srv.Name())
|
old, _ := lb.pool.Load(srv.Key())
|
||||||
lb.sumWeight -= old.Weight()
|
lb.sumWeight -= old.Weight()
|
||||||
lb.impl.OnRemoveServer(old)
|
lb.impl.OnRemoveServer(old)
|
||||||
}
|
}
|
||||||
lb.pool.Store(srv.Name(), srv)
|
lb.pool.Store(srv.Key(), srv)
|
||||||
lb.sumWeight += srv.Weight()
|
lb.sumWeight += srv.Weight()
|
||||||
|
|
||||||
lb.rebalance()
|
lb.rebalance()
|
||||||
|
@ -146,11 +146,11 @@ func (lb *LoadBalancer) RemoveServer(srv Server) {
|
||||||
lb.poolMu.Lock()
|
lb.poolMu.Lock()
|
||||||
defer lb.poolMu.Unlock()
|
defer lb.poolMu.Unlock()
|
||||||
|
|
||||||
if !lb.pool.Has(srv.Name()) {
|
if !lb.pool.Has(srv.Key()) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lb.pool.Delete(srv.Name())
|
lb.pool.Delete(srv.Key())
|
||||||
|
|
||||||
lb.sumWeight -= srv.Weight()
|
lb.sumWeight -= srv.Weight()
|
||||||
lb.rebalance()
|
lb.rebalance()
|
||||||
|
@ -227,7 +227,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Header.Get(common.HeaderCheckRedirect) != "" {
|
if r.Header.Get(httpheaders.HeaderGoDoxyCheckRedirect) != "" {
|
||||||
// wake all servers
|
// wake all servers
|
||||||
for _, srv := range srvs {
|
for _, srv := range srvs {
|
||||||
if err := srv.TryWake(); err != nil {
|
if err := srv.TryWake(); err != nil {
|
||||||
|
@ -244,7 +244,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||||
extra := make(map[string]any)
|
extra := make(map[string]any)
|
||||||
lb.pool.RangeAll(func(k string, v Server) {
|
lb.pool.RangeAll(func(k string, v Server) {
|
||||||
extra[v.Name()] = v
|
extra[v.Key()] = v
|
||||||
})
|
})
|
||||||
|
|
||||||
return (&monitor.JSONRepresentation{
|
return (&monitor.JSONRepresentation{
|
|
@ -3,7 +3,7 @@ package loadbalancer
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package loadbalancer
|
package loadbalancer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
|
@ -26,6 +26,7 @@ type (
|
||||||
http.Handler
|
http.Handler
|
||||||
health.HealthMonitor
|
health.HealthMonitor
|
||||||
Name() string
|
Name() string
|
||||||
|
Key() string
|
||||||
URL() *net.URL
|
URL() *net.URL
|
||||||
Weight() Weight
|
Weight() Weight
|
||||||
SetWeight(weight Weight)
|
SetWeight(weight Weight)
|
||||||
|
@ -51,6 +52,7 @@ func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, h
|
||||||
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
|
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
|
||||||
srv := &server{
|
srv := &server{
|
||||||
weight: Weight(weight),
|
weight: Weight(weight),
|
||||||
|
url: net.MustParseURL("http://localhost"),
|
||||||
}
|
}
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
@ -63,6 +65,10 @@ func (srv *server) URL() *net.URL {
|
||||||
return srv.url
|
return srv.url
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *server) Key() string {
|
||||||
|
return srv.url.Host
|
||||||
|
}
|
||||||
|
|
||||||
func (srv *server) Weight() Weight {
|
func (srv *server) Weight() Weight {
|
||||||
return srv.weight
|
return srv.weight
|
||||||
}
|
}
|
||||||
|
@ -78,9 +84,7 @@ func (srv *server) String() string {
|
||||||
func (srv *server) TryWake() error {
|
func (srv *server) TryWake() error {
|
||||||
waker, ok := srv.Handler.(idlewatcher.Waker)
|
waker, ok := srv.Handler.(idlewatcher.Waker)
|
||||||
if ok {
|
if ok {
|
||||||
if err := waker.Wake(); err != nil {
|
return waker.Wake()
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package http
|
package gphttp
|
||||||
|
|
||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
@ -61,7 +61,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCIDRWhitelist(t *testing.T) {
|
func TestCIDRWhitelist(t *testing.T) {
|
||||||
errs := E.NewBuilder("")
|
errs := gperr.NewBuilder("")
|
||||||
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
||||||
ExpectNoError(t, errs.Error())
|
ExpectNoError(t, errs.Error())
|
||||||
deny = mids["deny@file"]
|
deny = mids["deny@file"]
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
cfCIDRsLastUpdate time.Time
|
cfCIDRsLastUpdate atomic.Value[time.Time]
|
||||||
cfCIDRsMu sync.Mutex
|
cfCIDRsMu sync.Mutex
|
||||||
|
|
||||||
// RFC 1918.
|
// RFC 1918.
|
||||||
|
@ -68,14 +69,14 @@ func (cri *cloudflareRealIP) getTracer() *Tracer {
|
||||||
}
|
}
|
||||||
|
|
||||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cfCIDRsMu.Lock()
|
cfCIDRsMu.Lock()
|
||||||
defer cfCIDRsMu.Unlock()
|
defer cfCIDRsMu.Unlock()
|
||||||
|
|
||||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +89,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
|
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
|
||||||
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -97,7 +98,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cfCIDRsLastUpdate = time.Now()
|
cfCIDRsLastUpdate.Store(time.Now())
|
||||||
logging.Info().Msg("cloudflare CIDR range updated")
|
logging.Info().Msg("cloudflare CIDR range updated")
|
||||||
return
|
return
|
||||||
}
|
}
|
|
@ -9,14 +9,17 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
|
||||||
)
|
)
|
||||||
|
|
||||||
type customErrorPage struct{}
|
type customErrorPage struct{}
|
||||||
|
|
||||||
var CustomErrorPage = NewMiddleware[customErrorPage]()
|
var CustomErrorPage = NewMiddleware[customErrorPage]()
|
||||||
|
|
||||||
|
const StaticFilePathPrefix = "/$gperrorpage/"
|
||||||
|
|
||||||
// before implements RequestModifier.
|
// before implements RequestModifier.
|
||||||
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
return !ServeStaticErrorPageFile(w, r)
|
return !ServeStaticErrorPageFile(w, r)
|
||||||
|
@ -34,8 +37,8 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||||
resp.ContentLength = int64(len(errorPage))
|
resp.ContentLength = int64(len(errorPage))
|
||||||
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||||
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
|
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||||
} else {
|
} else {
|
||||||
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -49,8 +52,8 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
|
||||||
if path != "" && path[0] != '/' {
|
if path != "" && path[0] != '/' {
|
||||||
path = "/" + path
|
path = "/" + path
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) {
|
if strings.HasPrefix(path, StaticFilePathPrefix) {
|
||||||
filename := path[len(gphttp.StaticFilePathPrefix):]
|
filename := path[len(StaticFilePathPrefix):]
|
||||||
file, ok := errorpage.GetStaticFile(filename)
|
file, ok := errorpage.GetStaticFile(filename)
|
||||||
if !ok {
|
if !ok {
|
||||||
logging.Error().Msg("unable to load resource " + filename)
|
logging.Error().Msg("unable to load resource " + filename)
|
||||||
|
@ -59,11 +62,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
|
||||||
ext := filepath.Ext(filename)
|
ext := filepath.Ext(filename)
|
||||||
switch ext {
|
switch ext {
|
||||||
case ".html":
|
case ".html":
|
||||||
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
|
w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||||
case ".js":
|
case ".js":
|
||||||
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
|
w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
|
||||||
case ".css":
|
case ".css":
|
||||||
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
|
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
|
||||||
default:
|
default:
|
||||||
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||||
}
|
}
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
U "github.com/yusing/go-proxy/internal/utils"
|
||||||
|
@ -90,7 +90,7 @@ func watchDir() {
|
||||||
loadContent()
|
loadContent()
|
||||||
}
|
}
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
E.LogError("error watching error page directory", err)
|
gperr.LogError("error watching error page directory", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -7,15 +7,15 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
|
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Error = E.Error
|
Error = gperr.Error
|
||||||
|
|
||||||
ReverseProxy = reverseproxy.ReverseProxy
|
ReverseProxy = reverseproxy.ReverseProxy
|
||||||
ProxyRequest = reverseproxy.ProxyRequest
|
ProxyRequest = reverseproxy.ProxyRequest
|
||||||
|
@ -80,7 +80,7 @@ func NewMiddleware[ImplType any]() *Middleware {
|
||||||
func (m *Middleware) enableTrace() {
|
func (m *Middleware) enableTrace() {
|
||||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||||
tracer.enableTrace()
|
tracer.enableTrace()
|
||||||
logging.Debug().Msgf("middleware %s enabled trace", m.name)
|
logging.Trace().Msgf("middleware %s enabled trace", m.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ func (m *Middleware) setup() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
|
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
||||||
if len(optsRaw) == 0 {
|
if len(optsRaw) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -132,10 +132,10 @@ func (m *Middleware) finalize() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
|
||||||
if m.construct == nil { // likely a middleware from compose
|
if m.construct == nil { // likely a middleware from compose
|
||||||
if len(optsRaw) != 0 {
|
if len(optsRaw) != 0 {
|
||||||
return nil, E.New("additional options not allowed for middleware ").Subject(m.name)
|
return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,7 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := mid.finalize(); err != nil {
|
if err := mid.finalize(); err != nil {
|
||||||
return nil, E.From(err)
|
return nil, gperr.Wrap(err)
|
||||||
}
|
}
|
||||||
return mid, nil
|
return mid, nil
|
||||||
}
|
}
|
||||||
|
@ -196,7 +196,7 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
|
||||||
next(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
|
||||||
var middlewares []*Middleware
|
var middlewares []*Middleware
|
||||||
middlewares, err = compileMiddlewares(middlewaresMap)
|
middlewares, err = compileMiddlewares(middlewaresMap)
|
||||||
if err != nil {
|
if err != nil {
|
|
@ -6,13 +6,13 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrMissingMiddlewareUse = E.New("missing middleware 'use' field")
|
var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
|
||||||
|
|
||||||
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
|
func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
|
||||||
fileContent, err := os.ReadFile(filePath)
|
fileContent, err := os.ReadFile(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
eb.Add(err)
|
eb.Add(err)
|
||||||
|
@ -21,7 +21,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]
|
||||||
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
|
func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware {
|
||||||
var rawMap map[string][]map[string]any
|
var rawMap map[string][]map[string]any
|
||||||
err := yaml.Unmarshal(data, &rawMap)
|
err := yaml.Unmarshal(data, &rawMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -40,11 +40,11 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str
|
||||||
return middlewares
|
return middlewares
|
||||||
}
|
}
|
||||||
|
|
||||||
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
|
||||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||||
|
|
||||||
errs := E.NewBuilder("middlewares compile error")
|
errs := gperr.NewBuilder("middlewares compile error")
|
||||||
invalidOpts := E.NewBuilder("options compile error")
|
invalidOpts := gperr.NewBuilder("options compile error")
|
||||||
|
|
||||||
for name, opts := range middlewaresMap {
|
for name, opts := range middlewaresMap {
|
||||||
m, err := Get(name)
|
m, err := Get(name)
|
||||||
|
@ -68,7 +68,7 @@ func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.
|
||||||
return middlewares, errs.Error()
|
return middlewares, errs.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, E.Error) {
|
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
|
||||||
compiled, err := compileMiddlewares(middlewaresMap)
|
compiled, err := compileMiddlewares(middlewaresMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -77,8 +77,8 @@ func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check conflict or duplicates.
|
// TODO: check conflict or duplicates.
|
||||||
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, E.Error) {
|
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
|
||||||
chainErr := E.NewBuilder("")
|
chainErr := gperr.NewBuilder("")
|
||||||
chain := make([]*Middleware, 0, len(defs))
|
chain := make([]*Middleware, 0, len(defs))
|
||||||
for i, def := range defs {
|
for i, def := range defs {
|
||||||
if def["use"] == nil || def["use"] == "" {
|
if def["use"] == nil || def["use"] == "" {
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ import (
|
||||||
var testMiddlewareCompose []byte
|
var testMiddlewareCompose []byte
|
||||||
|
|
||||||
func TestBuild(t *testing.T) {
|
func TestBuild(t *testing.T) {
|
||||||
errs := E.NewBuilder("")
|
errs := gperr.NewBuilder("")
|
||||||
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
|
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
|
||||||
ExpectNoError(t, errs.Error())
|
ExpectNoError(t, errs.Error())
|
||||||
Must(json.MarshalIndent(middlewares, "", " "))
|
Must(json.MarshalIndent(middlewares, "", " "))
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type middlewareChain struct {
|
type middlewareChain struct {
|
||||||
|
@ -51,10 +51,10 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
||||||
if len(m.modResps) == 0 {
|
if len(m.modResps) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
errs := E.NewBuilder("modify response errors")
|
errs := gperr.NewBuilder("modify response errors")
|
||||||
for i, mr := range m.modResps {
|
for i, mr := range m.modResps {
|
||||||
if err := mr.modifyResponse(resp); err != nil {
|
if err := mr.modifyResponse(resp); err != nil {
|
||||||
errs.Add(E.From(err).Subjectf("%d", i))
|
errs.Add(gperr.Wrap(err).Subjectf("%d", i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return errs.Error()
|
return errs.Error()
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
@ -32,15 +32,11 @@ var allMiddlewares = map[string]*Middleware{
|
||||||
|
|
||||||
"cidrwhitelist": CIDRWhiteList,
|
"cidrwhitelist": CIDRWhiteList,
|
||||||
"ratelimit": RateLimiter,
|
"ratelimit": RateLimiter,
|
||||||
|
|
||||||
// !experimental
|
|
||||||
"forwardauth": ForwardAuth,
|
|
||||||
// "oauth2": OAuth2.m,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUnknownMiddleware = E.New("unknown middleware")
|
ErrUnknownMiddleware = gperr.New("unknown middleware")
|
||||||
ErrDuplicatedMiddleware = E.New("duplicated middleware")
|
ErrDuplicatedMiddleware = gperr.New("duplicated middleware")
|
||||||
)
|
)
|
||||||
|
|
||||||
func Get(name string) (*Middleware, Error) {
|
func Get(name string) (*Middleware, Error) {
|
||||||
|
@ -58,14 +54,14 @@ func All() map[string]*Middleware {
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadComposeFiles() {
|
func LoadComposeFiles() {
|
||||||
errs := E.NewBuilder("middleware compile errors")
|
errs := gperr.NewBuilder("middleware compile errors")
|
||||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Err(err).Msg("failed to list middleware definitions")
|
logging.Err(err).Msg("failed to list middleware definitions")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, defFile := range middlewareDefs {
|
for _, defFile := range middlewareDefs {
|
||||||
voidErrs := E.NewBuilder("") // ignore these errors, will be added in next step
|
voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step
|
||||||
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
|
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
|
||||||
if len(mws) == 0 {
|
if len(mws) == 0 {
|
||||||
continue
|
continue
|
||||||
|
@ -103,6 +99,6 @@ func LoadComposeFiles() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errs.HasError() {
|
if errs.HasError() {
|
||||||
E.LogError(errs.About(), errs.Error())
|
gperr.LogError(errs.About(), errs.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,12 +1,13 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type oidcMiddleware struct {
|
type oidcMiddleware struct {
|
||||||
|
@ -24,7 +25,7 @@ var OIDC = NewMiddleware[oidcMiddleware]()
|
||||||
|
|
||||||
func (amw *oidcMiddleware) finalize() error {
|
func (amw *oidcMiddleware) finalize() error {
|
||||||
if !auth.IsOIDCEnabled() {
|
if !auth.IsOIDCEnabled() {
|
||||||
return E.New("OIDC not enabled but OIDC middleware is used")
|
return gperr.New("OIDC not enabled but OIDC middleware is used")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -64,9 +65,6 @@ func (amw *oidcMiddleware) initSlow() error {
|
||||||
|
|
||||||
amw.authMux = http.NewServeMux()
|
amw.authMux = http.NewServeMux()
|
||||||
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||||
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
||||||
})
|
|
||||||
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
||||||
amw.auth = authProvider
|
amw.auth = authProvider
|
||||||
return nil
|
return nil
|
||||||
|
@ -79,13 +77,17 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := amw.auth.CheckToken(r); err != nil {
|
|
||||||
amw.authMux.ServeHTTP(w, r)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if r.URL.Path == auth.OIDCLogoutPath {
|
if r.URL.Path == auth.OIDCLogoutPath {
|
||||||
amw.auth.LogoutCallbackHandler(w, r)
|
amw.auth.LogoutCallbackHandler(w, r)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if err := amw.auth.CheckToken(r); err != nil {
|
||||||
|
if errors.Is(err, auth.ErrMissingToken) {
|
||||||
|
amw.authMux.ServeHTTP(w, r)
|
||||||
|
} else {
|
||||||
|
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -111,6 +111,6 @@ func (ri *realIP) setRealIP(req *http.Request) {
|
||||||
|
|
||||||
req.RemoteAddr = lastNonTrustedIP
|
req.RemoteAddr = lastNonTrustedIP
|
||||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||||
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
|
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
|
||||||
ri.AddTracef("set real ip %s", lastNonTrustedIP)
|
ri.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||||
}
|
}
|
|
@ -6,14 +6,14 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSetRealIPOpts(t *testing.T) {
|
func TestSetRealIPOpts(t *testing.T) {
|
||||||
opts := OptionsRaw{
|
opts := OptionsRaw{
|
||||||
"header": gphttp.HeaderXRealIP,
|
"header": httpheaders.HeaderXRealIP,
|
||||||
"from": []string{
|
"from": []string{
|
||||||
"127.0.0.0/8",
|
"127.0.0.0/8",
|
||||||
"192.168.0.0/16",
|
"192.168.0.0/16",
|
||||||
|
@ -22,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
||||||
"recursive": true,
|
"recursive": true,
|
||||||
}
|
}
|
||||||
optExpected := &RealIPOpts{
|
optExpected := &RealIPOpts{
|
||||||
Header: gphttp.HeaderXRealIP,
|
Header: httpheaders.HeaderXRealIP,
|
||||||
From: []*types.CIDR{
|
From: []*types.CIDR{
|
||||||
{
|
{
|
||||||
IP: net.ParseIP("127.0.0.0"),
|
IP: net.ParseIP("127.0.0.0"),
|
||||||
|
@ -51,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
||||||
|
|
||||||
func TestSetRealIP(t *testing.T) {
|
func TestSetRealIP(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
testHeader = gphttp.HeaderXRealIP
|
testHeader = httpheaders.HeaderXRealIP
|
||||||
testRealIP = "192.168.1.1"
|
testRealIP = "192.168.1.1"
|
||||||
)
|
)
|
||||||
opts := OptionsRaw{
|
opts := OptionsRaw{
|
|
@ -44,7 +44,7 @@ func (m *redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed b
|
||||||
r.URL.Host = host
|
r.URL.Host = host
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently)
|
http.Redirect(w, r, r.URL.String(), http.StatusPermanentRedirect)
|
||||||
|
|
||||||
logging.Debug().Str("url", r.URL.String()).Str("user_agent", r.UserAgent()).Msg("redirect to https")
|
logging.Debug().Str("url", r.URL.String()).Str("user_agent", r.UserAgent()).Msg("redirect to https")
|
||||||
return false
|
return false
|
|
@ -13,7 +13,7 @@ func TestRedirectToHTTPs(t *testing.T) {
|
||||||
reqURL: types.MustParseURL("http://example.com"),
|
reqURL: types.MustParseURL("http://example.com"),
|
||||||
})
|
})
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectEqual(t, result.ResponseStatus, http.StatusMovedPermanently)
|
ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
|
||||||
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com")
|
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com")
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,8 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
|
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// internal use only.
|
// internal use only.
|
||||||
|
@ -29,9 +29,9 @@ func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
|
||||||
|
|
||||||
// before implements RequestModifier.
|
// before implements RequestModifier.
|
||||||
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
r.Header.Set(gphttp.HeaderUpstreamName, s.Name)
|
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
|
||||||
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
|
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
|
||||||
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
|
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
|
||||||
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
|
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
|
@ -8,19 +8,6 @@ theGreatPretender:
|
||||||
- X-Test3
|
- X-Test3
|
||||||
- X-Test4
|
- X-Test4
|
||||||
|
|
||||||
notAuthenticAuthentik:
|
|
||||||
- use: RedirectHTTP
|
|
||||||
- use: ForwardAuth
|
|
||||||
address: https://authentik.company
|
|
||||||
trustForwardHeader: true
|
|
||||||
addAuthCookiesToResponse:
|
|
||||||
- session_id
|
|
||||||
- user_id
|
|
||||||
authResponseHeaders:
|
|
||||||
- X-Auth-SessionID
|
|
||||||
- X-Auth-UserID
|
|
||||||
- use: CustomErrorPage
|
|
||||||
|
|
||||||
realIPAuthentik:
|
realIPAuthentik:
|
||||||
- use: RedirectHTTP
|
- use: RedirectHTTP
|
||||||
- use: RealIP
|
- use: RealIP
|
||||||
|
@ -30,9 +17,6 @@ realIPAuthentik:
|
||||||
- "192.168.0.0/16"
|
- "192.168.0.0/16"
|
||||||
- "172.16.0.0/12"
|
- "172.16.0.0/12"
|
||||||
recursive: true
|
recursive: true
|
||||||
- use: ForwardAuth
|
|
||||||
address: https://authentik.company
|
|
||||||
trustForwardHeader: true
|
|
||||||
|
|
||||||
testFakeRealIP:
|
testFakeRealIP:
|
||||||
- use: ModifyRequest
|
- use: ModifyRequest
|
|
@ -9,8 +9,8 @@ import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
|
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
@ -122,7 +122,7 @@ func (args *testArgs) bodyReader() io.Reader {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
|
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||||
if args == nil {
|
if args == nil {
|
||||||
args = new(testArgs)
|
args = new(testArgs)
|
||||||
}
|
}
|
||||||
|
@ -136,7 +136,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
||||||
return newMiddlewaresTest([]*Middleware{mid}, args)
|
return newMiddlewaresTest([]*Middleware{mid}, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) {
|
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||||
if args == nil {
|
if args == nil {
|
||||||
args = new(testArgs)
|
args = new(testArgs)
|
||||||
}
|
}
|
||||||
|
@ -163,7 +163,7 @@ func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult,
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
data, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.From(err)
|
return nil, gperr.Wrap(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &TestResult{
|
return &TestResult{
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -37,7 +37,7 @@ func (tr *Trace) WithRequest(req *http.Request) *Trace {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
tr.URL = req.RequestURI
|
tr.URL = req.RequestURI
|
||||||
tr.ReqHeaders = gphttp.HeaderToMap(req.Header)
|
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,8 +46,8 @@ func (tr *Trace) WithResponse(resp *http.Response) *Trace {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
tr.URL = resp.Request.RequestURI
|
tr.URL = resp.Request.RequestURI
|
||||||
tr.ReqHeaders = gphttp.HeaderToMap(resp.Request.Header)
|
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
|
||||||
tr.RespHeaders = gphttp.HeaderToMap(resp.Header)
|
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
|
||||||
tr.RespStatus = resp.StatusCode
|
tr.RespStatus = resp.StatusCode
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -91,25 +91,25 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||||
return ""
|
return ""
|
||||||
},
|
},
|
||||||
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||||
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamName) },
|
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
|
||||||
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
|
||||||
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
|
||||||
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
|
||||||
VarUpstreamAddr: func(req *http.Request) string {
|
VarUpstreamAddr: func(req *http.Request) string {
|
||||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||||
if upPort != "" {
|
if upPort != "" {
|
||||||
return upHost + ":" + upPort
|
return upHost + ":" + upPort
|
||||||
}
|
}
|
||||||
return upHost
|
return upHost
|
||||||
},
|
},
|
||||||
VarUpstreamURL: func(req *http.Request) string {
|
VarUpstreamURL: func(req *http.Request) string {
|
||||||
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
|
||||||
if upScheme == "" {
|
if upScheme == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||||
upAddr := upHost
|
upAddr := upHost
|
||||||
if upPort != "" {
|
if upPort != "" {
|
||||||
upAddr += ":" + upPort
|
upAddr += ":" + upPort
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -20,10 +20,10 @@ var (
|
||||||
|
|
||||||
// before implements RequestModifier.
|
// before implements RequestModifier.
|
||||||
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
r.Header.Del(gphttp.HeaderXForwardedFor)
|
r.Header.Del(httpheaders.HeaderXForwardedFor)
|
||||||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
|
r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
|
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
|
||||||
// Copyright (c) 2020-2024 Traefik Labs
|
// Copyright (c) 2020-2024 Traefik Labs
|
||||||
|
|
||||||
package http
|
package gphttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
|
@ -1,4 +1,4 @@
|
||||||
package http
|
package gphttp
|
||||||
|
|
||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package http
|
package gphttp
|
||||||
|
|
||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
34
internal/net/gphttp/transport.go
Normal file
34
internal/net/gphttp/transport.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package gphttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var DefaultDialer = net.Dialer{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransport() *http.Transport {
|
||||||
|
return &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: DefaultDialer.DialContext,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConnsPerHost: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
// DisableCompression: true, // Prevent double compression
|
||||||
|
ResponseHeaderTimeout: 60 * time.Second,
|
||||||
|
WriteBufferSize: 16 * 1024, // 16KB
|
||||||
|
ReadBufferSize: 16 * 1024, // 16KB
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransportWithTLSConfig(tlsConfig *tls.Config) *http.Transport {
|
||||||
|
tr := NewTransport()
|
||||||
|
tr.TLSClientConfig = tlsConfig
|
||||||
|
return tr
|
||||||
|
}
|
|
@ -1,198 +0,0 @@
|
||||||
package accesslog
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Retention struct {
|
|
||||||
Days uint64 `json:"days"`
|
|
||||||
Last uint64 `json:"last"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const chunkSizeMax int64 = 128 * 1024 // 128KB
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidSyntax = E.New("invalid syntax")
|
|
||||||
ErrZeroValue = E.New("zero value")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Syntax:
|
|
||||||
//
|
|
||||||
// <N> days|weeks|months
|
|
||||||
//
|
|
||||||
// last <N>
|
|
||||||
//
|
|
||||||
// Parse implements strutils.Parser.
|
|
||||||
func (r *Retention) Parse(v string) (err error) {
|
|
||||||
split := strutils.SplitSpace(v)
|
|
||||||
if len(split) != 2 {
|
|
||||||
return ErrInvalidSyntax.Subject(v)
|
|
||||||
}
|
|
||||||
switch split[0] {
|
|
||||||
case "last":
|
|
||||||
r.Last, err = strconv.ParseUint(split[1], 10, 64)
|
|
||||||
default: // <N> days|weeks|months
|
|
||||||
r.Days, err = strconv.ParseUint(split[0], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
switch split[1] {
|
|
||||||
case "days":
|
|
||||||
case "weeks":
|
|
||||||
r.Days *= 7
|
|
||||||
case "months":
|
|
||||||
r.Days *= 30
|
|
||||||
default:
|
|
||||||
return ErrInvalidSyntax.Subject("unit " + split[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if r.Days == 0 && r.Last == 0 {
|
|
||||||
return ErrZeroValue
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Retention) rotateLogFile(file AccessLogIO) (err error) {
|
|
||||||
lastN := int(r.Last)
|
|
||||||
days := int(r.Days)
|
|
||||||
|
|
||||||
// Seek to end to get file size
|
|
||||||
size, err := file.Seek(0, io.SeekEnd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize ring buffer for last N lines
|
|
||||||
lines := make([][]byte, 0, lastN|(days*1000))
|
|
||||||
pos := size
|
|
||||||
unprocessed := 0
|
|
||||||
|
|
||||||
var chunk [chunkSizeMax]byte
|
|
||||||
var lastLine []byte
|
|
||||||
|
|
||||||
var shouldStop func() bool
|
|
||||||
if days > 0 {
|
|
||||||
cutoff := time.Now().AddDate(0, 0, -days)
|
|
||||||
shouldStop = func() bool {
|
|
||||||
return len(lastLine) > 0 && !parseLogTime(lastLine).After(cutoff)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
shouldStop = func() bool {
|
|
||||||
return len(lines) == lastN
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read backwards until we have enough lines or reach start of file
|
|
||||||
for pos > 0 {
|
|
||||||
if pos > chunkSizeMax {
|
|
||||||
pos -= chunkSizeMax
|
|
||||||
} else {
|
|
||||||
pos = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Seek to the current chunk
|
|
||||||
if _, err = file.Seek(pos, io.SeekStart); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var nRead int
|
|
||||||
// Read the chunk
|
|
||||||
if nRead, err = file.Read(chunk[unprocessed:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// last unprocessed bytes + read bytes
|
|
||||||
curChunk := chunk[:unprocessed+nRead]
|
|
||||||
unprocessed = len(curChunk)
|
|
||||||
|
|
||||||
// Split into lines
|
|
||||||
scanner := bufio.NewScanner(bytes.NewReader(curChunk))
|
|
||||||
for !shouldStop() && scanner.Scan() {
|
|
||||||
lastLine = scanner.Bytes()
|
|
||||||
lines = append(lines, lastLine)
|
|
||||||
unprocessed -= len(lastLine)
|
|
||||||
}
|
|
||||||
if shouldStop() {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// move unprocessed bytes to the beginning for next iteration
|
|
||||||
copy(chunk[:], curChunk[unprocessed:])
|
|
||||||
}
|
|
||||||
|
|
||||||
if days > 0 {
|
|
||||||
// truncate to the end of the log within last N days
|
|
||||||
return file.Truncate(pos)
|
|
||||||
}
|
|
||||||
|
|
||||||
// write lines to buffer in reverse order
|
|
||||||
// since we read them backwards
|
|
||||||
var buf bytes.Buffer
|
|
||||||
for i := len(lines) - 1; i >= 0; i-- {
|
|
||||||
buf.Write(lines[i])
|
|
||||||
buf.WriteRune('\n')
|
|
||||||
}
|
|
||||||
|
|
||||||
return writeTruncate(file, &buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeTruncate(file AccessLogIO, buf *bytes.Buffer) (err error) {
|
|
||||||
// Seek to beginning and truncate
|
|
||||||
if _, err := file.Seek(0, 0); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
buffered := bufio.NewWriter(file)
|
|
||||||
// Write buffer back to file
|
|
||||||
nWritten, err := buffered.Write(buf.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = buffered.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Truncate file
|
|
||||||
if err = file.Truncate(int64(nWritten)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// check bytes written == buffer size
|
|
||||||
if nWritten != buf.Len() {
|
|
||||||
return io.ErrShortWrite
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseLogTime(line []byte) (t time.Time) {
|
|
||||||
if len(line) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var start, end int
|
|
||||||
const jsonStart = len(`{"time":"`)
|
|
||||||
const jsonEnd = jsonStart + len(LogTimeFormat)
|
|
||||||
|
|
||||||
if len(line) == '{' { // possibly json log
|
|
||||||
start = jsonStart
|
|
||||||
end = jsonEnd
|
|
||||||
} else { // possibly common or combined format
|
|
||||||
// Format: <virtual host> <host ip> - - [02/Jan/2006:15:04:05 -0700] ...
|
|
||||||
start = bytes.IndexRune(line, '[')
|
|
||||||
end = bytes.IndexRune(line[start+1:], ']')
|
|
||||||
if start == -1 || end == -1 || start >= end {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
timeStr := line[start+1 : end]
|
|
||||||
t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,81 +0,0 @@
|
||||||
package accesslog_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseRetention(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expected *Retention
|
|
||||||
shouldErr bool
|
|
||||||
}{
|
|
||||||
{"30 days", &Retention{Days: 30}, false},
|
|
||||||
{"2 weeks", &Retention{Days: 14}, false},
|
|
||||||
{"last 5", &Retention{Last: 5}, false},
|
|
||||||
{"invalid input", &Retention{}, true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.input, func(t *testing.T) {
|
|
||||||
r := &Retention{}
|
|
||||||
err := r.Parse(test.input)
|
|
||||||
if !test.shouldErr {
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
} else {
|
|
||||||
ExpectDeepEqual(t, r, test.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRetentionCommonFormat(t *testing.T) {
|
|
||||||
var file MockFile
|
|
||||||
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{
|
|
||||||
Format: FormatCommon,
|
|
||||||
BufferSize: 1024,
|
|
||||||
})
|
|
||||||
for range 10 {
|
|
||||||
logger.Log(req, resp)
|
|
||||||
}
|
|
||||||
logger.Flush(true)
|
|
||||||
// test.Finish(nil)
|
|
||||||
|
|
||||||
ExpectEqual(t, logger.Config().Retention, nil)
|
|
||||||
ExpectTrue(t, file.Len() > 0)
|
|
||||||
ExpectEqual(t, file.Count(), 10)
|
|
||||||
|
|
||||||
t.Run("keep last", func(t *testing.T) {
|
|
||||||
logger.Config().Retention = strutils.MustParse[*Retention]("last 5")
|
|
||||||
ExpectEqual(t, logger.Config().Retention.Days, 0)
|
|
||||||
ExpectEqual(t, logger.Config().Retention.Last, 5)
|
|
||||||
ExpectNoError(t, logger.Rotate())
|
|
||||||
ExpectEqual(t, file.Count(), 5)
|
|
||||||
})
|
|
||||||
|
|
||||||
_ = file.Truncate(0)
|
|
||||||
|
|
||||||
timeNow := time.Now()
|
|
||||||
for i := range 10 {
|
|
||||||
logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time {
|
|
||||||
return timeNow.AddDate(0, 0, -i)
|
|
||||||
}
|
|
||||||
logger.Log(req, resp)
|
|
||||||
}
|
|
||||||
logger.Flush(true)
|
|
||||||
|
|
||||||
// FIXME: keep days does not work
|
|
||||||
t.Run("keep days", func(t *testing.T) {
|
|
||||||
logger.Config().Retention = strutils.MustParse[*Retention]("3 days")
|
|
||||||
ExpectEqual(t, logger.Config().Retention.Days, 3)
|
|
||||||
ExpectEqual(t, logger.Config().Retention.Last, 0)
|
|
||||||
ExpectNoError(t, logger.Rotate())
|
|
||||||
ExpectEqual(t, file.Count(), 3)
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,34 +0,0 @@
|
||||||
package http
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
defaultDialer = net.Dialer{
|
|
||||||
Timeout: 60 * time.Second,
|
|
||||||
}
|
|
||||||
DefaultTransport = &http.Transport{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: defaultDialer.DialContext,
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
MaxIdleConnsPerHost: 100,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
DisableCompression: true, // Prevent double compression
|
|
||||||
ResponseHeaderTimeout: 60 * time.Second,
|
|
||||||
WriteBufferSize: 16 * 1024, // 16KB
|
|
||||||
ReadBufferSize: 16 * 1024, // 16KB
|
|
||||||
}
|
|
||||||
DefaultTransportNoTLS = func() *http.Transport {
|
|
||||||
clone := DefaultTransport.Clone()
|
|
||||||
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
|
||||||
return clone
|
|
||||||
}()
|
|
||||||
)
|
|
||||||
|
|
||||||
const StaticFilePathPrefix = "/$gperrorpage/"
|
|
|
@ -1,221 +0,0 @@
|
||||||
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
|
|
||||||
// Copyright (c) 2020-2024 Traefik Labs
|
|
||||||
// Copyright (c) 2024 yusing
|
|
||||||
|
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
forwardAuth struct {
|
|
||||||
ForwardAuthOpts
|
|
||||||
Tracer
|
|
||||||
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
|
|
||||||
}
|
|
||||||
ForwardAuthOpts struct {
|
|
||||||
Address string `validate:"url,required"`
|
|
||||||
TrustForwardHeader bool
|
|
||||||
AuthResponseHeaders []string
|
|
||||||
AddAuthCookiesToResponse []string
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
var ForwardAuth = NewMiddleware[forwardAuth]()
|
|
||||||
|
|
||||||
var faHTTPClient = &http.Client{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// setup implements MiddlewareWithSetup.
|
|
||||||
func (fa *forwardAuth) setup() {
|
|
||||||
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
|
|
||||||
}
|
|
||||||
|
|
||||||
// before implements RequestModifier.
|
|
||||||
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
|
|
||||||
gphttp.RemoveHop(req.Header)
|
|
||||||
|
|
||||||
// Construct original URL for the redirect
|
|
||||||
scheme := "http"
|
|
||||||
if req.TLS != nil {
|
|
||||||
scheme = "https"
|
|
||||||
}
|
|
||||||
originalURL := scheme + "://" + req.Host + req.RequestURI
|
|
||||||
|
|
||||||
url := fa.Address
|
|
||||||
faReq, err := http.NewRequestWithContext(
|
|
||||||
req.Context(),
|
|
||||||
http.MethodGet,
|
|
||||||
url,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
fa.AddTracef("new request err to %s", url).WithError(err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
gphttp.CopyHeader(faReq.Header, req.Header)
|
|
||||||
gphttp.RemoveHop(faReq.Header)
|
|
||||||
|
|
||||||
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
|
|
||||||
fa.setAuthHeaders(req, faReq)
|
|
||||||
// Set headers needed by Authentik
|
|
||||||
faReq.Header.Set("X-Original-Url", originalURL)
|
|
||||||
fa.AddTraceRequest("forward auth request", faReq)
|
|
||||||
|
|
||||||
faResp, err := faHTTPClient.Do(faReq)
|
|
||||||
if err != nil {
|
|
||||||
fa.AddTracef("failed to call %s", url).WithError(err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer faResp.Body.Close()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(faResp.Body)
|
|
||||||
if err != nil {
|
|
||||||
fa.AddTracef("failed to read response body from %s", url).WithError(err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
fa.AddTraceResponse("forward auth response", faResp)
|
|
||||||
gphttp.CopyHeader(w.Header(), faResp.Header)
|
|
||||||
gphttp.RemoveHop(w.Header())
|
|
||||||
|
|
||||||
redirectURL, err := faResp.Location()
|
|
||||||
if err != nil {
|
|
||||||
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
} else if redirectURL.String() != "" {
|
|
||||||
w.Header().Set("Location", redirectURL.String())
|
|
||||||
fa.AddTracef("%s", "redirect to "+redirectURL.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(faResp.StatusCode)
|
|
||||||
|
|
||||||
if _, err = w.Write(body); err != nil {
|
|
||||||
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, key := range fa.AuthResponseHeaders {
|
|
||||||
key := http.CanonicalHeaderKey(key)
|
|
||||||
req.Header.Del(key)
|
|
||||||
if len(faResp.Header[key]) > 0 {
|
|
||||||
req.Header[key] = append([]string(nil), faResp.Header[key]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.RequestURI = req.URL.RequestURI()
|
|
||||||
|
|
||||||
authCookies := faResp.Cookies()
|
|
||||||
|
|
||||||
if len(authCookies) > 0 {
|
|
||||||
fa.reqCookiesMap.Store(req, authCookies)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// modifyResponse implements ResponseModifier.
|
|
||||||
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
|
|
||||||
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
|
|
||||||
fa.setAuthCookies(resp, cookies)
|
|
||||||
fa.reqCookiesMap.Delete(resp.Request)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
|
|
||||||
if len(fa.AddAuthCookiesToResponse) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cookies := resp.Cookies()
|
|
||||||
resp.Header.Del("Set-Cookie")
|
|
||||||
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
|
|
||||||
// this cookie is not an auth cookie, so add it back
|
|
||||||
resp.Header.Add("Set-Cookie", cookie.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cookie := range authCookies {
|
|
||||||
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
|
|
||||||
// this cookie is an auth cookie, so add to resp
|
|
||||||
resp.Header.Add("Set-Cookie", cookie.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
|
|
||||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
|
||||||
if fa.TrustForwardHeader {
|
|
||||||
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
|
|
||||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
|
||||||
}
|
|
||||||
}
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod)
|
|
||||||
switch {
|
|
||||||
case xMethod != "" && fa.TrustForwardHeader:
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod)
|
|
||||||
case req.Method != "":
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
|
|
||||||
default:
|
|
||||||
faReq.Header.Del(gphttp.HeaderXForwardedMethod)
|
|
||||||
}
|
|
||||||
|
|
||||||
xfp := req.Header.Get(gphttp.HeaderXForwardedProto)
|
|
||||||
switch {
|
|
||||||
case xfp != "" && fa.TrustForwardHeader:
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp)
|
|
||||||
case req.TLS != nil:
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedProto, "https")
|
|
||||||
default:
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedProto, "http")
|
|
||||||
}
|
|
||||||
|
|
||||||
if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader {
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp)
|
|
||||||
}
|
|
||||||
|
|
||||||
xfh := req.Header.Get(gphttp.HeaderXForwardedHost)
|
|
||||||
switch {
|
|
||||||
case xfh != "" && fa.TrustForwardHeader:
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh)
|
|
||||||
case req.Host != "":
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
|
|
||||||
default:
|
|
||||||
faReq.Header.Del(gphttp.HeaderXForwardedHost)
|
|
||||||
}
|
|
||||||
|
|
||||||
xfURI := req.Header.Get(gphttp.HeaderXForwardedURI)
|
|
||||||
switch {
|
|
||||||
case xfURI != "" && fa.TrustForwardHeader:
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI)
|
|
||||||
case req.URL.RequestURI() != "":
|
|
||||||
faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI())
|
|
||||||
default:
|
|
||||||
faReq.Header.Del(gphttp.HeaderXForwardedURI)
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Reference in a new issue