From 07bce90521819e3176d50edb1125e7412216550a Mon Sep 17 00:00:00 2001 From: yusing Date: Tue, 11 Feb 2025 09:16:21 +0800 Subject: [PATCH] fixed some issues --- agent/pkg/agent/config.go | 6 + agent/pkg/env/env.go | 14 +- agent/pkg/handler/check_health.go | 6 +- agent/pkg/handler/check_health_test.go | 215 ++++++++++++++++++ internal/docker/idlewatcher/waker.go | 2 + .../http/reverseproxy/reverse_proxy_mod.go | 2 +- internal/route/provider/provider.go | 4 + internal/route/reverse_proxy.go | 2 +- .../watcher/health/monitor/agent_route.go | 2 +- 9 files changed, 243 insertions(+), 10 deletions(-) create mode 100644 agent/pkg/handler/check_health_test.go diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go index f5d9e7d..72f94a2 100644 --- a/agent/pkg/agent/config.go +++ b/agent/pkg/agent/config.go @@ -101,6 +101,12 @@ func checkVersion(a, b string) bool { return withoutBuildTime(a) == withoutBuildTime(b) } +func (cfg *AgentConfig) Remove() { + agentMapMu.Lock() + defer agentMapMu.Unlock() + agents.Delete(cfg.Name()) +} + func (cfg *AgentConfig) load() E.Error { certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr)) if err != nil { diff --git a/agent/pkg/env/env.go b/agent/pkg/env/env.go index ad9ef53..70c1fd8 100644 --- a/agent/pkg/env/env.go +++ b/agent/pkg/env/env.go @@ -5,6 +5,7 @@ import ( "net" "os" "strings" + "sync" "github.com/yusing/go-proxy/internal/common" ) @@ -32,14 +33,11 @@ func init() { if err != nil { log.Fatalf("failed to parse allowed hosts: %v", err) } - if len(cidrs) == 0 { - log.Fatal("REGISTRATION_ALLOWED_HOSTS is empty") - } RegistrationAllowedCIDRs = cidrs } func toCIDRs(hosts []string) ([]*net.IPNet, error) { - var cidrs []*net.IPNet + cidrs := make([]*net.IPNet, 0, len(hosts)) for _, host := range hosts { if !strings.Contains(host, "/") { host += "/32" @@ -53,7 +51,15 @@ func toCIDRs(hosts []string) ([]*net.IPNet, error) { return cidrs, nil } +var warnOnce sync.Once + func IsAllowedHost(remoteAddr string) bool { + if len(RegistrationAllowedCIDRs) == 0 { + warnOnce.Do(func() { + log.Println("Warning: REGISTRATION_ALLOWED_HOSTS is empty, allowing all hosts") + }) + return true + } ip, _, err := net.SplitHostPort(remoteAddr) if err != nil { ip = remoteAddr diff --git a/agent/pkg/handler/check_health.go b/agent/pkg/handler/check_health.go index 7cef3c0..4852b37 100644 --- a/agent/pkg/handler/check_health.go +++ b/agent/pkg/handler/check_health.go @@ -3,10 +3,10 @@ package handler import ( "net/http" "net/url" + "os" apiUtils "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/net/types" - "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health/monitor" ) @@ -28,8 +28,8 @@ func CheckHealth(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } - ok, err := utils.FileExists(path) - result = &health.HealthCheckResult{Healthy: ok} + _, err := os.Stat(path) + result = &health.HealthCheckResult{Healthy: err == nil} if err != nil { result.Detail = err.Error() } diff --git a/agent/pkg/handler/check_health_test.go b/agent/pkg/handler/check_health_test.go new file mode 100644 index 0000000..4079998 --- /dev/null +++ b/agent/pkg/handler/check_health_test.go @@ -0,0 +1,215 @@ +package handler + +import ( + "encoding/json" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + + "github.com/yusing/go-proxy/agent/pkg/agent" + . "github.com/yusing/go-proxy/internal/utils/testing" + "github.com/yusing/go-proxy/internal/watcher/health" +) + +func TestCheckHealthHTTP(t *testing.T) { + tests := []struct { + name string + setupServer func() *httptest.Server + queryParams map[string]string + expectedStatus int + expectedHealthy bool + }{ + { + name: "Valid", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + }, + queryParams: map[string]string{ + "scheme": "http", + "host": "localhost", + "path": "/", + }, + expectedStatus: http.StatusOK, + expectedHealthy: true, + }, + { + name: "InvalidQuery", + setupServer: nil, + queryParams: map[string]string{ + "scheme": "http", + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "ConnectionError", + setupServer: nil, + queryParams: map[string]string{ + "scheme": "http", + "host": "localhost:12345", + }, + expectedStatus: http.StatusOK, + expectedHealthy: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var server *httptest.Server + if tt.setupServer != nil { + server = tt.setupServer() + defer server.Close() + u, _ := url.Parse(server.URL) + tt.queryParams["scheme"] = u.Scheme + tt.queryParams["host"] = u.Host + tt.queryParams["path"] = u.Path + } + + recorder := httptest.NewRecorder() + query := url.Values{} + for key, value := range tt.queryParams { + query.Set(key, value) + } + request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil) + CheckHealth(recorder, request) + + ExpectEqual(t, recorder.Code, tt.expectedStatus) + + if tt.expectedStatus == http.StatusOK { + var result health.HealthCheckResult + ExpectEqual(t, json.Unmarshal(recorder.Body.Bytes(), &result), nil) + ExpectEqual(t, result.Healthy, tt.expectedHealthy) + } + }) + } +} + +func TestCheckHealthFileServer(t *testing.T) { + tests := []struct { + name string + path string + expectedStatus int + expectedHealthy bool + expectedDetail string + }{ + { + name: "ValidPath", + path: t.TempDir(), + expectedStatus: http.StatusOK, + expectedHealthy: true, + expectedDetail: "", + }, + { + name: "InvalidPath", + path: "/invalid", + expectedStatus: http.StatusOK, + expectedHealthy: false, + expectedDetail: "stat /invalid: no such file or directory", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := url.Values{} + query.Set("scheme", "fileserver") + query.Set("path", tt.path) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil) + CheckHealth(recorder, request) + + ExpectEqual(t, recorder.Code, tt.expectedStatus) + + var result health.HealthCheckResult + ExpectEqual(t, json.Unmarshal(recorder.Body.Bytes(), &result), nil) + ExpectEqual(t, result.Healthy, tt.expectedHealthy) + ExpectEqual(t, result.Detail, tt.expectedDetail) + }) + } +} + +func TestCheckHealthTCPUDP(t *testing.T) { + tcp, err := net.Listen("tcp", "localhost:0") + ExpectNoError(t, err) + go func() { + conn, err := tcp.Accept() + ExpectNoError(t, err) + conn.Close() + }() + + udp, err := net.ListenPacket("udp", "localhost:0") + ExpectNoError(t, err) + go func() { + buf := make([]byte, 1024) + n, addr, err := udp.ReadFrom(buf) + ExpectNoError(t, err) + ExpectEqual(t, string(buf[:n]), "ping") + _, _ = udp.WriteTo([]byte("pong"), addr) + udp.Close() + }() + + tests := []struct { + name string + scheme string + host string + port int + expectedStatus int + expectedHealthy bool + }{ + { + name: "ValidTCP", + scheme: "tcp", + host: "localhost", + port: tcp.Addr().(*net.TCPAddr).Port, + expectedStatus: http.StatusOK, + expectedHealthy: true, + }, + { + name: "InvalidHost", + scheme: "tcp", + host: "invalid", + port: 8080, + expectedStatus: http.StatusOK, + expectedHealthy: false, + }, + { + name: "ValidUDP", + scheme: "udp", + host: "localhost", + port: udp.LocalAddr().(*net.UDPAddr).Port, + expectedStatus: http.StatusOK, + expectedHealthy: true, + }, + { + name: "InvalidHost", + scheme: "udp", + host: "invalid", + port: 8080, + expectedStatus: http.StatusOK, + expectedHealthy: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := url.Values{} + query.Set("scheme", "tcp") + query.Set("host", tt.host) + query.Set("port", strconv.Itoa(tt.port)) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil) + CheckHealth(recorder, request) + + ExpectEqual(t, recorder.Code, tt.expectedStatus) + + var result health.HealthCheckResult + ExpectEqual(t, json.Unmarshal(recorder.Body.Bytes(), &result), nil) + ExpectEqual(t, result.Healthy, tt.expectedHealthy) + }) + } +} diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 9a3ca13..c4ff23c 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -53,6 +53,8 @@ func newWaker(parent task.Parent, route route.Route, rp *reverseproxy.ReversePro } switch { + case route.IsAgent(): + waker.hc = monitor.NewAgentRouteMonitor(route.Agent(), hcCfg, monitor.AgentTargetFromURL(route.TargetURL())) case rp != nil: waker.hc = monitor.NewHTTPHealthChecker(route.TargetURL(), hcCfg) case stream != nil: diff --git a/internal/net/http/reverseproxy/reverse_proxy_mod.go b/internal/net/http/reverseproxy/reverse_proxy_mod.go index aeeb31d..eb3985c 100644 --- a/internal/net/http/reverseproxy/reverse_proxy_mod.go +++ b/internal/net/http/reverseproxy/reverse_proxy_mod.go @@ -168,7 +168,7 @@ func copyHeader(dst, src http.Header) { } func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) { - reqURL := r.Host + r.RequestURI + reqURL := r.Host + r.URL.Path switch { case errors.Is(err, context.Canceled), errors.Is(err, io.EOF), diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index fda83d0..8cf4c03 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -127,6 +127,10 @@ func (p *Provider) Start(parent task.Parent) E.Error { if err := errs.Error(); err != nil { return err.Subject(p.String()) } + + if p.t == types.ProviderTypeAgent { + t.OnCancel("remove agent", p.ProviderImpl.(*AgentProvider).AgentConfig.Remove) + } return nil } diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index c97c53b..5e31db0 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -206,7 +206,7 @@ func (r *ReveseProxyRoute) newHealthMonitor() interface { health.HealthChecker } { if a := r.Agent(); a != nil { - target := monitor.AgentCheckHealthTargetFromURL(r.ProxyURL) + target := monitor.AgentTargetFromURL(r.ProxyURL) return monitor.NewAgentRouteMonitor(a, r.HealthCheck, target) } return monitor.NewHTTPHealthMonitor(r.ProxyURL, r.HealthCheck) diff --git a/internal/watcher/health/monitor/agent_route.go b/internal/watcher/health/monitor/agent_route.go index c5d42e6..4137c3c 100644 --- a/internal/watcher/health/monitor/agent_route.go +++ b/internal/watcher/health/monitor/agent_route.go @@ -24,7 +24,7 @@ type ( } ) -func AgentCheckHealthTargetFromURL(url *types.URL) *AgentCheckHealthTarget { +func AgentTargetFromURL(url *types.URL) *AgentCheckHealthTarget { return &AgentCheckHealthTarget{ Scheme: url.Scheme, Host: url.Host,