fixed some issues

This commit is contained in:
yusing 2025-02-11 09:16:21 +08:00
parent 508b093278
commit 07bce90521
9 changed files with 243 additions and 10 deletions

View file

@ -101,6 +101,12 @@ func checkVersion(a, b string) bool {
return withoutBuildTime(a) == withoutBuildTime(b) return withoutBuildTime(a) == withoutBuildTime(b)
} }
func (cfg *AgentConfig) Remove() {
agentMapMu.Lock()
defer agentMapMu.Unlock()
agents.Delete(cfg.Name())
}
func (cfg *AgentConfig) load() E.Error { func (cfg *AgentConfig) load() E.Error {
certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr)) certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr))
if err != nil { if err != nil {

14
agent/pkg/env/env.go vendored
View file

@ -5,6 +5,7 @@ import (
"net" "net"
"os" "os"
"strings" "strings"
"sync"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
) )
@ -32,14 +33,11 @@ func init() {
if err != nil { if err != nil {
log.Fatalf("failed to parse allowed hosts: %v", err) log.Fatalf("failed to parse allowed hosts: %v", err)
} }
if len(cidrs) == 0 {
log.Fatal("REGISTRATION_ALLOWED_HOSTS is empty")
}
RegistrationAllowedCIDRs = cidrs RegistrationAllowedCIDRs = cidrs
} }
func toCIDRs(hosts []string) ([]*net.IPNet, error) { func toCIDRs(hosts []string) ([]*net.IPNet, error) {
var cidrs []*net.IPNet cidrs := make([]*net.IPNet, 0, len(hosts))
for _, host := range hosts { for _, host := range hosts {
if !strings.Contains(host, "/") { if !strings.Contains(host, "/") {
host += "/32" host += "/32"
@ -53,7 +51,15 @@ func toCIDRs(hosts []string) ([]*net.IPNet, error) {
return cidrs, nil return cidrs, nil
} }
var warnOnce sync.Once
func IsAllowedHost(remoteAddr string) bool { func IsAllowedHost(remoteAddr string) bool {
if len(RegistrationAllowedCIDRs) == 0 {
warnOnce.Do(func() {
log.Println("Warning: REGISTRATION_ALLOWED_HOSTS is empty, allowing all hosts")
})
return true
}
ip, _, err := net.SplitHostPort(remoteAddr) ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
ip = remoteAddr ip = remoteAddr

View file

@ -3,10 +3,10 @@ package handler
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"os"
apiUtils "github.com/yusing/go-proxy/internal/api/v1/utils" apiUtils "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor" "github.com/yusing/go-proxy/internal/watcher/health/monitor"
) )
@ -28,8 +28,8 @@ func CheckHealth(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return return
} }
ok, err := utils.FileExists(path) _, err := os.Stat(path)
result = &health.HealthCheckResult{Healthy: ok} result = &health.HealthCheckResult{Healthy: err == nil}
if err != nil { if err != nil {
result.Detail = err.Error() result.Detail = err.Error()
} }

View file

@ -0,0 +1,215 @@
package handler
import (
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"github.com/yusing/go-proxy/agent/pkg/agent"
. "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/internal/watcher/health"
)
func TestCheckHealthHTTP(t *testing.T) {
tests := []struct {
name string
setupServer func() *httptest.Server
queryParams map[string]string
expectedStatus int
expectedHealthy bool
}{
{
name: "Valid",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
},
queryParams: map[string]string{
"scheme": "http",
"host": "localhost",
"path": "/",
},
expectedStatus: http.StatusOK,
expectedHealthy: true,
},
{
name: "InvalidQuery",
setupServer: nil,
queryParams: map[string]string{
"scheme": "http",
},
expectedStatus: http.StatusBadRequest,
},
{
name: "ConnectionError",
setupServer: nil,
queryParams: map[string]string{
"scheme": "http",
"host": "localhost:12345",
},
expectedStatus: http.StatusOK,
expectedHealthy: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var server *httptest.Server
if tt.setupServer != nil {
server = tt.setupServer()
defer server.Close()
u, _ := url.Parse(server.URL)
tt.queryParams["scheme"] = u.Scheme
tt.queryParams["host"] = u.Host
tt.queryParams["path"] = u.Path
}
recorder := httptest.NewRecorder()
query := url.Values{}
for key, value := range tt.queryParams {
query.Set(key, value)
}
request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil)
CheckHealth(recorder, request)
ExpectEqual(t, recorder.Code, tt.expectedStatus)
if tt.expectedStatus == http.StatusOK {
var result health.HealthCheckResult
ExpectEqual(t, json.Unmarshal(recorder.Body.Bytes(), &result), nil)
ExpectEqual(t, result.Healthy, tt.expectedHealthy)
}
})
}
}
func TestCheckHealthFileServer(t *testing.T) {
tests := []struct {
name string
path string
expectedStatus int
expectedHealthy bool
expectedDetail string
}{
{
name: "ValidPath",
path: t.TempDir(),
expectedStatus: http.StatusOK,
expectedHealthy: true,
expectedDetail: "",
},
{
name: "InvalidPath",
path: "/invalid",
expectedStatus: http.StatusOK,
expectedHealthy: false,
expectedDetail: "stat /invalid: no such file or directory",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := url.Values{}
query.Set("scheme", "fileserver")
query.Set("path", tt.path)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil)
CheckHealth(recorder, request)
ExpectEqual(t, recorder.Code, tt.expectedStatus)
var result health.HealthCheckResult
ExpectEqual(t, json.Unmarshal(recorder.Body.Bytes(), &result), nil)
ExpectEqual(t, result.Healthy, tt.expectedHealthy)
ExpectEqual(t, result.Detail, tt.expectedDetail)
})
}
}
func TestCheckHealthTCPUDP(t *testing.T) {
tcp, err := net.Listen("tcp", "localhost:0")
ExpectNoError(t, err)
go func() {
conn, err := tcp.Accept()
ExpectNoError(t, err)
conn.Close()
}()
udp, err := net.ListenPacket("udp", "localhost:0")
ExpectNoError(t, err)
go func() {
buf := make([]byte, 1024)
n, addr, err := udp.ReadFrom(buf)
ExpectNoError(t, err)
ExpectEqual(t, string(buf[:n]), "ping")
_, _ = udp.WriteTo([]byte("pong"), addr)
udp.Close()
}()
tests := []struct {
name string
scheme string
host string
port int
expectedStatus int
expectedHealthy bool
}{
{
name: "ValidTCP",
scheme: "tcp",
host: "localhost",
port: tcp.Addr().(*net.TCPAddr).Port,
expectedStatus: http.StatusOK,
expectedHealthy: true,
},
{
name: "InvalidHost",
scheme: "tcp",
host: "invalid",
port: 8080,
expectedStatus: http.StatusOK,
expectedHealthy: false,
},
{
name: "ValidUDP",
scheme: "udp",
host: "localhost",
port: udp.LocalAddr().(*net.UDPAddr).Port,
expectedStatus: http.StatusOK,
expectedHealthy: true,
},
{
name: "InvalidHost",
scheme: "udp",
host: "invalid",
port: 8080,
expectedStatus: http.StatusOK,
expectedHealthy: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := url.Values{}
query.Set("scheme", "tcp")
query.Set("host", tt.host)
query.Set("port", strconv.Itoa(tt.port))
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil)
CheckHealth(recorder, request)
ExpectEqual(t, recorder.Code, tt.expectedStatus)
var result health.HealthCheckResult
ExpectEqual(t, json.Unmarshal(recorder.Body.Bytes(), &result), nil)
ExpectEqual(t, result.Healthy, tt.expectedHealthy)
})
}
}

