From eaf191e35080b54425f208d669dd99d7aa698785 Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 10 Feb 2025 09:36:37 +0800 Subject: [PATCH] implement godoxy-agent --- agent/cmd/args.go | 17 ++ agent/cmd/main.go | 100 +++++++++ agent/pkg/agent/config.go | 192 +++++++++++++++++ agent/pkg/agent/requests.go | 36 ++++ agent/pkg/agentproxy/headers.go | 27 +++ agent/pkg/certs/certs.go | 201 ++++++++++++++++++ agent/pkg/certs/zip.go | 76 +++++++ agent/pkg/env/env.go | 8 + agent/pkg/handler/check_health.go | 66 ++++++ agent/pkg/handler/docker_socket.go | 92 ++++++++ agent/pkg/handler/handler.go | 50 +++++ agent/pkg/handler/proxy_http.go | 59 +++++ agent/pkg/server/server.go | 51 +++++ cmd/main.go | 29 +-- cmd/new_agent.go | 46 ++++ ...main_production.go => pprof_production.go} | 0 cmd/{main_prof.go => pprof_prof.go} | 0 {internal => cmd}/setup.go | 2 +- internal/api/handler.go | 3 +- internal/api/v1/utils/logging.go | 1 - internal/api/v1/utils/ws.go | 2 +- internal/common/args.go | 55 ++--- internal/common/constants.go | 2 + internal/common/env.go | 4 + internal/common/ports.go | 22 +- internal/config/config.go | 3 + internal/config/types/config.go | 8 +- internal/docker/client.go | 62 +++--- internal/docker/container.go | 36 ++-- internal/error/log.go | 4 + internal/error/utils.go | 6 - internal/logging/html.go | 159 -------------- internal/logging/html_test.go | 30 --- .../v1 => logging/memlogger}/mem_logger.go | 38 +--- .../net/http/middleware/custom_error_page.go | 6 +- internal/net/http/server/server.go | 13 +- internal/net/http/{common.go => transport.go} | 28 +-- internal/route/fileserver.go | 2 +- internal/route/provider/agent.go | 34 +++ internal/route/provider/docker.go | 4 +- internal/route/provider/docker_test.go | 18 +- internal/route/provider/event_handler.go | 37 ++-- internal/route/provider/provider.go | 17 +- .../route/provider/types/provider_type.go | 1 + internal/route/reverse_proxy.go | 55 ++++- internal/route/route.go | 66 +++--- internal/route/routes/routequery/query.go | 6 +- internal/route/rules/do.go | 2 +- internal/route/rules/on_test.go | 3 +- internal/route/types/route.go | 4 + internal/utils/fs.go | 12 ++ internal/utils/wait_exit.go | 25 +++ internal/watcher/docker_watcher.go | 13 -- .../watcher/health/monitor/agent_route.go | 75 +++++++ internal/watcher/health/monitor/fileserver.go | 3 +- internal/watcher/health/types.go | 6 +- pkg/args.go | 29 +++ 57 files changed, 1479 insertions(+), 467 deletions(-) create mode 100644 agent/cmd/args.go create mode 100644 agent/cmd/main.go create mode 100644 agent/pkg/agent/config.go create mode 100644 agent/pkg/agent/requests.go create mode 100644 agent/pkg/agentproxy/headers.go create mode 100644 agent/pkg/certs/certs.go create mode 100644 agent/pkg/certs/zip.go create mode 100644 agent/pkg/env/env.go create mode 100644 agent/pkg/handler/check_health.go create mode 100644 agent/pkg/handler/docker_socket.go create mode 100644 agent/pkg/handler/handler.go create mode 100644 agent/pkg/handler/proxy_http.go create mode 100644 agent/pkg/server/server.go create mode 100644 cmd/new_agent.go rename cmd/{main_production.go => pprof_production.go} (100%) rename cmd/{main_prof.go => pprof_prof.go} (100%) rename {internal => cmd}/setup.go (99%) delete mode 100644 internal/logging/html.go delete mode 100644 internal/logging/html_test.go rename internal/{api/v1 => logging/memlogger}/mem_logger.go (82%) rename internal/net/http/{common.go => transport.go} (50%) create mode 100644 internal/route/provider/agent.go create mode 100644 internal/utils/wait_exit.go create mode 100644 internal/watcher/health/monitor/agent_route.go create mode 100644 pkg/args.go diff --git a/agent/cmd/args.go b/agent/cmd/args.go new file mode 100644 index 0000000..819fc29 --- /dev/null +++ b/agent/cmd/args.go @@ -0,0 +1,17 @@ +package main + +const ( + CommandStart = "" + CommandNewClient = "new-client" +) + +type agentCommandValidator struct{} + +func (v agentCommandValidator) IsCommandValid(cmd string) bool { + switch cmd { + case CommandStart, + CommandNewClient: + return true + } + return false +} diff --git a/agent/cmd/main.go b/agent/cmd/main.go new file mode 100644 index 0000000..4138489 --- /dev/null +++ b/agent/cmd/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "crypto/tls" + "encoding/base64" + "encoding/pem" + "fmt" + "net" + "os" + + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/agent/pkg/certs" + "github.com/yusing/go-proxy/agent/pkg/env" + "github.com/yusing/go-proxy/agent/pkg/server" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/logging/memlogger" + "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/pkg" + "gopkg.in/yaml.v3" +) + +func init() { + logging.InitLogger(zerolog.MultiLevelWriter(os.Stderr, memlogger.GetMemLogger())) +} + +func printNewClientHelp(ca *tls.Certificate) { + crt, key, err := certs.NewClientCert(ca) + if err != nil { + E.LogFatal("init SSL error", err) + } + caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Certificate[0]}) + ip := machineIP() + host := fmt.Sprintf("%s:%d", ip, env.AgentPort) + cfgYAML, _ := yaml.Marshal(map[string]any{ + "providers": map[string]any{ + "agents": host, + }, + }) + + certsData, err := certs.ZipCert(caPEM, crt, key) + if err != nil { + E.LogFatal("marshal certs error", err) + } + + fmt.Printf("Add this host (%s) to main server config like below:\n", host) + fmt.Println(string(cfgYAML)) + fmt.Printf("On main server, run:\ngodoxy new-agent '%s' '%s'\n", host, base64.StdEncoding.EncodeToString(certsData)) +} + +func machineIP() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "" + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String() + } + } + } + return "" +} + +func main() { + args := pkg.GetArgs(agentCommandValidator{}) + + ca, srv, isNew, err := certs.InitCerts() + if err != nil { + E.LogFatal("init CA error", err) + } + + switch args.Command { + case CommandNewClient: + printNewClientHelp(ca) + return + } + + logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion()) + logging.Info().Msgf("Agent name: %s", env.AgentName) + + if isNew { + logging.Info().Msg("Initialization complete.") + logging.Info().Msg("New client cert created") + printNewClientHelp(ca) + logging.Info().Msg("Exiting... Clear the screen and start agent again") + logging.Info().Msg("To create more client certs, run `godoxy-agent new-client`") + return + } + + server.StartAgentServer(task.RootTask("agent", false), server.Options{ + CACert: ca, + ServerCert: srv, + Port: env.AgentPort, + }) + + utils.WaitExit(3) +} diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go new file mode 100644 index 0000000..9806fe2 --- /dev/null +++ b/agent/pkg/agent/config.go @@ -0,0 +1,192 @@ +package agent + +import ( + "crypto/tls" + "crypto/x509" + "encoding/json" + "net" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/agent/pkg/certs" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/pkg" + "golang.org/x/net/context" +) + +type ( + AgentConfig struct { + Addr string + + httpClient *http.Client + tlsConfig *tls.Config + name string + l zerolog.Logger + } +) + +const ( + EndpointVersion = "/version" + EndpointName = "/name" + EndpointCACert = "/ca-cert" + EndpointProxyHTTP = "/proxy/http" + EndpointHealth = "/health" + EndpointLogs = "/logs" + + AgentHost = certs.CertsDNSName + + APIEndpointBase = "/godoxy/agent" + APIBaseURL = "https://" + AgentHost + APIEndpointBase + + DockerHost = "https://" + AgentHost + + FakeDockerHostPrefix = "agent://" + FakeDockerHostPrefixLen = len(FakeDockerHostPrefix) +) + +var ( + agents = functional.NewMapOf[string, *AgentConfig]() + agentMapMu sync.RWMutex +) + +var ( + HTTPProxyURL = types.MustParseURL(APIBaseURL + EndpointProxyHTTP) + HTTPProxyURLStripLen = len(HTTPProxyURL.Path) +) + +func IsDockerHostAgent(dockerHost string) bool { + return strings.HasPrefix(dockerHost, FakeDockerHostPrefix) +} + +func GetAgentFromDockerHost(dockerHost string) (*AgentConfig, bool) { + if !IsDockerHostAgent(dockerHost) { + return nil, false + } + return agents.Load(dockerHost[FakeDockerHostPrefixLen:]) +} + +func (cfg *AgentConfig) FakeDockerHost() string { + return FakeDockerHostPrefix + cfg.Name() +} + +func (cfg *AgentConfig) Parse(addr string) error { + cfg.Addr = addr + return cfg.load() +} + +func (cfg *AgentConfig) errIfNameExists() E.Error { + agentMapMu.RLock() + defer agentMapMu.RUnlock() + agent, ok := agents.Load(cfg.Name()) + if ok { + return E.Errorf("agent with name %s (%s) already exists", cfg.Name(), agent.Addr) + } + return nil +} + +func (cfg *AgentConfig) load() E.Error { + certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr)) + if err != nil { + if os.IsNotExist(err) { + return E.Errorf("agents certs not found, did you run `godoxy new-agent %s ...`?", cfg.Addr) + } + return E.Wrap(err) + } + + ca, crt, key, err := certs.ExtractCert(certData) + if err != nil { + return E.Wrap(err) + } + + clientCert, err := tls.X509KeyPair(crt, key) + if err != nil { + return E.Wrap(err) + } + + // create tls config + caCertPool := x509.NewCertPool() + ok := caCertPool.AppendCertsFromPEM(ca) + if !ok { + return E.New("invalid CA certificate") + } + + cfg.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: caCertPool, + } + + // create transport and http client + cfg.httpClient = cfg.NewHTTPClient() + + ctx, cancel := context.WithTimeout(task.RootContext(), 5*time.Second) + defer cancel() + + // check agent version + version, _, err := cfg.Fetch(ctx, EndpointVersion) + if err != nil { + return E.Wrap(err) + } + + if string(version) != pkg.GetVersion() { + return E.Errorf("agent version mismatch: server: %s, agent: %s", pkg.GetVersion(), string(version)) + } + + // get agent name + name, _, err := cfg.Fetch(ctx, EndpointName) + if err != nil { + return E.Wrap(err) + } + + // check if agent name is already used + cfg.name = string(name) + if err := cfg.errIfNameExists(); err != nil { + return err + } + + cfg.l = logging.With().Str("agent", cfg.name).Logger() + + agents.Store(cfg.name, cfg) + return nil +} + +func (cfg *AgentConfig) NewHTTPClient() *http.Client { + return &http.Client{ + Transport: cfg.Transport(), + } +} + +func (cfg *AgentConfig) Transport() *http.Transport { + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if addr != AgentHost+":443" { + return nil, &net.AddrError{Err: "invalid address", Addr: addr} + } + return gphttp.DefaultDialer.DialContext(ctx, network, cfg.Addr) + }, + TLSClientConfig: cfg.tlsConfig, + } +} + +func (cfg *AgentConfig) Name() string { + return cfg.name +} + +func (cfg *AgentConfig) String() string { + return "agent@" + cfg.Name() +} + +func (cfg *AgentConfig) MarshalText() ([]byte, error) { + return json.Marshal(map[string]string{ + "name": cfg.Name(), + "addr": cfg.Addr, + }) +} diff --git a/agent/pkg/agent/requests.go b/agent/pkg/agent/requests.go new file mode 100644 index 0000000..b36d365 --- /dev/null +++ b/agent/pkg/agent/requests.go @@ -0,0 +1,36 @@ +package agent + +import ( + "io" + "net/http" + + "github.com/coder/websocket" + "github.com/yusing/go-proxy/internal/logging" + "golang.org/x/net/context" +) + +func (cfg *AgentConfig) Do(ctx context.Context, method, endpoint string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, method, APIBaseURL+endpoint, body) + logging.Debug().Msgf("request: %s %s", method, req.URL.String()) + if err != nil { + return nil, err + } + return cfg.httpClient.Do(req) +} + +func (cfg *AgentConfig) Fetch(ctx context.Context, endpoint string) ([]byte, int, error) { + resp, err := cfg.Do(ctx, "GET", endpoint, nil) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + return data, resp.StatusCode, nil +} + +func (cfg *AgentConfig) Websocket(ctx context.Context, endpoint string) (*websocket.Conn, *http.Response, error) { + return websocket.Dial(ctx, APIBaseURL+endpoint, &websocket.DialOptions{ + HTTPClient: cfg.NewHTTPClient(), + Host: AgentHost, + }) +} diff --git a/agent/pkg/agentproxy/headers.go b/agent/pkg/agentproxy/headers.go new file mode 100644 index 0000000..098b5bc --- /dev/null +++ b/agent/pkg/agentproxy/headers.go @@ -0,0 +1,27 @@ +package agentproxy + +import ( + "net/http" + "strconv" +) + +const ( + HeaderXProxyHost = "X-Proxy-Host" + HeaderXProxyHTTPS = "X-Proxy-Https" + HeaderXProxySkipTLSVerify = "X-Proxy-Skip-Tls-Verify" + HeaderXProxyResponseHeaderTimeout = "X-Proxy-Response-Header-Timeout" +) + +type AgentProxyHeaders struct { + Host string + IsHTTPS bool + SkipTLSVerify bool + ResponseHeaderTimeout int +} + +func SetAgentProxyHeaders(r *http.Request, headers *AgentProxyHeaders) { + r.Header.Set(HeaderXProxyHost, headers.Host) + r.Header.Set(HeaderXProxyHTTPS, strconv.FormatBool(headers.IsHTTPS)) + r.Header.Set(HeaderXProxySkipTLSVerify, strconv.FormatBool(headers.SkipTLSVerify)) + r.Header.Set(HeaderXProxyResponseHeaderTimeout, strconv.Itoa(headers.ResponseHeaderTimeout)) +} diff --git a/agent/pkg/certs/certs.go b/agent/pkg/certs/certs.go new file mode 100644 index 0000000..d9cc984 --- /dev/null +++ b/agent/pkg/certs/certs.go @@ -0,0 +1,201 @@ +package certs + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "math/big" + "os" + "time" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/utils" +) + +const ( + CertsDNSName = "godoxy.agent" + + caCertPath = "certs/ca.crt" + caKeyPath = "certs/ca.key" + srvCertPath = "certs/agent.crt" + srvKeyPath = "certs/agent.key" +) + +func loadCerts(certPath, keyPath string) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + return &cert, err +} + +func write(b []byte, f *os.File) error { + _, err := f.Write(b) + return err +} + +func saveCerts(certDER []byte, key *rsa.PrivateKey, certPath, keyPath string) ([]byte, []byte, error) { + certPEM, keyPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), + pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + if certPath == "" || keyPath == "" { + return certPEM, keyPEM, nil + } + + certFile, err := os.Create(certPath) + if err != nil { + return nil, nil, err + } + defer certFile.Close() + + keyFile, err := os.Create(keyPath) + if err != nil { + return nil, nil, err + } + defer keyFile.Close() + + return certPEM, keyPEM, errors.Join( + write(certPEM, certFile), + write(keyPEM, keyFile), + ) +} + +func checkExists(certPath, keyPath string) bool { + certExists, err := utils.FileExists(certPath) + if err != nil { + E.LogFatal("cert error", err) + } + keyExists, err := utils.FileExists(keyPath) + if err != nil { + E.LogFatal("key error", err) + } + return certExists && keyExists +} + +func InitCerts() (ca *tls.Certificate, srv *tls.Certificate, isNew bool, err error) { + if checkExists(caCertPath, caKeyPath) && checkExists(srvCertPath, srvKeyPath) { + logging.Info().Msg("Loading existing certs...") + ca, err = loadCerts(caCertPath, caKeyPath) + if err != nil { + return nil, nil, false, err + } + srv, err = loadCerts(srvCertPath, srvKeyPath) + if err != nil { + return nil, nil, false, err + } + + logging.Info().Msg("Verifying agent cert...") + + roots := x509.NewCertPool() + roots.AddCert(ca.Leaf) + + srvCert, err := x509.ParseCertificate(srv.Certificate[0]) + if err != nil { + return nil, nil, false, err + } + + // check if srv is signed by ca + if _, err := srvCert.Verify(x509.VerifyOptions{ + Roots: roots, + }); err == nil { + logging.Info().Msg("OK") + return ca, srv, false, nil + } + logging.Error().Msg("Agent cert and CA cert mismatch, regenerating") + } + + // Create the CA's certificate + caTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"GoDoxy"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, false, err + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return nil, nil, false, err + } + + certPEM, keyPEM, err := saveCerts(caCertDER, caKey, caCertPath, caKeyPath) + if err != nil { + return nil, nil, false, err + } + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, nil, false, err + } + + ca = &cert + + // Generate a new private key for the server certificate + serverKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, false, err + } + + srvTemplate := caTemplate + srvTemplate.Issuer = srvTemplate.Subject + srvTemplate.DNSNames = append(srvTemplate.DNSNames, CertsDNSName) + + srvCertDER, err := x509.CreateCertificate(rand.Reader, &srvTemplate, &caTemplate, &serverKey.PublicKey, caKey) + if err != nil { + return nil, nil, false, err + } + + certPEM, keyPEM, err = saveCerts(srvCertDER, serverKey, srvCertPath, srvKeyPath) + if err != nil { + return nil, nil, false, err + } + + cert, err = tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, nil, false, err + } + + srv = &cert + + return ca, srv, true, nil +} + +func NewClientCert(ca *tls.Certificate) ([]byte, []byte, error) { + // Generate the SSL's private key + sslKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + + // Create the SSL's certificate + sslTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"GoDoxy"}, + CommonName: CertsDNSName, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + // Sign the certificate with the CA + sslCertDER, err := x509.CreateCertificate(rand.Reader, sslTemplate, ca.Leaf, &sslKey.PublicKey, ca.PrivateKey) + if err != nil { + return nil, nil, err + } + + return saveCerts(sslCertDER, sslKey, "", "") +} diff --git a/agent/pkg/certs/zip.go b/agent/pkg/certs/zip.go new file mode 100644 index 0000000..89532ac --- /dev/null +++ b/agent/pkg/certs/zip.go @@ -0,0 +1,76 @@ +package certs + +import ( + "archive/zip" + "bytes" + "io" + "path/filepath" + + "github.com/yusing/go-proxy/internal/common" +) + +func writeFile(zipWriter *zip.Writer, name string, data []byte) error { + w, err := zipWriter.CreateHeader(&zip.FileHeader{ + Name: name, + Method: zip.Deflate, + }) + if err != nil { + return err + } + _, err = w.Write(data) + return err +} + +func readFile(f *zip.File) ([]byte, error) { + r, err := f.Open() + if err != nil { + return nil, err + } + defer r.Close() + return io.ReadAll(r) +} + +func ZipCert(ca, crt, key []byte) ([]byte, error) { + data := bytes.NewBuffer(nil) + zipWriter := zip.NewWriter(data) + defer zipWriter.Close() + + if err := writeFile(zipWriter, "ca.pem", ca); err != nil { + return nil, err + } + if err := writeFile(zipWriter, "cert.pem", crt); err != nil { + return nil, err + } + if err := writeFile(zipWriter, "key.pem", key); err != nil { + return nil, err + } + if err := zipWriter.Close(); err != nil { + return nil, err + } + return data.Bytes(), nil +} + +func AgentCertsFilename(host string) string { + return filepath.Join(common.AgentCertsBasePath, host+".zip") +} + +func ExtractCert(data []byte) (ca, crt, key []byte, err error) { + zipReader, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + return nil, nil, nil, err + } + for _, file := range zipReader.File { + switch file.Name { + case "ca.pem": + ca, err = readFile(file) + case "cert.pem": + crt, err = readFile(file) + case "key.pem": + key, err = readFile(file) + } + if err != nil { + return nil, nil, nil, err + } + } + return ca, crt, key, nil +} diff --git a/agent/pkg/env/env.go b/agent/pkg/env/env.go new file mode 100644 index 0000000..870f9f6 --- /dev/null +++ b/agent/pkg/env/env.go @@ -0,0 +1,8 @@ +package env + +import "github.com/yusing/go-proxy/internal/common" + +var ( + AgentName = common.GetEnvString("AGENT_NAME", "agent") + AgentPort = common.GetEnvInt("AGENT_PORT", 8890) +) diff --git a/agent/pkg/handler/check_health.go b/agent/pkg/handler/check_health.go new file mode 100644 index 0000000..7cef3c0 --- /dev/null +++ b/agent/pkg/handler/check_health.go @@ -0,0 +1,66 @@ +package handler + +import ( + "net/http" + "net/url" + + apiUtils "github.com/yusing/go-proxy/internal/api/v1/utils" + "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/watcher/health" + "github.com/yusing/go-proxy/internal/watcher/health/monitor" +) + +func CheckHealth(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + scheme := query.Get("scheme") + if scheme == "" { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + + var result *health.HealthCheckResult + var err error + switch scheme { + case "fileserver": + path := query.Get("path") + if path == "" { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + ok, err := utils.FileExists(path) + result = &health.HealthCheckResult{Healthy: ok} + if err != nil { + result.Detail = err.Error() + } + case "http", "https": // path is optional + host := query.Get("host") + path := query.Get("path") + if host == "" { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + result, err = monitor.NewHTTPHealthChecker(types.NewURL(&url.URL{ + Scheme: scheme, + Host: host, + Path: path, + }), health.DefaultHealthConfig).CheckHealth() + case "tcp", "udp": + host := query.Get("host") + if host == "" { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + result, err = monitor.NewRawHealthChecker(types.NewURL(&url.URL{ + Scheme: scheme, + Host: host, + }), health.DefaultHealthConfig).CheckHealth() + } + + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + + apiUtils.RespondJSON(w, r, result) +} diff --git a/agent/pkg/handler/docker_socket.go b/agent/pkg/handler/docker_socket.go new file mode 100644 index 0000000..a167b09 --- /dev/null +++ b/agent/pkg/handler/docker_socket.go @@ -0,0 +1,92 @@ +package handler + +import ( + "bufio" + "errors" + "io" + "net/http" + "strings" + + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/docker" + "github.com/yusing/go-proxy/internal/logging" + godoxyIO "github.com/yusing/go-proxy/internal/utils" +) + +func DockerSocketHandler() http.HandlerFunc { + dockerClient, err := docker.ConnectClient(common.DockerHostFromEnv) + if err != nil { + logging.Fatal().Err(err).Msg("failed to connect to docker client") + } + dockerDialerCallback := dockerClient.Dialer() + + return func(w http.ResponseWriter, r *http.Request) { + conn, err := dockerDialerCallback(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + + // Create a done channel to handle cancellation + done := make(chan struct{}) + defer close(done) + + closed := false + + // Start a goroutine to monitor context cancellation + go func() { + select { + case <-r.Context().Done(): + closed = true + conn.Close() // Force close the connection when client disconnects + case <-done: + } + }() + + if err := r.Write(conn); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // Set any response headers before writing the status code + for k, v := range resp.Header { + w.Header()[k] = v + } + w.WriteHeader(resp.StatusCode) + + // For event streams, we need to flush the writer to ensure + // events are sent immediately + if f, ok := w.(http.Flusher); ok && strings.HasSuffix(r.URL.Path, "/events") { + // Copy the body in chunks and flush after each write + buf := make([]byte, 2048) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + _, werr := w.Write(buf[:n]) + if werr != nil { + logging.Error().Err(werr).Msg("error writing docker event response") + break + } + f.Flush() + } + if err != nil { + if !closed && !errors.Is(err, io.EOF) { + logging.Error().Err(err).Msg("error reading docker event response") + } + return + } + } + } else { + // For non-event streams, just copy the body + godoxyIO.NewPipe(r.Context(), resp.Body, NopWriteCloser{w}).Start() + } + } +} diff --git a/agent/pkg/handler/handler.go b/agent/pkg/handler/handler.go new file mode 100644 index 0000000..b596bda --- /dev/null +++ b/agent/pkg/handler/handler.go @@ -0,0 +1,50 @@ +package handler + +import ( + "fmt" + "io" + "net/http" + + "github.com/yusing/go-proxy/agent/pkg/agent" + "github.com/yusing/go-proxy/agent/pkg/env" + v1 "github.com/yusing/go-proxy/internal/api/v1" + "github.com/yusing/go-proxy/internal/logging/memlogger" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type ServeMux struct{ *http.ServeMux } + +func (mux ServeMux) HandleMethods(methods, endpoint string, handler http.HandlerFunc) { + for _, m := range strutils.CommaSeperatedList(methods) { + mux.ServeMux.HandleFunc(m+" "+agent.APIEndpointBase+endpoint, handler) + } +} + +func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) { + mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler) +} + +type NopWriteCloser struct { + io.Writer +} + +func (NopWriteCloser) Close() error { + return nil +} + +func NewHandler(caCertPEM []byte) http.Handler { + mux := ServeMux{http.NewServeMux()} + + mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP) + mux.HandleMethods("GET", agent.EndpointVersion, v1.GetVersion) + mux.HandleMethods("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, env.AgentName) + }) + mux.HandleMethods("GET", agent.EndpointCACert, func(w http.ResponseWriter, r *http.Request) { + w.Write(caCertPEM) + }) + mux.HandleMethods("GET", agent.EndpointHealth, CheckHealth) + mux.HandleMethods("GET", agent.EndpointLogs, memlogger.LogsWS(nil)) + mux.ServeMux.HandleFunc("/", DockerSocketHandler()) + return mux +} diff --git a/agent/pkg/handler/proxy_http.go b/agent/pkg/handler/proxy_http.go new file mode 100644 index 0000000..badfeb9 --- /dev/null +++ b/agent/pkg/handler/proxy_http.go @@ -0,0 +1,59 @@ +package handler + +import ( + "crypto/tls" + "net/http" + "strconv" + "time" + + "github.com/yusing/go-proxy/agent/pkg/agent" + agentproxy "github.com/yusing/go-proxy/agent/pkg/agentproxy" + "github.com/yusing/go-proxy/internal/logging" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" + "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +func ProxyHTTP(w http.ResponseWriter, r *http.Request) { + host := r.Header.Get(agentproxy.HeaderXProxyHost) + isHTTPs := strutils.ParseBool(r.Header.Get(agentproxy.HeaderXProxyHTTPS)) + skipTLSVerify := strutils.ParseBool(r.Header.Get(agentproxy.HeaderXProxySkipTLSVerify)) + responseHeaderTimeout, err := strconv.Atoi(r.Header.Get(agentproxy.HeaderXProxyResponseHeaderTimeout)) + if err != nil { + responseHeaderTimeout = 0 + } + + logging.Debug().Msgf("proxy http request: host=%s, isHTTPs=%t, skipTLSVerify=%t, responseHeaderTimeout=%d", host, isHTTPs, skipTLSVerify, responseHeaderTimeout) + + if host == "" { + http.Error(w, "missing required headers", http.StatusBadRequest) + return + } + + scheme := "http" + if isHTTPs { + scheme = "https" + } + + var transport *http.Transport + if skipTLSVerify { + transport = gphttp.NewTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + } else { + transport = gphttp.NewTransport() + } + + if responseHeaderTimeout > 0 { + transport = transport.Clone() + transport.ResponseHeaderTimeout = time.Duration(responseHeaderTimeout) * time.Second + } + + r.URL.Scheme = scheme + r.URL.Host = host + r.URL.Path = r.URL.Path[agent.HTTPProxyURLStripLen:] // strip the {API_BASE}/proxy/http prefix + + logging.Debug().Msgf("proxy http request: %s %s", r.Method, r.URL.String()) + + rp := reverseproxy.NewReverseProxy("agent", types.NewURL(r.URL), transport) + rp.ServeHTTP(w, r) +} diff --git a/agent/pkg/server/server.go b/agent/pkg/server/server.go new file mode 100644 index 0000000..a8c34c1 --- /dev/null +++ b/agent/pkg/server/server.go @@ -0,0 +1,51 @@ +package server + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "log" + "net" + "net/http" + + "github.com/yusing/go-proxy/agent/pkg/handler" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/task" +) + +type Options struct { + CACert, ServerCert *tls.Certificate + Port int +} + +func StartAgentServer(parent task.Parent, opt Options) { + caCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: opt.CACert.Certificate[0]}) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCertPEM) + + // Configure TLS + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*opt.ServerCert}, + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + if common.IsDebug { + tlsConfig.ClientAuth = tls.NoClientCert + } + l, err := net.Listen("tcp", fmt.Sprintf(":%d", opt.Port)) + if err != nil { + logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to listen on port") + return + } + defer l.Close() + + server := &http.Server{ + Handler: handler.NewHandler(caCertPEM), + TLSConfig: tlsConfig, + ErrorLog: log.New(logging.GetLogger(), "", 0), + } + server.Serve(tls.NewListener(l, tlsConfig)) +} diff --git a/cmd/main.go b/cmd/main.go index 4352af6..3ed9ca8 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -5,13 +5,9 @@ import ( "io" "log" "os" - "os/signal" - "syscall" - "time" "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal" - v1 "github.com/yusing/go-proxy/internal/api/v1" "github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/favicon" "github.com/yusing/go-proxy/internal/api/v1/query" @@ -20,9 +16,10 @@ import ( E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/homepage" "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/route/routes/routequery" - "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/pkg" ) @@ -31,19 +28,21 @@ var rawLogger = log.New(os.Stdout, "", 0) func init() { var out io.Writer = os.Stderr if common.EnableLogStreaming { - out = zerolog.MultiLevelWriter(out, v1.GetMemLogger()) + out = zerolog.MultiLevelWriter(out, memlogger.GetMemLogger()) } logging.InitLogger(out) - // logging.AddHook(v1.GetMemLogger()) } func main() { initProfiling() - args := common.GetArgs() + args := pkg.GetArgs(common.MainServerCommandValidator{}) switch args.Command { case common.CommandSetup: - internal.Setup() + Setup() + return + case common.CommandNewAgent: + NewAgent(args.Args) return case common.CommandReload: if err := query.ReloadServer(); err != nil { @@ -141,17 +140,7 @@ func main() { config.WatchChanges() - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGINT) - signal.Notify(sig, syscall.SIGTERM) - signal.Notify(sig, syscall.SIGHUP) - - // wait for signal - <-sig - - // gracefully shutdown - logging.Info().Msg("shutting down") - _ = task.GracefulShutdown(time.Second * time.Duration(cfg.Value().TimeoutShutdown)) + utils.WaitExit(cfg.Value().TimeoutShutdown) } func prepareDirectory(dir string) { diff --git a/cmd/new_agent.go b/cmd/new_agent.go new file mode 100644 index 0000000..05c9afe --- /dev/null +++ b/cmd/new_agent.go @@ -0,0 +1,46 @@ +package main + +import ( + "encoding/base64" + "log" + "net" + "os" + + "github.com/yusing/go-proxy/agent/pkg/certs" +) + +func NewAgent(args []string) { + if len(args) != 2 { + log.Fatalf("invalid arguments: %v", args) + } + host := args[0] + certDataBase64 := args[1] + + ip, _, err := net.SplitHostPort(host) + if err != nil { + log.Fatalf("invalid host: %v", err) + } + + _, err = net.ResolveIPAddr("ip", ip) + if err != nil { + log.Fatalf("invalid host: %v", err) + } + + certData, err := base64.StdEncoding.DecodeString(certDataBase64) + if err != nil { + log.Fatalf("invalid cert data: %v", err) + } + + f, err := os.OpenFile(certs.AgentCertsFilename(host), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + log.Fatalf("failed to create file: %v", err) + } + defer f.Close() + + _, err = f.Write(certData) + if err != nil { + log.Fatalf("failed to write cert data: %v", err) + } + + log.Printf("agent cert created: %s", certs.AgentCertsFilename(host)) +} diff --git a/cmd/main_production.go b/cmd/pprof_production.go similarity index 100% rename from cmd/main_production.go rename to cmd/pprof_production.go diff --git a/cmd/main_prof.go b/cmd/pprof_prof.go similarity index 100% rename from cmd/main_prof.go rename to cmd/pprof_prof.go diff --git a/internal/setup.go b/cmd/setup.go similarity index 99% rename from internal/setup.go rename to cmd/setup.go index 60b3ec6..d0440fd 100644 --- a/internal/setup.go +++ b/cmd/setup.go @@ -1,4 +1,4 @@ -package internal +package main import ( "io" diff --git a/internal/api/handler.go b/internal/api/handler.go index 1ce6ea9..5e9e4b2 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -10,6 +10,7 @@ import ( "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -35,7 +36,7 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats)) mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS)) mux.HandleFunc("GET", "/v1/health/ws", auth.RequireAuth(useCfg(cfg, v1.HealthWS))) - mux.HandleFunc("GET", "/v1/logs/ws", auth.RequireAuth(useCfg(cfg, v1.LogsWS()))) + mux.HandleFunc("GET", "/v1/logs/ws", auth.RequireAuth(memlogger.LogsWS(cfg))) mux.HandleFunc("GET", "/v1/favicon", auth.RequireAuth(favicon.GetFavIcon)) mux.HandleFunc("POST", "/v1/homepage/set", auth.RequireAuth(v1.SetHomePageOverrides)) diff --git a/internal/api/v1/utils/logging.go b/internal/api/v1/utils/logging.go index 194735f..6f06e31 100644 --- a/internal/api/v1/utils/logging.go +++ b/internal/api/v1/utils/logging.go @@ -9,7 +9,6 @@ import ( func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event { return logging.WithLevel(level). - Str("module", "api"). Str("remote", r.RemoteAddr). Str("host", r.Host). Str("uri", r.Method+" "+r.RequestURI) diff --git a/internal/api/v1/utils/ws.go b/internal/api/v1/utils/ws.go index 127892d..4e417a0 100644 --- a/internal/api/v1/utils/ws.go +++ b/internal/api/v1/utils/ws.go @@ -22,7 +22,7 @@ func InitiateWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} - if len(cfg.Value().MatchDomains) == 0 { + if cfg == nil || len(cfg.Value().MatchDomains) == 0 { warnNoMatchDomainOnce.Do(warnNoMatchDomains) originPats = []string{"*"} } else { diff --git a/internal/common/args.go b/internal/common/args.go index e9f6f57..cd825fe 100644 --- a/internal/common/args.go +++ b/internal/common/args.go @@ -1,18 +1,9 @@ package common -import ( - "flag" - "fmt" - "log" -) - -type Args struct { - Command string -} - const ( CommandStart = "" CommandSetup = "setup" + CommandNewAgent = "new-agent" CommandValidate = "validate" CommandListConfigs = "ls-config" CommandListRoutes = "ls-routes" @@ -23,34 +14,22 @@ const ( CommandDebugListMTrace = "debug-ls-mtrace" ) -var ValidCommands = []string{ - CommandStart, - CommandSetup, - CommandValidate, - CommandListConfigs, - CommandListRoutes, - CommandListIcons, - CommandReload, - CommandDebugListEntries, - CommandDebugListProviders, - CommandDebugListMTrace, -} +type MainServerCommandValidator struct{} -func GetArgs() Args { - var args Args - flag.Parse() - args.Command = flag.Arg(0) - if err := validateArg(args.Command); err != nil { - log.Fatalf("invalid command: %s", err) +func (v MainServerCommandValidator) IsCommandValid(cmd string) bool { + switch cmd { + case CommandStart, + CommandSetup, + CommandNewAgent, + CommandValidate, + CommandListConfigs, + CommandListRoutes, + CommandListIcons, + CommandReload, + CommandDebugListEntries, + CommandDebugListProviders, + CommandDebugListMTrace: + return true } - return args -} - -func validateArg(arg string) error { - for _, v := range ValidCommands { - if arg == v { - return nil - } - } - return fmt.Errorf("invalid command %q", arg) + return false } diff --git a/internal/common/constants.go b/internal/common/constants.go index cf6d0f2..230a389 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -30,6 +30,8 @@ const ( ComposeExampleFileName = "compose.example.yml" ErrorPagesBasePath = "error_pages" + + AgentCertsBasePath = "certs" ) var RequiredDirectories = []string{ diff --git a/internal/common/env.go b/internal/common/env.go index 8191b8c..a62f861 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -86,6 +86,10 @@ func GetEnvBool(key string, defaultValue bool) bool { return GetEnv(key, defaultValue, strconv.ParseBool) } +func GetEnvInt(key string, defaultValue int) int { + return GetEnv(key, defaultValue, strconv.Atoi) +} + func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL string) { addr = GetEnvString(key, defaultValue) if addr == "" { diff --git a/internal/common/ports.go b/internal/common/ports.go index 7816355..6bd3cfe 100644 --- a/internal/common/ports.go +++ b/internal/common/ports.go @@ -9,7 +9,7 @@ var ( "3000": true, } - ServiceNamePortMapTCP = map[string]int{ + ImageNamePortMapTCP = map[string]int{ "mssql": 1433, "mysql": 3306, "mariadb": 3306, @@ -19,27 +19,9 @@ var ( "memcached": 11211, "mongo": 27017, "minecraft-server": 25565, - - "ssh": 22, - "ftp": 21, - "smtp": 25, - "dns": 53, - "pop3": 110, - "imap": 143, } - ImageNamePortMap = func() (m map[string]int) { - m = make(map[string]int, len(ServiceNamePortMapTCP)+len(imageNamePortMap)) - for k, v := range ServiceNamePortMapTCP { - m[k] = v - } - for k, v := range imageNamePortMap { - m[k] = v - } - return - }() - - imageNamePortMap = map[string]int{ + ImageNamePortMapHTTP = map[string]int{ "adguardhome": 3000, "bazarr": 6767, "calibre-web": 8083, diff --git a/internal/config/config.go b/internal/config/config.go index c086156..bfc72d1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -287,6 +287,9 @@ func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error { lenLongestName = len(p.String()) } } + for _, agent := range providers.Agents { + cfg.providers.Store(agent.Name(), proxy.NewAgentProvider(&agent)) + } cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { if err := p.LoadRoutes(); err != nil { errs.Add(err.Subject(p.String())) diff --git a/internal/config/types/config.go b/internal/config/types/config.go index e7a7e3d..b75d421 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -5,6 +5,7 @@ import ( "regexp" "github.com/go-playground/validator/v10" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/notif" @@ -23,9 +24,10 @@ type ( TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"` } Providers struct { - Files []string `json:"include" validate:"dive,filepath"` - Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"` - Notification []notif.NotificationConfig `json:"notification"` + Files []string `json:"include" yaml:"include,omitempty" validate:"dive,filepath"` + Docker map[string]string `json:"docker" yaml:"docker,omitempty" validate:"dive,unix_addr|url"` + Agents []agent.AgentConfig `json:"agents" yaml:"agents,omitempty"` + Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty"` } Entrypoint struct { Middlewares []map[string]any `json:"middlewares"` diff --git a/internal/docker/client.go b/internal/docker/client.go index f547334..4bb2a6b 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -2,12 +2,14 @@ package docker import ( "errors" + "fmt" "net/http" "sync" "github.com/docker/cli/cli/connhelper" "github.com/docker/docker/client" "github.com/rs/zerolog" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" @@ -81,32 +83,44 @@ func ConnectClient(host string) (*SharedClient, error) { // create client var opt []client.Opt - switch host { - case "": - return nil, errors.New("empty docker host") - case common.DockerHostFromEnv: - opt = clientOptEnvHost - default: - helper, err := connhelper.GetConnectionHelper(host) - if err != nil { - logging.Panic().Err(err).Msg("failed to get connection helper") + if agent.IsDockerHostAgent(host) { + cfg, ok := agent.GetAgentFromDockerHost(host) + if !ok { + return nil, fmt.Errorf("agent not found for host: %s", host) } - if helper != nil { - httpClient := &http.Client{ - Transport: &http.Transport{ - DialContext: helper.Dialer, - }, + opt = []client.Opt{ + client.WithHost(agent.DockerHost), + client.WithHTTPClient(cfg.NewHTTPClient()), + client.WithAPIVersionNegotiation(), + } + } else { + switch host { + case "": + return nil, errors.New("empty docker host") + case common.DockerHostFromEnv: + opt = clientOptEnvHost + default: + helper, err := connhelper.GetConnectionHelper(host) + if err != nil { + logging.Panic().Err(err).Msg("failed to get connection helper") } - opt = []client.Opt{ - client.WithHTTPClient(httpClient), - client.WithHost(helper.Host), - client.WithAPIVersionNegotiation(), - client.WithDialContext(helper.Dialer), - } - } else { - opt = []client.Opt{ - client.WithHost(host), - client.WithAPIVersionNegotiation(), + if helper != nil { + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: helper.Dialer, + }, + } + opt = []client.Opt{ + client.WithHTTPClient(httpClient), + client.WithHost(helper.Host), + client.WithAPIVersionNegotiation(), + client.WithDialContext(helper.Dialer), + } + } else { + opt = []client.Opt{ + client.WithHost(host), + client.WithAPIVersionNegotiation(), + } } } } diff --git a/internal/docker/container.go b/internal/docker/container.go index b8e341f..11b56d7 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/docker/docker/api/types" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/logging" U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -21,13 +22,14 @@ type ( ContainerID string `json:"container_id"` ImageName string `json:"image_name"` + Agent *agent.AgentConfig `json:"agent"` + Labels map[string]string `json:"-"` PublicPortMapping PortMapping `json:"public_ports"` // non-zero publicPort:types.Port PrivatePortMapping PortMapping `json:"private_ports"` // privatePort:types.Port - PublicIP string `json:"public_ip"` - PrivateIP string `json:"private_ip"` - NetworkMode string `json:"network_mode"` + PublicHostname string `json:"public_hostname"` + PrivateHostname string `json:"private_hostname"` Aliases []string `json:"aliases"` IsExcluded bool `json:"is_excluded"` @@ -51,7 +53,8 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) { for lbl := range c.Labels { if strings.HasPrefix(lbl, NSProxy+".") { isExplicit = true - break + } else { + delete(c.Labels, lbl) } } res = &Container{ @@ -64,7 +67,6 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) { PublicPortMapping: helper.getPublicPortMapping(), PrivatePortMapping: helper.getPrivatePortMapping(), - NetworkMode: c.HostConfig.NetworkMode, Aliases: helper.getAliases(), IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)), @@ -78,8 +80,13 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) { StartEndpoint: helper.getDeleteLabel(LabelStartEndpoint), Running: c.Status == "running" || c.State == "running", } - res.setPrivateIP(helper) - res.setPublicIP() + + if agent.IsDockerHostAgent(dockerHost) { + res.Agent, _ = agent.GetAgentFromDockerHost(dockerHost) + } + + res.setPrivateHostname(helper) + res.setPublicHostname() return } @@ -115,29 +122,28 @@ func FromJSON(json types.ContainerJSON, dockerHost string) *Container { Networks: json.NetworkSettings.Networks, }, }, dockerHost) - cont.NetworkMode = string(json.HostConfig.NetworkMode) return cont } -func (c *Container) setPublicIP() { +func (c *Container) setPublicHostname() { if !c.Running { return } if strings.HasPrefix(c.DockerHost, "unix://") { - c.PublicIP = "127.0.0.1" + c.PublicHostname = "127.0.0.1" return } url, err := url.Parse(c.DockerHost) if err != nil { logging.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost) - c.PublicIP = "127.0.0.1" + c.PublicHostname = "127.0.0.1" return } - c.PublicIP = url.Hostname() + c.PublicHostname = url.Hostname() } -func (c *Container) setPrivateIP(helper containerHelper) { - if !strings.HasPrefix(c.DockerHost, "unix://") { +func (c *Container) setPrivateHostname(helper containerHelper) { + if !strings.HasPrefix(c.DockerHost, "unix://") && c.Agent == nil { return } if helper.NetworkSettings == nil { @@ -147,7 +153,7 @@ func (c *Container) setPrivateIP(helper containerHelper) { if v.IPAddress == "" { continue } - c.PrivateIP = v.IPAddress + c.PrivateHostname = v.IPAddress return } } diff --git a/internal/error/log.go b/internal/error/log.go index 7e79d1e..0729023 100644 --- a/internal/error/log.go +++ b/internal/error/log.go @@ -2,6 +2,7 @@ package err import ( "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/logging" ) @@ -14,6 +15,9 @@ func getLogger(logger ...*zerolog.Logger) *zerolog.Logger { //go:inline func LogFatal(msg string, err error, logger ...*zerolog.Logger) { + if common.IsDebug { + LogPanic(msg, err, logger...) + } getLogger(logger...).Fatal().Msg(err.Error()) } diff --git a/internal/error/utils.go b/internal/error/utils.go index e4440c2..db346cb 100644 --- a/internal/error/utils.go +++ b/internal/error/utils.go @@ -66,9 +66,3 @@ func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn eb.Add(err) return result } - -func Collect2[T any, Err error, Arg1 any, Arg2 any, Func func(Arg1, Arg2) (T, Err)](eb *Builder, fn Func, arg1 Arg1, arg2 Arg2) T { - result, err := fn(arg1, arg2) - eb.Add(err) - return result -} diff --git a/internal/logging/html.go b/internal/logging/html.go deleted file mode 100644 index 5d25b51..0000000 --- a/internal/logging/html.go +++ /dev/null @@ -1,159 +0,0 @@ -package logging - -import ( - "errors" - "fmt" - "time" - - "github.com/rs/zerolog" - "github.com/yusing/go-proxy/internal/common" -) - -var levelHTMLFormats = [][]byte{ - []byte(` TRC `), - []byte(` DBG `), - []byte(` INF `), - []byte(` WRN `), - []byte(` ERR `), - []byte(` FTL `), - []byte(` PAN `), -} - -var colorToClass = map[string]string{ - "1": "log-bold", - "3": "log-italic", - "4": "log-underline", - "30": "log-black", - "31": "log-red", - "32": "log-green", - "33": "log-yellow", - "34": "log-blue", - "35": "log-magenta", - "36": "log-cyan", - "37": "log-white", - "90": "log-bright-black", - "91": "log-red", - "92": "log-bright-green", - "93": "log-bright-yellow", - "94": "log-bright-blue", - "95": "log-bright-magenta", - "96": "log-bright-cyan", - "97": "log-bright-white", -} - -// FormatMessageToHTMLBytes converts text with ANSI color codes to HTML with class names. -// ANSI codes are mapped to classes via a static map, and reset codes ([0m) close all spans. -// Time complexity is O(n) with minimal allocations. -func FormatMessageToHTMLBytes(msg string, buf []byte) ([]byte, error) { - buf = append(buf, ""...) - var stack []string - lastPos := 0 - - for i := 0; i < len(msg); { - if msg[i] == '\x1b' && i+1 < len(msg) && msg[i+1] == '[' { - if lastPos < i { - escapeAndAppend(msg[lastPos:i], &buf) - } - i += 2 // Skip \x1b[ - - start := i - for ; i < len(msg) && msg[i] != 'm'; i++ { - if !isANSICodeChar(msg[i]) { - return nil, fmt.Errorf("invalid ANSI char: %c", msg[i]) - } - } - - if i >= len(msg) { - return nil, errors.New("unterminated ANSI sequence") - } - - codeStr := msg[start:i] - i++ // Skip 'm' - lastPos = i - - startPart := 0 - for j := 0; j <= len(codeStr); j++ { - if j == len(codeStr) || codeStr[j] == ';' { - part := codeStr[startPart:j] - if part == "" { - return nil, errors.New("empty code part") - } - - if part == "0" { - for range stack { - buf = append(buf, ""...) - } - stack = stack[:0] - } else { - className, ok := colorToClass[part] - if !ok { - return nil, fmt.Errorf("invalid ANSI code: %s", part) - } - stack = append(stack, className) - buf = append(buf, ``...) - } - startPart = j + 1 - } - } - } else { - i++ - } - } - - if lastPos < len(msg) { - escapeAndAppend(msg[lastPos:], &buf) - } - - for range stack { - buf = append(buf, ""...) - } - - buf = append(buf, ""...) - return buf, nil -} - -func isANSICodeChar(c byte) bool { - return (c >= '0' && c <= '9') || c == ';' -} - -func escapeAndAppend(s string, buf *[]byte) { - for i, r := range s { - switch r { - case '•': - *buf = append(*buf, "·"...) - case '&': - *buf = append(*buf, "&"...) - case '<': - *buf = append(*buf, "<"...) - case '>': - *buf = append(*buf, ">"...) - case '\t': - *buf = append(*buf, " "...) - case '\n': - *buf = append(*buf, "
"...) - *buf = append(*buf, prefixHTML...) - default: - *buf = append(*buf, s[i]) - } - } -} - -func timeNowHTML() []byte { - if !common.IsTest { - return []byte(time.Now().Format(timeFmt)) - } - return []byte(time.Date(2024, 1, 1, 1, 1, 1, 1, time.UTC).Format(timeFmt)) -} - -func FormatLogEntryHTML(level zerolog.Level, message string, buf []byte) []byte { - buf = append(buf, []byte(`
`)...)
-	buf = append(buf, timeNowHTML()...)
-	if level < zerolog.NoLevel {
-		buf = append(buf, levelHTMLFormats[level+1]...)
-	}
-	buf, _ = FormatMessageToHTMLBytes(message, buf)
-	buf = append(buf, []byte("
")...) - return buf -} diff --git a/internal/logging/html_test.go b/internal/logging/html_test.go deleted file mode 100644 index 9c6b578..0000000 --- a/internal/logging/html_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package logging - -import ( - "testing" - - "github.com/rs/zerolog" - . "github.com/yusing/go-proxy/internal/utils/testing" -) - -func TestFormatHTML(t *testing.T) { - buf := make([]byte, 0, 100) - buf = FormatLogEntryHTML(zerolog.InfoLevel, "This is a test.\nThis is a new line.", buf) - ExpectEqual(t, string(buf), `
01-01 01:01 INF This is a test.
`+prefix+`This is a new line.
`) -} - -func TestFormatHTMLANSI(t *testing.T) { - buf := make([]byte, 0, 100) - buf = FormatLogEntryHTML(zerolog.InfoLevel, "This is \x1b[91m\x1b[1ma test.\x1b[0mOK!.", buf) - ExpectEqual(t, string(buf), `
01-01 01:01 INF This is a test.OK!.
`) - buf = buf[:0] - buf = FormatLogEntryHTML(zerolog.InfoLevel, "This is \x1b[91ma \x1b[1mtest.\x1b[0mOK!.", buf) - ExpectEqual(t, string(buf), `
01-01 01:01 INF This is a test.OK!.
`) -} - -func BenchmarkFormatLogEntryHTML(b *testing.B) { - buf := make([]byte, 0, 250) - for range b.N { - FormatLogEntryHTML(zerolog.InfoLevel, "This is \x1b[91ma \x1b[1mtest.\x1b[0mOK!.", buf) - } -} diff --git a/internal/api/v1/mem_logger.go b/internal/logging/memlogger/mem_logger.go similarity index 82% rename from internal/api/v1/mem_logger.go rename to internal/logging/memlogger/mem_logger.go index 7879a53..8647e62 100644 --- a/internal/api/v1/mem_logger.go +++ b/internal/logging/memlogger/mem_logger.go @@ -1,4 +1,4 @@ -package v1 +package memlogger import ( "bytes" @@ -9,7 +9,6 @@ import ( "time" "github.com/coder/websocket" - "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" @@ -31,11 +30,7 @@ type memLogger struct { bufPool sync.Pool // used in hook mode } -type MemLogger interface { - io.Writer - // TODO: hook does not pass in fields, looking for a workaround to do server side log rendering - zerolog.Hook -} +type MemLogger io.Writer type buffer struct { data []byte @@ -85,8 +80,10 @@ func init() { } } -func LogsWS() func(config config.ConfigInstance, w http.ResponseWriter, r *http.Request) { - return memLoggerInstance.ServeHTTP +func LogsWS(config config.ConfigInstance) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + memLoggerInstance.ServeHTTP(config, w, r) + } } func GetMemLogger() MemLogger { @@ -138,29 +135,6 @@ func (m *memLogger) writeBuf(b []byte) (pos int, err error) { return } -// Run implements zerolog.Hook. -func (m *memLogger) Run(e *zerolog.Event, level zerolog.Level, message string) { - bufStruct := m.bufPool.Get().(*buffer) - buf := bufStruct.data - defer func() { - bufStruct.data = bufStruct.data[:0] - m.bufPool.Put(bufStruct) - }() - - buf = logging.FormatLogEntryHTML(level, message, buf) - n := len(buf) - - m.truncateIfNeeded(n) - - pos, err := m.writeBuf(buf) - if err != nil { - // not logging the error here, it will cause Run to be called again = infinite loop - return - } - - m.notifyWS(pos, n) -} - // Write implements io.Writer. func (m *memLogger) Write(p []byte) (n int, err error) { n = len(p) diff --git a/internal/net/http/middleware/custom_error_page.go b/internal/net/http/middleware/custom_error_page.go index 730422f..538c418 100644 --- a/internal/net/http/middleware/custom_error_page.go +++ b/internal/net/http/middleware/custom_error_page.go @@ -17,6 +17,8 @@ type customErrorPage struct{} var CustomErrorPage = NewMiddleware[customErrorPage]() +const StaticFilePathPrefix = "/$gperrorpage/" + // before implements RequestModifier. func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) { return !ServeStaticErrorPageFile(w, r) @@ -49,8 +51,8 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo if path != "" && path[0] != '/' { path = "/" + path } - if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) { - filename := path[len(gphttp.StaticFilePathPrefix):] + if strings.HasPrefix(path, StaticFilePathPrefix) { + filename := path[len(StaticFilePathPrefix):] file, ok := errorpage.GetStaticFile(filename) if !ok { logging.Error().Msg("unable to load resource " + filename) diff --git a/internal/net/http/server/server.go b/internal/net/http/server/server.go index 6e27a66..d0e48ab 100644 --- a/internal/net/http/server/server.go +++ b/internal/net/http/server/server.go @@ -6,6 +6,7 @@ import ( "errors" "io" "log" + "net" "net/http" "time" @@ -45,7 +46,7 @@ func StartServer(parent task.Parent, opt Options) (s *Server) { func NewServer(opt Options) (s *Server) { var httpSer, httpsSer *http.Server - logger := logging.With().Str("module", "server").Str("name", opt.Name).Logger() + logger := logging.With().Str("server", opt.Name).Logger() certAvailable := false if opt.CertProvider != nil { @@ -55,7 +56,7 @@ func NewServer(opt Options) (s *Server) { out := io.Discard if common.IsDebug { - out = logging.GetLogger() + out = logger } if opt.HTTPAddr != "" { @@ -107,7 +108,13 @@ func (s *Server) Start(parent task.Parent) { if s.https != nil { go func() { - s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath())) + l, err := net.Listen("tcp", s.https.Addr) + if err != nil { + s.handleErr("https", err) + return + } + defer l.Close() + s.handleErr("https", s.https.Serve(tls.NewListener(l, s.https.TLSConfig))) }() s.httpsStarted = true s.l.Info().Str("addr", s.https.Addr).Msgf("server started") diff --git a/internal/net/http/common.go b/internal/net/http/transport.go similarity index 50% rename from internal/net/http/common.go rename to internal/net/http/transport.go index 8671cf3..84e294f 100644 --- a/internal/net/http/common.go +++ b/internal/net/http/transport.go @@ -7,28 +7,28 @@ import ( "time" ) -var ( - defaultDialer = net.Dialer{ - Timeout: 60 * time.Second, - } - DefaultTransport = &http.Transport{ +var DefaultDialer = net.Dialer{ + Timeout: 5 * time.Second, +} + +func NewTransport() *http.Transport { + return &http.Transport{ Proxy: http.ProxyFromEnvironment, - DialContext: defaultDialer.DialContext, + DialContext: DefaultDialer.DialContext, ForceAttemptHTTP2: true, MaxIdleConnsPerHost: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, - DisableCompression: true, // Prevent double compression + // DisableCompression: true, // Prevent double compression ResponseHeaderTimeout: 60 * time.Second, WriteBufferSize: 16 * 1024, // 16KB ReadBufferSize: 16 * 1024, // 16KB } - DefaultTransportNoTLS = func() *http.Transport { - clone := DefaultTransport.Clone() - clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - return clone - }() -) +} -const StaticFilePathPrefix = "/$gperrorpage/" +func NewTransportWithTLSConfig(tlsConfig *tls.Config) *http.Transport { + tr := NewTransport() + tr.TLSClientConfig = tlsConfig + return tr +} diff --git a/internal/route/fileserver.go b/internal/route/fileserver.go index 8d45c4d..e407d9b 100644 --- a/internal/route/fileserver.go +++ b/internal/route/fileserver.go @@ -99,7 +99,7 @@ func (s *FileServer) Start(parent task.Parent) E.Error { } if s.UseHealthCheck() { - s.Health = monitor.NewFileServerHealthMonitor(s.TargetName(), s.HealthCheck, s.Root) + s.Health = monitor.NewFileServerHealthMonitor(s.HealthCheck, s.Root) if err := s.Health.Start(s.task); err != nil { return err } diff --git a/internal/route/provider/agent.go b/internal/route/provider/agent.go new file mode 100644 index 0000000..be2736f --- /dev/null +++ b/internal/route/provider/agent.go @@ -0,0 +1,34 @@ +package provider + +import ( + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/agent/pkg/agent" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/watcher" +) + +type AgentProvider struct { + *agent.AgentConfig + docker ProviderImpl +} + +func (p *AgentProvider) ShortName() string { + return p.Name() +} + +func (p *AgentProvider) NewWatcher() watcher.Watcher { + return p.docker.NewWatcher() +} + +func (p *AgentProvider) IsExplicitOnly() bool { + return p.docker.IsExplicitOnly() +} + +func (p *AgentProvider) loadRoutesImpl() (route.Routes, E.Error) { + return p.docker.loadRoutesImpl() +} + +func (p *AgentProvider) Logger() *zerolog.Logger { + return p.docker.Logger() +} diff --git a/internal/route/provider/docker.go b/internal/route/provider/docker.go index 90c0c16..2bbc4b7 100755 --- a/internal/route/provider/docker.go +++ b/internal/route/provider/docker.go @@ -29,7 +29,7 @@ const ( var ErrAliasRefIndexOutOfRange = E.New("index out of range") -func DockerProviderImpl(name, dockerHost string) (ProviderImpl, error) { +func DockerProviderImpl(name, dockerHost string) ProviderImpl { if dockerHost == common.DockerHostFromEnv { dockerHost = common.GetEnvString("DOCKER_HOST", client.DefaultDockerHost) } @@ -37,7 +37,7 @@ func DockerProviderImpl(name, dockerHost string) (ProviderImpl, error) { name, dockerHost, logging.With().Str("type", "docker").Str("name", name).Logger(), - }, nil + } } func (p *DockerProvider) String() string { diff --git a/internal/route/provider/docker_test.go b/internal/route/provider/docker_test.go index 67fafa9..bff489e 100644 --- a/internal/route/provider/docker_test.go +++ b/internal/route/provider/docker_test.go @@ -258,16 +258,16 @@ func TestPublicIPLocalhost(t *testing.T) { c := &types.Container{Names: dummyNames, State: "running"} r, ok := makeRoutes(c)["a"] ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PublicIP, "127.0.0.1") - ExpectEqual(t, r.Host, r.Container.PublicIP) + ExpectEqual(t, r.Container.PublicHostname, "127.0.0.1") + ExpectEqual(t, r.Host, r.Container.PublicHostname) } func TestPublicIPRemote(t *testing.T) { c := &types.Container{Names: dummyNames, State: "running"} raw, ok := makeRoutes(c, testIP)["a"] ExpectTrue(t, ok) - ExpectEqual(t, raw.Container.PublicIP, testIP) - ExpectEqual(t, raw.Host, raw.Container.PublicIP) + ExpectEqual(t, raw.Container.PublicHostname, testIP) + ExpectEqual(t, raw.Host, raw.Container.PublicHostname) } func TestPrivateIPLocalhost(t *testing.T) { @@ -283,8 +283,8 @@ func TestPrivateIPLocalhost(t *testing.T) { } r, ok := makeRoutes(c)["a"] ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PrivateIP, testDockerIP) - ExpectEqual(t, r.Host, r.Container.PrivateIP) + ExpectEqual(t, r.Container.PrivateHostname, testDockerIP) + ExpectEqual(t, r.Host, r.Container.PrivateHostname) } func TestPrivateIPRemote(t *testing.T) { @@ -301,9 +301,9 @@ func TestPrivateIPRemote(t *testing.T) { } r, ok := makeRoutes(c, testIP)["a"] ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PrivateIP, "") - ExpectEqual(t, r.Container.PublicIP, testIP) - ExpectEqual(t, r.Host, r.Container.PublicIP) + ExpectEqual(t, r.Container.PrivateHostname, "") + ExpectEqual(t, r.Container.PublicHostname, testIP) + ExpectEqual(t, r.Host, r.Container.PublicHostname) } func TestStreamDefaultValues(t *testing.T) { diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 9e132d5..63bb1d3 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -1,7 +1,6 @@ package provider import ( - "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route/provider/types" @@ -38,25 +37,25 @@ func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) } } - if common.IsDebug { - eventsLog := E.NewBuilder("events") - for _, event := range events { - eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID) - } - E.LogDebug(eventsLog.About(), eventsLog.Error(), handler.provider.Logger()) + // if common.IsDebug { + // eventsLog := E.NewBuilder("events") + // for _, event := range events { + // eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID) + // } + // E.LogDebug(eventsLog.About(), eventsLog.Error(), handler.provider.Logger()) - oldRoutesLog := E.NewBuilder("old routes") - for k := range oldRoutes { - oldRoutesLog.Adds(k) - } - E.LogDebug(oldRoutesLog.About(), oldRoutesLog.Error(), handler.provider.Logger()) + // oldRoutesLog := E.NewBuilder("old routes") + // for k := range oldRoutes { + // oldRoutesLog.Adds(k) + // } + // E.LogDebug(oldRoutesLog.About(), oldRoutesLog.Error(), handler.provider.Logger()) - newRoutesLog := E.NewBuilder("new routes") - for k := range newRoutes { - newRoutesLog.Adds(k) - } - E.LogDebug(newRoutesLog.About(), newRoutesLog.Error(), handler.provider.Logger()) - } + // newRoutesLog := E.NewBuilder("new routes") + // for k := range newRoutes { + // newRoutesLog.Adds(k) + // } + // E.LogDebug(newRoutesLog.About(), newRoutesLog.Error(), handler.provider.Logger()) + // } for k, oldr := range oldRoutes { newr, ok := newRoutes[k] @@ -85,7 +84,7 @@ func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { switch handler.provider.GetType() { - case types.ProviderTypeDocker: + case types.ProviderTypeDocker, types.ProviderTypeAgent: return route.Container.ContainerID == event.ActorID || route.Container.ContainerName == event.ActorName case types.ProviderTypeFile: diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index 6e94b28..fda83d0 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -7,6 +7,7 @@ import ( "time" "github.com/rs/zerolog" + "github.com/yusing/go-proxy/agent/pkg/agent" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route/provider/types" @@ -64,14 +65,22 @@ func NewDockerProvider(name string, dockerHost string) (p *Provider, err error) } p = newProvider(types.ProviderTypeDocker) - p.ProviderImpl, err = DockerProviderImpl(name, dockerHost) - if err != nil { - return nil, err - } + p.ProviderImpl = DockerProviderImpl(name, dockerHost) p.watcher = p.NewWatcher() return } +func NewAgentProvider(cfg *agent.AgentConfig) *Provider { + p := newProvider(types.ProviderTypeAgent) + agent := &AgentProvider{ + AgentConfig: cfg, + docker: DockerProviderImpl(cfg.Name(), cfg.FakeDockerHost()), + } + p.ProviderImpl = agent + p.watcher = p.NewWatcher() + return p +} + func (p *Provider) GetType() types.ProviderType { return p.t } diff --git a/internal/route/provider/types/provider_type.go b/internal/route/provider/types/provider_type.go index 2907762..6447e39 100644 --- a/internal/route/provider/types/provider_type.go +++ b/internal/route/provider/types/provider_type.go @@ -5,4 +5,5 @@ type ProviderType string const ( ProviderTypeDocker ProviderType = "docker" ProviderTypeFile ProviderType = "file" + ProviderTypeAgent ProviderType = "agent" ) diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index a80b94b..b841b33 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -1,8 +1,11 @@ package route import ( + "crypto/tls" "net/http" + "github.com/yusing/go-proxy/agent/pkg/agent" + "github.com/yusing/go-proxy/agent/pkg/agentproxy" "github.com/yusing/go-proxy/internal/api/v1/favicon" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker" @@ -38,20 +41,27 @@ type ( // var globalMux = http.NewServeMux() // TODO: support regex subdomain matching. +// TODO: fix this for agent func NewReverseProxyRoute(base *Route) (*ReveseProxyRoute, E.Error) { - trans := gphttp.DefaultTransport httpConfig := base.HTTPConfig + proxyURL := base.ProxyURL - if httpConfig.NoTLSVerify { - trans = gphttp.DefaultTransportNoTLS - } - if httpConfig.ResponseHeaderTimeout > 0 { - trans = trans.Clone() - trans.ResponseHeaderTimeout = httpConfig.ResponseHeaderTimeout + trans := gphttp.NewTransport() + a := base.Agent() + if a != nil { + trans = a.Transport() + proxyURL = agent.HTTPProxyURL + } else { + if httpConfig.NoTLSVerify { + trans.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + if httpConfig.ResponseHeaderTimeout > 0 { + trans.ResponseHeaderTimeout = httpConfig.ResponseHeaderTimeout + } } service := base.TargetName() - rp := reverseproxy.NewReverseProxy(service, base.ProxyURL, trans) + rp := reverseproxy.NewReverseProxy(service, proxyURL, trans) if len(base.Middlewares) > 0 { err := middleware.PatchReverseProxy(rp, base.Middlewares) @@ -60,6 +70,20 @@ func NewReverseProxyRoute(base *Route) (*ReveseProxyRoute, E.Error) { } } + if a != nil { + headers := &agentproxy.AgentProxyHeaders{ + Host: base.ProxyURL.Host, + IsHTTPS: base.ProxyURL.Scheme == "https", + SkipTLSVerify: httpConfig.NoTLSVerify, + ResponseHeaderTimeout: int(httpConfig.ResponseHeaderTimeout.Seconds()), + } + ori := rp.HandlerFunc + rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + agentproxy.SetAgentProxyHeaders(r, headers) + ori(w, r) + } + } + r := &ReveseProxyRoute{ Route: base, rp: rp, @@ -88,13 +112,13 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) E.Error { if r.IsDocker() { client, err := docker.ConnectClient(r.Idlewatcher.DockerHost) if err == nil { - fallback := monitor.NewHTTPHealthChecker(r.rp.TargetURL, r.HealthCheck) + fallback := r.newHealthMonitor() r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.HealthCheck, fallback) r.task.OnCancel("close_docker_client", client.Close) } } if r.HealthMon == nil { - r.HealthMon = monitor.NewHTTPHealthMonitor(r.rp.TargetURL, r.HealthCheck) + r.HealthMon = r.newHealthMonitor() } } @@ -178,6 +202,17 @@ func (r *ReveseProxyRoute) HealthMonitor() health.HealthMonitor { return r.HealthMon } +func (r *ReveseProxyRoute) newHealthMonitor() interface { + health.HealthMonitor + health.HealthChecker +} { + if a := r.Agent(); a != nil { + target := monitor.AgentCheckHealthTargetFromURL(r.ProxyURL) + return monitor.NewAgentRouteMonitor(a, r.HealthCheck, target) + } + return monitor.NewHTTPHealthMonitor(r.ProxyURL, r.HealthCheck) +} + func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) { var lb *loadbalancer.LoadBalancer cfg := r.LoadBalance diff --git a/internal/route/route.go b/internal/route/route.go index 080598c..d6e3c00 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -5,6 +5,7 @@ import ( "strconv" "strings" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/docker" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" "github.com/yusing/go-proxy/internal/homepage" @@ -159,6 +160,17 @@ func (r *Route) Type() types.RouteType { panic(fmt.Errorf("unexpected scheme %s for alias %s", r.Scheme, r.Alias)) } +func (r *Route) Agent() *agent.AgentConfig { + if r.Container == nil { + return nil + } + return r.Container.Agent +} + +func (r *Route) IsAgent() bool { + return r.Container != nil && r.Container.Agent != nil +} + func (r *Route) HealthMonitor() health.HealthMonitor { return r.impl.HealthMonitor() } @@ -240,24 +252,24 @@ func (r *Route) Finalize() { switch { case !isDocker: r.Host = "localhost" - case cont.PrivateIP != "": - r.Host = cont.PrivateIP - case cont.PublicIP != "": - r.Host = cont.PublicIP + case cont.PrivateHostname != "": + r.Host = cont.PrivateHostname + case cont.PublicHostname != "": + r.Host = cont.PublicHostname } } lp, pp := r.Port.Listening, r.Port.Proxy if isDocker { - if port, ok := common.ServiceNamePortMapTCP[cont.ImageName]; ok { + if port, ok := common.ImageNamePortMapTCP[cont.ImageName]; ok { if pp == 0 { pp = port } if r.Scheme == "" { r.Scheme = "tcp" } - } else if port, ok := common.ImageNamePortMap[cont.ImageName]; ok { + } else if port, ok := common.ImageNamePortMapHTTP[cont.ImageName]; ok { if pp == 0 { pp = port } @@ -268,39 +280,34 @@ func (r *Route) Finalize() { } if pp == 0 { - switch { - case r.Scheme == "https": - pp = 443 - case !isDocker: - pp = 80 - default: + if isDocker { pp = lowestPort(cont.PrivatePortMapping) if pp == 0 { pp = lowestPort(cont.PublicPortMapping) } + } else if r.Scheme == "https" { + pp = 443 + } else { + pp = 80 } } if isDocker { // replace private port with public port if using public IP. - if r.Host == cont.PublicIP { + if r.Host == cont.PublicHostname { if p, ok := cont.PrivatePortMapping[pp]; ok { pp = int(p.PublicPort) + if r.Scheme == "" && p.Type == "udp" { + r.Scheme = "udp" + } } - } - // replace public port with private port if using private IP. - if r.Host == cont.PrivateIP { + } else { + // replace public port with private port if using private IP. if p, ok := cont.PublicPortMapping[pp]; ok { pp = int(p.PrivatePort) - } - } - - if r.Scheme == "" { - switch { - case r.Host == cont.PublicIP && cont.PublicPortMapping[pp].Type == "udp": - r.Scheme = "udp" - case r.Host == cont.PrivateIP && cont.PrivatePortMapping[pp].Type == "udp": - r.Scheme = "udp" + if r.Scheme == "" && p.Type == "udp" { + r.Scheme = "udp" + } } } } @@ -322,13 +329,10 @@ func (r *Route) Finalize() { r.HealthCheck = health.DefaultHealthConfig } + // set or keep at least default if !r.HealthCheck.Disable { - if r.HealthCheck.Interval == 0 { - r.HealthCheck.Interval = common.HealthCheckIntervalDefault - } - if r.HealthCheck.Timeout == 0 { - r.HealthCheck.Timeout = common.HealthCheckTimeoutDefault - } + r.HealthCheck.Interval |= common.HealthCheckIntervalDefault + r.HealthCheck.Timeout |= common.HealthCheckTimeoutDefault } if isDocker && cont.IdleTimeout != "" { diff --git a/internal/route/routes/routequery/query.go b/internal/route/routes/routequery/query.go index 33a2070..d48681d 100644 --- a/internal/route/routes/routequery/query.go +++ b/internal/route/routes/routequery/query.go @@ -125,7 +125,11 @@ func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter st if item.Category == "" { item.Category = "Docker" } - item.SourceType = string(provider.ProviderTypeDocker) + if r.IsAgent() { + item.SourceType = string(provider.ProviderTypeAgent) + } else { + item.SourceType = string(provider.ProviderTypeDocker) + } case r.UseLoadBalance(): if item.Category == "" { item.Category = "Load-balanced" diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index 978a9e9..6490987 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -164,7 +164,7 @@ var commands = map[string]struct { if target.Scheme == "" { target.Scheme = "http" } - rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport) + rp := reverseproxy.NewReverseProxy("", target, gphttp.NewTransport()) return ReturningCommand(rp.ServeHTTP) }, }, diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index a27fa1a..7c671ce 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -234,7 +234,8 @@ func TestOnCorrectness(t *testing.T) { tests = append(tests, genCorrectnessTestCases("header", func(k, v string) *http.Request { return &http.Request{ - Header: http.Header{k: []string{v}}} + Header: http.Header{k: []string{v}}, + } })...) tests = append(tests, genCorrectnessTestCases("query", func(k, v string) *http.Request { return &http.Request{ diff --git a/internal/route/types/route.go b/internal/route/types/route.go index 32fcda1..87a027d 100644 --- a/internal/route/types/route.go +++ b/internal/route/types/route.go @@ -3,6 +3,7 @@ package types import ( "net/http" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/docker" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" "github.com/yusing/go-proxy/internal/homepage" @@ -31,7 +32,10 @@ type ( HomepageConfig() *homepage.Item ContainerInfo() *docker.Container + Agent() *agent.AgentConfig + IsDocker() bool + IsAgent() bool UseLoadBalance() bool UseIdleWatcher() bool UseHealthCheck() bool diff --git a/internal/utils/fs.go b/internal/utils/fs.go index 4102901..2cc6224 100644 --- a/internal/utils/fs.go +++ b/internal/utils/fs.go @@ -34,3 +34,15 @@ func ListFiles(dir string, maxDepth int, hideHidden ...bool) ([]string, error) { } return files, nil } + +// FileExists checks if a file exists. +// +// If the file does not exist, it returns false and nil, +// otherwise it returns true and any error that is not os.ErrNotExist. +func FileExists(file string) (bool, error) { + _, err := os.Stat(file) + if os.IsNotExist(err) { + return false, nil + } + return true, err +} diff --git a/internal/utils/wait_exit.go b/internal/utils/wait_exit.go new file mode 100644 index 0000000..459472e --- /dev/null +++ b/internal/utils/wait_exit.go @@ -0,0 +1,25 @@ +package utils + +import ( + "os" + "os/signal" + "syscall" + "time" + + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/task" +) + +func WaitExit(shutdownTimeout int) { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT) + signal.Notify(sig, syscall.SIGTERM) + signal.Notify(sig, syscall.SIGHUP) + + // wait for signal + <-sig + + // gracefully shutdown + logging.Info().Msg("shutting down") + _ = task.GracefulShutdown(time.Second * time.Duration(shutdownTimeout)) +} diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go index cc81fe5..d8babfd 100644 --- a/internal/watcher/docker_watcher.go +++ b/internal/watcher/docker_watcher.go @@ -6,17 +6,13 @@ import ( docker_events "github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/filters" - "github.com/rs/zerolog" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/watcher/events" ) type ( DockerWatcher struct { - zerolog.Logger - host string client *D.SharedClient clientOwned bool @@ -56,20 +52,12 @@ func NewDockerWatcher(host string) DockerWatcher { return DockerWatcher{ host: host, clientOwned: true, - Logger: logging.With(). - Str("type", "docker"). - Str("host", host). - Logger(), } } func NewDockerWatcherWithClient(client *D.SharedClient) DockerWatcher { return DockerWatcher{ client: client, - Logger: logging.With(). - Str("type", "docker"). - Str("host", client.DaemonHost()). - Logger(), } } @@ -124,7 +112,6 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList case msg := <-cEventCh: action, ok := events.DockerEventMap[msg.Action] if !ok { - w.Debug().Msgf("ignored unknown docker event: %s for container %s", msg.Action, msg.Actor.Attributes["name"]) continue } event := Event{ diff --git a/internal/watcher/health/monitor/agent_route.go b/internal/watcher/health/monitor/agent_route.go new file mode 100644 index 0000000..c5d42e6 --- /dev/null +++ b/internal/watcher/health/monitor/agent_route.go @@ -0,0 +1,75 @@ +package monitor + +import ( + "encoding/json" + "errors" + "net/http" + "net/url" + + agentPkg "github.com/yusing/go-proxy/agent/pkg/agent" + "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/watcher/health" +) + +type ( + AgentRouteMonior struct { + agent *agentPkg.AgentConfig + endpointURL string + *monitor + } + AgentCheckHealthTarget struct { + Scheme string + Host string + Path string + } +) + +func AgentCheckHealthTargetFromURL(url *types.URL) *AgentCheckHealthTarget { + return &AgentCheckHealthTarget{ + Scheme: url.Scheme, + Host: url.Host, + Path: url.Path, + } +} + +func (target *AgentCheckHealthTarget) buildQuery() string { + query := make(url.Values, 3) + query.Set("scheme", target.Scheme) + query.Set("host", target.Host) + query.Set("path", target.Path) + return query.Encode() +} + +func (target *AgentCheckHealthTarget) displayURL() *types.URL { + return types.NewURL(&url.URL{ + Scheme: target.Scheme, + Host: target.Host, + Path: target.Path, + }) +} + +func NewAgentRouteMonitor(agent *agentPkg.AgentConfig, config *health.HealthCheckConfig, target *AgentCheckHealthTarget) *AgentRouteMonior { + mon := &AgentRouteMonior{ + agent: agent, + endpointURL: agentPkg.EndpointHealth + "?" + target.buildQuery(), + } + mon.monitor = newMonitor(target.displayURL(), config, mon.CheckHealth) + return mon +} + +func (mon *AgentRouteMonior) CheckHealth() (result *health.HealthCheckResult, err error) { + result = new(health.HealthCheckResult) + ctx, cancel := mon.ContextWithTimeout("timeout querying agent") + defer cancel() + data, status, err := mon.agent.Fetch(ctx, mon.endpointURL) + if err != nil { + return result, err + } + switch status { + case http.StatusOK: + err = json.Unmarshal(data, result) + default: + err = errors.New(string(data)) + } + return +} diff --git a/internal/watcher/health/monitor/fileserver.go b/internal/watcher/health/monitor/fileserver.go index e62e392..67a973b 100644 --- a/internal/watcher/health/monitor/fileserver.go +++ b/internal/watcher/health/monitor/fileserver.go @@ -12,10 +12,9 @@ type FileServerHealthMonitor struct { path string } -func NewFileServerHealthMonitor(alias string, config *health.HealthCheckConfig, path string) *FileServerHealthMonitor { +func NewFileServerHealthMonitor(config *health.HealthCheckConfig, path string) *FileServerHealthMonitor { mon := &FileServerHealthMonitor{path: path} mon.monitor = newMonitor(nil, config, mon.CheckHealth) - mon.service = alias return mon } diff --git a/internal/watcher/health/types.go b/internal/watcher/health/types.go index 3ced2c0..4c8c0c5 100644 --- a/internal/watcher/health/types.go +++ b/internal/watcher/health/types.go @@ -11,9 +11,9 @@ import ( type ( HealthCheckResult struct { - Healthy bool - Detail string - Latency time.Duration + Healthy bool `json:"healthy"` + Detail string `json:"detail"` + Latency time.Duration `json:"latency"` } WithHealthInfo interface { Status() Status diff --git a/pkg/args.go b/pkg/args.go new file mode 100644 index 0000000..941951b --- /dev/null +++ b/pkg/args.go @@ -0,0 +1,29 @@ +package pkg + +import ( + "flag" + "log" +) + +type ( + Args struct { + Command string + Args []string + } + CommandValidator interface { + IsCommandValid(cmd string) bool + } +) + +func GetArgs(validator CommandValidator) Args { + var args Args + flag.Parse() + args.Command = flag.Arg(0) + if !validator.IsCommandValid(args.Command) { + log.Fatalf("invalid command: %s", args.Command) + } + if len(flag.Args()) > 1 { + args.Args = flag.Args()[1:] + } + return args +}