refactor: organize code (#90)

* fix: improved sync.Pool handling

* refactor: ping if-flow and remove timeout

* refactor: enhance favicon fetching with context support and improve cache management

- Added context support to favicon fetching functions to handle timeouts and cancellations.
- Improved cache entry structure to include content type and utilize atomic values for last access time.
- Implemented maximum cache size and entry limits to optimize memory usage.
- Updated error handling for HTTP requests and refined the logic for managing redirects.

* fix: log formatting

* feat(pool): add checkExists method to debug build to detect unexpected behavior

* chore: cont. 0866feb

* refactor: unify route handling by consolidating route query methods with Pool

- Replaced direct calls to routequery with a new routes package for better organization and maintainability.
- Updated various components to utilize the new routes methods for fetching health information, homepage configurations, and route aliases.
- Enhanced the overall structure of the routing logic to improve clarity and reduce redundancy.

* chore: uncomment icon list cache code

* refactor: update task management code

- Rename needFinish to waitFinish
- Fixed some tasks not being waited they should be
- Adjusted mutex usage in the directory watcher to utilize read-write locks for improved concurrency management.

* refactor: enhance idlewatcher logging and exit handling

* fix(server): ensure HTTP handler is set only if initialized

* refactor(accesslog): replace JSON log entry struct with zerolog for improved logging efficiency, updated test

* refactor: remove test run code

---------

Co-authored-by: yusing <yusing@6uo.me>
This commit is contained in:
Yuzerion 2025-04-17 15:30:05 +08:00 committed by GitHub
parent a35ac33bd5
commit 04f806239d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 511 additions and 448 deletions

View file

@ -10,7 +10,7 @@ var Agents = agents{pool.New[*AgentConfig]("agents")}
func (agents agents) Get(agentAddrOrDockerHost string) (*AgentConfig, bool) { func (agents agents) Get(agentAddrOrDockerHost string) (*AgentConfig, bool) {
if !IsDockerHostAgent(agentAddrOrDockerHost) { if !IsDockerHostAgent(agentAddrOrDockerHost) {
return agents.Base().Load(agentAddrOrDockerHost) return agents.Get(agentAddrOrDockerHost)
} }
return agents.Base().Load(GetAgentAddrFromDockerHost(agentAddrOrDockerHost)) return agents.Get(GetAgentAddrFromDockerHost(agentAddrOrDockerHost))
} }

View file

@ -18,7 +18,7 @@ import (
"github.com/yusing/go-proxy/internal/metrics/systeminfo" "github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/internal/metrics/uptime" "github.com/yusing/go-proxy/internal/metrics/uptime"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
"github.com/yusing/go-proxy/internal/route/routes/routequery" "github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/migrations" "github.com/yusing/go-proxy/migrations"
"github.com/yusing/go-proxy/pkg" "github.com/yusing/go-proxy/pkg"
@ -124,7 +124,7 @@ func main() {
switch args.Command { switch args.Command {
case common.CommandListRoutes: case common.CommandListRoutes:
cfg.StartProxyProviders() cfg.StartProxyProviders()
printJSON(routequery.RoutesByAlias()) printJSON(routes.ByAlias())
return return
case common.CommandListConfigs: case common.CommandListConfigs:
printJSON(cfg.Value()) printJSON(cfg.Value())

View file

@ -36,7 +36,7 @@ func GetFavIcon(w http.ResponseWriter, req *http.Request) {
gphttp.ClientError(w, err, http.StatusBadRequest) gphttp.ClientError(w, err, http.StatusBadRequest)
return return
} }
fetchResult := homepage.FetchFavIconFromURL(&iconURL) fetchResult := homepage.FetchFavIconFromURL(req.Context(), &iconURL)
if !fetchResult.OK() { if !fetchResult.OK() {
http.Error(w, fetchResult.ErrMsg, fetchResult.StatusCode) http.Error(w, fetchResult.ErrMsg, fetchResult.StatusCode)
return return
@ -47,7 +47,7 @@ func GetFavIcon(w http.ResponseWriter, req *http.Request) {
} }
// try with route.Icon // try with route.Icon
r, ok := routes.GetHTTPRoute(alias) r, ok := routes.HTTP.Get(alias)
if !ok { if !ok {
gphttp.ClientError(w, errors.New("no such route"), http.StatusNotFound) gphttp.ClientError(w, errors.New("no such route"), http.StatusNotFound)
return return
@ -59,7 +59,7 @@ func GetFavIcon(w http.ResponseWriter, req *http.Request) {
if hp.Icon.IconSource == homepage.IconSourceRelative { if hp.Icon.IconSource == homepage.IconSourceRelative {
result = homepage.FindIcon(req.Context(), r, hp.Icon.Value) result = homepage.FindIcon(req.Context(), r, hp.Icon.Value)
} else { } else {
result = homepage.FetchFavIconFromURL(hp.Icon) result = homepage.FetchFavIconFromURL(req.Context(), hp.Icon)
} }
} else { } else {
// try extract from "link[rel=icon]" // try extract from "link[rel=icon]"

View file

@ -5,9 +5,9 @@ import (
"time" "time"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/route/routes/routequery" "github.com/yusing/go-proxy/internal/route/routes"
) )
func Health(w http.ResponseWriter, r *http.Request) { func Health(w http.ResponseWriter, r *http.Request) {
gpwebsocket.DynamicJSONHandler(w, r, routequery.HealthMap, 1*time.Second) gpwebsocket.DynamicJSONHandler(w, r, routes.HealthMap, 1*time.Second)
} }

View file

@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/homepage" "github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
"github.com/yusing/go-proxy/internal/route/routes/routequery" "github.com/yusing/go-proxy/internal/route/routes"
route "github.com/yusing/go-proxy/internal/route/types" route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -45,7 +45,7 @@ func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
gphttp.RespondJSON(w, r, route) gphttp.RespondJSON(w, r, route)
} }
case ListRoutes: case ListRoutes:
gphttp.RespondJSON(w, r, routequery.RoutesByAlias(route.RouteType(r.FormValue("type")))) gphttp.RespondJSON(w, r, routes.ByAlias(route.RouteType(r.FormValue("type"))))
case ListFiles: case ListFiles:
listFiles(w, r) listFiles(w, r)
case ListMiddlewares: case ListMiddlewares:
@ -55,11 +55,11 @@ func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
case ListMatchDomains: case ListMatchDomains:
gphttp.RespondJSON(w, r, cfg.Value().MatchDomains) gphttp.RespondJSON(w, r, cfg.Value().MatchDomains)
case ListHomepageConfig: case ListHomepageConfig:
gphttp.RespondJSON(w, r, routequery.HomepageConfig(r.FormValue("category"), r.FormValue("provider"))) gphttp.RespondJSON(w, r, routes.HomepageConfig(r.FormValue("category"), r.FormValue("provider")))
case ListRouteProviders: case ListRouteProviders:
gphttp.RespondJSON(w, r, cfg.RouteProviderList()) gphttp.RespondJSON(w, r, cfg.RouteProviderList())
case ListHomepageCategories: case ListHomepageCategories:
gphttp.RespondJSON(w, r, routequery.HomepageCategories()) gphttp.RespondJSON(w, r, routes.HomepageCategories())
case ListIcons: case ListIcons:
limit, err := strconv.Atoi(r.FormValue("limit")) limit, err := strconv.Atoi(r.FormValue("limit"))
if err != nil { if err != nil {
@ -83,9 +83,9 @@ func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
// otherwise, return a single Route with alias which or nil if not found. // otherwise, return a single Route with alias which or nil if not found.
func listRoute(which string) any { func listRoute(which string) any {
if which == "" || which == "all" { if which == "" || which == "all" {
return routequery.RoutesByAlias() return routes.ByAlias()
} }
routes := routequery.RoutesByAlias() routes := routes.ByAlias()
route, ok := routes[which] route, ok := routes[which]
if !ok { if !ok {
return nil return nil

View file

@ -46,7 +46,7 @@ const (
) )
func initClientCleaner() { func initClientCleaner() {
cleaner := task.RootTask("docker_clients_cleaner") cleaner := task.RootTask("docker_clients_cleaner", false)
go func() { go func() {
ticker := time.NewTicker(cleanInterval) ticker := time.NewTicker(cleanInterval)
defer ticker.Stop() defer ticker.Stop()

View file

@ -12,7 +12,6 @@ import (
"github.com/yusing/go-proxy/internal/net/gphttp/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage" "github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
"github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/route/routes"
route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -20,7 +19,7 @@ import (
type Entrypoint struct { type Entrypoint struct {
middleware *middleware.Middleware middleware *middleware.Middleware
accessLogger *accesslog.AccessLogger accessLogger *accesslog.AccessLogger
findRouteFunc func(host string) (route.HTTPRoute, error) findRouteFunc func(host string) (routes.HTTPRoute, error)
} }
var ErrNoSuchRoute = errors.New("no such route") var ErrNoSuchRoute = errors.New("no such route")
@ -108,7 +107,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func findRouteAnyDomain(host string) (route.HTTPRoute, error) { func findRouteAnyDomain(host string) (routes.HTTPRoute, error) {
hostSplit := strutils.SplitRune(host, '.') hostSplit := strutils.SplitRune(host, '.')
target := hostSplit[0] target := hostSplit[0]
@ -118,19 +117,19 @@ func findRouteAnyDomain(host string) (route.HTTPRoute, error) {
return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, target) return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, target)
} }
func findRouteByDomains(domains []string) func(host string) (route.HTTPRoute, error) { func findRouteByDomains(domains []string) func(host string) (routes.HTTPRoute, error) {
return func(host string) (route.HTTPRoute, error) { return func(host string) (routes.HTTPRoute, error) {
for _, domain := range domains { for _, domain := range domains {
if strings.HasSuffix(host, domain) { if strings.HasSuffix(host, domain) {
target := strings.TrimSuffix(host, domain) target := strings.TrimSuffix(host, domain)
if r, ok := routes.GetHTTPRoute(target); ok { if r, ok := routes.HTTP.Get(target); ok {
return r, nil return r, nil
} }
} }
} }
// fallback to exact match // fallback to exact match
if r, ok := routes.GetHTTPRoute(host); ok { if r, ok := routes.HTTP.Get(host); ok {
return r, nil return r, nil
} }
return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host) return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host)

View file

@ -8,21 +8,29 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
var ( var ep = NewEntrypoint()
r route.ReveseProxyRoute
ep = NewEntrypoint() func addRoute(alias string) *route.ReveseProxyRoute {
) r := &route.ReveseProxyRoute{
Route: &route.Route{
Alias: alias,
},
}
routes.HTTP.Add(r)
return r
}
func run(t *testing.T, match []string, noMatch []string) { func run(t *testing.T, match []string, noMatch []string) {
t.Helper() t.Helper()
t.Cleanup(routes.TestClear) t.Cleanup(routes.Clear)
t.Cleanup(func() { ep.SetFindRouteDomains(nil) }) t.Cleanup(func() { ep.SetFindRouteDomains(nil) })
for _, test := range match { for _, test := range match {
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
r := addRoute(test)
found, err := ep.findRouteFunc(test) found, err := ep.findRouteFunc(test)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectTrue(t, found == &r) ExpectTrue(t, found == r)
}) })
} }
@ -35,7 +43,7 @@ func run(t *testing.T, match []string, noMatch []string) {
} }
func TestFindRouteAnyDomain(t *testing.T) { func TestFindRouteAnyDomain(t *testing.T) {
routes.SetHTTPRoute("app1", &r) addRoute("app1")
tests := []string{ tests := []string{
"app1.com", "app1.com",
@ -66,7 +74,7 @@ func TestFindRouteExactHostMatch(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
routes.SetHTTPRoute(test, &r) addRoute(test)
} }
run(t, tests, testsNoMatch) run(t, tests, testsNoMatch)
@ -78,7 +86,7 @@ func TestFindRouteByDomains(t *testing.T) {
".sub.domain.com", ".sub.domain.com",
}) })
routes.SetHTTPRoute("app1", &r) addRoute("app1")
tests := []string{ tests := []string{
"app1.domain.com", "app1.domain.com",
@ -103,7 +111,7 @@ func TestFindRouteByDomainsExactMatch(t *testing.T) {
".sub.domain.com", ".sub.domain.com",
}) })
routes.SetHTTPRoute("app1.foo.bar", &r) addRoute("app1")
tests := []string{ tests := []string{
"app1.foo.bar", // exact match "app1.foo.bar", // exact match

View file

@ -6,14 +6,14 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
) )
func log(_ string, err error, level zerolog.Level, logger ...*zerolog.Logger) { func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) {
var l *zerolog.Logger var l *zerolog.Logger
if len(logger) > 0 { if len(logger) > 0 {
l = logger[0] l = logger[0]
} else { } else {
l = logging.GetLogger() l = logging.GetLogger()
} }
l.WithLevel(level).Msg(err.Error()) l.WithLevel(level).Msg(New(highlight(msg)).With(err).Error())
} }
func LogFatal(msg string, err error, logger ...*zerolog.Logger) { func LogFatal(msg string, err error, logger ...*zerolog.Logger) {

View file

@ -7,15 +7,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMultiline(t *testing.T) {
multiline := Multiline()
multiline.Addf("line 1 %s", "test")
multiline.Adds("line 2")
multiline.AddLines([]any{1, "2", 3.0, net.IPv4(127, 0, 0, 1)})
t.Error(New("result").With(multiline))
t.Error(multiline.Subject("subject").Withf("inner"))
}
func TestWrapMultiline(t *testing.T) { func TestWrapMultiline(t *testing.T) {
multiline := Multiline() multiline := Multiline()
var wrapper error = wrap(multiline) var wrapper error = wrap(multiline)

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"slices"
"strings" "strings"
"time" "time"
@ -24,8 +25,10 @@ type FetchResult struct {
contentType string contentType string
} }
const faviconFetchTimeout = 3 * time.Second
func (res *FetchResult) OK() bool { func (res *FetchResult) OK() bool {
return res.Icon != nil return len(res.Icon) > 0
} }
func (res *FetchResult) ContentType() string { func (res *FetchResult) ContentType() string {
@ -40,39 +43,55 @@ func (res *FetchResult) ContentType() string {
const maxRedirectDepth = 5 const maxRedirectDepth = 5
func FetchFavIconFromURL(iconURL *IconURL) *FetchResult { func FetchFavIconFromURL(ctx context.Context, iconURL *IconURL) *FetchResult {
switch iconURL.IconSource { switch iconURL.IconSource {
case IconSourceAbsolute: case IconSourceAbsolute:
return fetchIconAbsolute(iconURL.URL()) return fetchIconAbsolute(ctx, iconURL.URL())
case IconSourceRelative: case IconSourceRelative:
return &FetchResult{StatusCode: http.StatusBadRequest, ErrMsg: "unexpected relative icon"} return &FetchResult{StatusCode: http.StatusBadRequest, ErrMsg: "unexpected relative icon"}
case IconSourceWalkXCode, IconSourceSelfhSt: case IconSourceWalkXCode, IconSourceSelfhSt:
return fetchKnownIcon(iconURL) return fetchKnownIcon(ctx, iconURL)
} }
return &FetchResult{StatusCode: http.StatusBadRequest, ErrMsg: "invalid icon source"} return &FetchResult{StatusCode: http.StatusBadRequest, ErrMsg: "invalid icon source"}
} }
func fetchIconAbsolute(url string) *FetchResult { func fetchIconAbsolute(ctx context.Context, url string) *FetchResult {
if result := loadIconCache(url); result != nil { if result := loadIconCache(url); result != nil {
return result return result
} }
resp, err := gphttp.Get(url) req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil || resp.StatusCode != http.StatusOK { if err != nil {
if err == nil { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
err = errors.New(resp.Status) return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "request timeout"}
} }
return &FetchResult{StatusCode: http.StatusInternalServerError, ErrMsg: err.Error()}
}
resp, err := gphttp.Do(req)
if err == nil {
defer resp.Body.Close()
}
if err != nil || resp.StatusCode != http.StatusOK {
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "connection error"} return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "connection error"}
} }
defer resp.Body.Close()
icon, err := io.ReadAll(resp.Body) icon, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return &FetchResult{StatusCode: http.StatusInternalServerError, ErrMsg: "internal error"} return &FetchResult{StatusCode: http.StatusInternalServerError, ErrMsg: "internal error"}
} }
storeIconCache(url, icon) if len(icon) == 0 {
return &FetchResult{Icon: icon} return &FetchResult{StatusCode: http.StatusNotFound, ErrMsg: "empty icon"}
}
res := &FetchResult{Icon: icon}
if contentType := resp.Header.Get("Content-Type"); contentType != "" {
res.contentType = contentType
}
// else leave it empty
storeIconCache(url, res)
return res
} }
var nameSanitizer = strings.NewReplacer( var nameSanitizer = strings.NewReplacer(
@ -86,44 +105,53 @@ func sanitizeName(name string) string {
return strings.ToLower(nameSanitizer.Replace(name)) return strings.ToLower(nameSanitizer.Replace(name))
} }
func fetchKnownIcon(url *IconURL) *FetchResult { func fetchKnownIcon(ctx context.Context, url *IconURL) *FetchResult {
// if icon isn't in the list, no need to fetch // if icon isn't in the list, no need to fetch
if !url.HasIcon() { if !url.HasIcon() {
return &FetchResult{StatusCode: http.StatusNotFound, ErrMsg: "no such icon"} return &FetchResult{StatusCode: http.StatusNotFound, ErrMsg: "no such icon"}
} }
return fetchIconAbsolute(url.URL()) return fetchIconAbsolute(ctx, url.URL())
} }
func fetchIcon(filetype, filename string) *FetchResult { func fetchIcon(ctx context.Context, filetype, filename string) *FetchResult {
result := fetchKnownIcon(NewSelfhStIconURL(filename, filetype)) result := fetchKnownIcon(ctx, NewSelfhStIconURL(filename, filetype))
if result.Icon == nil { if result.OK() {
return result return result
} }
return fetchKnownIcon(NewWalkXCodeIconURL(filename, filetype)) return fetchKnownIcon(ctx, NewWalkXCodeIconURL(filename, filetype))
} }
func FindIcon(ctx context.Context, r route, uri string) *FetchResult { func FindIcon(ctx context.Context, r route, uri string) *FetchResult {
key := routeKey(r) if result := loadIconCache(r.Key()); result != nil {
if result := loadIconCache(key); result != nil {
return result return result
} }
result := fetchIcon("png", sanitizeName(r.Reference())) result := fetchIcon(ctx, "png", sanitizeName(r.Reference()))
if !result.OK() { if !result.OK() {
if r, ok := r.(httpRoute); ok { if r, ok := r.(httpRoute); ok {
// fallback to parse html // fallback to parse html
result = findIconSlow(ctx, r, uri, 0) result = findIconSlow(ctx, r, uri, nil)
} }
} }
if result.OK() { if result.OK() {
storeIconCache(key, result.Icon) storeIconCache(r.Key(), result)
} }
return result return result
} }
func findIconSlow(ctx context.Context, r httpRoute, uri string, depth int) *FetchResult { func findIconSlow(ctx context.Context, r httpRoute, uri string, stack []string) *FetchResult {
ctx, cancel := context.WithTimeoutCause(ctx, 3*time.Second, errors.New("favicon request timeout")) select {
case <-ctx.Done():
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "request timeout"}
default:
}
if len(stack) > maxRedirectDepth {
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "too many redirects"}
}
ctx, cancel := context.WithTimeoutCause(ctx, faviconFetchTimeout, errors.New("favicon request timeout"))
defer cancel() defer cancel()
newReq, err := http.NewRequestWithContext(ctx, "GET", r.TargetURL().String(), nil) newReq, err := http.NewRequestWithContext(ctx, "GET", r.TargetURL().String(), nil)
@ -149,14 +177,13 @@ func findIconSlow(ctx context.Context, r httpRoute, uri string, depth int) *Fetc
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "connection error"} return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "connection error"}
default: default:
if loc := c.Header().Get("Location"); loc != "" { if loc := c.Header().Get("Location"); loc != "" {
if depth > maxRedirectDepth {
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "too many redirects"}
}
loc = strutils.SanitizeURI(loc) loc = strutils.SanitizeURI(loc)
if loc == "/" || loc == newReq.URL.Path { if loc == "/" || loc == newReq.URL.Path || slices.Contains(stack, loc) {
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "circular redirect"} return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "circular redirect"}
} }
return findIconSlow(ctx, r, loc, depth+1) // append current path to stack
// handles redirect to the same path with different query
return findIconSlow(ctx, r, loc, append(stack, newReq.URL.Path))
} }
} }
return &FetchResult{StatusCode: c.status, ErrMsg: "upstream error: " + string(c.data)} return &FetchResult{StatusCode: c.status, ErrMsg: "upstream error: " + string(c.data)}
@ -188,8 +215,8 @@ func findIconSlow(ctx context.Context, r httpRoute, uri string, depth int) *Fetc
} }
switch { switch {
case strings.HasPrefix(href, "http://"), strings.HasPrefix(href, "https://"): case strings.HasPrefix(href, "http://"), strings.HasPrefix(href, "https://"):
return fetchIconAbsolute(href) return fetchIconAbsolute(ctx, href)
default: default:
return findIconSlow(ctx, r, href, 0) return findIconSlow(ctx, r, href, append(stack, newReq.URL.Path))
} }
} }

View file

@ -1,6 +1,7 @@
package homepage package homepage
import ( import (
"encoding/base64"
"sync" "sync"
"time" "time"
@ -10,11 +11,13 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/atomic"
) )
type cacheEntry struct { type cacheEntry struct {
Icon []byte `json:"icon"` Icon []byte `json:"icon"`
LastAccess time.Time `json:"lastAccess"` ContentType string `json:"content_type"`
LastAccess atomic.Value[time.Time] `json:"last_access"`
} }
// cache key can be absolute url or route name. // cache key can be absolute url or route name.
@ -25,7 +28,9 @@ var (
const ( const (
iconCacheTTL = 3 * 24 * time.Hour iconCacheTTL = 3 * 24 * time.Hour
cleanUpInterval = time.Hour cleanUpInterval = time.Minute
maxCacheSize = 1024 * 1024 // 1MB
maxCacheEntries = 100
) )
func InitIconCache() { func InitIconCache() {
@ -77,19 +82,29 @@ func pruneExpiredIconCache() {
nPruned++ nPruned++
} }
} }
if len(iconCache) > maxCacheEntries {
newIconCache := make(map[string]*cacheEntry, maxCacheEntries)
i := 0
for key, icon := range iconCache {
if i == maxCacheEntries {
break
}
if !icon.IsExpired() {
newIconCache[key] = icon
i++
}
}
iconCache = newIconCache
}
if nPruned > 0 { if nPruned > 0 {
logging.Info().Int("pruned", nPruned).Msg("pruned expired icon cache") logging.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
} }
} }
func routeKey(r route) string {
return r.ProviderName() + ":" + r.TargetName()
}
func PruneRouteIconCache(route route) { func PruneRouteIconCache(route route) {
iconCacheMu.Lock() iconCacheMu.Lock()
defer iconCacheMu.Unlock() defer iconCacheMu.Unlock()
delete(iconCache, routeKey(route)) delete(iconCache, route.Key())
} }
func loadIconCache(key string) *FetchResult { func loadIconCache(key string) *FetchResult {
@ -97,41 +112,49 @@ func loadIconCache(key string) *FetchResult {
defer iconCacheMu.RUnlock() defer iconCacheMu.RUnlock()
icon, ok := iconCache[key] icon, ok := iconCache[key]
if ok && icon != nil { if ok && len(icon.Icon) > 0 {
logging.Debug(). logging.Debug().
Str("key", key). Str("key", key).
Msg("icon found in cache") Msg("icon found in cache")
icon.LastAccess = time.Now() icon.LastAccess.Store(time.Now())
return &FetchResult{Icon: icon.Icon} return &FetchResult{Icon: icon.Icon, contentType: icon.ContentType}
} }
return nil return nil
} }
func storeIconCache(key string, icon []byte) { func storeIconCache(key string, result *FetchResult) {
icon := result.Icon
if len(icon) > maxCacheSize {
logging.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
return
}
iconCacheMu.Lock() iconCacheMu.Lock()
defer iconCacheMu.Unlock() defer iconCacheMu.Unlock()
iconCache[key] = &cacheEntry{Icon: icon, LastAccess: time.Now()} entry := &cacheEntry{Icon: icon, ContentType: result.contentType}
entry.LastAccess.Store(time.Now())
iconCache[key] = entry
logging.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
} }
func (e *cacheEntry) IsExpired() bool { func (e *cacheEntry) IsExpired() bool {
return time.Since(e.LastAccess) > iconCacheTTL return time.Since(e.LastAccess.Load()) > iconCacheTTL
} }
func (e *cacheEntry) UnmarshalJSON(data []byte) error { func (e *cacheEntry) UnmarshalJSON(data []byte) error {
attempt := struct { // check if data is json
Icon []byte `json:"icon"` if json.Valid(data) {
LastAccess time.Time `json:"lastAccess"` err := json.Unmarshal(data, &e)
}{} // return only if unmarshal is successful
err := json.Unmarshal(data, &attempt) // otherwise fallback to base64
if err == nil { if err == nil {
e.Icon = attempt.Icon return nil
e.LastAccess = attempt.LastAccess }
return nil
} }
// fallback to bytes // fallback to base64
err = json.Unmarshal(data, &e.Icon) icon, err := base64.StdEncoding.DecodeString(string(data))
if err == nil { if err == nil {
e.LastAccess = time.Now() e.Icon = icon
e.LastAccess.Store(time.Now())
return nil return nil
} }
return err return err

View file

@ -60,15 +60,15 @@ func InitIconListCache() {
DisplayNames: make(ReferenceDisplayNameMap), DisplayNames: make(ReferenceDisplayNameMap),
IconList: []string{}, IconList: []string{},
} }
// err := utils.LoadJSONIfExist(common.IconListCachePath, iconsCache) err := utils.LoadJSONIfExist(common.IconListCachePath, iconsCache)
// if err != nil { if err != nil {
// logging.Error().Err(err).Msg("failed to load icon list cache config") logging.Error().Err(err).Msg("failed to load icon list cache config")
// } else if len(iconsCache.IconList) > 0 { } else if len(iconsCache.IconList) > 0 {
// logging.Info(). logging.Info().
// Int("icons", len(iconsCache.IconList)). Int("icons", len(iconsCache.IconList)).
// Int("display_names", len(iconsCache.DisplayNames)). Int("display_names", len(iconsCache.DisplayNames)).
// Msg("icon list cache loaded") Msg("icon list cache loaded")
// } }
} }
func ListAvailableIcons() (*Cache, error) { func ListAvailableIcons() (*Cache, error) {

View file

@ -3,10 +3,12 @@ package homepage
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/yusing/go-proxy/internal/utils/pool"
) )
type route interface { type route interface {
TargetName() string pool.Object
ProviderName() string ProviderName() string
Reference() string Reference() string
TargetURL() *url.URL TargetURL() *url.URL

View file

@ -13,7 +13,7 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
net "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"
route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/atomic" "github.com/yusing/go-proxy/internal/utils/atomic"
@ -80,7 +80,7 @@ var (
const reqTimeout = 3 * time.Second const reqTimeout = 3 * time.Second
// TODO: fix stream type // TODO: fix stream type
func NewWatcher(parent task.Parent, r route.Route) (*Watcher, error) { func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
cfg := r.IdlewatcherConfig() cfg := r.IdlewatcherConfig()
key := cfg.Key() key := cfg.Key()
@ -126,9 +126,9 @@ func NewWatcher(parent task.Parent, r route.Route) (*Watcher, error) {
Logger() Logger()
switch r := r.(type) { switch r := r.(type) {
case route.ReverseProxyRoute: case routes.ReverseProxyRoute:
w.rp = r.ReverseProxy() w.rp = r.ReverseProxy()
case route.StreamRoute: case routes.StreamRoute:
w.stream = r w.stream = r
default: default:
return nil, gperr.New("unexpected route type") return nil, gperr.New("unexpected route type")
@ -153,14 +153,14 @@ func NewWatcher(parent task.Parent, r route.Route) (*Watcher, error) {
w.state.Store(&containerState{status: status}) w.state.Store(&containerState{status: status})
w.task = parent.Subtask("idlewatcher."+r.TargetName(), true) w.task = parent.Subtask("idlewatcher."+r.Name(), true)
watcherMapMu.Lock() watcherMapMu.Lock()
defer watcherMapMu.Unlock() defer watcherMapMu.Unlock()
watcherMap[key] = w watcherMap[key] = w
go func() { go func() {
cause := w.watchUntilDestroy() cause := w.watchUntilDestroy()
if cause.Is(causeContainerDestroy) { if cause.Is(causeContainerDestroy) || cause.Is(task.ErrProgramExiting) {
watcherMapMu.Lock() watcherMapMu.Lock()
defer watcherMapMu.Unlock() defer watcherMapMu.Unlock()
delete(watcherMap, key) delete(watcherMap, key)
@ -173,7 +173,11 @@ func NewWatcher(parent task.Parent, r route.Route) (*Watcher, error) {
w.provider.Close() w.provider.Close()
w.task.Finish(cause) w.task.Finish(cause)
}() }()
w.l.Info().Msg("idlewatcher started") if exists {
w.l.Info().Msg("idlewatcher reloaded")
} else {
w.l.Info().Msg("idlewatcher started")
}
return w, nil return w, nil
} }

View file

@ -133,7 +133,7 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) {
} }
func (p *Poller[T, AggregateT]) Start() { func (p *Poller[T, AggregateT]) Start() {
t := task.RootTask("poller." + p.name) t := task.RootTask("poller."+p.name, true)
err := p.load() err := p.load()
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {

View file

@ -10,15 +10,14 @@ import (
"github.com/yusing/go-proxy/internal/metrics/period" "github.com/yusing/go-proxy/internal/metrics/period"
metricsutils "github.com/yusing/go-proxy/internal/metrics/utils" metricsutils "github.com/yusing/go-proxy/internal/metrics/utils"
"github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/route/routes/routequery"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/pkg/json" "github.com/yusing/go-proxy/pkg/json"
) )
type ( type (
StatusByAlias struct { StatusByAlias struct {
Map json.Map[*routequery.HealthInfoRaw] `json:"statuses"` Map json.Map[*routes.HealthInfoRaw] `json:"statuses"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp"`
} }
Aggregated = json.MapSlice[any] Aggregated = json.MapSlice[any]
) )
@ -27,7 +26,7 @@ var Poller = period.NewPoller("uptime", getStatuses, aggregateStatuses)
func getStatuses(ctx context.Context, _ *StatusByAlias) (*StatusByAlias, error) { func getStatuses(ctx context.Context, _ *StatusByAlias) (*StatusByAlias, error) {
return &StatusByAlias{ return &StatusByAlias{
Map: routequery.HealthInfo(), Map: routes.HealthInfo(),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
}, nil }, nil
} }
@ -111,7 +110,7 @@ func (rs RouteStatuses) aggregate(limit int, offset int) Aggregated {
"avg_latency": latency, "avg_latency": latency,
"statuses": statuses, "statuses": statuses,
} }
r, ok := routes.GetRoute(alias) r, ok := routes.HTTP.Get(alias)
if ok { if ok {
result[i]["display_name"] = r.HomepageConfig().Name result[i]["display_name"] = r.HomepageConfig().Name
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/synk"
) )
type ( type (
@ -20,7 +21,7 @@ type (
io AccessLogIO io AccessLogIO
buffered *bufio.Writer buffered *bufio.Writer
lineBufPool sync.Pool // buffer pool for formatting a single log line lineBufPool *synk.BytesPool // buffer pool for formatting a single log line
Formatter Formatter
} }
@ -78,7 +79,7 @@ func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *Acc
cfg.BufferSize = 4096 cfg.BufferSize = 4096
} }
l := &AccessLogger{ l := &AccessLogger{
task: parent.Subtask("accesslog"), task: parent.Subtask("accesslog."+io.Name(), true),
cfg: cfg, cfg: cfg,
io: io, io: io,
buffered: bufio.NewWriterSize(io, cfg.BufferSize), buffered: bufio.NewWriterSize(io, cfg.BufferSize),
@ -96,9 +97,7 @@ func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *Acc
panic("invalid access log format") panic("invalid access log format")
} }
l.lineBufPool.New = func() any { l.lineBufPool = synk.NewBytesPool(1024, synk.DefaultMaxBytes)
return bytes.NewBuffer(make([]byte, 0, 1024))
}
go l.start() go l.start()
return l return l
} }
@ -118,12 +117,11 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
return return
} }
line := l.lineBufPool.Get().(*bytes.Buffer) line := l.lineBufPool.Get()
line.Reset()
defer l.lineBufPool.Put(line) defer l.lineBufPool.Put(line)
l.Formatter.Format(line, req, res) l.Formatter.Format(bytes.NewBuffer(line), req, res)
line.WriteRune('\n') line = append(line, '\n')
l.write(line.Bytes()) l.write(line)
} }
func (l *AccessLogger) LogError(req *http.Request, err error) { func (l *AccessLogger) LogError(req *http.Request, err error) {

View file

@ -23,7 +23,7 @@ const (
referer = "https://www.google.com/" referer = "https://www.google.com/"
proto = "HTTP/1.1" proto = "HTTP/1.1"
ua = "Go-http-client/1.1" ua = "Go-http-client/1.1"
status = http.StatusOK status = http.StatusNotFound
contentLength = 100 contentLength = 100
method = http.MethodGet method = http.MethodGet
) )
@ -99,6 +99,25 @@ func TestAccessLoggerRedactQuery(t *testing.T) {
) )
} }
type JSONLogEntry struct {
Time string `json:"time"`
IP string `json:"ip"`
Method string `json:"method"`
Scheme string `json:"scheme"`
Host string `json:"host"`
URI string `json:"uri"`
Protocol string `json:"protocol"`
Status int `json:"status"`
Error string `json:"error,omitempty"`
ContentType string `json:"type"`
Size int64 `json:"size"`
Referer string `json:"referer"`
UserAgent string `json:"useragent"`
Query map[string][]string `json:"query,omitempty"`
Headers map[string][]string `json:"headers,omitempty"`
Cookies map[string]string `json:"cookies,omitempty"`
}
func getJSONEntry(t *testing.T, config *Config) JSONLogEntry { func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
t.Helper() t.Helper()
config.Format = FormatJSON config.Format = FormatJSON
@ -125,4 +144,7 @@ func TestAccessLoggerJSON(t *testing.T) {
ExpectEqual(t, entry.UserAgent, ua) ExpectEqual(t, entry.UserAgent, ua)
ExpectEqual(t, len(entry.Headers), 0) ExpectEqual(t, len(entry.Headers), 0)
ExpectEqual(t, len(entry.Cookies), 0) ExpectEqual(t, len(entry.Cookies), 0)
if status >= 400 {
ExpectEqual(t, entry.Error, http.StatusText(status))
}
} }

