From 5d2df3550b697e1ea0c43a26696b83933f81cd8b Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 28 Mar 2025 07:03:35 +0800 Subject: [PATCH] refactor: remove forward auth, move module net/http to net/gphttp --- .../accesslog/access_logger.go | 104 ++++----- .../accesslog/access_logger_test.go | 2 +- .../net/{http => gphttp}/accesslog/config.go | 2 +- .../{http => gphttp}/accesslog/config_test.go | 2 +- .../net/{http => gphttp}/accesslog/fields.go | 0 .../{http => gphttp}/accesslog/fields_test.go | 2 +- .../{http => gphttp}/accesslog/file_logger.go | 0 .../accesslog/file_logger_test.go | 4 +- .../net/{http => gphttp}/accesslog/filter.go | 4 +- .../{http => gphttp}/accesslog/filter_test.go | 2 +- .../{http => gphttp}/accesslog/formatter.go | 0 .../{http => gphttp}/accesslog/mock_file.go | 7 +- internal/net/gphttp/accesslog/retention.go | 56 +++++ .../net/gphttp/accesslog/retention_test.go | 33 +++ .../accesslog/status_code_range.go | 8 +- internal/net/{http => gphttp}/content_type.go | 2 +- .../net/{http => gphttp}/content_type_test.go | 2 +- .../{http => gphttp}/loadbalancer/ip_hash.go | 8 +- .../loadbalancer/least_conn.go | 0 .../loadbalancer/loadbalancer.go | 22 +- .../loadbalancer/loadbalancer_test.go | 2 +- .../loadbalancer/round_robin.go | 0 .../{http => gphttp}/loadbalancer/types.go | 2 +- .../loadbalancer/types/config.go | 0 .../loadbalancer/types/mode.go | 0 .../loadbalancer/types/server.go | 10 +- .../loadbalancer/types/weight.go | 0 internal/net/{http => gphttp}/methods.go | 2 +- .../middleware/cidr_whitelist.go | 2 +- .../middleware/cidr_whitelist_test.go | 4 +- .../middleware/cloudflare_real_ip.go | 11 +- .../middleware/custom_error_page.go | 21 +- .../middleware/errorpage/error_page.go | 4 +- .../metrics_logger/metrics_logger.go | 0 .../metrics_logger/metrics_response_writer.go | 0 .../{http => gphttp}/middleware/middleware.go | 20 +- .../middleware/middleware_builder.go | 20 +- .../middleware/middleware_builder_test.go | 4 +- .../middleware/middleware_chain.go | 6 +- .../middleware/middleware_test.go | 0 .../middleware/middlewares.go | 16 +- .../middleware/modify_request.go | 0 .../middleware/modify_request_test.go | 0 .../middleware/modify_response.go | 0 .../middleware/modify_response_test.go | 0 .../net/{http => gphttp}/middleware/oidc.go | 20 +- .../{http => gphttp}/middleware/rate_limit.go | 0 .../middleware/rate_limit_test.go | 0 .../{http => gphttp}/middleware/real_ip.go | 4 +- .../middleware/real_ip_test.go | 8 +- .../middleware/redirect_http.go | 2 +- .../middleware/redirect_http_test.go | 2 +- .../middleware/set_upstream_headers.go | 12 +- .../test_data/cidr_whitelist_test.yml | 0 .../test_data/middleware_compose.yml | 16 -- .../middleware/test_data/sample_headers.json | 0 .../{http => gphttp}/middleware/test_utils.go | 10 +- .../net/{http => gphttp}/middleware/trace.go | 8 +- .../net/{http => gphttp}/middleware/tracer.go | 0 .../net/{http => gphttp}/middleware/vars.go | 20 +- .../middleware/x_forwarded.go | 6 +- .../modify_response_writer.go | 2 +- internal/net/{http => gphttp}/serve_mux.go | 2 +- internal/net/{http => gphttp}/status_code.go | 2 +- internal/net/gphttp/transport.go | 34 +++ internal/net/http/accesslog/retention.go | 198 ---------------- internal/net/http/accesslog/retention_test.go | 81 ------- internal/net/http/common.go | 34 --- internal/net/http/middleware/forward_auth.go | 221 ------------------ 69 files changed, 321 insertions(+), 745 deletions(-) rename internal/net/{http => gphttp}/accesslog/access_logger.go (64%) rename internal/net/{http => gphttp}/accesslog/access_logger_test.go (98%) rename internal/net/{http => gphttp}/accesslog/config.go (95%) rename internal/net/{http => gphttp}/accesslog/config_test.go (97%) rename internal/net/{http => gphttp}/accesslog/fields.go (100%) rename internal/net/{http => gphttp}/accesslog/fields_test.go (97%) rename internal/net/{http => gphttp}/accesslog/file_logger.go (100%) rename internal/net/{http => gphttp}/accesslog/file_logger_test.go (97%) rename internal/net/{http => gphttp}/accesslog/filter.go (94%) rename internal/net/{http => gphttp}/accesslog/filter_test.go (98%) rename internal/net/{http => gphttp}/accesslog/formatter.go (100%) rename internal/net/{http => gphttp}/accesslog/mock_file.go (92%) create mode 100644 internal/net/gphttp/accesslog/retention.go create mode 100644 internal/net/gphttp/accesslog/retention_test.go rename internal/net/{http => gphttp}/accesslog/status_code_range.go (81%) rename internal/net/{http => gphttp}/content_type.go (98%) rename internal/net/{http => gphttp}/content_type_test.go (99%) rename internal/net/{http => gphttp}/loadbalancer/ip_hash.go (89%) rename internal/net/{http => gphttp}/loadbalancer/least_conn.go (100%) rename internal/net/{http => gphttp}/loadbalancer/loadbalancer.go (92%) rename internal/net/{http => gphttp}/loadbalancer/loadbalancer_test.go (95%) rename internal/net/{http => gphttp}/loadbalancer/round_robin.go (100%) rename internal/net/{http => gphttp}/loadbalancer/types.go (72%) rename internal/net/{http => gphttp}/loadbalancer/types/config.go (100%) rename internal/net/{http => gphttp}/loadbalancer/types/mode.go (100%) rename internal/net/{http => gphttp}/loadbalancer/types/server.go (91%) rename internal/net/{http => gphttp}/loadbalancer/types/weight.go (100%) rename internal/net/{http => gphttp}/methods.go (95%) rename internal/net/{http => gphttp}/middleware/cidr_whitelist.go (97%) rename internal/net/{http => gphttp}/middleware/cidr_whitelist_test.go (96%) rename internal/net/{http => gphttp}/middleware/cloudflare_real_ip.go (88%) rename internal/net/{http => gphttp}/middleware/custom_error_page.go (72%) rename internal/net/{http => gphttp}/middleware/errorpage/error_page.go (95%) rename internal/net/{http => gphttp}/middleware/metrics_logger/metrics_logger.go (100%) rename internal/net/{http => gphttp}/middleware/metrics_logger/metrics_response_writer.go (100%) rename internal/net/{http => gphttp}/middleware/middleware.go (90%) rename internal/net/{http => gphttp}/middleware/middleware_builder.go (79%) rename internal/net/{http => gphttp}/middleware/middleware_builder_test.go (85%) rename internal/net/{http => gphttp}/middleware/middleware_chain.go (90%) rename internal/net/{http => gphttp}/middleware/middleware_test.go (100%) rename internal/net/{http => gphttp}/middleware/middlewares.go (85%) rename internal/net/{http => gphttp}/middleware/modify_request.go (100%) rename internal/net/{http => gphttp}/middleware/modify_request_test.go (100%) rename internal/net/{http => gphttp}/middleware/modify_response.go (100%) rename internal/net/{http => gphttp}/middleware/modify_response_test.go (100%) rename internal/net/{http => gphttp}/middleware/oidc.go (85%) rename internal/net/{http => gphttp}/middleware/rate_limit.go (100%) rename internal/net/{http => gphttp}/middleware/rate_limit_test.go (100%) rename internal/net/{http => gphttp}/middleware/real_ip.go (95%) rename internal/net/{http => gphttp}/middleware/real_ip_test.go (90%) rename internal/net/{http => gphttp}/middleware/redirect_http.go (93%) rename internal/net/{http => gphttp}/middleware/redirect_http_test.go (90%) rename internal/net/{http => gphttp}/middleware/set_upstream_headers.go (63%) rename internal/net/{http => gphttp}/middleware/test_data/cidr_whitelist_test.yml (100%) rename internal/net/{http => gphttp}/middleware/test_data/middleware_compose.yml (55%) rename internal/net/{http => gphttp}/middleware/test_data/sample_headers.json (100%) rename internal/net/{http => gphttp}/middleware/test_utils.go (94%) rename internal/net/{http => gphttp}/middleware/trace.go (87%) rename internal/net/{http => gphttp}/middleware/tracer.go (100%) rename internal/net/{http => gphttp}/middleware/vars.go (90%) rename internal/net/{http => gphttp}/middleware/x_forwarded.go (82%) rename internal/net/{http => gphttp}/modify_response_writer.go (99%) rename internal/net/{http => gphttp}/serve_mux.go (97%) rename internal/net/{http => gphttp}/status_code.go (93%) create mode 100644 internal/net/gphttp/transport.go delete mode 100644 internal/net/http/accesslog/retention.go delete mode 100644 internal/net/http/accesslog/retention_test.go delete mode 100644 internal/net/http/common.go delete mode 100644 internal/net/http/middleware/forward_auth.go diff --git a/internal/net/http/accesslog/access_logger.go b/internal/net/gphttp/accesslog/access_logger.go similarity index 64% rename from internal/net/http/accesslog/access_logger.go rename to internal/net/gphttp/accesslog/access_logger.go index 16c10d4..9a6fc8a 100644 --- a/internal/net/http/accesslog/access_logger.go +++ b/internal/net/gphttp/accesslog/access_logger.go @@ -1,29 +1,26 @@ package accesslog import ( + "bufio" "bytes" "io" "net/http" "sync" "time" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" ) type ( AccessLogger struct { - task *task.Task - cfg *Config - io AccessLogIO - - buf bytes.Buffer // buffer for non-flushed log - bufMu sync.RWMutex - bufPool sync.Pool // buffer pool for formatting a single log line - - flushThreshold int + task *task.Task + cfg *Config + io AccessLogIO + buffered *bufio.Writer + lineBufPool sync.Pool // buffer pool for formatting a single log line Formatter } @@ -44,14 +41,18 @@ type ( ) func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger { - l := &AccessLogger{ - task: parent.Subtask("accesslog"), - cfg: cfg, - io: io, - } - if cfg.BufferSize < 1024 { + if cfg.BufferSize == 0 { cfg.BufferSize = DefaultBufferSize } + if cfg.BufferSize < 4096 { + cfg.BufferSize = 4096 + } + l := &AccessLogger{ + task: parent.Subtask("accesslog"), + cfg: cfg, + io: io, + buffered: bufio.NewWriterSize(io, cfg.BufferSize), + } fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now} switch l.cfg.Format { @@ -65,10 +66,8 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog panic("invalid access log format") } - l.flushThreshold = int(cfg.BufferSize * 4 / 5) // 80% - l.buf.Grow(int(cfg.BufferSize)) - l.bufPool.New = func() any { - return new(bytes.Buffer) + l.lineBufPool.New = func() any { + return bytes.NewBuffer(make([]byte, 0, 1024)) } go l.start() return l @@ -89,15 +88,12 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) { return } - line := l.bufPool.Get().(*bytes.Buffer) - l.Format(line, req, res) - line.WriteRune('\n') - - l.bufMu.Lock() - l.buf.Write(line.Bytes()) + line := l.lineBufPool.Get().(*bytes.Buffer) line.Reset() - l.bufPool.Put(line) - l.bufMu.Unlock() + defer l.lineBufPool.Put(line) + l.Formatter.Format(line, req, res) + line.WriteRune('\n') + l.write(line.Bytes()) } func (l *AccessLogger) LogError(req *http.Request, err error) { @@ -115,55 +111,53 @@ func (l *AccessLogger) Rotate() error { l.io.Lock() defer l.io.Unlock() - return l.cfg.Retention.rotateLogFile(l.io) -} - -func (l *AccessLogger) Flush(force bool) { - if l.buf.Len() == 0 { - return - } - if force || l.buf.Len() >= l.flushThreshold { - l.bufMu.RLock() - l.write(l.buf.Bytes()) - l.buf.Reset() - l.bufMu.RUnlock() - } + return l.rotate() } func (l *AccessLogger) handleErr(err error) { - E.LogError("failed to write access log", err) + gperr.LogError("failed to write access log", err) } func (l *AccessLogger) start() { defer func() { - if l.buf.Len() > 0 { // flush last - l.write(l.buf.Bytes()) + if err := l.Flush(); err != nil { + l.handleErr(err) } - l.io.Close() + l.close() l.task.Finish(nil) }() - // periodic flush + threshold flush - periodic := time.NewTicker(5 * time.Second) - threshold := time.NewTicker(time.Second) - defer periodic.Stop() - defer threshold.Stop() + // flushes the buffer every 30 seconds + flushTicker := time.NewTicker(30 * time.Second) + defer flushTicker.Stop() for { select { case <-l.task.Context().Done(): return - case <-periodic.C: - l.Flush(true) - case <-threshold.C: - l.Flush(false) + case <-flushTicker.C: + if err := l.Flush(); err != nil { + l.handleErr(err) + } } } } +func (l *AccessLogger) Flush() error { + l.io.Lock() + defer l.io.Unlock() + return l.buffered.Flush() +} + +func (l *AccessLogger) close() { + l.io.Lock() + defer l.io.Unlock() + l.io.Close() +} + func (l *AccessLogger) write(data []byte) { l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers - _, err := l.io.Write(data) + _, err := l.buffered.Write(data) l.io.Unlock() if err != nil { l.handleErr(err) diff --git a/internal/net/http/accesslog/access_logger_test.go b/internal/net/gphttp/accesslog/access_logger_test.go similarity index 98% rename from internal/net/http/accesslog/access_logger_test.go rename to internal/net/gphttp/accesslog/access_logger_test.go index 2398ca7..012d8eb 100644 --- a/internal/net/http/accesslog/access_logger_test.go +++ b/internal/net/gphttp/accesslog/access_logger_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - . "github.com/yusing/go-proxy/internal/net/http/accesslog" + . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/task" . "github.com/yusing/go-proxy/internal/utils/testing" ) diff --git a/internal/net/http/accesslog/config.go b/internal/net/gphttp/accesslog/config.go similarity index 95% rename from internal/net/http/accesslog/config.go rename to internal/net/gphttp/accesslog/config.go index 820c302..a1dbe2f 100644 --- a/internal/net/http/accesslog/config.go +++ b/internal/net/gphttp/accesslog/config.go @@ -17,7 +17,7 @@ type ( Cookies FieldConfig `json:"cookies"` } Config struct { - BufferSize uint `json:"buffer_size" validate:"gte=1"` + BufferSize int `json:"buffer_size"` Format Format `json:"format" validate:"oneof=common combined json"` Path string `json:"path" validate:"required"` Filters Filters `json:"filters"` diff --git a/internal/net/http/accesslog/config_test.go b/internal/net/gphttp/accesslog/config_test.go similarity index 97% rename from internal/net/http/accesslog/config_test.go rename to internal/net/gphttp/accesslog/config_test.go index 125b501..e8de01d 100644 --- a/internal/net/http/accesslog/config_test.go +++ b/internal/net/gphttp/accesslog/config_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/docker" - . "github.com/yusing/go-proxy/internal/net/http/accesslog" + . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/utils" . "github.com/yusing/go-proxy/internal/utils/testing" ) diff --git a/internal/net/http/accesslog/fields.go b/internal/net/gphttp/accesslog/fields.go similarity index 100% rename from internal/net/http/accesslog/fields.go rename to internal/net/gphttp/accesslog/fields.go diff --git a/internal/net/http/accesslog/fields_test.go b/internal/net/gphttp/accesslog/fields_test.go similarity index 97% rename from internal/net/http/accesslog/fields_test.go rename to internal/net/gphttp/accesslog/fields_test.go index feac44d..1cfa370 100644 --- a/internal/net/http/accesslog/fields_test.go +++ b/internal/net/gphttp/accesslog/fields_test.go @@ -3,7 +3,7 @@ package accesslog_test import ( "testing" - . "github.com/yusing/go-proxy/internal/net/http/accesslog" + . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" . "github.com/yusing/go-proxy/internal/utils/testing" ) diff --git a/internal/net/http/accesslog/file_logger.go b/internal/net/gphttp/accesslog/file_logger.go similarity index 100% rename from internal/net/http/accesslog/file_logger.go rename to internal/net/gphttp/accesslog/file_logger.go diff --git a/internal/net/http/accesslog/file_logger_test.go b/internal/net/gphttp/accesslog/file_logger_test.go similarity index 97% rename from internal/net/http/accesslog/file_logger_test.go rename to internal/net/gphttp/accesslog/file_logger_test.go index ffa7aab..0321a85 100644 --- a/internal/net/http/accesslog/file_logger_test.go +++ b/internal/net/gphttp/accesslog/file_logger_test.go @@ -71,14 +71,14 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { go func(l *AccessLogger) { defer wg.Done() parallelLog(l, req, resp, logCountPerLogger) - l.Flush(true) + l.Flush() }(logger) } wg.Wait() expected := loggerCount * logCountPerLogger - actual := file.Count() + actual := file.LineCount() ExpectEqual(t, actual, expected) } diff --git a/internal/net/http/accesslog/filter.go b/internal/net/gphttp/accesslog/filter.go similarity index 94% rename from internal/net/http/accesslog/filter.go rename to internal/net/gphttp/accesslog/filter.go index 822101e..c0c3e29 100644 --- a/internal/net/http/accesslog/filter.go +++ b/internal/net/gphttp/accesslog/filter.go @@ -5,7 +5,7 @@ import ( "net/http" "strings" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -27,7 +27,7 @@ type ( CIDR struct{ types.CIDR } ) -var ErrInvalidHTTPHeaderFilter = E.New("invalid http header filter") +var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter") func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool { if len(f.Values) == 0 { diff --git a/internal/net/http/accesslog/filter_test.go b/internal/net/gphttp/accesslog/filter_test.go similarity index 98% rename from internal/net/http/accesslog/filter_test.go rename to internal/net/gphttp/accesslog/filter_test.go index 7160dce..a934a7b 100644 --- a/internal/net/http/accesslog/filter_test.go +++ b/internal/net/gphttp/accesslog/filter_test.go @@ -4,7 +4,7 @@ import ( "net/http" "testing" - . "github.com/yusing/go-proxy/internal/net/http/accesslog" + . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/utils/strutils" . "github.com/yusing/go-proxy/internal/utils/testing" ) diff --git a/internal/net/http/accesslog/formatter.go b/internal/net/gphttp/accesslog/formatter.go similarity index 100% rename from internal/net/http/accesslog/formatter.go rename to internal/net/gphttp/accesslog/formatter.go diff --git a/internal/net/http/accesslog/mock_file.go b/internal/net/gphttp/accesslog/mock_file.go similarity index 92% rename from internal/net/http/accesslog/mock_file.go rename to internal/net/gphttp/accesslog/mock_file.go index f960429..54c30c6 100644 --- a/internal/net/http/accesslog/mock_file.go +++ b/internal/net/gphttp/accesslog/mock_file.go @@ -49,7 +49,6 @@ func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) { return 0, io.EOF } n = copy(p, m.data[off:]) - m.position += int64(n) return n, nil } @@ -63,7 +62,7 @@ func (m *MockFile) Truncate(size int64) error { return nil } -func (m *MockFile) Count() int { +func (m *MockFile) LineCount() int { m.Lock() defer m.Unlock() return bytes.Count(m.data[:m.position], []byte("\n")) @@ -72,3 +71,7 @@ func (m *MockFile) Count() int { func (m *MockFile) Len() int64 { return m.position } + +func (m *MockFile) Content() []byte { + return m.data[:m.position] +} diff --git a/internal/net/gphttp/accesslog/retention.go b/internal/net/gphttp/accesslog/retention.go new file mode 100644 index 0000000..f0b5e2a --- /dev/null +++ b/internal/net/gphttp/accesslog/retention.go @@ -0,0 +1,56 @@ +package accesslog + +import ( + "strconv" + + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type Retention struct { + Days uint64 `json:"days"` + Last uint64 `json:"last"` +} + +var ( + ErrInvalidSyntax = gperr.New("invalid syntax") + ErrZeroValue = gperr.New("zero value") +) + +var defaultChunkSize = 64 * 1024 // 64KB + +// Syntax: +// +// days|weeks|months +// +// last +// +// Parse implements strutils.Parser. +func (r *Retention) Parse(v string) (err error) { + split := strutils.SplitSpace(v) + if len(split) != 2 { + return ErrInvalidSyntax.Subject(v) + } + switch split[0] { + case "last": + r.Last, err = strconv.ParseUint(split[1], 10, 64) + default: // days|weeks|months + r.Days, err = strconv.ParseUint(split[0], 10, 64) + if err != nil { + return + } + switch split[1] { + case "days": + case "weeks": + r.Days *= 7 + case "months": + r.Days *= 30 + default: + return ErrInvalidSyntax.Subject("unit " + split[1]) + } + } + if r.Days == 0 && r.Last == 0 { + return ErrZeroValue + } + return +} diff --git a/internal/net/gphttp/accesslog/retention_test.go b/internal/net/gphttp/accesslog/retention_test.go new file mode 100644 index 0000000..1efd18d --- /dev/null +++ b/internal/net/gphttp/accesslog/retention_test.go @@ -0,0 +1,33 @@ +package accesslog_test + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestParseRetention(t *testing.T) { + tests := []struct { + input string + expected *Retention + shouldErr bool + }{ + {"30 days", &Retention{Days: 30}, false}, + {"2 weeks", &Retention{Days: 14}, false}, + {"last 5", &Retention{Last: 5}, false}, + {"invalid input", &Retention{}, true}, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + r := &Retention{} + err := r.Parse(test.input) + if !test.shouldErr { + ExpectNoError(t, err) + } else { + ExpectDeepEqual(t, r, test.expected) + } + }) + } +} diff --git a/internal/net/http/accesslog/status_code_range.go b/internal/net/gphttp/accesslog/status_code_range.go similarity index 81% rename from internal/net/http/accesslog/status_code_range.go rename to internal/net/gphttp/accesslog/status_code_range.go index 599f119..7ec94a2 100644 --- a/internal/net/http/accesslog/status_code_range.go +++ b/internal/net/gphttp/accesslog/status_code_range.go @@ -3,7 +3,7 @@ package accesslog import ( "strconv" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -12,7 +12,7 @@ type StatusCodeRange struct { End int } -var ErrInvalidStatusCodeRange = E.New("invalid status code range") +var ErrInvalidStatusCodeRange = gperr.New("invalid status code range") func (r *StatusCodeRange) Includes(code int) bool { return r.Start <= code && code <= r.End @@ -25,7 +25,7 @@ func (r *StatusCodeRange) Parse(v string) error { case 1: start, err := strconv.Atoi(split[0]) if err != nil { - return E.From(err) + return gperr.Wrap(err) } r.Start = start r.End = start @@ -33,7 +33,7 @@ func (r *StatusCodeRange) Parse(v string) error { case 2: start, errStart := strconv.Atoi(split[0]) end, errEnd := strconv.Atoi(split[1]) - if err := E.Join(errStart, errEnd); err != nil { + if err := gperr.Join(errStart, errEnd); err != nil { return err } r.Start = start diff --git a/internal/net/http/content_type.go b/internal/net/gphttp/content_type.go similarity index 98% rename from internal/net/http/content_type.go rename to internal/net/gphttp/content_type.go index 3b4be65..dee78ff 100644 --- a/internal/net/http/content_type.go +++ b/internal/net/gphttp/content_type.go @@ -1,4 +1,4 @@ -package http +package gphttp import ( "mime" diff --git a/internal/net/http/content_type_test.go b/internal/net/gphttp/content_type_test.go similarity index 99% rename from internal/net/http/content_type_test.go rename to internal/net/gphttp/content_type_test.go index ee4ea56..f5bba69 100644 --- a/internal/net/http/content_type_test.go +++ b/internal/net/gphttp/content_type_test.go @@ -1,4 +1,4 @@ -package http +package gphttp import ( "net/http" diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/gphttp/loadbalancer/ip_hash.go similarity index 89% rename from internal/net/http/loadbalancer/ip_hash.go rename to internal/net/gphttp/loadbalancer/ip_hash.go index 384f7cf..d8a54ed 100644 --- a/internal/net/http/loadbalancer/ip_hash.go +++ b/internal/net/gphttp/loadbalancer/ip_hash.go @@ -6,8 +6,8 @@ import ( "net/http" "sync" - E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/net/http/middleware" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/net/gphttp/middleware" ) type ipHash struct { @@ -23,10 +23,10 @@ func (lb *LoadBalancer) newIPHash() impl { if len(lb.Options) == 0 { return impl } - var err E.Error + var err gperr.Error impl.realIP, err = middleware.RealIP.New(lb.Options) if err != nil { - E.LogError("invalid real_ip options, ignoring", err, &impl.l) + gperr.LogError("invalid real_ip options, ignoring", err, &impl.l) } return impl } diff --git a/internal/net/http/loadbalancer/least_conn.go b/internal/net/gphttp/loadbalancer/least_conn.go similarity index 100% rename from internal/net/http/loadbalancer/least_conn.go rename to internal/net/gphttp/loadbalancer/least_conn.go diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/gphttp/loadbalancer/loadbalancer.go similarity index 92% rename from internal/net/http/loadbalancer/loadbalancer.go rename to internal/net/gphttp/loadbalancer/loadbalancer.go index b453648..3d474cf 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/gphttp/loadbalancer/loadbalancer.go @@ -6,10 +6,10 @@ import ( "time" "github.com/rs/zerolog" - "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" + "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" "github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher/health" @@ -54,7 +54,7 @@ func New(cfg *Config) *LoadBalancer { } // Start implements task.TaskStarter. -func (lb *LoadBalancer) Start(parent task.Parent) E.Error { +func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error { lb.startTime = time.Now() lb.task = parent.Subtask("loadbalancer."+lb.Link, false) parent.OnCancel("lb_remove_route", func() { @@ -125,12 +125,12 @@ func (lb *LoadBalancer) AddServer(srv Server) { lb.poolMu.Lock() defer lb.poolMu.Unlock() - if lb.pool.Has(srv.Name()) { - old, _ := lb.pool.Load(srv.Name()) + if lb.pool.Has(srv.Key()) { // FIXME: this should be a warning + old, _ := lb.pool.Load(srv.Key()) lb.sumWeight -= old.Weight() lb.impl.OnRemoveServer(old) } - lb.pool.Store(srv.Name(), srv) + lb.pool.Store(srv.Key(), srv) lb.sumWeight += srv.Weight() lb.rebalance() @@ -146,11 +146,11 @@ func (lb *LoadBalancer) RemoveServer(srv Server) { lb.poolMu.Lock() defer lb.poolMu.Unlock() - if !lb.pool.Has(srv.Name()) { + if !lb.pool.Has(srv.Key()) { return } - lb.pool.Delete(srv.Name()) + lb.pool.Delete(srv.Key()) lb.sumWeight -= srv.Weight() lb.rebalance() @@ -227,7 +227,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) return } - if r.Header.Get(common.HeaderCheckRedirect) != "" { + if r.Header.Get(httpheaders.HeaderGoDoxyCheckRedirect) != "" { // wake all servers for _, srv := range srvs { if err := srv.TryWake(); err != nil { @@ -244,7 +244,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (lb *LoadBalancer) MarshalJSON() ([]byte, error) { extra := make(map[string]any) lb.pool.RangeAll(func(k string, v Server) { - extra[v.Name()] = v + extra[v.Key()] = v }) return (&monitor.JSONRepresentation{ diff --git a/internal/net/http/loadbalancer/loadbalancer_test.go b/internal/net/gphttp/loadbalancer/loadbalancer_test.go similarity index 95% rename from internal/net/http/loadbalancer/loadbalancer_test.go rename to internal/net/gphttp/loadbalancer/loadbalancer_test.go index 25199b5..03f2bfc 100644 --- a/internal/net/http/loadbalancer/loadbalancer_test.go +++ b/internal/net/gphttp/loadbalancer/loadbalancer_test.go @@ -3,7 +3,7 @@ package loadbalancer import ( "testing" - "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" + "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) diff --git a/internal/net/http/loadbalancer/round_robin.go b/internal/net/gphttp/loadbalancer/round_robin.go similarity index 100% rename from internal/net/http/loadbalancer/round_robin.go rename to internal/net/gphttp/loadbalancer/round_robin.go diff --git a/internal/net/http/loadbalancer/types.go b/internal/net/gphttp/loadbalancer/types.go similarity index 72% rename from internal/net/http/loadbalancer/types.go rename to internal/net/gphttp/loadbalancer/types.go index 36b45ad..2a83a8f 100644 --- a/internal/net/http/loadbalancer/types.go +++ b/internal/net/gphttp/loadbalancer/types.go @@ -1,7 +1,7 @@ package loadbalancer import ( - "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" + "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" ) type ( diff --git a/internal/net/http/loadbalancer/types/config.go b/internal/net/gphttp/loadbalancer/types/config.go similarity index 100% rename from internal/net/http/loadbalancer/types/config.go rename to internal/net/gphttp/loadbalancer/types/config.go diff --git a/internal/net/http/loadbalancer/types/mode.go b/internal/net/gphttp/loadbalancer/types/mode.go similarity index 100% rename from internal/net/http/loadbalancer/types/mode.go rename to internal/net/gphttp/loadbalancer/types/mode.go diff --git a/internal/net/http/loadbalancer/types/server.go b/internal/net/gphttp/loadbalancer/types/server.go similarity index 91% rename from internal/net/http/loadbalancer/types/server.go rename to internal/net/gphttp/loadbalancer/types/server.go index e0e7b5e..15abc0a 100644 --- a/internal/net/http/loadbalancer/types/server.go +++ b/internal/net/gphttp/loadbalancer/types/server.go @@ -26,6 +26,7 @@ type ( http.Handler health.HealthMonitor Name() string + Key() string URL() *net.URL Weight() Weight SetWeight(weight Weight) @@ -51,6 +52,7 @@ func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, h func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server { srv := &server{ weight: Weight(weight), + url: net.MustParseURL("http://localhost"), } return srv } @@ -63,6 +65,10 @@ func (srv *server) URL() *net.URL { return srv.url } +func (srv *server) Key() string { + return srv.url.Host +} + func (srv *server) Weight() Weight { return srv.weight } @@ -78,9 +84,7 @@ func (srv *server) String() string { func (srv *server) TryWake() error { waker, ok := srv.Handler.(idlewatcher.Waker) if ok { - if err := waker.Wake(); err != nil { - return err - } + return waker.Wake() } return nil } diff --git a/internal/net/http/loadbalancer/types/weight.go b/internal/net/gphttp/loadbalancer/types/weight.go similarity index 100% rename from internal/net/http/loadbalancer/types/weight.go rename to internal/net/gphttp/loadbalancer/types/weight.go diff --git a/internal/net/http/methods.go b/internal/net/gphttp/methods.go similarity index 95% rename from internal/net/http/methods.go rename to internal/net/gphttp/methods.go index caca564..6e49202 100644 --- a/internal/net/http/methods.go +++ b/internal/net/gphttp/methods.go @@ -1,4 +1,4 @@ -package http +package gphttp import "net/http" diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/gphttp/middleware/cidr_whitelist.go similarity index 97% rename from internal/net/http/middleware/cidr_whitelist.go rename to internal/net/gphttp/middleware/cidr_whitelist.go index e123c86..6b9271f 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/gphttp/middleware/cidr_whitelist.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/go-playground/validator/v10" - gphttp "github.com/yusing/go-proxy/internal/net/http" + gphttp "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/gphttp/middleware/cidr_whitelist_test.go similarity index 96% rename from internal/net/http/middleware/cidr_whitelist_test.go rename to internal/net/gphttp/middleware/cidr_whitelist_test.go index 64fc9e8..a8c7cee 100644 --- a/internal/net/http/middleware/cidr_whitelist_test.go +++ b/internal/net/gphttp/middleware/cidr_whitelist_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/utils" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -61,7 +61,7 @@ func TestCIDRWhitelistValidation(t *testing.T) { } func TestCIDRWhitelist(t *testing.T) { - errs := E.NewBuilder("") + errs := gperr.NewBuilder("") mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs) ExpectNoError(t, errs.Error()) deny = mids["deny@file"] diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/gphttp/middleware/cloudflare_real_ip.go similarity index 88% rename from internal/net/http/middleware/cloudflare_real_ip.go rename to internal/net/gphttp/middleware/cloudflare_real_ip.go index 845bc20..4fdcf2e 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/gphttp/middleware/cloudflare_real_ip.go @@ -12,6 +12,7 @@ import ( "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils/atomic" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -28,7 +29,7 @@ const ( ) var ( - cfCIDRsLastUpdate time.Time + cfCIDRsLastUpdate atomic.Value[time.Time] cfCIDRsMu sync.Mutex // RFC 1918. @@ -68,14 +69,14 @@ func (cri *cloudflareRealIP) getTracer() *Tracer { } func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { - if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval { + if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval { return } cfCIDRsMu.Lock() defer cfCIDRsMu.Unlock() - if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval { + if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval { return } @@ -88,7 +89,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs), ) if err != nil { - cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval) + cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)) logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval)) return nil } @@ -97,7 +98,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { } } - cfCIDRsLastUpdate = time.Now() + cfCIDRsLastUpdate.Store(time.Now()) logging.Info().Msg("cloudflare CIDR range updated") return } diff --git a/internal/net/http/middleware/custom_error_page.go b/internal/net/gphttp/middleware/custom_error_page.go similarity index 72% rename from internal/net/http/middleware/custom_error_page.go rename to internal/net/gphttp/middleware/custom_error_page.go index 730422f..76110be 100644 --- a/internal/net/http/middleware/custom_error_page.go +++ b/internal/net/gphttp/middleware/custom_error_page.go @@ -9,14 +9,17 @@ import ( "strings" "github.com/yusing/go-proxy/internal/logging" - gphttp "github.com/yusing/go-proxy/internal/net/http" - "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" + gphttp "github.com/yusing/go-proxy/internal/net/gphttp" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" + "github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage" ) type customErrorPage struct{} var CustomErrorPage = NewMiddleware[customErrorPage]() +const StaticFilePathPrefix = "/$gperrorpage/" + // before implements RequestModifier. func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) { return !ServeStaticErrorPageFile(w, r) @@ -34,8 +37,8 @@ func (customErrorPage) modifyResponse(resp *http.Response) error { resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(errorPage)) resp.ContentLength = int64(len(errorPage)) - resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage))) - resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8") + resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage))) + resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8") } else { logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode) } @@ -49,8 +52,8 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo if path != "" && path[0] != '/' { path = "/" + path } - if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) { - filename := path[len(gphttp.StaticFilePathPrefix):] + if strings.HasPrefix(path, StaticFilePathPrefix) { + filename := path[len(StaticFilePathPrefix):] file, ok := errorpage.GetStaticFile(filename) if !ok { logging.Error().Msg("unable to load resource " + filename) @@ -59,11 +62,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo ext := filepath.Ext(filename) switch ext { case ".html": - w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8") + w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8") case ".js": - w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8") + w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8") case ".css": - w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8") + w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8") default: logging.Error().Msgf("unexpected file type %q for %s", ext, filename) } diff --git a/internal/net/http/middleware/errorpage/error_page.go b/internal/net/gphttp/middleware/errorpage/error_page.go similarity index 95% rename from internal/net/http/middleware/errorpage/error_page.go rename to internal/net/gphttp/middleware/errorpage/error_page.go index 2fb09e1..a1acec6 100644 --- a/internal/net/http/middleware/errorpage/error_page.go +++ b/internal/net/gphttp/middleware/errorpage/error_page.go @@ -7,7 +7,7 @@ import ( "sync" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" @@ -90,7 +90,7 @@ func watchDir() { loadContent() } case err := <-errCh: - E.LogError("error watching error page directory", err) + gperr.LogError("error watching error page directory", err) } } } diff --git a/internal/net/http/middleware/metrics_logger/metrics_logger.go b/internal/net/gphttp/middleware/metrics_logger/metrics_logger.go similarity index 100% rename from internal/net/http/middleware/metrics_logger/metrics_logger.go rename to internal/net/gphttp/middleware/metrics_logger/metrics_logger.go diff --git a/internal/net/http/middleware/metrics_logger/metrics_response_writer.go b/internal/net/gphttp/middleware/metrics_logger/metrics_response_writer.go similarity index 100% rename from internal/net/http/middleware/metrics_logger/metrics_response_writer.go rename to internal/net/gphttp/middleware/metrics_logger/metrics_response_writer.go diff --git a/internal/net/http/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go similarity index 90% rename from internal/net/http/middleware/middleware.go rename to internal/net/gphttp/middleware/middleware.go index b206d26..ccf6705 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -7,15 +7,15 @@ import ( "sort" "strings" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" - gphttp "github.com/yusing/go-proxy/internal/net/http" - "github.com/yusing/go-proxy/internal/net/http/reverseproxy" + gphttp "github.com/yusing/go-proxy/internal/net/gphttp" + "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/utils" ) type ( - Error = E.Error + Error = gperr.Error ReverseProxy = reverseproxy.ReverseProxy ProxyRequest = reverseproxy.ProxyRequest @@ -80,7 +80,7 @@ func NewMiddleware[ImplType any]() *Middleware { func (m *Middleware) enableTrace() { if tracer, ok := m.impl.(MiddlewareWithTracer); ok { tracer.enableTrace() - logging.Debug().Msgf("middleware %s enabled trace", m.name) + logging.Trace().Msgf("middleware %s enabled trace", m.name) } } @@ -103,7 +103,7 @@ func (m *Middleware) setup() { } } -func (m *Middleware) apply(optsRaw OptionsRaw) E.Error { +func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error { if len(optsRaw) == 0 { return nil } @@ -132,10 +132,10 @@ func (m *Middleware) finalize() error { return nil } -func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { +func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) { if m.construct == nil { // likely a middleware from compose if len(optsRaw) != 0 { - return nil, E.New("additional options not allowed for middleware ").Subject(m.name) + return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name) } return m, nil } @@ -145,7 +145,7 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { return nil, err } if err := mid.finalize(); err != nil { - return nil, E.From(err) + return nil, gperr.Wrap(err) } return mid, nil } @@ -196,7 +196,7 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r * next(w, r) } -func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) { +func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) { var middlewares []*Middleware middlewares, err = compileMiddlewares(middlewaresMap) if err != nil { diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/gphttp/middleware/middleware_builder.go similarity index 79% rename from internal/net/http/middleware/middleware_builder.go rename to internal/net/gphttp/middleware/middleware_builder.go index 8ea5403..c7be596 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/gphttp/middleware/middleware_builder.go @@ -6,13 +6,13 @@ import ( "path" "sort" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "gopkg.in/yaml.v3" ) -var ErrMissingMiddlewareUse = E.New("missing middleware 'use' field") +var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field") -func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware { +func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware { fileContent, err := os.ReadFile(filePath) if err != nil { eb.Add(err) @@ -21,7 +21,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string] return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb) } -func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware { +func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware { var rawMap map[string][]map[string]any err := yaml.Unmarshal(data, &rawMap) if err != nil { @@ -40,11 +40,11 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str return middlewares } -func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) { +func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) { middlewares := make([]*Middleware, 0, len(middlewaresMap)) - errs := E.NewBuilder("middlewares compile error") - invalidOpts := E.NewBuilder("options compile error") + errs := gperr.NewBuilder("middlewares compile error") + invalidOpts := gperr.NewBuilder("options compile error") for name, opts := range middlewaresMap { m, err := Get(name) @@ -68,7 +68,7 @@ func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E. return middlewares, errs.Error() } -func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, E.Error) { +func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) { compiled, err := compileMiddlewares(middlewaresMap) if err != nil { return nil, err @@ -77,8 +77,8 @@ func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) ( } // TODO: check conflict or duplicates. -func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, E.Error) { - chainErr := E.NewBuilder("") +func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) { + chainErr := gperr.NewBuilder("") chain := make([]*Middleware, 0, len(defs)) for i, def := range defs { if def["use"] == nil || def["use"] == "" { diff --git a/internal/net/http/middleware/middleware_builder_test.go b/internal/net/gphttp/middleware/middleware_builder_test.go similarity index 85% rename from internal/net/http/middleware/middleware_builder_test.go rename to internal/net/gphttp/middleware/middleware_builder_test.go index 2c9828c..08e8402 100644 --- a/internal/net/http/middleware/middleware_builder_test.go +++ b/internal/net/gphttp/middleware/middleware_builder_test.go @@ -5,7 +5,7 @@ import ( "encoding/json" "testing" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -13,7 +13,7 @@ import ( var testMiddlewareCompose []byte func TestBuild(t *testing.T) { - errs := E.NewBuilder("") + errs := gperr.NewBuilder("") middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs) ExpectNoError(t, errs.Error()) Must(json.MarshalIndent(middlewares, "", " ")) diff --git a/internal/net/http/middleware/middleware_chain.go b/internal/net/gphttp/middleware/middleware_chain.go similarity index 90% rename from internal/net/http/middleware/middleware_chain.go rename to internal/net/gphttp/middleware/middleware_chain.go index 932d278..cf1258b 100644 --- a/internal/net/http/middleware/middleware_chain.go +++ b/internal/net/gphttp/middleware/middleware_chain.go @@ -4,7 +4,7 @@ import ( "net/http" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" ) type middlewareChain struct { @@ -51,10 +51,10 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error { if len(m.modResps) == 0 { return nil } - errs := E.NewBuilder("modify response errors") + errs := gperr.NewBuilder("modify response errors") for i, mr := range m.modResps { if err := mr.modifyResponse(resp); err != nil { - errs.Add(E.From(err).Subjectf("%d", i)) + errs.Add(gperr.Wrap(err).Subjectf("%d", i)) } } return errs.Error() diff --git a/internal/net/http/middleware/middleware_test.go b/internal/net/gphttp/middleware/middleware_test.go similarity index 100% rename from internal/net/http/middleware/middleware_test.go rename to internal/net/gphttp/middleware/middleware_test.go diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/gphttp/middleware/middlewares.go similarity index 85% rename from internal/net/http/middleware/middlewares.go rename to internal/net/gphttp/middleware/middlewares.go index 261954b..184de94 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/gphttp/middleware/middlewares.go @@ -4,7 +4,7 @@ import ( "path" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -32,15 +32,11 @@ var allMiddlewares = map[string]*Middleware{ "cidrwhitelist": CIDRWhiteList, "ratelimit": RateLimiter, - - // !experimental - "forwardauth": ForwardAuth, - // "oauth2": OAuth2.m, } var ( - ErrUnknownMiddleware = E.New("unknown middleware") - ErrDuplicatedMiddleware = E.New("duplicated middleware") + ErrUnknownMiddleware = gperr.New("unknown middleware") + ErrDuplicatedMiddleware = gperr.New("duplicated middleware") ) func Get(name string) (*Middleware, Error) { @@ -58,14 +54,14 @@ func All() map[string]*Middleware { } func LoadComposeFiles() { - errs := E.NewBuilder("middleware compile errors") + errs := gperr.NewBuilder("middleware compile errors") middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0) if err != nil { logging.Err(err).Msg("failed to list middleware definitions") return } for _, defFile := range middlewareDefs { - voidErrs := E.NewBuilder("") // ignore these errors, will be added in next step + voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs) if len(mws) == 0 { continue @@ -103,6 +99,6 @@ func LoadComposeFiles() { } } if errs.HasError() { - E.LogError(errs.About(), errs.Error()) + gperr.LogError(errs.About(), errs.Error()) } } diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/gphttp/middleware/modify_request.go similarity index 100% rename from internal/net/http/middleware/modify_request.go rename to internal/net/gphttp/middleware/modify_request.go diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/gphttp/middleware/modify_request_test.go similarity index 100% rename from internal/net/http/middleware/modify_request_test.go rename to internal/net/gphttp/middleware/modify_request_test.go diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/gphttp/middleware/modify_response.go similarity index 100% rename from internal/net/http/middleware/modify_response.go rename to internal/net/gphttp/middleware/modify_response.go diff --git a/internal/net/http/middleware/modify_response_test.go b/internal/net/gphttp/middleware/modify_response_test.go similarity index 100% rename from internal/net/http/middleware/modify_response_test.go rename to internal/net/gphttp/middleware/modify_response_test.go diff --git a/internal/net/http/middleware/oidc.go b/internal/net/gphttp/middleware/oidc.go similarity index 85% rename from internal/net/http/middleware/oidc.go rename to internal/net/gphttp/middleware/oidc.go index 3af1ca3..231422b 100644 --- a/internal/net/http/middleware/oidc.go +++ b/internal/net/gphttp/middleware/oidc.go @@ -1,12 +1,13 @@ package middleware import ( + "errors" "net/http" "sync" "sync/atomic" "github.com/yusing/go-proxy/internal/api/v1/auth" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" ) type oidcMiddleware struct { @@ -24,7 +25,7 @@ var OIDC = NewMiddleware[oidcMiddleware]() func (amw *oidcMiddleware) finalize() error { if !auth.IsOIDCEnabled() { - return E.New("OIDC not enabled but OIDC middleware is used") + return gperr.New("OIDC not enabled but OIDC middleware is used") } return nil } @@ -64,9 +65,6 @@ func (amw *oidcMiddleware) initSlow() error { amw.authMux = http.NewServeMux() amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler) - amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - }) amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage) amw.auth = authProvider return nil @@ -79,13 +77,17 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce return false } - if err := amw.auth.CheckToken(r); err != nil { - amw.authMux.ServeHTTP(w, r) - return false - } if r.URL.Path == auth.OIDCLogoutPath { amw.auth.LogoutCallbackHandler(w, r) return false } + if err := amw.auth.CheckToken(r); err != nil { + if errors.Is(err, auth.ErrMissingToken) { + amw.authMux.ServeHTTP(w, r) + } else { + auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath) + } + return false + } return true } diff --git a/internal/net/http/middleware/rate_limit.go b/internal/net/gphttp/middleware/rate_limit.go similarity index 100% rename from internal/net/http/middleware/rate_limit.go rename to internal/net/gphttp/middleware/rate_limit.go diff --git a/internal/net/http/middleware/rate_limit_test.go b/internal/net/gphttp/middleware/rate_limit_test.go similarity index 100% rename from internal/net/http/middleware/rate_limit_test.go rename to internal/net/gphttp/middleware/rate_limit_test.go diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/gphttp/middleware/real_ip.go similarity index 95% rename from internal/net/http/middleware/real_ip.go rename to internal/net/gphttp/middleware/real_ip.go index 0b5a53d..ed11d12 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/gphttp/middleware/real_ip.go @@ -4,7 +4,7 @@ import ( "net" "net/http" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/types" ) @@ -111,6 +111,6 @@ func (ri *realIP) setRealIP(req *http.Request) { req.RemoteAddr = lastNonTrustedIP req.Header.Set(ri.Header, lastNonTrustedIP) - req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP) + req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP) ri.AddTracef("set real ip %s", lastNonTrustedIP) } diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/gphttp/middleware/real_ip_test.go similarity index 90% rename from internal/net/http/middleware/real_ip_test.go rename to internal/net/gphttp/middleware/real_ip_test.go index 02f5bd5..372862d 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/gphttp/middleware/real_ip_test.go @@ -6,14 +6,14 @@ import ( "strings" "testing" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) func TestSetRealIPOpts(t *testing.T) { opts := OptionsRaw{ - "header": gphttp.HeaderXRealIP, + "header": httpheaders.HeaderXRealIP, "from": []string{ "127.0.0.0/8", "192.168.0.0/16", @@ -22,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) { "recursive": true, } optExpected := &RealIPOpts{ - Header: gphttp.HeaderXRealIP, + Header: httpheaders.HeaderXRealIP, From: []*types.CIDR{ { IP: net.ParseIP("127.0.0.0"), @@ -51,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) { func TestSetRealIP(t *testing.T) { const ( - testHeader = gphttp.HeaderXRealIP + testHeader = httpheaders.HeaderXRealIP testRealIP = "192.168.1.1" ) opts := OptionsRaw{ diff --git a/internal/net/http/middleware/redirect_http.go b/internal/net/gphttp/middleware/redirect_http.go similarity index 93% rename from internal/net/http/middleware/redirect_http.go rename to internal/net/gphttp/middleware/redirect_http.go index 1419d01..26ec2cc 100644 --- a/internal/net/http/middleware/redirect_http.go +++ b/internal/net/gphttp/middleware/redirect_http.go @@ -44,7 +44,7 @@ func (m *redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed b r.URL.Host = host } - http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently) + http.Redirect(w, r, r.URL.String(), http.StatusPermanentRedirect) logging.Debug().Str("url", r.URL.String()).Str("user_agent", r.UserAgent()).Msg("redirect to https") return false diff --git a/internal/net/http/middleware/redirect_http_test.go b/internal/net/gphttp/middleware/redirect_http_test.go similarity index 90% rename from internal/net/http/middleware/redirect_http_test.go rename to internal/net/gphttp/middleware/redirect_http_test.go index 24c2662..eccd33c 100644 --- a/internal/net/http/middleware/redirect_http_test.go +++ b/internal/net/gphttp/middleware/redirect_http_test.go @@ -13,7 +13,7 @@ func TestRedirectToHTTPs(t *testing.T) { reqURL: types.MustParseURL("http://example.com"), }) ExpectNoError(t, err) - ExpectEqual(t, result.ResponseStatus, http.StatusMovedPermanently) + ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect) ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com") } diff --git a/internal/net/http/middleware/set_upstream_headers.go b/internal/net/gphttp/middleware/set_upstream_headers.go similarity index 63% rename from internal/net/http/middleware/set_upstream_headers.go rename to internal/net/gphttp/middleware/set_upstream_headers.go index 009fc84..434c4cd 100644 --- a/internal/net/http/middleware/set_upstream_headers.go +++ b/internal/net/gphttp/middleware/set_upstream_headers.go @@ -3,8 +3,8 @@ package middleware import ( "net/http" - gphttp "github.com/yusing/go-proxy/internal/net/http" - "github.com/yusing/go-proxy/internal/net/http/reverseproxy" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" + "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" ) // internal use only. @@ -29,9 +29,9 @@ func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware { // before implements RequestModifier. func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) { - r.Header.Set(gphttp.HeaderUpstreamName, s.Name) - r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme) - r.Header.Set(gphttp.HeaderUpstreamHost, s.Host) - r.Header.Set(gphttp.HeaderUpstreamPort, s.Port) + r.Header.Set(httpheaders.HeaderUpstreamName, s.Name) + r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme) + r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host) + r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port) return true } diff --git a/internal/net/http/middleware/test_data/cidr_whitelist_test.yml b/internal/net/gphttp/middleware/test_data/cidr_whitelist_test.yml similarity index 100% rename from internal/net/http/middleware/test_data/cidr_whitelist_test.yml rename to internal/net/gphttp/middleware/test_data/cidr_whitelist_test.yml diff --git a/internal/net/http/middleware/test_data/middleware_compose.yml b/internal/net/gphttp/middleware/test_data/middleware_compose.yml similarity index 55% rename from internal/net/http/middleware/test_data/middleware_compose.yml rename to internal/net/gphttp/middleware/test_data/middleware_compose.yml index 4ec3040..2e1718b 100644 --- a/internal/net/http/middleware/test_data/middleware_compose.yml +++ b/internal/net/gphttp/middleware/test_data/middleware_compose.yml @@ -8,19 +8,6 @@ theGreatPretender: - X-Test3 - X-Test4 -notAuthenticAuthentik: - - use: RedirectHTTP - - use: ForwardAuth - address: https://authentik.company - trustForwardHeader: true - addAuthCookiesToResponse: - - session_id - - user_id - authResponseHeaders: - - X-Auth-SessionID - - X-Auth-UserID - - use: CustomErrorPage - realIPAuthentik: - use: RedirectHTTP - use: RealIP @@ -30,9 +17,6 @@ realIPAuthentik: - "192.168.0.0/16" - "172.16.0.0/12" recursive: true - - use: ForwardAuth - address: https://authentik.company - trustForwardHeader: true testFakeRealIP: - use: ModifyRequest diff --git a/internal/net/http/middleware/test_data/sample_headers.json b/internal/net/gphttp/middleware/test_data/sample_headers.json similarity index 100% rename from internal/net/http/middleware/test_data/sample_headers.json rename to internal/net/gphttp/middleware/test_data/sample_headers.json diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/gphttp/middleware/test_utils.go similarity index 94% rename from internal/net/http/middleware/test_utils.go rename to internal/net/gphttp/middleware/test_utils.go index 0adb1a5..2bd208b 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/gphttp/middleware/test_utils.go @@ -9,8 +9,8 @@ import ( "net/http/httptest" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/net/http/reverseproxy" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -122,7 +122,7 @@ func (args *testArgs) bodyReader() io.Reader { return nil } -func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) { +func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) { if args == nil { args = new(testArgs) } @@ -136,7 +136,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E return newMiddlewaresTest([]*Middleware{mid}, args) } -func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) { +func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) { if args == nil { args = new(testArgs) } @@ -163,7 +163,7 @@ func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, data, err := io.ReadAll(resp.Body) if err != nil { - return nil, E.From(err) + return nil, gperr.Wrap(err) } return &TestResult{ diff --git a/internal/net/http/middleware/trace.go b/internal/net/gphttp/middleware/trace.go similarity index 87% rename from internal/net/http/middleware/trace.go rename to internal/net/gphttp/middleware/trace.go index c3b0c73..4c9550a 100644 --- a/internal/net/http/middleware/trace.go +++ b/internal/net/gphttp/middleware/trace.go @@ -4,7 +4,7 @@ import ( "net/http" "sync" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" ) type ( @@ -37,7 +37,7 @@ func (tr *Trace) WithRequest(req *http.Request) *Trace { return nil } tr.URL = req.RequestURI - tr.ReqHeaders = gphttp.HeaderToMap(req.Header) + tr.ReqHeaders = httpheaders.HeaderToMap(req.Header) return tr } @@ -46,8 +46,8 @@ func (tr *Trace) WithResponse(resp *http.Response) *Trace { return nil } tr.URL = resp.Request.RequestURI - tr.ReqHeaders = gphttp.HeaderToMap(resp.Request.Header) - tr.RespHeaders = gphttp.HeaderToMap(resp.Header) + tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header) + tr.RespHeaders = httpheaders.HeaderToMap(resp.Header) tr.RespStatus = resp.StatusCode return tr } diff --git a/internal/net/http/middleware/tracer.go b/internal/net/gphttp/middleware/tracer.go similarity index 100% rename from internal/net/http/middleware/tracer.go rename to internal/net/gphttp/middleware/tracer.go diff --git a/internal/net/http/middleware/vars.go b/internal/net/gphttp/middleware/vars.go similarity index 90% rename from internal/net/http/middleware/vars.go rename to internal/net/gphttp/middleware/vars.go index 0830542..472ca72 100644 --- a/internal/net/http/middleware/vars.go +++ b/internal/net/gphttp/middleware/vars.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" ) type ( @@ -91,25 +91,25 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ return "" }, VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, - VarUpstreamName: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamName) }, - VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, - VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, - VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, + VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) }, + VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) }, + VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) }, + VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) }, VarUpstreamAddr: func(req *http.Request) string { - upHost := req.Header.Get(gphttp.HeaderUpstreamHost) - upPort := req.Header.Get(gphttp.HeaderUpstreamPort) + upHost := req.Header.Get(httpheaders.HeaderUpstreamHost) + upPort := req.Header.Get(httpheaders.HeaderUpstreamPort) if upPort != "" { return upHost + ":" + upPort } return upHost }, VarUpstreamURL: func(req *http.Request) string { - upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) + upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme) if upScheme == "" { return "" } - upHost := req.Header.Get(gphttp.HeaderUpstreamHost) - upPort := req.Header.Get(gphttp.HeaderUpstreamPort) + upHost := req.Header.Get(httpheaders.HeaderUpstreamHost) + upPort := req.Header.Get(httpheaders.HeaderUpstreamPort) upAddr := upHost if upPort != "" { upAddr += ":" + upPort diff --git a/internal/net/http/middleware/x_forwarded.go b/internal/net/gphttp/middleware/x_forwarded.go similarity index 82% rename from internal/net/http/middleware/x_forwarded.go rename to internal/net/gphttp/middleware/x_forwarded.go index ff8a558..a2de34b 100644 --- a/internal/net/http/middleware/x_forwarded.go +++ b/internal/net/gphttp/middleware/x_forwarded.go @@ -5,7 +5,7 @@ import ( "net/http" "strings" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" ) type ( @@ -20,10 +20,10 @@ var ( // before implements RequestModifier. func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) { - r.Header.Del(gphttp.HeaderXForwardedFor) + r.Header.Del(httpheaders.HeaderXForwardedFor) clientIP, _, err := net.SplitHostPort(r.RemoteAddr) if err == nil { - r.Header.Set(gphttp.HeaderXForwardedFor, clientIP) + r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP) } return true } diff --git a/internal/net/http/modify_response_writer.go b/internal/net/gphttp/modify_response_writer.go similarity index 99% rename from internal/net/http/modify_response_writer.go rename to internal/net/gphttp/modify_response_writer.go index a8c7b89..41f846d 100644 --- a/internal/net/http/modify_response_writer.go +++ b/internal/net/gphttp/modify_response_writer.go @@ -1,7 +1,7 @@ // Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go) // Copyright (c) 2020-2024 Traefik Labs -package http +package gphttp import ( "bufio" diff --git a/internal/net/http/serve_mux.go b/internal/net/gphttp/serve_mux.go similarity index 97% rename from internal/net/http/serve_mux.go rename to internal/net/gphttp/serve_mux.go index 0e487b6..902c0e2 100644 --- a/internal/net/http/serve_mux.go +++ b/internal/net/gphttp/serve_mux.go @@ -1,4 +1,4 @@ -package http +package gphttp import "net/http" diff --git a/internal/net/http/status_code.go b/internal/net/gphttp/status_code.go similarity index 93% rename from internal/net/http/status_code.go rename to internal/net/gphttp/status_code.go index 8235805..25977df 100644 --- a/internal/net/http/status_code.go +++ b/internal/net/gphttp/status_code.go @@ -1,4 +1,4 @@ -package http +package gphttp import "net/http" diff --git a/internal/net/gphttp/transport.go b/internal/net/gphttp/transport.go new file mode 100644 index 0000000..d633ee6 --- /dev/null +++ b/internal/net/gphttp/transport.go @@ -0,0 +1,34 @@ +package gphttp + +import ( + "crypto/tls" + "net" + "net/http" + "time" +) + +var DefaultDialer = net.Dialer{ + Timeout: 5 * time.Second, +} + +func NewTransport() *http.Transport { + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: DefaultDialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + // DisableCompression: true, // Prevent double compression + ResponseHeaderTimeout: 60 * time.Second, + WriteBufferSize: 16 * 1024, // 16KB + ReadBufferSize: 16 * 1024, // 16KB + } +} + +func NewTransportWithTLSConfig(tlsConfig *tls.Config) *http.Transport { + tr := NewTransport() + tr.TLSClientConfig = tlsConfig + return tr +} diff --git a/internal/net/http/accesslog/retention.go b/internal/net/http/accesslog/retention.go deleted file mode 100644 index da31544..0000000 --- a/internal/net/http/accesslog/retention.go +++ /dev/null @@ -1,198 +0,0 @@ -package accesslog - -import ( - "bufio" - "bytes" - "io" - "strconv" - "time" - - E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/utils/strutils" -) - -type Retention struct { - Days uint64 `json:"days"` - Last uint64 `json:"last"` -} - -const chunkSizeMax int64 = 128 * 1024 // 128KB - -var ( - ErrInvalidSyntax = E.New("invalid syntax") - ErrZeroValue = E.New("zero value") -) - -// Syntax: -// -// days|weeks|months -// -// last -// -// Parse implements strutils.Parser. -func (r *Retention) Parse(v string) (err error) { - split := strutils.SplitSpace(v) - if len(split) != 2 { - return ErrInvalidSyntax.Subject(v) - } - switch split[0] { - case "last": - r.Last, err = strconv.ParseUint(split[1], 10, 64) - default: // days|weeks|months - r.Days, err = strconv.ParseUint(split[0], 10, 64) - if err != nil { - return - } - switch split[1] { - case "days": - case "weeks": - r.Days *= 7 - case "months": - r.Days *= 30 - default: - return ErrInvalidSyntax.Subject("unit " + split[1]) - } - } - if r.Days == 0 && r.Last == 0 { - return ErrZeroValue - } - return -} - -func (r *Retention) rotateLogFile(file AccessLogIO) (err error) { - lastN := int(r.Last) - days := int(r.Days) - - // Seek to end to get file size - size, err := file.Seek(0, io.SeekEnd) - if err != nil { - return err - } - - // Initialize ring buffer for last N lines - lines := make([][]byte, 0, lastN|(days*1000)) - pos := size - unprocessed := 0 - - var chunk [chunkSizeMax]byte - var lastLine []byte - - var shouldStop func() bool - if days > 0 { - cutoff := time.Now().AddDate(0, 0, -days) - shouldStop = func() bool { - return len(lastLine) > 0 && !parseLogTime(lastLine).After(cutoff) - } - } else { - shouldStop = func() bool { - return len(lines) == lastN - } - } - - // Read backwards until we have enough lines or reach start of file - for pos > 0 { - if pos > chunkSizeMax { - pos -= chunkSizeMax - } else { - pos = 0 - } - - // Seek to the current chunk - if _, err = file.Seek(pos, io.SeekStart); err != nil { - return err - } - - var nRead int - // Read the chunk - if nRead, err = file.Read(chunk[unprocessed:]); err != nil { - return err - } - - // last unprocessed bytes + read bytes - curChunk := chunk[:unprocessed+nRead] - unprocessed = len(curChunk) - - // Split into lines - scanner := bufio.NewScanner(bytes.NewReader(curChunk)) - for !shouldStop() && scanner.Scan() { - lastLine = scanner.Bytes() - lines = append(lines, lastLine) - unprocessed -= len(lastLine) - } - if shouldStop() { - break - } - - // move unprocessed bytes to the beginning for next iteration - copy(chunk[:], curChunk[unprocessed:]) - } - - if days > 0 { - // truncate to the end of the log within last N days - return file.Truncate(pos) - } - - // write lines to buffer in reverse order - // since we read them backwards - var buf bytes.Buffer - for i := len(lines) - 1; i >= 0; i-- { - buf.Write(lines[i]) - buf.WriteRune('\n') - } - - return writeTruncate(file, &buf) -} - -func writeTruncate(file AccessLogIO, buf *bytes.Buffer) (err error) { - // Seek to beginning and truncate - if _, err := file.Seek(0, 0); err != nil { - return err - } - - buffered := bufio.NewWriter(file) - // Write buffer back to file - nWritten, err := buffered.Write(buf.Bytes()) - if err != nil { - return err - } - if err = buffered.Flush(); err != nil { - return err - } - - // Truncate file - if err = file.Truncate(int64(nWritten)); err != nil { - return err - } - - // check bytes written == buffer size - if nWritten != buf.Len() { - return io.ErrShortWrite - } - return -} - -func parseLogTime(line []byte) (t time.Time) { - if len(line) == 0 { - return - } - - var start, end int - const jsonStart = len(`{"time":"`) - const jsonEnd = jsonStart + len(LogTimeFormat) - - if len(line) == '{' { // possibly json log - start = jsonStart - end = jsonEnd - } else { // possibly common or combined format - // Format: - - [02/Jan/2006:15:04:05 -0700] ... - start = bytes.IndexRune(line, '[') - end = bytes.IndexRune(line[start+1:], ']') - if start == -1 || end == -1 || start >= end { - return - } - } - - timeStr := line[start+1 : end] - t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error - return -} diff --git a/internal/net/http/accesslog/retention_test.go b/internal/net/http/accesslog/retention_test.go deleted file mode 100644 index ea51fe2..0000000 --- a/internal/net/http/accesslog/retention_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package accesslog_test - -import ( - "testing" - "time" - - . "github.com/yusing/go-proxy/internal/net/http/accesslog" - "github.com/yusing/go-proxy/internal/task" - "github.com/yusing/go-proxy/internal/utils/strutils" - . "github.com/yusing/go-proxy/internal/utils/testing" -) - -func TestParseRetention(t *testing.T) { - tests := []struct { - input string - expected *Retention - shouldErr bool - }{ - {"30 days", &Retention{Days: 30}, false}, - {"2 weeks", &Retention{Days: 14}, false}, - {"last 5", &Retention{Last: 5}, false}, - {"invalid input", &Retention{}, true}, - } - - for _, test := range tests { - t.Run(test.input, func(t *testing.T) { - r := &Retention{} - err := r.Parse(test.input) - if !test.shouldErr { - ExpectNoError(t, err) - } else { - ExpectDeepEqual(t, r, test.expected) - } - }) - } -} - -func TestRetentionCommonFormat(t *testing.T) { - var file MockFile - logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{ - Format: FormatCommon, - BufferSize: 1024, - }) - for range 10 { - logger.Log(req, resp) - } - logger.Flush(true) - // test.Finish(nil) - - ExpectEqual(t, logger.Config().Retention, nil) - ExpectTrue(t, file.Len() > 0) - ExpectEqual(t, file.Count(), 10) - - t.Run("keep last", func(t *testing.T) { - logger.Config().Retention = strutils.MustParse[*Retention]("last 5") - ExpectEqual(t, logger.Config().Retention.Days, 0) - ExpectEqual(t, logger.Config().Retention.Last, 5) - ExpectNoError(t, logger.Rotate()) - ExpectEqual(t, file.Count(), 5) - }) - - _ = file.Truncate(0) - - timeNow := time.Now() - for i := range 10 { - logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time { - return timeNow.AddDate(0, 0, -i) - } - logger.Log(req, resp) - } - logger.Flush(true) - - // FIXME: keep days does not work - t.Run("keep days", func(t *testing.T) { - logger.Config().Retention = strutils.MustParse[*Retention]("3 days") - ExpectEqual(t, logger.Config().Retention.Days, 3) - ExpectEqual(t, logger.Config().Retention.Last, 0) - ExpectNoError(t, logger.Rotate()) - ExpectEqual(t, file.Count(), 3) - }) -} diff --git a/internal/net/http/common.go b/internal/net/http/common.go deleted file mode 100644 index 8671cf3..0000000 --- a/internal/net/http/common.go +++ /dev/null @@ -1,34 +0,0 @@ -package http - -import ( - "crypto/tls" - "net" - "net/http" - "time" -) - -var ( - defaultDialer = net.Dialer{ - Timeout: 60 * time.Second, - } - DefaultTransport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: defaultDialer.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - DisableCompression: true, // Prevent double compression - ResponseHeaderTimeout: 60 * time.Second, - WriteBufferSize: 16 * 1024, // 16KB - ReadBufferSize: 16 * 1024, // 16KB - } - DefaultTransportNoTLS = func() *http.Transport { - clone := DefaultTransport.Clone() - clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - return clone - }() -) - -const StaticFilePathPrefix = "/$gperrorpage/" diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go deleted file mode 100644 index a505ff4..0000000 --- a/internal/net/http/middleware/forward_auth.go +++ /dev/null @@ -1,221 +0,0 @@ -// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go) -// Copyright (c) 2020-2024 Traefik Labs -// Copyright (c) 2024 yusing - -package middleware - -import ( - "io" - "net" - "net/http" - "slices" - "strings" - "time" - - gphttp "github.com/yusing/go-proxy/internal/net/http" - F "github.com/yusing/go-proxy/internal/utils/functional" -) - -type ( - forwardAuth struct { - ForwardAuthOpts - Tracer - reqCookiesMap F.Map[*http.Request, []*http.Cookie] - } - ForwardAuthOpts struct { - Address string `validate:"url,required"` - TrustForwardHeader bool - AuthResponseHeaders []string - AddAuthCookiesToResponse []string - } -) - -var ForwardAuth = NewMiddleware[forwardAuth]() - -var faHTTPClient = &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(r *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, -} - -// setup implements MiddlewareWithSetup. -func (fa *forwardAuth) setup() { - fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]() -} - -// before implements RequestModifier. -func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) { - gphttp.RemoveHop(req.Header) - - // Construct original URL for the redirect - scheme := "http" - if req.TLS != nil { - scheme = "https" - } - originalURL := scheme + "://" + req.Host + req.RequestURI - - url := fa.Address - faReq, err := http.NewRequestWithContext( - req.Context(), - http.MethodGet, - url, - nil, - ) - if err != nil { - fa.AddTracef("new request err to %s", url).WithError(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - gphttp.CopyHeader(faReq.Header, req.Header) - gphttp.RemoveHop(faReq.Header) - - faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders) - fa.setAuthHeaders(req, faReq) - // Set headers needed by Authentik - faReq.Header.Set("X-Original-Url", originalURL) - fa.AddTraceRequest("forward auth request", faReq) - - faResp, err := faHTTPClient.Do(faReq) - if err != nil { - fa.AddTracef("failed to call %s", url).WithError(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer faResp.Body.Close() - - body, err := io.ReadAll(faResp.Body) - if err != nil { - fa.AddTracef("failed to read response body from %s", url).WithError(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices { - fa.AddTraceResponse("forward auth response", faResp) - gphttp.CopyHeader(w.Header(), faResp.Header) - gphttp.RemoveHop(w.Header()) - - redirectURL, err := faResp.Location() - if err != nil { - fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp) - w.WriteHeader(http.StatusInternalServerError) - return - } else if redirectURL.String() != "" { - w.Header().Set("Location", redirectURL.String()) - fa.AddTracef("%s", "redirect to "+redirectURL.String()) - } - - w.WriteHeader(faResp.StatusCode) - - if _, err = w.Write(body); err != nil { - fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp) - } - return - } - - for _, key := range fa.AuthResponseHeaders { - key := http.CanonicalHeaderKey(key) - req.Header.Del(key) - if len(faResp.Header[key]) > 0 { - req.Header[key] = append([]string(nil), faResp.Header[key]...) - } - } - - req.RequestURI = req.URL.RequestURI() - - authCookies := faResp.Cookies() - - if len(authCookies) > 0 { - fa.reqCookiesMap.Store(req, authCookies) - } - return true -} - -// modifyResponse implements ResponseModifier. -func (fa *forwardAuth) modifyResponse(resp *http.Response) error { - if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok { - fa.setAuthCookies(resp, cookies) - fa.reqCookiesMap.Delete(resp.Request) - } - return nil -} - -func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) { - if len(fa.AddAuthCookiesToResponse) == 0 { - return - } - - cookies := resp.Cookies() - resp.Header.Del("Set-Cookie") - - for _, cookie := range cookies { - if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) { - // this cookie is not an auth cookie, so add it back - resp.Header.Add("Set-Cookie", cookie.String()) - } - } - - for _, cookie := range authCookies { - if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) { - // this cookie is an auth cookie, so add to resp - resp.Header.Add("Set-Cookie", cookie.String()) - } - } -} - -func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) { - if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - if fa.TrustForwardHeader { - if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok { - clientIP = strings.Join(prior, ", ") + ", " + clientIP - } - } - faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP) - } - - xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod) - switch { - case xMethod != "" && fa.TrustForwardHeader: - faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod) - case req.Method != "": - faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method) - default: - faReq.Header.Del(gphttp.HeaderXForwardedMethod) - } - - xfp := req.Header.Get(gphttp.HeaderXForwardedProto) - switch { - case xfp != "" && fa.TrustForwardHeader: - faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp) - case req.TLS != nil: - faReq.Header.Set(gphttp.HeaderXForwardedProto, "https") - default: - faReq.Header.Set(gphttp.HeaderXForwardedProto, "http") - } - - if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader { - faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp) - } - - xfh := req.Header.Get(gphttp.HeaderXForwardedHost) - switch { - case xfh != "" && fa.TrustForwardHeader: - faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh) - case req.Host != "": - faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host) - default: - faReq.Header.Del(gphttp.HeaderXForwardedHost) - } - - xfURI := req.Header.Get(gphttp.HeaderXForwardedURI) - switch { - case xfURI != "" && fa.TrustForwardHeader: - faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI) - case req.URL.RequestURI() != "": - faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI()) - default: - faReq.Header.Del(gphttp.HeaderXForwardedURI) - } -}