refactor: remove forward auth, move module net/http to net/gphttp

This commit is contained in:
yusing 2025-03-28 07:03:35 +08:00
parent c0c6e21a16
commit 5d2df3550b
69 changed files with 321 additions and 745 deletions

View file

@ -1,29 +1,26 @@
package accesslog
import (
"bufio"
"bytes"
"io"
"net/http"
"sync"
"time"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
type (
AccessLogger struct {
task *task.Task
cfg *Config
io AccessLogIO
buf bytes.Buffer // buffer for non-flushed log
bufMu sync.RWMutex
bufPool sync.Pool // buffer pool for formatting a single log line
flushThreshold int
task *task.Task
cfg *Config
io AccessLogIO
buffered *bufio.Writer
lineBufPool sync.Pool // buffer pool for formatting a single log line
Formatter
}
@ -44,14 +41,18 @@ type (
)
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
l := &AccessLogger{
task: parent.Subtask("accesslog"),
cfg: cfg,
io: io,
}
if cfg.BufferSize < 1024 {
if cfg.BufferSize == 0 {
cfg.BufferSize = DefaultBufferSize
}
if cfg.BufferSize < 4096 {
cfg.BufferSize = 4096
}
l := &AccessLogger{
task: parent.Subtask("accesslog"),
cfg: cfg,
io: io,
buffered: bufio.NewWriterSize(io, cfg.BufferSize),
}
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
switch l.cfg.Format {
@ -65,10 +66,8 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog
panic("invalid access log format")
}
l.flushThreshold = int(cfg.BufferSize * 4 / 5) // 80%
l.buf.Grow(int(cfg.BufferSize))
l.bufPool.New = func() any {
return new(bytes.Buffer)
l.lineBufPool.New = func() any {
return bytes.NewBuffer(make([]byte, 0, 1024))
}
go l.start()
return l
@ -89,15 +88,12 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
return
}
line := l.bufPool.Get().(*bytes.Buffer)
l.Format(line, req, res)
line.WriteRune('\n')
l.bufMu.Lock()
l.buf.Write(line.Bytes())
line := l.lineBufPool.Get().(*bytes.Buffer)
line.Reset()
l.bufPool.Put(line)
l.bufMu.Unlock()
defer l.lineBufPool.Put(line)
l.Formatter.Format(line, req, res)
line.WriteRune('\n')
l.write(line.Bytes())
}
func (l *AccessLogger) LogError(req *http.Request, err error) {
@ -115,55 +111,53 @@ func (l *AccessLogger) Rotate() error {
l.io.Lock()
defer l.io.Unlock()
return l.cfg.Retention.rotateLogFile(l.io)
}
func (l *AccessLogger) Flush(force bool) {
if l.buf.Len() == 0 {
return
}
if force || l.buf.Len() >= l.flushThreshold {
l.bufMu.RLock()
l.write(l.buf.Bytes())
l.buf.Reset()
l.bufMu.RUnlock()
}
return l.rotate()
}
func (l *AccessLogger) handleErr(err error) {
E.LogError("failed to write access log", err)
gperr.LogError("failed to write access log", err)
}
func (l *AccessLogger) start() {
defer func() {
if l.buf.Len() > 0 { // flush last
l.write(l.buf.Bytes())
if err := l.Flush(); err != nil {
l.handleErr(err)
}
l.io.Close()
l.close()
l.task.Finish(nil)
}()
// periodic flush + threshold flush
periodic := time.NewTicker(5 * time.Second)
threshold := time.NewTicker(time.Second)
defer periodic.Stop()
defer threshold.Stop()
// flushes the buffer every 30 seconds
flushTicker := time.NewTicker(30 * time.Second)
defer flushTicker.Stop()
for {
select {
case <-l.task.Context().Done():
return
case <-periodic.C:
l.Flush(true)
case <-threshold.C:
l.Flush(false)
case <-flushTicker.C:
if err := l.Flush(); err != nil {
l.handleErr(err)
}
}
}
}
func (l *AccessLogger) Flush() error {
l.io.Lock()
defer l.io.Unlock()
return l.buffered.Flush()
}
func (l *AccessLogger) close() {
l.io.Lock()
defer l.io.Unlock()
l.io.Close()
}
func (l *AccessLogger) write(data []byte) {
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
_, err := l.io.Write(data)
_, err := l.buffered.Write(data)
l.io.Unlock()
if err != nil {
l.handleErr(err)

View file

@ -9,7 +9,7 @@ import (
"testing"
"time"
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task"
. "github.com/yusing/go-proxy/internal/utils/testing"
)

View file

@ -17,7 +17,7 @@ type (
Cookies FieldConfig `json:"cookies"`
}
Config struct {
BufferSize uint `json:"buffer_size" validate:"gte=1"`
BufferSize int `json:"buffer_size"`
Format Format `json:"format" validate:"oneof=common combined json"`
Path string `json:"path" validate:"required"`
Filters Filters `json:"filters"`

View file

@ -4,7 +4,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)

View file

@ -3,7 +3,7 @@ package accesslog_test
import (
"testing"
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
)

View file

@ -71,14 +71,14 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
go func(l *AccessLogger) {
defer wg.Done()
parallelLog(l, req, resp, logCountPerLogger)
l.Flush(true)
l.Flush()
}(logger)
}
wg.Wait()
expected := loggerCount * logCountPerLogger
actual := file.Count()
actual := file.LineCount()
ExpectEqual(t, actual, expected)
}

View file

@ -5,7 +5,7 @@ import (
"net/http"
"strings"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -27,7 +27,7 @@ type (
CIDR struct{ types.CIDR }
)
var ErrInvalidHTTPHeaderFilter = E.New("invalid http header filter")
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool {
if len(f.Values) == 0 {

View file

@ -4,7 +4,7 @@ import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)

View file

@ -49,7 +49,6 @@ func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) {
return 0, io.EOF
}
n = copy(p, m.data[off:])
m.position += int64(n)
return n, nil
}
@ -63,7 +62,7 @@ func (m *MockFile) Truncate(size int64) error {
return nil
}
func (m *MockFile) Count() int {
func (m *MockFile) LineCount() int {
m.Lock()
defer m.Unlock()
return bytes.Count(m.data[:m.position], []byte("\n"))
@ -72,3 +71,7 @@ func (m *MockFile) Count() int {
func (m *MockFile) Len() int64 {
return m.position
}
func (m *MockFile) Content() []byte {
return m.data[:m.position]
}

View file

@ -0,0 +1,56 @@
package accesslog
import (
"strconv"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Retention struct {
Days uint64 `json:"days"`
Last uint64 `json:"last"`
}
var (
ErrInvalidSyntax = gperr.New("invalid syntax")
ErrZeroValue = gperr.New("zero value")
)
var defaultChunkSize = 64 * 1024 // 64KB
// Syntax:
//
// <N> days|weeks|months
//
// last <N>
//
// Parse implements strutils.Parser.
func (r *Retention) Parse(v string) (err error) {
split := strutils.SplitSpace(v)
if len(split) != 2 {
return ErrInvalidSyntax.Subject(v)
}
switch split[0] {
case "last":
r.Last, err = strconv.ParseUint(split[1], 10, 64)
default: // <N> days|weeks|months
r.Days, err = strconv.ParseUint(split[0], 10, 64)
if err != nil {
return
}
switch split[1] {
case "days":
case "weeks":
r.Days *= 7
case "months":
r.Days *= 30
default:
return ErrInvalidSyntax.Subject("unit " + split[1])
}
}
if r.Days == 0 && r.Last == 0 {
return ErrZeroValue
}
return
}

View file

@ -0,0 +1,33 @@
package accesslog_test
import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRetention(t *testing.T) {
tests := []struct {
input string
expected *Retention
shouldErr bool
}{
{"30 days", &Retention{Days: 30}, false},
{"2 weeks", &Retention{Days: 14}, false},
{"last 5", &Retention{Last: 5}, false},
{"invalid input", &Retention{}, true},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
r := &Retention{}
err := r.Parse(test.input)
if !test.shouldErr {
ExpectNoError(t, err)
} else {
ExpectDeepEqual(t, r, test.expected)
}
})
}
}

View file

@ -3,7 +3,7 @@ package accesslog
import (
"strconv"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -12,7 +12,7 @@ type StatusCodeRange struct {
End int
}
var ErrInvalidStatusCodeRange = E.New("invalid status code range")
var ErrInvalidStatusCodeRange = gperr.New("invalid status code range")
func (r *StatusCodeRange) Includes(code int) bool {
return r.Start <= code && code <= r.End
@ -25,7 +25,7 @@ func (r *StatusCodeRange) Parse(v string) error {
case 1:
start, err := strconv.Atoi(split[0])
if err != nil {
return E.From(err)
return gperr.Wrap(err)
}
r.Start = start
r.End = start
@ -33,7 +33,7 @@ func (r *StatusCodeRange) Parse(v string) error {
case 2:
start, errStart := strconv.Atoi(split[0])
end, errEnd := strconv.Atoi(split[1])
if err := E.Join(errStart, errEnd); err != nil {
if err := gperr.Join(errStart, errEnd); err != nil {
return err
}
r.Start = start

View file

@ -1,4 +1,4 @@
package http
package gphttp
import (
"mime"

View file

@ -1,4 +1,4 @@
package http
package gphttp
import (
"net/http"

View file

@ -6,8 +6,8 @@ import (
"net/http"
"sync"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
)
type ipHash struct {
@ -23,10 +23,10 @@ func (lb *LoadBalancer) newIPHash() impl {
if len(lb.Options) == 0 {
return impl
}
var err E.Error
var err gperr.Error
impl.realIP, err = middleware.RealIP.New(lb.Options)
if err != nil {
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
gperr.LogError("invalid real_ip options, ignoring", err, &impl.l)
}
return impl
}

View file

@ -6,10 +6,10 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
@ -54,7 +54,7 @@ func New(cfg *Config) *LoadBalancer {
}
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(parent task.Parent) E.Error {
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
lb.startTime = time.Now()
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
parent.OnCancel("lb_remove_route", func() {
@ -125,12 +125,12 @@ func (lb *LoadBalancer) AddServer(srv Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if lb.pool.Has(srv.Name()) {
old, _ := lb.pool.Load(srv.Name())
if lb.pool.Has(srv.Key()) { // FIXME: this should be a warning
old, _ := lb.pool.Load(srv.Key())
lb.sumWeight -= old.Weight()
lb.impl.OnRemoveServer(old)
}
lb.pool.Store(srv.Name(), srv)
lb.pool.Store(srv.Key(), srv)
lb.sumWeight += srv.Weight()
lb.rebalance()
@ -146,11 +146,11 @@ func (lb *LoadBalancer) RemoveServer(srv Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if !lb.pool.Has(srv.Name()) {
if !lb.pool.Has(srv.Key()) {
return
}
lb.pool.Delete(srv.Name())
lb.pool.Delete(srv.Key())
lb.sumWeight -= srv.Weight()
lb.rebalance()
@ -227,7 +227,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
if r.Header.Get(common.HeaderCheckRedirect) != "" {
if r.Header.Get(httpheaders.HeaderGoDoxyCheckRedirect) != "" {
// wake all servers
for _, srv := range srvs {
if err := srv.TryWake(); err != nil {
@ -244,7 +244,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v Server) {
extra[v.Name()] = v
extra[v.Key()] = v
})
return (&monitor.JSONRepresentation{

View file

@ -3,7 +3,7 @@ package loadbalancer
import (
"testing"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)

View file

@ -1,7 +1,7 @@
package loadbalancer
import (
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
)
type (

View file

@ -26,6 +26,7 @@ type (
http.Handler
health.HealthMonitor
Name() string
Key() string
URL() *net.URL
Weight() Weight
SetWeight(weight Weight)
@ -51,6 +52,7 @@ func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, h
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
srv := &server{
weight: Weight(weight),
url: net.MustParseURL("http://localhost"),
}
return srv
}
@ -63,6 +65,10 @@ func (srv *server) URL() *net.URL {
return srv.url
}
func (srv *server) Key() string {
return srv.url.Host
}
func (srv *server) Weight() Weight {
return srv.weight
}
@ -78,9 +84,7 @@ func (srv *server) String() string {
func (srv *server) TryWake() error {
waker, ok := srv.Handler.(idlewatcher.Waker)
if ok {
if err := waker.Wake(); err != nil {
return err
}
return waker.Wake()
}
return nil
}

View file

@ -1,4 +1,4 @@
package http
package gphttp
import "net/http"

View file

@ -5,7 +5,7 @@ import (
"net/http"
"github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/http"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"

View file

@ -7,7 +7,7 @@ import (
"strings"
"testing"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -61,7 +61,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
}
func TestCIDRWhitelist(t *testing.T) {
errs := E.NewBuilder("")
errs := gperr.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
ExpectNoError(t, errs.Error())
deny = mids["deny@file"]

View file

@ -12,6 +12,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -28,7 +29,7 @@ const (
)
var (
cfCIDRsLastUpdate time.Time
cfCIDRsLastUpdate atomic.Value[time.Time]
cfCIDRsMu sync.Mutex
// RFC 1918.
@ -68,14 +69,14 @@ func (cri *cloudflareRealIP) getTracer() *Tracer {
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return
}
cfCIDRsMu.Lock()
defer cfCIDRsMu.Unlock()
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return
}
@ -88,7 +89,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
)
if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
@ -97,7 +98,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
}
}
cfCIDRsLastUpdate = time.Now()
cfCIDRsLastUpdate.Store(time.Now())
logging.Info().Msg("cloudflare CIDR range updated")
return
}

View file

@ -9,14 +9,17 @@ import (
"strings"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
)
type customErrorPage struct{}
var CustomErrorPage = NewMiddleware[customErrorPage]()
const StaticFilePathPrefix = "/$gperrorpage/"
// before implements RequestModifier.
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
return !ServeStaticErrorPageFile(w, r)
@ -34,8 +37,8 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
} else {
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
@ -49,8 +52,8 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
if path != "" && path[0] != '/' {
path = "/" + path
}
if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) {
filename := path[len(gphttp.StaticFilePathPrefix):]
if strings.HasPrefix(path, StaticFilePathPrefix) {
filename := path[len(StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename)
if !ok {
logging.Error().Msg("unable to load resource " + filename)
@ -59,11 +62,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
case ".js":
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
case ".css":
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
default:
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
}

View file

@ -7,7 +7,7 @@ import (
"sync"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
@ -90,7 +90,7 @@ func watchDir() {
loadContent()
}
case err := <-errCh:
E.LogError("error watching error page directory", err)
gperr.LogError("error watching error page directory", err)
}
}
}

View file

@ -7,15 +7,15 @@ import (
"sort"
"strings"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Error = E.Error
Error = gperr.Error
ReverseProxy = reverseproxy.ReverseProxy
ProxyRequest = reverseproxy.ProxyRequest
@ -80,7 +80,7 @@ func NewMiddleware[ImplType any]() *Middleware {
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.enableTrace()
logging.Debug().Msgf("middleware %s enabled trace", m.name)
logging.Trace().Msgf("middleware %s enabled trace", m.name)
}
}
@ -103,7 +103,7 @@ func (m *Middleware) setup() {
}
}
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
if len(optsRaw) == 0 {
return nil
}
@ -132,10 +132,10 @@ func (m *Middleware) finalize() error {
return nil
}
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
if m.construct == nil { // likely a middleware from compose
if len(optsRaw) != 0 {
return nil, E.New("additional options not allowed for middleware ").Subject(m.name)
return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
}
return m, nil
}
@ -145,7 +145,7 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
return nil, err
}
if err := mid.finalize(); err != nil {
return nil, E.From(err)
return nil, gperr.Wrap(err)
}
return mid, nil
}
@ -196,7 +196,7 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
next(w, r)
}
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
var middlewares []*Middleware
middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil {

View file

@ -6,13 +6,13 @@ import (
"path"
"sort"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"gopkg.in/yaml.v3"
)
var ErrMissingMiddlewareUse = E.New("missing middleware 'use' field")
var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
fileContent, err := os.ReadFile(filePath)
if err != nil {
eb.Add(err)
@ -21,7 +21,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
}
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware {
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
@ -40,11 +40,11 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str
return middlewares
}
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := E.NewBuilder("middlewares compile error")
invalidOpts := E.NewBuilder("options compile error")
errs := gperr.NewBuilder("middlewares compile error")
invalidOpts := gperr.NewBuilder("options compile error")
for name, opts := range middlewaresMap {
m, err := Get(name)
@ -68,7 +68,7 @@ func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.
return middlewares, errs.Error()
}
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, E.Error) {
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
compiled, err := compileMiddlewares(middlewaresMap)
if err != nil {
return nil, err
@ -77,8 +77,8 @@ func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (
}
// TODO: check conflict or duplicates.
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, E.Error) {
chainErr := E.NewBuilder("")
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
chainErr := gperr.NewBuilder("")
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"] == "" {

View file

@ -5,7 +5,7 @@ import (
"encoding/json"
"testing"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -13,7 +13,7 @@ import (
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
errs := E.NewBuilder("")
errs := gperr.NewBuilder("")
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
ExpectNoError(t, errs.Error())
Must(json.MarshalIndent(middlewares, "", " "))

View file

@ -4,7 +4,7 @@ import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
)
type middlewareChain struct {
@ -51,10 +51,10 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error {
if len(m.modResps) == 0 {
return nil
}
errs := E.NewBuilder("modify response errors")
errs := gperr.NewBuilder("modify response errors")
for i, mr := range m.modResps {
if err := mr.modifyResponse(resp); err != nil {
errs.Add(E.From(err).Subjectf("%d", i))
errs.Add(gperr.Wrap(err).Subjectf("%d", i))
}
}
return errs.Error()

View file

@ -4,7 +4,7 @@ import (
"path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
@ -32,15 +32,11 @@ var allMiddlewares = map[string]*Middleware{
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
// !experimental
"forwardauth": ForwardAuth,
// "oauth2": OAuth2.m,
}
var (
ErrUnknownMiddleware = E.New("unknown middleware")
ErrDuplicatedMiddleware = E.New("duplicated middleware")
ErrUnknownMiddleware = gperr.New("unknown middleware")
ErrDuplicatedMiddleware = gperr.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
@ -58,14 +54,14 @@ func All() map[string]*Middleware {
}
func LoadComposeFiles() {
errs := E.NewBuilder("middleware compile errors")
errs := gperr.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
logging.Err(err).Msg("failed to list middleware definitions")
return
}
for _, defFile := range middlewareDefs {
voidErrs := E.NewBuilder("") // ignore these errors, will be added in next step
voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
if len(mws) == 0 {
continue
@ -103,6 +99,6 @@ func LoadComposeFiles() {
}
}
if errs.HasError() {
E.LogError(errs.About(), errs.Error())
gperr.LogError(errs.About(), errs.Error())
}
}

View file

@ -1,12 +1,13 @@
package middleware
import (
"errors"
"net/http"
"sync"
"sync/atomic"
"github.com/yusing/go-proxy/internal/api/v1/auth"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/gperr"
)
type oidcMiddleware struct {
@ -24,7 +25,7 @@ var OIDC = NewMiddleware[oidcMiddleware]()
func (amw *oidcMiddleware) finalize() error {
if !auth.IsOIDCEnabled() {
return E.New("OIDC not enabled but OIDC middleware is used")
return gperr.New("OIDC not enabled but OIDC middleware is used")
}
return nil
}
@ -64,9 +65,6 @@ func (amw *oidcMiddleware) initSlow() error {
amw.authMux = http.NewServeMux()
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
amw.auth = authProvider
return nil
@ -79,13 +77,17 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
return false
}
if err := amw.auth.CheckToken(r); err != nil {
amw.authMux.ServeHTTP(w, r)
return false
}
if r.URL.Path == auth.OIDCLogoutPath {
amw.auth.LogoutCallbackHandler(w, r)
return false
}
if err := amw.auth.CheckToken(r); err != nil {
if errors.Is(err, auth.ErrMissingToken) {
amw.authMux.ServeHTTP(w, r)
} else {
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
}
return false
}
return true
}

View file

@ -4,7 +4,7 @@ import (
"net"
"net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
)
@ -111,6 +111,6 @@ func (ri *realIP) setRealIP(req *http.Request) {
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP)
}

View file

@ -6,14 +6,14 @@ import (
"strings"
"testing"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{
"header": gphttp.HeaderXRealIP,
"header": httpheaders.HeaderXRealIP,
"from": []string{
"127.0.0.0/8",
"192.168.0.0/16",
@ -22,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
"recursive": true,
}
optExpected := &RealIPOpts{
Header: gphttp.HeaderXRealIP,
Header: httpheaders.HeaderXRealIP,
From: []*types.CIDR{
{
IP: net.ParseIP("127.0.0.0"),
@ -51,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) {
func TestSetRealIP(t *testing.T) {
const (
testHeader = gphttp.HeaderXRealIP
testHeader = httpheaders.HeaderXRealIP
testRealIP = "192.168.1.1"
)
opts := OptionsRaw{

View file

@ -44,7 +44,7 @@ func (m *redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed b
r.URL.Host = host
}
http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently)
http.Redirect(w, r, r.URL.String(), http.StatusPermanentRedirect)
logging.Debug().Str("url", r.URL.String()).Str("user_agent", r.UserAgent()).Msg("redirect to https")
return false

View file

@ -13,7 +13,7 @@ func TestRedirectToHTTPs(t *testing.T) {
reqURL: types.MustParseURL("http://example.com"),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusMovedPermanently)
ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com")
}

View file

@ -3,8 +3,8 @@ package middleware
import (
"net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
)
// internal use only.
@ -29,9 +29,9 @@ func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
// before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(gphttp.HeaderUpstreamName, s.Name)
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
return true
}

View file

@ -8,19 +8,6 @@ theGreatPretender:
- X-Test3
- X-Test4
notAuthenticAuthentik:
- use: RedirectHTTP
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
addAuthCookiesToResponse:
- session_id
- user_id
authResponseHeaders:
- X-Auth-SessionID
- X-Auth-UserID
- use: CustomErrorPage
realIPAuthentik:
- use: RedirectHTTP
- use: RealIP
@ -30,9 +17,6 @@ realIPAuthentik:
- "192.168.0.0/16"
- "172.16.0.0/12"
recursive: true
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
testFakeRealIP:
- use: ModifyRequest

View file

@ -9,8 +9,8 @@ import (
"net/http/httptest"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -122,7 +122,7 @@ func (args *testArgs) bodyReader() io.Reader {
return nil
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
if args == nil {
args = new(testArgs)
}
@ -136,7 +136,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
return newMiddlewaresTest([]*Middleware{mid}, args)
}
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) {
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
if args == nil {
args = new(testArgs)
}
@ -163,7 +163,7 @@ func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult,
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, E.From(err)
return nil, gperr.Wrap(err)
}
return &TestResult{

View file

@ -4,7 +4,7 @@ import (
"net/http"
"sync"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
@ -37,7 +37,7 @@ func (tr *Trace) WithRequest(req *http.Request) *Trace {
return nil
}
tr.URL = req.RequestURI
tr.ReqHeaders = gphttp.HeaderToMap(req.Header)
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
return tr
}
@ -46,8 +46,8 @@ func (tr *Trace) WithResponse(resp *http.Response) *Trace {
return nil
}
tr.URL = resp.Request.RequestURI
tr.ReqHeaders = gphttp.HeaderToMap(resp.Request.Header)
tr.RespHeaders = gphttp.HeaderToMap(resp.Header)
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
tr.RespStatus = resp.StatusCode
return tr
}

View file

@ -7,7 +7,7 @@ import (
"strconv"
"strings"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
@ -91,25 +91,25 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
return ""
},
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
if upPort != "" {
return upHost + ":" + upPort
}
return upHost
},
VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
if upScheme == "" {
return ""
}
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
upAddr := upHost
if upPort != "" {
upAddr += ":" + upPort

View file

@ -5,7 +5,7 @@ import (
"net/http"
"strings"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
@ -20,10 +20,10 @@ var (
// before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(gphttp.HeaderXForwardedFor)
r.Header.Del(httpheaders.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
}
return true
}

View file

@ -1,7 +1,7 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
// Copyright (c) 2020-2024 Traefik Labs
package http
package gphttp
import (
"bufio"

View file

@ -1,4 +1,4 @@
package http
package gphttp
import "net/http"

View file

@ -1,4 +1,4 @@
package http
package gphttp
import "net/http"

View file

@ -0,0 +1,34 @@
package gphttp
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var DefaultDialer = net.Dialer{
Timeout: 5 * time.Second,
}
func NewTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: DefaultDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// DisableCompression: true, // Prevent double compression
ResponseHeaderTimeout: 60 * time.Second,
WriteBufferSize: 16 * 1024, // 16KB
ReadBufferSize: 16 * 1024, // 16KB
}
}
func NewTransportWithTLSConfig(tlsConfig *tls.Config) *http.Transport {
tr := NewTransport()
tr.TLSClientConfig = tlsConfig
return tr
}

View file

@ -1,198 +0,0 @@
package accesslog
import (
"bufio"
"bytes"
"io"
"strconv"
"time"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Retention struct {
Days uint64 `json:"days"`
Last uint64 `json:"last"`
}
const chunkSizeMax int64 = 128 * 1024 // 128KB
var (
ErrInvalidSyntax = E.New("invalid syntax")
ErrZeroValue = E.New("zero value")
)
// Syntax:
//
// <N> days|weeks|months
//
// last <N>
//
// Parse implements strutils.Parser.
func (r *Retention) Parse(v string) (err error) {
split := strutils.SplitSpace(v)
if len(split) != 2 {
return ErrInvalidSyntax.Subject(v)
}
switch split[0] {
case "last":
r.Last, err = strconv.ParseUint(split[1], 10, 64)
default: // <N> days|weeks|months
r.Days, err = strconv.ParseUint(split[0], 10, 64)
if err != nil {
return
}
switch split[1] {
case "days":
case "weeks":
r.Days *= 7
case "months":
r.Days *= 30
default:
return ErrInvalidSyntax.Subject("unit " + split[1])
}
}
if r.Days == 0 && r.Last == 0 {
return ErrZeroValue
}
return
}
func (r *Retention) rotateLogFile(file AccessLogIO) (err error) {
lastN := int(r.Last)
days := int(r.Days)
// Seek to end to get file size
size, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
// Initialize ring buffer for last N lines
lines := make([][]byte, 0, lastN|(days*1000))
pos := size
unprocessed := 0
var chunk [chunkSizeMax]byte
var lastLine []byte
var shouldStop func() bool
if days > 0 {
cutoff := time.Now().AddDate(0, 0, -days)
shouldStop = func() bool {
return len(lastLine) > 0 && !parseLogTime(lastLine).After(cutoff)
}
} else {
shouldStop = func() bool {
return len(lines) == lastN
}
}
// Read backwards until we have enough lines or reach start of file
for pos > 0 {
if pos > chunkSizeMax {
pos -= chunkSizeMax
} else {
pos = 0
}
// Seek to the current chunk
if _, err = file.Seek(pos, io.SeekStart); err != nil {
return err
}
var nRead int
// Read the chunk
if nRead, err = file.Read(chunk[unprocessed:]); err != nil {
return err
}
// last unprocessed bytes + read bytes
curChunk := chunk[:unprocessed+nRead]
unprocessed = len(curChunk)
// Split into lines
scanner := bufio.NewScanner(bytes.NewReader(curChunk))
for !shouldStop() && scanner.Scan() {
lastLine = scanner.Bytes()
lines = append(lines, lastLine)
unprocessed -= len(lastLine)
}
if shouldStop() {
break
}
// move unprocessed bytes to the beginning for next iteration
copy(chunk[:], curChunk[unprocessed:])
}
if days > 0 {
// truncate to the end of the log within last N days
return file.Truncate(pos)
}
// write lines to buffer in reverse order
// since we read them backwards
var buf bytes.Buffer
for i := len(lines) - 1; i >= 0; i-- {
buf.Write(lines[i])
buf.WriteRune('\n')
}
return writeTruncate(file, &buf)
}
func writeTruncate(file AccessLogIO, buf *bytes.Buffer) (err error) {
// Seek to beginning and truncate
if _, err := file.Seek(0, 0); err != nil {
return err
}
buffered := bufio.NewWriter(file)
// Write buffer back to file
nWritten, err := buffered.Write(buf.Bytes())
if err != nil {
return err
}
if err = buffered.Flush(); err != nil {
return err
}
// Truncate file
if err = file.Truncate(int64(nWritten)); err != nil {
return err
}
// check bytes written == buffer size
if nWritten != buf.Len() {
return io.ErrShortWrite
}
return
}
func parseLogTime(line []byte) (t time.Time) {
if len(line) == 0 {
return
}
var start, end int
const jsonStart = len(`{"time":"`)
const jsonEnd = jsonStart + len(LogTimeFormat)
if len(line) == '{' { // possibly json log
start = jsonStart
end = jsonEnd
} else { // possibly common or combined format
// Format: <virtual host> <host ip> - - [02/Jan/2006:15:04:05 -0700] ...
start = bytes.IndexRune(line, '[')
end = bytes.IndexRune(line[start+1:], ']')
if start == -1 || end == -1 || start >= end {
return
}
}
timeStr := line[start+1 : end]
t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error
return
}

View file

@ -1,81 +0,0 @@
package accesslog_test
import (
"testing"
"time"
. "github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRetention(t *testing.T) {
tests := []struct {
input string
expected *Retention
shouldErr bool
}{
{"30 days", &Retention{Days: 30}, false},
{"2 weeks", &Retention{Days: 14}, false},
{"last 5", &Retention{Last: 5}, false},
{"invalid input", &Retention{}, true},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
r := &Retention{}
err := r.Parse(test.input)
if !test.shouldErr {
ExpectNoError(t, err)
} else {
ExpectDeepEqual(t, r, test.expected)
}
})
}
}
func TestRetentionCommonFormat(t *testing.T) {
var file MockFile
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{
Format: FormatCommon,
BufferSize: 1024,
})
for range 10 {
logger.Log(req, resp)
}
logger.Flush(true)
// test.Finish(nil)
ExpectEqual(t, logger.Config().Retention, nil)
ExpectTrue(t, file.Len() > 0)
ExpectEqual(t, file.Count(), 10)
t.Run("keep last", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("last 5")
ExpectEqual(t, logger.Config().Retention.Days, 0)
ExpectEqual(t, logger.Config().Retention.Last, 5)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.Count(), 5)
})
_ = file.Truncate(0)
timeNow := time.Now()
for i := range 10 {
logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time {
return timeNow.AddDate(0, 0, -i)
}
logger.Log(req, resp)
}
logger.Flush(true)
// FIXME: keep days does not work
t.Run("keep days", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("3 days")
ExpectEqual(t, logger.Config().Retention.Days, 3)
ExpectEqual(t, logger.Config().Retention.Last, 0)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.Count(), 3)
})
}

View file

@ -1,34 +0,0 @@
package http
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var (
defaultDialer = net.Dialer{
Timeout: 60 * time.Second,
}
DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: defaultDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: true, // Prevent double compression
ResponseHeaderTimeout: 60 * time.Second,
WriteBufferSize: 16 * 1024, // 16KB
ReadBufferSize: 16 * 1024, // 16KB
}
DefaultTransportNoTLS = func() *http.Transport {
clone := DefaultTransport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone
}()
)
const StaticFilePathPrefix = "/$gperrorpage/"

View file

@ -1,221 +0,0 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
// Copyright (c) 2020-2024 Traefik Labs
// Copyright (c) 2024 yusing
package middleware
import (
"io"
"net"
"net/http"
"slices"
"strings"
"time"
gphttp "github.com/yusing/go-proxy/internal/net/http"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
forwardAuth struct {
ForwardAuthOpts
Tracer
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
}
ForwardAuthOpts struct {
Address string `validate:"url,required"`
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
}
)
var ForwardAuth = NewMiddleware[forwardAuth]()
var faHTTPClient = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
// setup implements MiddlewareWithSetup.
func (fa *forwardAuth) setup() {
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
}
// before implements RequestModifier.
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect
scheme := "http"
if req.TLS != nil {
scheme = "https"
}
originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address
faReq, err := http.NewRequestWithContext(
req.Context(),
http.MethodGet,
url,
nil,
)
if err != nil {
fa.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
gphttp.CopyHeader(faReq.Header, req.Header)
gphttp.RemoveHop(faReq.Header)
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik
faReq.Header.Set("X-Original-Url", originalURL)
fa.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq)
if err != nil {
fa.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer faResp.Body.Close()
body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.AddTraceResponse("forward auth response", faResp)
gphttp.CopyHeader(w.Header(), faResp.Header)
gphttp.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.AddTracef("%s", "redirect to "+redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
}
return
}
for _, key := range fa.AuthResponseHeaders {
key := http.CanonicalHeaderKey(key)
req.Header.Del(key)
if len(faResp.Header[key]) > 0 {
req.Header[key] = append([]string(nil), faResp.Header[key]...)
}
}
req.RequestURI = req.URL.RequestURI()
authCookies := faResp.Cookies()
if len(authCookies) > 0 {
fa.reqCookiesMap.Store(req, authCookies)
}
return true
}
// modifyResponse implements ResponseModifier.
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
fa.setAuthCookies(resp, cookies)
fa.reqCookiesMap.Delete(resp.Request)
}
return nil
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
cookies := resp.Cookies()
resp.Header.Del("Set-Cookie")
for _, cookie := range cookies {
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is not an auth cookie, so add it back
resp.Header.Add("Set-Cookie", cookie.String())
}
}
for _, cookie := range authCookies {
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is an auth cookie, so add to resp
resp.Header.Add("Set-Cookie", cookie.String())
}
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
}
xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod)
switch {
case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod)
case req.Method != "":
faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
default:
faReq.Header.Del(gphttp.HeaderXForwardedMethod)
}
xfp := req.Header.Get(gphttp.HeaderXForwardedProto)
switch {
case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp)
case req.TLS != nil:
faReq.Header.Set(gphttp.HeaderXForwardedProto, "https")
default:
faReq.Header.Set(gphttp.HeaderXForwardedProto, "http")
}
if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp)
}
xfh := req.Header.Get(gphttp.HeaderXForwardedHost)
switch {
case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh)
case req.Host != "":
faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
default:
faReq.Header.Del(gphttp.HeaderXForwardedHost)
}
xfURI := req.Header.Get(gphttp.HeaderXForwardedURI)
switch {
case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI)
case req.URL.RequestURI() != "":
faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI())
default:
faReq.Header.Del(gphttp.HeaderXForwardedURI)
}
}