View file

@ -8,9 +8,8 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/pkg/json" "github.com/yusing/go-proxy/pkg/json"
"github.com/yusing/go-proxy/internal/logging"
) )
type ( type (
@ -20,25 +19,6 @@ type (
} }
CombinedFormatter struct{ CommonFormatter } CombinedFormatter struct{ CommonFormatter }
JSONFormatter struct{ CommonFormatter } JSONFormatter struct{ CommonFormatter }
JSONLogEntry struct {
Time string `json:"time"`
IP string `json:"ip"`
Method string `json:"method"`
Scheme string `json:"scheme"`
Host string `json:"host"`
URI string `json:"uri"`
Protocol string `json:"protocol"`
Status int `json:"status"`
Error string `json:"error,omitempty"`
ContentType string `json:"type"`
Size int64 `json:"size"`
Referer string `json:"referer"`
UserAgent string `json:"useragent"`
Query map[string][]string `json:"query,omitempty"`
Headers map[string][]string `json:"headers,omitempty"`
Cookies map[string]string `json:"cookies,omitempty"`
}
) )
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700" const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
@ -109,37 +89,36 @@ func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.
headers := f.cfg.Headers.ProcessHeaders(req.Header) headers := f.cfg.Headers.ProcessHeaders(req.Header)
headers.Del("Cookie") headers.Del("Cookie")
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies()) cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
contentType := res.Header.Get("Content-Type")
entry := JSONLogEntry{ queryBytes, _ := json.Marshal(query)
Time: f.GetTimeNow().Format(LogTimeFormat), headersBytes, _ := json.Marshal(headers)
IP: clientIP(req), cookiesBytes, _ := json.Marshal(cookies)
Method: req.Method,
Scheme: scheme(req), logger := zerolog.New(line).With().Logger()
Host: req.Host, event := logger.Info().
URI: requestURI(req.URL, query), Str("time", f.GetTimeNow().Format(LogTimeFormat)).
Protocol: req.Proto, Str("ip", clientIP(req)).
Status: res.StatusCode, Str("method", req.Method).
ContentType: res.Header.Get("Content-Type"), Str("scheme", scheme(req)).
Size: res.ContentLength, Str("host", req.Host).
Referer: req.Referer(), Str("uri", requestURI(req.URL, query)).
UserAgent: req.UserAgent(), Str("protocol", req.Proto).
Query: query, Int("status", res.StatusCode).
Headers: headers, Str("type", contentType).
Cookies: cookies, Int64("size", res.ContentLength).
} Str("referer", req.Referer()).
Str("useragent", req.UserAgent()).
RawJSON("query", queryBytes).
RawJSON("headers", headersBytes).
RawJSON("cookies", cookiesBytes)
if res.StatusCode >= 400 { if res.StatusCode >= 400 {
entry.Error = res.Status if res.Status != "" {
} event.Str("error", res.Status)
} else {
if entry.ContentType == "" { event.Str("error", http.StatusText(res.StatusCode))
// try to get content type from request }
entry.ContentType = req.Header.Get("Content-Type")
}
marshaller := json.NewEncoder(line)
err := marshaller.Encode(entry)
if err != nil {
logging.Err(err).Msg("failed to marshal json log")
} }
event.Send()
} }

