diff --git a/internal/common/http.go b/internal/common/http.go index a9ea8b5..d2bfde7 100644 --- a/internal/common/http.go +++ b/internal/common/http.go @@ -16,6 +16,7 @@ var ( Proxy: http.ProxyFromEnvironment, DialContext: defaultDialer.DialContext, MaxIdleConnsPerHost: 1000, + IdleConnTimeout: 90 * time.Second, } DefaultTransportNoTLS = func() *http.Transport { var clone = DefaultTransport.Clone() diff --git a/internal/config/query.go b/internal/config/query.go index c113041..ebb1f2d 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -71,17 +71,11 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig { if item.Category == "" { item.Category = "Docker" } - if item.Icon == "" { - item.Icon = "🐳" - } item.SourceType = string(PR.ProviderTypeDocker) } else if p.GetType() == PR.ProviderTypeFile { if item.Category == "" { item.Category = "Others" } - if item.Icon == "" { - item.Icon = "🔗" - } item.SourceType = string(PR.ProviderTypeFile) } @@ -90,6 +84,7 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig { item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port) } } + item.AltURL = r.URL().String() hpCfg.Add(&item) }) diff --git a/internal/homepage/homepage.go b/internal/homepage/homepage.go index 0a08407..5ef178f 100644 --- a/internal/homepage/homepage.go +++ b/internal/homepage/homepage.go @@ -8,13 +8,14 @@ type ( Show bool `yaml:"show" json:"show"` Name string `yaml:"name" json:"name"` Icon string `yaml:"icon" json:"icon"` - URL string `yaml:"url" json:"url"` // URL or unicodes + URL string `yaml:"url" json:"url"` // alias + domain Category string `yaml:"category" json:"category"` Description string `yaml:"description" json:"description"` WidgetConfig map[string]any `yaml:",flow" json:"widget_config"` SourceType string `yaml:"-" json:"source_type"` Initialized bool `yaml:"-" json:"-"` + AltURL string `yaml:"-" json:"alt_url"` // original proxy target } ) diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index 923011a..0f8fb5c 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -25,6 +25,8 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/net/http/httpguts" + + U "github.com/yusing/go-proxy/internal/utils" ) // A ProxyRequest contains a request to be rewritten by a [ReverseProxy]. @@ -418,9 +420,11 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(res.StatusCode) - _, err = io.Copy(rw, res.Body) + err = U.Copy2(req.Context(), rw, res.Body) if err != nil { - p.errorHandler(rw, req, err, true) + if !errors.Is(err, context.Canceled) { + p.errorHandler(rw, req, err, true) + } res.Body.Close() return } @@ -525,17 +529,9 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true) return } - errc := make(chan error, 1) - go func() { - _, err := io.Copy(conn, backConn) - errc <- err - }() - go func() { - _, err := io.Copy(backConn, conn) - errc <- err - }() - <-errc + bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn) + bdp.Start() } func IsPrint(s string) bool { diff --git a/internal/route/http.go b/internal/route/http.go index dadfabb..ac63673 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -27,7 +27,7 @@ type ( PathPatterns PT.PathPatterns `json:"path_patterns"` entry *P.ReverseProxyEntry - mux *http.ServeMux + mux http.Handler handler *ReverseProxy regIdleWatcher func() E.NestedError @@ -36,16 +36,24 @@ type ( URL url.URL SubdomainKey = PT.Alias + + ReverseProxyHandler struct { + *ReverseProxy + } ) var ( findMuxFunc = findMuxAnyDomain - httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]() + httpRoutes = F.NewMapOf[string, *HTTPRoute]() httpRoutesMu sync.Mutex globalMux = http.NewServeMux() // TODO: support regex subdomain matching ) +func (rp ReverseProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + rp.ReverseProxy.ServeHTTP(w, r) +} + func SetFindMuxDomains(domains []string) { if len(domains) == 0 { findMuxFunc = findMuxAnyDomain @@ -134,12 +142,17 @@ func (r *HTTPRoute) Start() E.NestedError { return nil } - r.mux = http.NewServeMux() - for _, p := range r.PathPatterns { - r.mux.HandleFunc(string(p), r.handler.ServeHTTP) + if len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/" { + r.mux = ReverseProxyHandler{r.handler} + } else { + mux := http.NewServeMux() + for _, p := range r.PathPatterns { + mux.HandleFunc(string(p), r.handler.ServeHTTP) + } + r.mux = mux } - httpRoutes.Store(r.Alias, r) + httpRoutes.Store(string(r.Alias), r) return nil } @@ -157,7 +170,7 @@ func (r *HTTPRoute) Stop() E.NestedError { } r.mux = nil - httpRoutes.Delete(r.Alias) + httpRoutes.Delete(string(r.Alias)) return nil } @@ -194,21 +207,21 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) { mux.ServeHTTP(w, r) } -func findMuxAnyDomain(host string) (*http.ServeMux, error) { +func findMuxAnyDomain(host string) (http.Handler, error) { hostSplit := strings.Split(host, ".") n := len(hostSplit) if n <= 2 { return nil, fmt.Errorf("missing subdomain in url") } sd := strings.Join(hostSplit[:n-2], ".") - if r, ok := httpRoutes.Load(PT.Alias(sd)); ok { + if r, ok := httpRoutes.Load(sd); ok { return r.mux, nil } return nil, fmt.Errorf("no such route: %s", sd) } -func findMuxByDomains(domains []string) func(host string) (*http.ServeMux, error) { - return func(host string) (*http.ServeMux, error) { +func findMuxByDomains(domains []string) func(host string) (http.Handler, error) { + return func(host string) (http.Handler, error) { var subdomain string for _, domain := range domains { @@ -223,7 +236,7 @@ func findMuxByDomains(domains []string) func(host string) (*http.ServeMux, error if len(subdomain) == len(host) { // not matched return nil, fmt.Errorf("%s does not match any base domain", host) } - if r, ok := httpRoutes.Load(PT.Alias(subdomain)); ok { + if r, ok := httpRoutes.Load(subdomain); ok { return r.mux, nil } return nil, fmt.Errorf("no such route: %s", subdomain) diff --git a/internal/route/route.go b/internal/route/route.go index 80cd40a..0a179c1 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -74,7 +74,7 @@ func (rt *route) Type() RouteType { } func (rt *route) URL() *url.URL { - url, _ := url.Parse(fmt.Sprintf("%s://%s", rt.entry.Scheme, rt.entry.Host)) + url, _ := url.Parse(fmt.Sprintf("%s://%s:%s", rt.entry.Scheme, rt.entry.Host, rt.entry.Port)) return url } diff --git a/internal/utils/io.go b/internal/utils/io.go index 4274de0..2875f45 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -110,6 +110,10 @@ func Copy(dst *ContextWriter, src *ContextReader) error { return err } +func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { + return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src}) +} + func LoadJson[T any](path string, pointer *T) E.NestedError { data, err := E.Check(os.ReadFile(path)) if err.HasError() {