View file

@ -53,6 +53,8 @@ func newWaker(parent task.Parent, route route.Route, rp *reverseproxy.ReversePro
} }
switch { switch {
case route.IsAgent():
waker.hc = monitor.NewAgentRouteMonitor(route.Agent(), hcCfg, monitor.AgentTargetFromURL(route.TargetURL()))
case rp != nil: case rp != nil:
waker.hc = monitor.NewHTTPHealthChecker(route.TargetURL(), hcCfg) waker.hc = monitor.NewHTTPHealthChecker(route.TargetURL(), hcCfg)
case stream != nil: case stream != nil:

View file

@ -168,7 +168,7 @@ func copyHeader(dst, src http.Header) {
} }
func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) { func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) {
reqURL := r.Host + r.RequestURI reqURL := r.Host + r.URL.Path
switch { switch {
case errors.Is(err, context.Canceled), case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF), errors.Is(err, io.EOF),

View file

@ -127,6 +127,10 @@ func (p *Provider) Start(parent task.Parent) E.Error {
if err := errs.Error(); err != nil { if err := errs.Error(); err != nil {
return err.Subject(p.String()) return err.Subject(p.String())
} }
if p.t == types.ProviderTypeAgent {
t.OnCancel("remove agent", p.ProviderImpl.(*AgentProvider).AgentConfig.Remove)
}
return nil return nil
} }

View file

@ -206,7 +206,7 @@ func (r *ReveseProxyRoute) newHealthMonitor() interface {
health.HealthChecker health.HealthChecker
} { } {
if a := r.Agent(); a != nil { if a := r.Agent(); a != nil {
target := monitor.AgentCheckHealthTargetFromURL(r.ProxyURL) target := monitor.AgentTargetFromURL(r.ProxyURL)
return monitor.NewAgentRouteMonitor(a, r.HealthCheck, target) return monitor.NewAgentRouteMonitor(a, r.HealthCheck, target)
} }
return monitor.NewHTTPHealthMonitor(r.ProxyURL, r.HealthCheck) return monitor.NewHTTPHealthMonitor(r.ProxyURL, r.HealthCheck)

View file

@ -24,7 +24,7 @@ type (
} }
) )
func AgentCheckHealthTargetFromURL(url *types.URL) *AgentCheckHealthTarget { func AgentTargetFromURL(url *types.URL) *AgentCheckHealthTarget {
return &AgentCheckHealthTarget{ return &AgentCheckHealthTarget{
Scheme: url.Scheme, Scheme: url.Scheme,
Host: url.Host, Host: url.Host,