View file

@ -24,4 +24,5 @@ var (
Get = httpClient.Get Get = httpClient.Get
Post = httpClient.Post Post = httpClient.Post
Head = httpClient.Head Head = httpClient.Head
Do = httpClient.Do
) )

View file

@ -1,6 +1,7 @@
package loadbalancer package loadbalancer
import ( import (
"fmt"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -10,8 +11,8 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/pool"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -30,7 +31,7 @@ type (
task *task.Task task *task.Task
pool Pool pool pool.Pool[Server]
poolMu sync.Mutex poolMu sync.Mutex
sumWeight Weight sumWeight Weight
@ -45,7 +46,7 @@ const maxWeight Weight = 100
func New(cfg *Config) *LoadBalancer { func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{ lb := &LoadBalancer{
Config: new(Config), Config: new(Config),
pool: types.NewServerPool(), pool: pool.New[Server]("loadbalancer." + cfg.Link),
l: logging.With().Str("name", cfg.Link).Logger(), l: logging.With().Str("name", cfg.Link).Logger(),
} }
lb.UpdateConfigIfNeeded(cfg) lb.UpdateConfigIfNeeded(cfg)
@ -55,16 +56,14 @@ func New(cfg *Config) *LoadBalancer {
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error { func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
lb.startTime = time.Now() lb.startTime = time.Now()
lb.task = parent.Subtask("loadbalancer."+lb.Link, false) lb.task = parent.Subtask("loadbalancer."+lb.Link, true)
parent.OnCancel("lb_remove_route", func() { lb.task.OnCancel("cleanup", func() {
routes.DeleteHTTPRoute(lb.Link)
})
lb.task.OnFinished("cleanup", func() {
if lb.impl != nil { if lb.impl != nil {
lb.pool.RangeAll(func(k string, v Server) { for _, srv := range lb.pool.Iter {
lb.impl.OnRemoveServer(v) lb.impl.OnRemoveServer(srv)
}) }
} }
lb.task.Finish(nil)
}) })
return nil return nil
} }
@ -90,9 +89,9 @@ func (lb *LoadBalancer) updateImpl() {
default: // should happen in test only default: // should happen in test only
lb.impl = lb.newRoundRobin() lb.impl = lb.newRoundRobin()
} }
lb.pool.RangeAll(func(_ string, srv Server) { for _, srv := range lb.pool.Iter {
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
}) }
} }
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) { func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
@ -124,12 +123,12 @@ func (lb *LoadBalancer) AddServer(srv Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
if lb.pool.Has(srv.Key()) { // FIXME: this should be a warning if old, ok := lb.pool.Get(srv.Key()); ok { // FIXME: this should be a warning
old, _ := lb.pool.Load(srv.Key())
lb.sumWeight -= old.Weight() lb.sumWeight -= old.Weight()
lb.impl.OnRemoveServer(old) lb.impl.OnRemoveServer(old)
lb.pool.Del(old)
} }
lb.pool.Store(srv.Key(), srv) lb.pool.Add(srv)
lb.sumWeight += srv.Weight() lb.sumWeight += srv.Weight()
lb.rebalance() lb.rebalance()
@ -145,11 +144,11 @@ func (lb *LoadBalancer) RemoveServer(srv Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
if !lb.pool.Has(srv.Key()) { if _, ok := lb.pool.Get(srv.Key()); !ok {
return return
} }
lb.pool.Delete(srv.Key()) lb.pool.Del(srv)
lb.sumWeight -= srv.Weight() lb.sumWeight -= srv.Weight()
lb.rebalance() lb.rebalance()
@ -178,15 +177,15 @@ func (lb *LoadBalancer) rebalance() {
if lb.sumWeight == 0 { // distribute evenly if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / Weight(poolSize) weightEach := maxWeight / Weight(poolSize)
remainder := maxWeight % Weight(poolSize) remainder := maxWeight % Weight(poolSize)
lb.pool.RangeAll(func(_ string, s Server) { for _, srv := range lb.pool.Iter {
w := weightEach w := weightEach
lb.sumWeight += weightEach lb.sumWeight += weightEach
if remainder > 0 { if remainder > 0 {
w++ w++
remainder-- remainder--
} }
s.SetWeight(w) srv.SetWeight(w)
}) }
return return
} }
@ -194,30 +193,29 @@ func (lb *LoadBalancer) rebalance() {
scaleFactor := float64(maxWeight) / float64(lb.sumWeight) scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0 lb.sumWeight = 0
lb.pool.RangeAll(func(_ string, s Server) { for _, srv := range lb.pool.Iter {
s.SetWeight(Weight(float64(s.Weight()) * scaleFactor)) srv.SetWeight(Weight(float64(srv.Weight()) * scaleFactor))
lb.sumWeight += s.Weight() lb.sumWeight += srv.Weight()
}) }
delta := maxWeight - lb.sumWeight delta := maxWeight - lb.sumWeight
if delta == 0 { if delta == 0 {
return return
} }
lb.pool.Range(func(_ string, s Server) bool { for _, srv := range lb.pool.Iter {
if delta == 0 { if delta == 0 {
return false break
} }
if delta > 0 { if delta > 0 {
s.SetWeight(s.Weight() + 1) srv.SetWeight(srv.Weight() + 1)
lb.sumWeight++ lb.sumWeight++
delta-- delta--
} else { } else {
s.SetWeight(s.Weight() - 1) srv.SetWeight(srv.Weight() - 1)
lb.sumWeight-- lb.sumWeight--
delta++ delta++
} }
return true }
})
} }
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
@ -242,13 +240,16 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// MarshalMap implements health.HealthMonitor. // MarshalMap implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalMap() map[string]any { func (lb *LoadBalancer) MarshalMap() map[string]any {
extra := make(map[string]any) extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v Server) { for _, srv := range lb.pool.Iter {
extra[v.Key()] = v extra[srv.Key()] = srv
}) }
status, numHealthy := lb.status()
return (&health.JSONRepresentation{ return (&health.JSONRepresentation{
Name: lb.Name(), Name: lb.Name(),
Status: lb.Status(), Status: status,
Detail: fmt.Sprintf("%d/%d servers are healthy", numHealthy, lb.pool.Size()),
Started: lb.startTime, Started: lb.startTime,
Uptime: lb.Uptime(), Uptime: lb.Uptime(),
Extra: map[string]any{ Extra: map[string]any{
@ -265,22 +266,26 @@ func (lb *LoadBalancer) Name() string {
// Status implements health.HealthMonitor. // Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() health.Status { func (lb *LoadBalancer) Status() health.Status {
status, _ := lb.status()
return status
}
func (lb *LoadBalancer) status() (status health.Status, numHealthy int) {
if lb.pool.Size() == 0 { if lb.pool.Size() == 0 {
return health.StatusUnknown return health.StatusUnknown, 0
} }
isHealthy := true // should be healthy if at least one server is healthy
lb.pool.Range(func(_ string, srv Server) bool { numHealthy = 0
if srv.Status().Bad() { for _, srv := range lb.pool.Iter {
isHealthy = false if srv.Status().Good() {
return false numHealthy++
} }
return true
})
if !isHealthy {
return health.StatusUnhealthy
} }
return health.StatusHealthy if numHealthy == 0 {
return health.StatusUnhealthy, numHealthy
}
return health.StatusHealthy, numHealthy
} }
// Uptime implements health.HealthMonitor. // Uptime implements health.HealthMonitor.
@ -291,9 +296,9 @@ func (lb *LoadBalancer) Uptime() time.Duration {
// Latency implements health.HealthMonitor. // Latency implements health.HealthMonitor.
func (lb *LoadBalancer) Latency() time.Duration { func (lb *LoadBalancer) Latency() time.Duration {
var sum time.Duration var sum time.Duration
lb.pool.RangeAll(func(_ string, srv Server) { for _, srv := range lb.pool.Iter {
sum += srv.Latency() sum += srv.Latency()
}) }
return sum return sum
} }
@ -304,10 +309,10 @@ func (lb *LoadBalancer) String() string {
func (lb *LoadBalancer) availServers() []Server { func (lb *LoadBalancer) availServers() []Server {
avail := make([]Server, 0, lb.pool.Size()) avail := make([]Server, 0, lb.pool.Size())
lb.pool.RangeAll(func(_ string, srv Server) { for _, srv := range lb.pool.Iter {
if srv.Status().Good() { if srv.Status().Good() {
avail = append(avail, srv) avail = append(avail, srv)
} }
}) }
return avail return avail
} }

View file

@ -7,7 +7,6 @@ import (
type ( type (
Server = types.Server Server = types.Server
Servers = []types.Server Servers = []types.Server
Pool = types.Pool
Weight = types.Weight Weight = types.Weight
Config = types.Config Config = types.Config
Mode = types.Mode Mode = types.Mode

View file

@ -6,7 +6,6 @@ import (
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -32,12 +31,8 @@ type (
SetWeight(weight Weight) SetWeight(weight Weight)
TryWake() error TryWake() error
} }
Pool = F.Map[string, Server]
) )
var NewServerPool = F.NewMap[Pool]
func NewServer(name string, url *url.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server { func NewServer(name string, url *url.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
srv := &server{ srv := &server{
name: name, name: name,

View file

@ -96,7 +96,9 @@ func (s *Server) Start(parent task.Parent) {
TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig), TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig),
} }
Start(subtask, h3, &s.l) Start(subtask, h3, &s.l)
s.http.Handler = advertiseHTTP3(s.http.Handler, h3) if s.http != nil {
s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
}
s.https.Handler = advertiseHTTP3(s.https.Handler, h3) s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
} }

View file

@ -94,24 +94,20 @@ func Ping(ctx context.Context, ip net.IP) (bool, error) {
} }
} }
var pingDialer = &net.Dialer{
Timeout: 2 * time.Second,
}
// PingWithTCPFallback pings the IP address using ICMP and TCP fallback. // PingWithTCPFallback pings the IP address using ICMP and TCP fallback.
// //
// If the ICMP ping fails due to permission error, it will try to connect to the specified port. // If the ICMP ping fails due to permission error, it will try to connect to the specified port.
func PingWithTCPFallback(ctx context.Context, ip net.IP, port int) (bool, error) { func PingWithTCPFallback(ctx context.Context, ip net.IP, port int) (bool, error) {
ok, err := Ping(ctx, ip) ok, err := Ping(ctx, ip)
if err != nil { if err == nil {
if !errors.Is(err, os.ErrPermission) {
return false, err
}
} else {
return ok, nil return ok, nil
} }
if !errors.Is(err, os.ErrPermission) {
return false, err
}
conn, err := pingDialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", ip, port)) var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", ip, port))
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -1,13 +1,11 @@
package gpnet package gpnet
import ( import (
"fmt"
"net" "net"
) )
type ( type (
Stream interface { Stream interface {
fmt.Stringer
StreamListener StreamListener
Setup() error Setup() error
Handle(conn StreamConn) error Handle(conn StreamConn) error

View file

@ -33,7 +33,7 @@ const dispatchErr = "notification dispatch error"
func StartNotifDispatcher(parent task.Parent) *Dispatcher { func StartNotifDispatcher(parent task.Parent) *Dispatcher {
dispatcher = &Dispatcher{ dispatcher = &Dispatcher{
task: parent.Subtask("notification"), task: parent.Subtask("notification", true),
logCh: make(chan *LogMessage), logCh: make(chan *LogMessage),
providers: F.NewSet[Provider](), providers: F.NewSet[Provider](),
} }
@ -86,7 +86,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) {
if true { if true {
return return
} }
task := disp.task.Subtask("dispatcher") task := disp.task.Subtask("dispatcher", true)
defer task.Finish("notif dispatched") defer task.Finish("notif dispatched")
errs := gperr.NewBuilder(dispatchErr) errs := gperr.NewBuilder(dispatchErr)

View file

@ -57,7 +57,7 @@ func NewFileServer(base *Route) (*FileServer, gperr.Error) {
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (s *FileServer) Start(parent task.Parent) gperr.Error { func (s *FileServer) Start(parent task.Parent) gperr.Error {
s.task = parent.Subtask("fileserver."+s.TargetName(), false) s.task = parent.Subtask("fileserver."+s.Name(), false)
pathPatterns := s.PathPatterns pathPatterns := s.PathPatterns
switch { switch {
@ -92,7 +92,7 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
} }
if common.PrometheusEnabled { if common.PrometheusEnabled {
metricsLogger := metricslogger.NewMetricsLogger(s.TargetName()) metricsLogger := metricslogger.NewMetricsLogger(s.Name())
s.handler = metricsLogger.GetHandler(s.handler) s.handler = metricsLogger.GetHandler(s.handler)
s.task.OnCancel("reset_metrics", metricsLogger.ResetMetrics) s.task.OnCancel("reset_metrics", metricsLogger.ResetMetrics)
} }
@ -104,9 +104,9 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
} }
} }
routes.SetHTTPRoute(s.TargetName(), s) routes.HTTP.Add(s)
s.task.OnCancel("entrypoint_remove_route", func() { s.task.OnCancel("entrypoint_remove_route", func() {
routes.DeleteHTTPRoute(s.TargetName()) routes.HTTP.Del(s)
}) })
return nil return nil
} }

View file

@ -59,7 +59,7 @@ func NewReverseProxyRoute(base *Route) (*ReveseProxyRoute, gperr.Error) {
} }
} }
service := base.TargetName() service := base.Name()
rp := reverseproxy.NewReverseProxy(service, proxyURL, trans) rp := reverseproxy.NewReverseProxy(service, proxyURL, trans)
if len(base.Middlewares) > 0 { if len(base.Middlewares) > 0 {
@ -90,16 +90,12 @@ func NewReverseProxyRoute(base *Route) (*ReveseProxyRoute, gperr.Error) {
return r, nil return r, nil
} }
func (r *ReveseProxyRoute) String() string {
return r.TargetName()
}
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error { func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
if existing, ok := routes.GetHTTPRoute(r.TargetName()); ok && !r.UseLoadBalance() { if existing, ok := routes.HTTP.Get(r.Key()); ok && !r.UseLoadBalance() {
return gperr.Errorf("route already exists: from provider %s and %s", existing.ProviderName(), r.ProviderName()) return gperr.Errorf("route already exists: from provider %s and %s", existing.ProviderName(), r.ProviderName())
} }
r.task = parent.Subtask("http."+r.TargetName(), false) r.task = parent.Subtask("http."+r.Name(), false)
switch { switch {
case r.UseIdleWatcher(): case r.UseIdleWatcher():
@ -132,7 +128,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
r.handler = r.rp r.handler = r.rp
default: default:
logging.Warn(). logging.Warn().
Str("route", r.TargetName()). Str("route", r.Name()).
Msg("`path_patterns` for reverse proxy is deprecated. Use `rules` instead.") Msg("`path_patterns` for reverse proxy is deprecated. Use `rules` instead.")
mux := gphttp.NewServeMux() mux := gphttp.NewServeMux()
patErrs := gperr.NewBuilder("invalid path pattern(s)") patErrs := gperr.NewBuilder("invalid path pattern(s)")
@ -148,7 +144,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
} }
if len(r.Rules) > 0 { if len(r.Rules) > 0 {
r.handler = r.Rules.BuildHandler(r.TargetName(), r.handler) r.handler = r.Rules.BuildHandler(r.Name(), r.handler)
} }
if r.HealthMon != nil { if r.HealthMon != nil {
@ -158,7 +154,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
} }
if common.PrometheusEnabled { if common.PrometheusEnabled {
metricsLogger := metricslogger.NewMetricsLogger(r.TargetName()) metricsLogger := metricslogger.NewMetricsLogger(r.Name())
r.handler = metricsLogger.GetHandler(r.handler) r.handler = metricsLogger.GetHandler(r.handler)
r.task.OnCancel("reset_metrics", metricsLogger.ResetMetrics) r.task.OnCancel("reset_metrics", metricsLogger.ResetMetrics)
} }
@ -166,9 +162,9 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
if r.UseLoadBalance() { if r.UseLoadBalance() {
r.addToLoadBalancer(parent) r.addToLoadBalancer(parent)
} else { } else {
routes.SetHTTPRoute(r.TargetName(), r) routes.HTTP.Add(r)
r.task.OnFinished("entrypoint_remove_route", func() { r.task.OnFinished("entrypoint_remove_route", func() {
routes.DeleteHTTPRoute(r.TargetName()) routes.HTTP.Del(r)
}) })
} }
@ -201,7 +197,7 @@ func (r *ReveseProxyRoute) HealthMonitor() health.HealthMonitor {
func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) { func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) {
var lb *loadbalancer.LoadBalancer var lb *loadbalancer.LoadBalancer
cfg := r.LoadBalance cfg := r.LoadBalance
l, ok := routes.GetHTTPRoute(cfg.Link) l, ok := routes.HTTP.Get(cfg.Link)
var linked *ReveseProxyRoute var linked *ReveseProxyRoute
if ok { if ok {
linked = l.(*ReveseProxyRoute) linked = l.(*ReveseProxyRoute)
@ -222,7 +218,10 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) {
loadBalancer: lb, loadBalancer: lb,
handler: lb, handler: lb,
} }
routes.SetHTTPRoute(cfg.Link, linked) routes.HTTP.Add(linked)
r.task.OnFinished("entrypoint_remove_route", func() {
routes.HTTP.Del(linked)
})
} }
r.loadBalancer = lb r.loadBalancer = lb

View file

@ -24,6 +24,7 @@ import (
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
loadbalance "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" loadbalance "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/route/rules" "github.com/yusing/go-proxy/internal/route/rules"
route "github.com/yusing/go-proxy/internal/route/types" route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
@ -62,7 +63,7 @@ type (
LisURL *url.URL `json:"lurl,omitempty"` LisURL *url.URL `json:"lurl,omitempty"`
ProxyURL *url.URL `json:"purl,omitempty"` ProxyURL *url.URL `json:"purl,omitempty"`
impl route.Route impl routes.Route
} }
Routes map[string]*Route Routes map[string]*Route
) )
@ -77,17 +78,6 @@ func (r Routes) Contains(alias string) bool {
func (r *Route) Validate() (err gperr.Error) { func (r *Route) Validate() (err gperr.Error) {
r.Finalize() r.Finalize()
// return error if route is localhost:<godoxy_port>
switch r.Host {
case "localhost", "127.0.0.1":
switch r.Port.Proxy {
case common.ProxyHTTPPort, common.ProxyHTTPSPort, common.APIHTTPPort:
if r.Scheme.IsReverseProxy() || r.Scheme == route.SchemeTCP {
return gperr.Errorf("localhost:%d is reserved for godoxy", r.Port.Proxy)
}
}
}
if r.Idlewatcher != nil && r.Idlewatcher.Proxmox != nil { if r.Idlewatcher != nil && r.Idlewatcher.Proxmox != nil {
node := r.Idlewatcher.Proxmox.Node node := r.Idlewatcher.Proxmox.Node
vmid := r.Idlewatcher.Proxmox.VMID vmid := r.Idlewatcher.Proxmox.VMID
@ -152,6 +142,17 @@ func (r *Route) Validate() (err gperr.Error) {
} }
} }
// return error if route is localhost:<godoxy_port>
switch r.Host {
case "localhost", "127.0.0.1":
switch r.Port.Proxy {
case common.ProxyHTTPPort, common.ProxyHTTPSPort, common.APIHTTPPort:
if r.Scheme.IsReverseProxy() || r.Scheme == route.SchemeTCP {
return gperr.Errorf("localhost:%d is reserved for godoxy", r.Port.Proxy)
}
}
}
errs := gperr.NewBuilder("entry validation failed") errs := gperr.NewBuilder("entry validation failed")
if r.Scheme == route.SchemeFileServer { if r.Scheme == route.SchemeFileServer {
@ -227,7 +228,17 @@ func (r *Route) ProviderName() string {
return r.Provider return r.Provider
} }
func (r *Route) TargetName() string { // Name implements pool.Object.
func (r *Route) Name() string {
return r.Alias
}
// Key implements pool.Object.
func (r *Route) Key() string {
return r.Alias
}
func (r *Route) String() string {
return r.Alias return r.Alias
} }

View file

@ -1,15 +1,14 @@
package routequery package routes
import ( import (
"time" "time"
"github.com/yusing/go-proxy/internal/homepage" "github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/route/routes"
route "github.com/yusing/go-proxy/internal/route/types" route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
func getHealthInfo(r route.Route) map[string]string { func getHealthInfo(r Route) map[string]string {
mon := r.HealthMonitor() mon := r.HealthMonitor()
if mon == nil { if mon == nil {
return map[string]string{ return map[string]string{
@ -30,7 +29,7 @@ type HealthInfoRaw struct {
Latency time.Duration `json:"latency"` Latency time.Duration `json:"latency"`
} }
func getHealthInfoRaw(r route.Route) *HealthInfoRaw { func getHealthInfoRaw(r Route) *HealthInfoRaw {
mon := r.HealthMonitor() mon := r.HealthMonitor()
if mon == nil { if mon == nil {
return &HealthInfoRaw{ return &HealthInfoRaw{
@ -45,69 +44,69 @@ func getHealthInfoRaw(r route.Route) *HealthInfoRaw {
} }
func HealthMap() map[string]map[string]string { func HealthMap() map[string]map[string]string {
healthMap := make(map[string]map[string]string, routes.NumRoutes()) healthMap := make(map[string]map[string]string, NumRoutes())
routes.RangeRoutes(func(alias string, r route.Route) { for alias, r := range Iter {
healthMap[alias] = getHealthInfo(r) healthMap[alias] = getHealthInfo(r)
}) }
return healthMap return healthMap
} }
func HealthInfo() map[string]*HealthInfoRaw { func HealthInfo() map[string]*HealthInfoRaw {
healthMap := make(map[string]*HealthInfoRaw, routes.NumRoutes()) healthMap := make(map[string]*HealthInfoRaw, NumRoutes())
routes.RangeRoutes(func(alias string, r route.Route) { for alias, r := range Iter {
healthMap[alias] = getHealthInfoRaw(r) healthMap[alias] = getHealthInfoRaw(r)
}) }
return healthMap return healthMap
} }
func HomepageCategories() []string { func HomepageCategories() []string {
check := make(map[string]struct{}) check := make(map[string]struct{})
categories := make([]string, 0) categories := make([]string, 0)
routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) { for _, r := range HTTP.Iter {
item := r.HomepageConfig() item := r.HomepageConfig()
if item == nil || item.Category == "" { if item == nil || item.Category == "" {
return continue
} }
if _, ok := check[item.Category]; ok { if _, ok := check[item.Category]; ok {
return continue
} }
check[item.Category] = struct{}{} check[item.Category] = struct{}{}
categories = append(categories, item.Category) categories = append(categories, item.Category)
}) }
return categories return categories
} }
func HomepageConfig(categoryFilter, providerFilter string) homepage.Homepage { func HomepageConfig(categoryFilter, providerFilter string) homepage.Homepage {
hp := make(homepage.Homepage) hp := make(homepage.Homepage)
routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) { for _, r := range HTTP.Iter {
if providerFilter != "" && r.ProviderName() != providerFilter { if providerFilter != "" && r.ProviderName() != providerFilter {
return continue
} }
item := r.HomepageItem() item := r.HomepageItem()
if categoryFilter != "" && item.Category != categoryFilter { if categoryFilter != "" && item.Category != categoryFilter {
return continue
} }
hp.Add(item) hp.Add(item)
}) }
return hp return hp
} }
func RoutesByAlias(typeFilter ...route.RouteType) map[string]route.Route { func ByAlias(typeFilter ...route.RouteType) map[string]Route {
rts := make(map[string]route.Route) rts := make(map[string]Route)
if len(typeFilter) == 0 || typeFilter[0] == "" { if len(typeFilter) == 0 || typeFilter[0] == "" {
typeFilter = []route.RouteType{route.RouteTypeHTTP, route.RouteTypeStream} typeFilter = []route.RouteType{route.RouteTypeHTTP, route.RouteTypeStream}
} }
for _, t := range typeFilter { for _, t := range typeFilter {
switch t { switch t {
case route.RouteTypeHTTP: case route.RouteTypeHTTP:
routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) { for alias, r := range HTTP.Iter {
rts[alias] = r rts[alias] = r
}) }
case route.RouteTypeStream: case route.RouteTypeStream:
routes.GetStreamRoutes().RangeAll(func(alias string, r route.StreamRoute) { for alias, r := range Stream.Iter {
rts[alias] = r rts[alias] = r
}) }
} }
} }
return rts return rts

View file

@ -1,4 +1,4 @@
package route package routes
import ( import (
"net/http" "net/http"
@ -10,6 +10,7 @@ import (
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
net "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/pool"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
loadbalance "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" loadbalance "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
@ -21,8 +22,8 @@ type (
Route interface { Route interface {
task.TaskStarter task.TaskStarter
task.TaskFinisher task.TaskFinisher
pool.Object
ProviderName() string ProviderName() string
TargetName() string
TargetURL() *url.URL TargetURL() *url.URL
HealthMonitor() health.HealthMonitor HealthMonitor() health.HealthMonitor
Reference() string Reference() string

View file

@ -1,78 +1,49 @@
package routes package routes
import ( import (
route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/utils/pool"
F "github.com/yusing/go-proxy/internal/utils/functional"
) )
var ( var (
httpRoutes = F.NewMapOf[string, route.HTTPRoute]() HTTP = pool.New[HTTPRoute]("http_routes")
streamRoutes = F.NewMapOf[string, route.StreamRoute]() Stream = pool.New[StreamRoute]("stream_routes")
) )
func RangeRoutes(callback func(alias string, r route.Route)) { func Iter(yield func(alias string, r Route) bool) {
httpRoutes.RangeAll(func(alias string, r route.HTTPRoute) { for k, r := range HTTP.Iter {
callback(alias, r) if !yield(k, r) {
}) break
streamRoutes.RangeAll(func(alias string, r route.StreamRoute) { }
callback(alias, r) }
}) for k, r := range Stream.Iter {
if !yield(k, r) {
break
}
}
} }
func NumRoutes() int { func NumRoutes() int {
return httpRoutes.Size() + streamRoutes.Size() return HTTP.Size() + Stream.Size()
} }
func GetHTTPRoutes() F.Map[string, route.HTTPRoute] { func Clear() {
return httpRoutes HTTP.Clear()
Stream.Clear()
} }
func GetStreamRoutes() F.Map[string, route.StreamRoute] { func GetHTTPRouteOrExact(alias, host string) (HTTPRoute, bool) {
return streamRoutes r, ok := HTTP.Get(alias)
}
func GetHTTPRouteOrExact(alias, host string) (route.HTTPRoute, bool) {
r, ok := httpRoutes.Load(alias)
if ok { if ok {
return r, true return r, true
} }
// try find with exact match // try find with exact match
return httpRoutes.Load(host) return HTTP.Get(host)
} }
func GetHTTPRoute(alias string) (route.HTTPRoute, bool) { func Get(alias string) (Route, bool) {
return httpRoutes.Load(alias) r, ok := HTTP.Get(alias)
}
func GetStreamRoute(alias string) (route.StreamRoute, bool) {
return streamRoutes.Load(alias)
}
func GetRoute(alias string) (route.Route, bool) {
r, ok := httpRoutes.Load(alias)
if ok { if ok {
return r, true return r, true
} }
return streamRoutes.Load(alias) return Stream.Get(alias)
}
func SetHTTPRoute(alias string, r route.HTTPRoute) {
httpRoutes.Store(alias, r)
}
func SetStreamRoute(alias string, r route.StreamRoute) {
streamRoutes.Store(alias, r)
}
func DeleteHTTPRoute(alias string) {
httpRoutes.Delete(alias)
}
func DeleteStreamRoute(alias string) {
streamRoutes.Delete(alias)
}
func TestClear() {
httpRoutes = F.NewMapOf[string, route.HTTPRoute]()
streamRoutes = F.NewMapOf[string, route.StreamRoute]()
} }

View file

@ -10,7 +10,6 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
net "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/route/routes"
route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor" "github.com/yusing/go-proxy/internal/watcher/health/monitor"
@ -29,31 +28,24 @@ type StreamRoute struct {
l zerolog.Logger l zerolog.Logger
} }
func NewStreamRoute(base *Route) (route.Route, gperr.Error) { func NewStreamRoute(base *Route) (routes.Route, gperr.Error) {
// TODO: support non-coherent scheme // TODO: support non-coherent scheme
return &StreamRoute{ return &StreamRoute{
Route: base, Route: base,
l: logging.With(). l: logging.With().
Str("type", string(base.Scheme)). Str("type", string(base.Scheme)).
Str("name", base.TargetName()). Str("name", base.Name()).
Logger(), Logger(),
}, nil }, nil
} }
func (r *StreamRoute) String() string {
return "stream " + r.TargetName()
}
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (r *StreamRoute) Start(parent task.Parent) gperr.Error { func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
if existing, ok := routes.GetStreamRoute(r.TargetName()); ok { if existing, ok := routes.Stream.Get(r.Key()); ok {
return gperr.Errorf("route already exists: from provider %s and %s", existing.ProviderName(), r.ProviderName()) return gperr.Errorf("route already exists: from provider %s and %s", existing.ProviderName(), r.ProviderName())
} }
r.task = parent.Subtask("stream." + r.TargetName()) r.task = parent.Subtask("stream."+r.Name(), true)
r.Stream = NewStream(r) r.Stream = NewStream(r)
parent.OnCancel("finish", func() {
r.task.Finish(nil)
})
switch { switch {
case r.UseIdleWatcher(): case r.UseIdleWatcher():
@ -83,9 +75,9 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
go r.acceptConnections() go r.acceptConnections()
routes.SetStreamRoute(r.TargetName(), r) routes.Stream.Add(r)
r.task.OnFinished("entrypoint_remove_route", func() { r.task.OnFinished("entrypoint_remove_route", func() {
routes.DeleteStreamRoute(r.TargetName()) routes.Stream.Del(r)
}) })
return nil return nil
} }

View file

@ -57,7 +57,7 @@ func (t *Task) callbackList() []map[string]any {
func (t *Task) MarshalMap() map[string]any { func (t *Task) MarshalMap() map[string]any {
return map[string]any{ return map[string]any{
"name": t.name, "name": t.name,
"need_finish": strconv.FormatBool(t.needFinish), "need_finish": strconv.FormatBool(t.waitFinish),
"childrens": t.children, "childrens": t.children,
"callbacks": t.callbackList(), "callbacks": t.callbackList(),
"finish_called": t.finishedCalled, "finish_called": t.finishedCalled,

View file

@ -45,7 +45,7 @@ type (
callbacks map[*Callback]struct{} callbacks map[*Callback]struct{}
callbacksDone chan struct{} callbacksDone chan struct{}
needFinish bool waitFinish bool
finished chan struct{} finished chan struct{}
// finishedCalled == 1 Finish has been called // finishedCalled == 1 Finish has been called
// but does not mean that the task is finished yet // but does not mean that the task is finished yet
@ -59,7 +59,7 @@ type (
} }
Parent interface { Parent interface {
Context() context.Context Context() context.Context
Subtask(name string, needFinish ...bool) *Task Subtask(name string, waitFinish bool) *Task
Name() string Name() string
Finish(reason any) Finish(reason any)
OnCancel(name string, f func()) OnCancel(name string, f func())
@ -141,13 +141,11 @@ func (t *Task) finish(reason any) {
// Subtask returns a new subtask with the given name, derived from the parent's context. // Subtask returns a new subtask with the given name, derived from the parent's context.
// //
// This should not be called after Finish is called. // This should not be called after Finish is called.
func (t *Task) Subtask(name string, needFinish ...bool) *Task { func (t *Task) Subtask(name string, waitFinish bool) *Task {
nf := len(needFinish) == 0 || needFinish[0]
ctx, cancel := context.WithCancelCause(t.ctx) ctx, cancel := context.WithCancelCause(t.ctx)
child := &Task{ child := &Task{
parent: t, parent: t,
needFinish: nf, waitFinish: waitFinish,
finished: make(chan struct{}), finished: make(chan struct{}),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@ -161,7 +159,7 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
allTasks.Add(child) allTasks.Add(child)
t.addChildCount() t.addChildCount()
if !nf { if !waitFinish {
go func() { go func() {
<-child.ctx.Done() <-child.ctx.Done()
child.Finish(nil) child.Finish(nil)

View file

@ -17,7 +17,7 @@ func TestChildTaskCancellation(t *testing.T) {
t.Cleanup(testCleanup) t.Cleanup(testCleanup)
parent := testTask() parent := testTask()
child := parent.Subtask("") child := parent.Subtask("child", false)
go func() { go func() {
defer child.Finish(nil) defer child.Finish(nil)

View file

@ -25,8 +25,8 @@ func testCleanup() {
} }
// RootTask returns a new Task with the given name, derived from the root context. // RootTask returns a new Task with the given name, derived from the root context.
func RootTask(name string, needFinish ...bool) *Task { func RootTask(name string, needFinish bool) *Task {
return root.Subtask(name, needFinish...) return root.Subtask(name, needFinish)
} }
func newRoot() *Task { func newRoot() *Task {
@ -66,6 +66,9 @@ func GracefulShutdown(timeout time.Duration) (err error) {
return return
case <-after: case <-after:
logging.Warn().Msgf("Timeout waiting for %d tasks to finish", allTasks.Size()) logging.Warn().Msgf("Timeout waiting for %d tasks to finish", allTasks.Size())
for t := range allTasks.Range {
logging.Warn().Msgf("Task %s is still running", t.name)
}
return context.DeadlineExceeded return context.DeadlineExceeded
} }
} }

View file

@ -9,6 +9,7 @@ import (
"syscall" "syscall"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/synk"
) )
// TODO: move to "utils/io". // TODO: move to "utils/io".
@ -117,24 +118,21 @@ func getHttpFlusher(dst io.Writer) httpFlusher {
return nil return nil
} }
const ( const copyBufSize = 32 * 1024
copyBufSize = 32 * 1024
)
var copyBufPool = sync.Pool{ var copyBufPool = synk.NewBytesPool(copyBufSize, synk.DefaultMaxBytes)
New: func() any {
return make([]byte, copyBufSize)
},
}
// Copyright 2009 The Go Authors. All rights reserved. // Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// This is a copy of io.Copy with context and HTTP flusher handling // This is a copy of io.Copy with context and HTTP flusher handling
// Author: yusing <yusing@6uo.me>. // Author: yusing <yusing@6uo.me>.
func CopyClose(dst *ContextWriter, src *ContextReader) (err error) { func CopyClose(dst *ContextWriter, src *ContextReader) (err error) {
var buf []byte buf := copyBufPool.Get()
defer copyBufPool.Put(buf)
var size int
if l, ok := src.Reader.(*io.LimitedReader); ok { if l, ok := src.Reader.(*io.LimitedReader); ok {
size := copyBufSize size = copyBufSize
if int64(size) > l.N { if int64(size) > l.N {
if l.N < 1 { if l.N < 1 {
size = 1 size = 1
@ -142,10 +140,8 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) {
size = int(l.N) size = int(l.N)
} }
} }
buf = make([]byte, 0, size)
} else { } else {
buf = copyBufPool.Get().([]byte) size = cap(buf)
defer copyBufPool.Put(buf[:0])
} }
// close both as soon as one of them is done // close both as soon as one of them is done
wCloser, wCanClose := dst.Writer.(io.Closer) wCloser, wCanClose := dst.Writer.(io.Closer)
@ -179,7 +175,7 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) {
flusher := getHttpFlusher(dst.Writer) flusher := getHttpFlusher(dst.Writer)
canFlush := flusher != nil canFlush := flusher != nil
for { for {
nr, er := src.Reader.Read(buf[:copyBufSize]) nr, er := src.Reader.Read(buf[:size])
if nr > 0 { if nr > 0 {
nw, ew := dst.Writer.Write(buf[0:nr]) nw, ew := dst.Writer.Write(buf[0:nr])
if nw < 0 || nr < nw { if nw < 0 || nr < nw {

View file

@ -3,25 +3,23 @@ package pool
import ( import (
"sort" "sort"
"github.com/puzpuzpuz/xsync/v3"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/functional"
) )
type ( type (
Pool[T Object] struct { Pool[T Object] struct {
m functional.Map[string, T] m *xsync.MapOf[string, T]
name string name string
} }
Object interface { Object interface {
Key() string Key() string
Name() string Name() string
utils.MapMarshaler
} }
) )
func New[T Object](name string) Pool[T] { func New[T Object](name string) Pool[T] {
return Pool[T]{functional.NewMapOf[string, T](), name} return Pool[T]{xsync.NewMapOf[string, T](), name}
} }
func (p Pool[T]) Name() string { func (p Pool[T]) Name() string {
@ -29,6 +27,7 @@ func (p Pool[T]) Name() string {
} }
func (p Pool[T]) Add(obj T) { func (p Pool[T]) Add(obj T) {
p.checkExists(obj.Key())
p.m.Store(obj.Key(), obj) p.m.Store(obj.Key(), obj)
logging.Info().Msgf("%s: added %s", p.name, obj.Name()) logging.Info().Msgf("%s: added %s", p.name, obj.Name())
} }
@ -50,8 +49,8 @@ func (p Pool[T]) Clear() {
p.m.Clear() p.m.Clear()
} }
func (p Pool[T]) Base() functional.Map[string, T] { func (p Pool[T]) Iter(fn func(k string, v T) bool) {
return p.m p.m.Range(fn)
} }
func (p Pool[T]) Slice() []T { func (p Pool[T]) Slice() []T {
@ -64,11 +63,3 @@ func (p Pool[T]) Slice() []T {
}) })
return slice return slice
} }
func (p Pool[T]) Iter(fn func(k string, v T) bool) {
p.m.Range(fn)
}
func (p Pool[T]) IterAll(fn func(k string, v T)) {
p.m.RangeAll(fn)
}

View file

@ -0,0 +1,15 @@
//go:build debug
package pool
import (
"runtime/debug"
"github.com/yusing/go-proxy/internal/logging"
)
func (p Pool[T]) checkExists(key string) {
if _, ok := p.m.Load(key); ok {
logging.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
}
}

View file

@ -0,0 +1,7 @@
//go:build !debug
package pool
func (p Pool[T]) checkExists(key string) {
// no-op in production
}

View file

@ -0,0 +1,42 @@
package synk
import "sync"
type (
// Pool is a wrapper of sync.Pool that limits the size of the object.
Pool[T any] struct {
pool sync.Pool
maxSize int
}
BytesPool = Pool[byte]
)
const (
DefaultInitBytes = 1024
DefaultMaxBytes = 1024 * 1024
)
func NewPool[T any](initSize int, maxSize int) *Pool[T] {
return &Pool[T]{
pool: sync.Pool{
New: func() any {
return make([]T, 0, initSize)
},
},
maxSize: maxSize,
}
}
func NewBytesPool(initSize int, maxSize int) *BytesPool {
return NewPool[byte](initSize, maxSize)
}
func (p *Pool[T]) Get() []T {
return p.pool.Get().([]T)
}
func (p *Pool[T]) Put(b []T) {
if cap(b) <= p.maxSize {
p.pool.Put(b[:0])
}
}

View file

@ -21,7 +21,7 @@ type DirWatcher struct {
w *fsnotify.Watcher w *fsnotify.Watcher
fwMap map[string]*fileWatcher fwMap map[string]*fileWatcher
mu sync.Mutex mu sync.RWMutex
eventCh chan Event eventCh chan Event
errCh chan gperr.Error errCh chan gperr.Error
@ -56,7 +56,7 @@ func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
fwMap: make(map[string]*fileWatcher), fwMap: make(map[string]*fileWatcher),
eventCh: make(chan Event), eventCh: make(chan Event),
errCh: make(chan gperr.Error), errCh: make(chan gperr.Error),
task: parent.Subtask("dir_watcher(" + dirPath + ")"), task: parent.Subtask("dir_watcher("+dirPath+")", true),
} }
go helper.start() go helper.start()
return helper return helper
@ -95,7 +95,9 @@ func (h *DirWatcher) cleanup() {
close(fw.eventCh) close(fw.eventCh)
close(fw.errCh) close(fw.errCh)
} }
h.fwMap = nil
h.task.Finish(nil) h.task.Finish(nil)
h.Info().Msg("directory watcher closed")
} }
func (h *DirWatcher) start() { func (h *DirWatcher) start() {
@ -143,9 +145,9 @@ func (h *DirWatcher) start() {
} }
// send event to file watcher too // send event to file watcher too
h.mu.Lock() h.mu.RLock()
w, ok := h.fwMap[relPath] w, ok := h.fwMap[relPath]
h.mu.Unlock() h.mu.RUnlock()
if ok { if ok {
select { select {
case w.eventCh <- msg: case w.eventCh <- msg:

View file

@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/notif"
route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/atomic" "github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
@ -37,15 +37,15 @@ type (
var ErrNegativeInterval = errors.New("negative interval") var ErrNegativeInterval = errors.New("negative interval")
func NewMonitor(r route.Route) health.HealthMonCheck { func NewMonitor(r routes.Route) health.HealthMonCheck {
var mon health.HealthMonCheck var mon health.HealthMonCheck
if r.IsAgent() { if r.IsAgent() {
mon = NewAgentProxiedMonitor(r.Agent(), r.HealthCheckConfig(), AgentTargetFromURL(r.TargetURL())) mon = NewAgentProxiedMonitor(r.Agent(), r.HealthCheckConfig(), AgentTargetFromURL(r.TargetURL()))
} else { } else {
switch r := r.(type) { switch r := r.(type) {
case route.HTTPRoute: case routes.HTTPRoute:
mon = NewHTTPHealthMonitor(r.TargetURL(), r.HealthCheckConfig()) mon = NewHTTPHealthMonitor(r.TargetURL(), r.HealthCheckConfig())
case route.StreamRoute: case routes.StreamRoute:
mon = NewRawHealthMonitor(r.TargetURL(), r.HealthCheckConfig()) mon = NewRawHealthMonitor(r.TargetURL(), r.HealthCheckConfig())
default: default:
logging.Panic().Msgf("unexpected route type: %T", r) logging.Panic().Msgf("unexpected route type: %T", r)
@ -58,7 +58,7 @@ func NewMonitor(r route.Route) health.HealthMonCheck {
return mon return mon
} }
r.Task().OnCancel("close_docker_client", client.Close) r.Task().OnCancel("close_docker_client", client.Close)
return NewDockerHealthMonitor(client, cont.ContainerID, r.TargetName(), r.HealthCheckConfig(), mon) return NewDockerHealthMonitor(client, cont.ContainerID, r.Name(), r.HealthCheckConfig(), mon)
} }
return mon return mon
} }
@ -88,7 +88,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error {
} }
mon.service = parent.Name() mon.service = parent.Name()
mon.task = parent.Subtask("health_monitor") mon.task = parent.Subtask("health_monitor", true)
go func() { go func() {
logger := logging.With().Str("name", mon.service).Logger() logger := logging.With().Str("name", mon.service).Logger()

View file

@ -2,9 +2,9 @@ package json
import ( import (
"reflect" "reflect"
"sync"
"github.com/bytedance/sonic" "github.com/bytedance/sonic"
"github.com/yusing/go-proxy/internal/utils/synk"
) )
type Marshaler interface { type Marshaler interface {
@ -38,8 +38,8 @@ var (
// //
// - It does not support maps other than string-keyed maps. // - It does not support maps other than string-keyed maps.
func Marshal(v any) ([]byte, error) { func Marshal(v any) ([]byte, error) {
buf := newBytes() buf := bytesPool.Get()
defer putBytes(buf) defer bytesPool.Put(buf)
return cloneBytes(appendMarshal(reflect.ValueOf(v), buf)), nil return cloneBytes(appendMarshal(reflect.ValueOf(v), buf)), nil
} }
@ -47,21 +47,9 @@ func MarshalTo(v any, buf []byte) []byte {
return appendMarshal(reflect.ValueOf(v), buf) return appendMarshal(reflect.ValueOf(v), buf)
} }
const bufSize = 8192 const initBufSize = 4096
var bytesPool = sync.Pool{ var bytesPool = synk.NewBytesPool(initBufSize, synk.DefaultMaxBytes)
New: func() any {
return make([]byte, 0, bufSize)
},
}
func newBytes() []byte {
return bytesPool.Get().([]byte)
}
func putBytes(buf []byte) {
bytesPool.Put(buf[:0])
}
func cloneBytes(buf []byte) (res []byte) { func cloneBytes(buf []byte) (res []byte) {
return append(res, buf...) return append(res, buf...)