diff --git a/agent.compose.yml b/agent.compose.yml deleted file mode 100644 index ebd83cb..0000000 --- a/agent.compose.yml +++ /dev/null @@ -1,15 +0,0 @@ -services: - godoxy-agent: - image: ghcr.io/yusing/godoxy-agent:latest - container_name: godoxy-agent - restart: always - network_mode: host # do not change this - environment: - AGENT_NAME: # defaults to hostname - AGENT_PORT: # defaults to 8890 - # comma separated list of allowed main server IPs or CIDRs - # to register from this agent - REGISTRATION_ALLOWED_HOSTS: - volumes: - - /var/run/docker.sock:/var/run/docker.sock - - ./certs:/app/certs # store Agent CA cert and Agent SSL cert diff --git a/agent/cmd/main.go b/agent/cmd/main.go index b2ddefd..a4c7a28 100644 --- a/agent/cmd/main.go +++ b/agent/cmd/main.go @@ -1,45 +1,45 @@ package main import ( - "fmt" + "os" "github.com/yusing/go-proxy/agent/pkg/agent" - "github.com/yusing/go-proxy/agent/pkg/certs" "github.com/yusing/go-proxy/agent/pkg/env" "github.com/yusing/go-proxy/agent/pkg/server" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/pkg" - "gopkg.in/yaml.v3" ) -func printNewClientHelp() { - ip, ok := agent.MachineIP() - if !ok { - logging.Warn().Msg("No valid network interface found, change to your actual IP") - ip = "" - } else { - logging.Info().Msgf("Detected machine IP: %s, change if needed", ip) - } - - host := fmt.Sprintf("%s:%d", ip, env.AgentPort) - cfgYAML, _ := yaml.Marshal(map[string]any{ - "providers": map[string]any{ - "agents": host, - }, - }) - - logging.Info().Msgf("On main server, run:\n\ndocker exec godoxy /app/run add-agent '%s'\n", host) - logging.Info().Msgf("Then add this host (%s) to main server config like below:\n", host) - logging.Info().Msg(string(cfgYAML)) -} - func main() { - ca, srv, isNew, err := certs.InitCerts() + args := os.Args + if len(args) > 1 && args[1] == "migrate" { + if err := agent.MigrateFromOld(); err != nil { + E.LogFatal("failed to migrate from old docker compose", err) + } + return + } + _ = os.Chmod("/app/compose.yml", 0600) + ca := &agent.PEMPair{} + err := ca.Load(env.AgentCACert) if err != nil { E.LogFatal("init CA error", err) } + caCert, err := ca.ToTLSCert() + if err != nil { + E.LogFatal("init CA error", err) + } + + srv := &agent.PEMPair{} + srv.Load(env.AgentSSLCert) + if err != nil { + E.LogFatal("init SSL error", err) + } + srvCert, err := srv.ToTLSCert() + if err != nil { + E.LogFatal("init SSL error", err) + } logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion()) logging.Info().Msgf("Agent name: %s", env.AgentName) @@ -49,23 +49,15 @@ func main() { Tips: 1. To change the agent name, you can set the AGENT_NAME environment variable. 2. To change the agent port, you can set the AGENT_PORT environment variable. -3. To skip the version check, you can set AGENT_SKIP_VERSION_CHECK to true. -4. If anything goes wrong, you can remove the 'certs' directory and start over. `) t := task.RootTask("agent", false) opts := server.Options{ - CACert: ca, - ServerCert: srv, + CACert: caCert, + ServerCert: srvCert, Port: env.AgentPort, } - if isNew { - logging.Info().Msg("Initialization complete.") - printNewClientHelp() - server.StartRegistrationServer(t, opts) - } - server.StartAgentServer(t, opts) task.WaitExit(3) diff --git a/agent/pkg/agent/agent.compose.yml b/agent/pkg/agent/agent.compose.yml new file mode 100644 index 0000000..c90e25e --- /dev/null +++ b/agent/pkg/agent/agent.compose.yml @@ -0,0 +1,14 @@ +services: + agent: + image: "{{.Image}}" + container_name: godoxy-agent + restart: always + network_mode: host # do not change this + environment: + AGENT_NAME: "{{.Name}}" + AGENT_PORT: "{{.Port}}" + AGENT_CA_CERT: "{{.CACert}}" + AGENT_SSL_CERT: "{{.SSLCert}}" + volumes: + - /var/run/docker.sock:/var/run/docker.sock + - ./compose.yml:/app/compose.yml diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go index 79d8004..7bcdc52 100644 --- a/agent/pkg/agent/config.go +++ b/agent/pkg/agent/config.go @@ -38,7 +38,7 @@ const ( EndpointLogs = "/logs" EndpointSystemInfo = "/system_info" - AgentHost = certs.CertsDNSName + AgentHost = CertsDNSName APIEndpointBase = "/godoxy/agent" APIBaseURL = "https://" + AgentHost + APIEndpointBase @@ -80,20 +80,7 @@ func checkVersion(a, b string) bool { return withoutBuildTime(a) == withoutBuildTime(b) } -func (cfg *AgentConfig) Start(parent task.Parent) E.Error { - certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr)) - if err != nil { - if os.IsNotExist(err) { - return E.Errorf("agents certs not found, did you run `godoxy new-agent %s ...`?", cfg.Addr) - } - return E.Wrap(err) - } - - ca, crt, key, err := certs.ExtractCert(certData) - if err != nil { - return E.Wrap(err) - } - +func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte) E.Error { clientCert, err := tls.X509KeyPair(crt, key) if err != nil { return E.Wrap(err) @@ -109,6 +96,7 @@ func (cfg *AgentConfig) Start(parent task.Parent) E.Error { cfg.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{clientCert}, RootCAs: caCertPool, + ServerName: CertsDNSName, } // create transport and http client @@ -140,6 +128,23 @@ func (cfg *AgentConfig) Start(parent task.Parent) E.Error { return nil } +func (cfg *AgentConfig) Start(parent task.Parent) E.Error { + certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr)) + if err != nil { + if os.IsNotExist(err) { + return E.Errorf("agents certs not found, did you run `godoxy new-agent %s ...`?", cfg.Addr) + } + return E.Wrap(err) + } + + ca, crt, key, err := certs.ExtractCert(certData) + if err != nil { + return E.Wrap(err) + } + + return cfg.StartWithCerts(parent, ca, crt, key) +} + func (cfg *AgentConfig) NewHTTPClient() *http.Client { return &http.Client{ Transport: cfg.Transport(), diff --git a/agent/pkg/agent/docker_compose.go b/agent/pkg/agent/docker_compose.go new file mode 100644 index 0000000..503df65 --- /dev/null +++ b/agent/pkg/agent/docker_compose.go @@ -0,0 +1,123 @@ +package agent + +import ( + "bytes" + "os" + "path" + "strconv" + "text/template" + + _ "embed" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/utils" + "gopkg.in/yaml.v3" +) + +//go:embed agent.compose.yml +var agentComposeYAML []byte +var agentComposeYAMLTemplate = template.Must(template.New("agent.compose.yml").Parse(string(agentComposeYAML))) + +const ( + DockerImageProduction = "ghcr.io/yusing/godoxy-agent:latest" + DockerImageNightly = "yusing/godoxy-agent-nightly:latest" +) + +type ( + AgentComposeConfig struct { + Image string + Name string + Port int + CACert string + SSLCert string + } +) + +func (c *AgentComposeConfig) Generate() (string, error) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + if err := agentComposeYAMLTemplate.Execute(buf, c); err != nil { + return "", err + } + return buf.String(), nil +} + +func pemPairFromFile(path string) (*PEMPair, error) { + cert, err := os.ReadFile(path + ".crt") + if err != nil { + return nil, err + } + key, err := os.ReadFile(path + ".key") + if err != nil { + return nil, err + } + return &PEMPair{ + Cert: cert, + Key: key, + }, nil +} + +func rmOldCerts(p string) error { + files, err := utils.ListFiles(p, 0) + if err != nil { + return err + } + for _, file := range files { + if err := os.Remove(path.Join(p, file)); err != nil { + return err + } + } + return nil +} + +type dockerCompose struct { + Services struct { + GodoxyAgent struct { + Environment struct { + AGENT_NAME string `yaml:"GODOXY_AGENT_NAME"` + AGENT_PORT string `yaml:"GODOXY_AGENT_PORT"` + } `yaml:"environment"` + } `yaml:"godoxy-agent"` + } `yaml:"services"` +} + +// TODO: remove this +func MigrateFromOld() error { + oldCompose, err := os.ReadFile("/app/compose.yml") + if err != nil { + return err + } + var compose dockerCompose + if err := yaml.Unmarshal(oldCompose, &compose); err != nil { + return err + } + ca, err := pemPairFromFile("/app/certs/ca") + if err != nil { + return err + } + agentCert, err := pemPairFromFile("/app/certs/agent") + if err != nil { + return err + } + var composeConfig AgentComposeConfig + composeConfig.Image = DockerImageNightly + composeConfig.Name = compose.Services.GodoxyAgent.Environment.AGENT_NAME + composeConfig.Port, err = strconv.Atoi(compose.Services.GodoxyAgent.Environment.AGENT_PORT) // ignore error, empty is fine + if composeConfig.Port == 0 { + composeConfig.Port = 8890 + } + composeConfig.CACert = ca.String() + composeConfig.SSLCert = agentCert.String() + composeTemplate, err := composeConfig.Generate() + if err != nil { + return E.Wrap(err, "failed to generate new docker compose") + } + + if err := os.WriteFile("/app/compose.yml", []byte(composeTemplate), 0600); err != nil { + return E.Wrap(err, "failed to write new docker compose") + } + + logging.Info().Msg("Migrated from old docker compose:") + logging.Info().Msg(composeTemplate) + return nil +} diff --git a/agent/pkg/agent/new_agent.go b/agent/pkg/agent/new_agent.go new file mode 100644 index 0000000..7d75328 --- /dev/null +++ b/agent/pkg/agent/new_agent.go @@ -0,0 +1,139 @@ +package agent + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "errors" + "math/big" + "strings" + "time" +) + +const ( + CertsDNSName = "godoxy.agent" + KeySize = 2048 +) + +func toPEMPair(certDER []byte, key *rsa.PrivateKey) *PEMPair { + return &PEMPair{ + Cert: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), + Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), + } +} + +func b64Encode(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +func b64Decode(data string) ([]byte, error) { + return base64.StdEncoding.DecodeString(data) +} + +type PEMPair struct { + Cert, Key []byte +} + +func (p *PEMPair) String() string { + return b64Encode(p.Cert) + ";" + b64Encode(p.Key) +} + +func (p *PEMPair) Load(data string) (err error) { + parts := strings.Split(data, ";") + if len(parts) != 2 { + return errors.New("invalid PEM pair") + } + p.Cert, err = b64Decode(parts[0]) + if err != nil { + return err + } + p.Key, err = b64Decode(parts[1]) + if err != nil { + return err + } + return nil +} + +func (p *PEMPair) ToTLSCert() (*tls.Certificate, error) { + cert, err := tls.X509KeyPair(p.Cert, p.Key) + return &cert, err +} + +func NewAgent() (ca, srv, client *PEMPair, err error) { + // Create the CA's certificate + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"GoDoxy"}, + CommonName: CertsDNSName, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caKey, err := rsa.GenerateKey(rand.Reader, KeySize) + if err != nil { + return nil, nil, nil, err + } + + caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return nil, nil, nil, err + } + + ca = toPEMPair(caDER, caKey) + + // Generate a new private key for the server certificate + serverKey, err := rsa.GenerateKey(rand.Reader, KeySize) + if err != nil { + return nil, nil, nil, err + } + + srvTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Issuer: caTemplate.Subject, + Subject: caTemplate.Subject, + DNSNames: []string{CertsDNSName}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1000, 0, 0), // Add validity period + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + srvCertDER, err := x509.CreateCertificate(rand.Reader, srvTemplate, caTemplate, &serverKey.PublicKey, caKey) + if err != nil { + return nil, nil, nil, err + } + + srv = toPEMPair(srvCertDER, serverKey) + + clientKey, err := rsa.GenerateKey(rand.Reader, KeySize) + if err != nil { + return nil, nil, nil, err + } + + clientTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(3), + Issuer: caTemplate.Subject, + Subject: caTemplate.Subject, + DNSNames: []string{CertsDNSName}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1000, 0, 0), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + clientCertDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey) + if err != nil { + return nil, nil, nil, err + } + + client = toPEMPair(clientCertDER, clientKey) + return +} diff --git a/agent/pkg/agent/new_agent_test.go b/agent/pkg/agent/new_agent_test.go new file mode 100644 index 0000000..92537ed --- /dev/null +++ b/agent/pkg/agent/new_agent_test.go @@ -0,0 +1,91 @@ +package agent + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestNewAgent(t *testing.T) { + ca, srv, client, err := NewAgent() + ExpectNoError(t, err) + ExpectTrue(t, ca != nil) + ExpectTrue(t, srv != nil) + ExpectTrue(t, client != nil) +} + +func TestPEMPair(t *testing.T) { + ca, srv, client, err := NewAgent() + ExpectNoError(t, err) + + for i, p := range []*PEMPair{ca, srv, client} { + t.Run(fmt.Sprintf("load-%d", i), func(t *testing.T) { + var pp PEMPair + err := pp.Load(p.String()) + ExpectNoError(t, err) + ExpectBytesEqual(t, p.Cert, pp.Cert) + ExpectBytesEqual(t, p.Key, pp.Key) + }) + } +} + +func TestPEMPairToTLSCert(t *testing.T) { + ca, srv, client, err := NewAgent() + ExpectNoError(t, err) + + for i, p := range []*PEMPair{ca, srv, client} { + t.Run(fmt.Sprintf("toTLSCert-%d", i), func(t *testing.T) { + cert, err := p.ToTLSCert() + ExpectNoError(t, err) + ExpectTrue(t, cert != nil) + }) + } +} + +func TestServerClient(t *testing.T) { + ca, srv, client, err := NewAgent() + ExpectNoError(t, err) + + srvTLS, err := srv.ToTLSCert() + ExpectNoError(t, err) + ExpectTrue(t, srvTLS != nil) + + clientTLS, err := client.ToTLSCert() + ExpectNoError(t, err) + ExpectTrue(t, clientTLS != nil) + + caPool := x509.NewCertPool() + ExpectTrue(t, caPool.AppendCertsFromPEM(ca.Cert)) + + srvTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{*srvTLS}, + ClientCAs: caPool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + clientTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{*clientTLS}, + RootCAs: caPool, + ServerName: CertsDNSName, + } + + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + server.TLS = srvTLSConfig + server.StartTLS() + defer server.Close() + + httpClient := &http.Client{ + Transport: &http.Transport{TLSClientConfig: clientTLSConfig}, + } + + resp, err := httpClient.Get(server.URL) + ExpectNoError(t, err) + ExpectEqual(t, resp.StatusCode, http.StatusOK) +} diff --git a/agent/pkg/agent/utils.go b/agent/pkg/agent/utils.go deleted file mode 100644 index d6f0305..0000000 --- a/agent/pkg/agent/utils.go +++ /dev/null @@ -1,30 +0,0 @@ -package agent - -import ( - "net" - "strings" -) - -func MachineIP() (string, bool) { - interfaces, err := net.Interfaces() - if err != nil { - interfaces = []net.Interface{} - } - for _, in := range interfaces { - addrs, err := in.Addrs() - if err != nil { - continue - } - if !strings.HasPrefix(in.Name, "eth") && !strings.HasPrefix(in.Name, "en") { - continue - } - for _, addr := range addrs { - if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { - if ipnet.IP.To4() != nil { - return ipnet.IP.String(), true - } - } - } - } - return "", false -} diff --git a/agent/pkg/certs/certs.go b/agent/pkg/certs/certs.go deleted file mode 100644 index 9051274..0000000 --- a/agent/pkg/certs/certs.go +++ /dev/null @@ -1,201 +0,0 @@ -package certs - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "errors" - "math/big" - "os" - "time" - - E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/utils" -) - -const ( - CertsDNSName = "godoxy.agent" - - caCertPath = "certs/ca.crt" - caKeyPath = "certs/ca.key" - srvCertPath = "certs/agent.crt" - srvKeyPath = "certs/agent.key" -) - -func loadCerts(certPath, keyPath string) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(certPath, keyPath) - return &cert, err -} - -func write(b []byte, f *os.File) error { - _, err := f.Write(b) - return err -} - -func saveCerts(certDER []byte, key *rsa.PrivateKey, certPath, keyPath string) ([]byte, []byte, error) { - certPEM, keyPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), - pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - - if certPath == "" || keyPath == "" { - return certPEM, keyPEM, nil - } - - certFile, err := os.Create(certPath) - if err != nil { - return nil, nil, err - } - defer certFile.Close() - - keyFile, err := os.Create(keyPath) - if err != nil { - return nil, nil, err - } - defer keyFile.Close() - - return certPEM, keyPEM, errors.Join( - write(certPEM, certFile), - write(keyPEM, keyFile), - ) -} - -func checkExists(certPath, keyPath string) bool { - certExists, err := utils.FileExists(certPath) - if err != nil { - E.LogFatal("cert error", err) - } - keyExists, err := utils.FileExists(keyPath) - if err != nil { - E.LogFatal("key error", err) - } - return certExists && keyExists -} - -func InitCerts() (ca *tls.Certificate, srv *tls.Certificate, isNew bool, err error) { - if checkExists(caCertPath, caKeyPath) && checkExists(srvCertPath, srvKeyPath) { - logging.Info().Msg("Loading existing certs...") - ca, err = loadCerts(caCertPath, caKeyPath) - if err != nil { - return nil, nil, false, err - } - srv, err = loadCerts(srvCertPath, srvKeyPath) - if err != nil { - return nil, nil, false, err - } - - logging.Info().Msg("Verifying agent cert...") - - roots := x509.NewCertPool() - roots.AddCert(ca.Leaf) - - srvCert, err := x509.ParseCertificate(srv.Certificate[0]) - if err != nil { - return nil, nil, false, err - } - - // check if srv is signed by ca - if _, err := srvCert.Verify(x509.VerifyOptions{ - Roots: roots, - }); err == nil { - logging.Info().Msg("OK") - return ca, srv, false, nil - } - logging.Error().Msg("Agent cert and CA cert mismatch, regenerating") - } - - // Create the CA's certificate - caTemplate := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"GoDoxy"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - } - - caKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, false, err - } - - caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey) - if err != nil { - return nil, nil, false, err - } - - certPEM, keyPEM, err := saveCerts(caCertDER, caKey, caCertPath, caKeyPath) - if err != nil { - return nil, nil, false, err - } - - caCert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - return nil, nil, false, err - } - - ca = &caCert - - // Generate a new private key for the server certificate - serverKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, false, err - } - - srvTemplate := caTemplate - srvTemplate.Issuer = srvTemplate.Subject - srvTemplate.DNSNames = append(srvTemplate.DNSNames, CertsDNSName) - - srvCertDER, err := x509.CreateCertificate(rand.Reader, &srvTemplate, &caTemplate, &serverKey.PublicKey, caKey) - if err != nil { - return nil, nil, false, err - } - - certPEM, keyPEM, err = saveCerts(srvCertDER, serverKey, srvCertPath, srvKeyPath) - if err != nil { - return nil, nil, false, err - } - - agentCert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - return nil, nil, false, err - } - - srv = &agentCert - - return ca, srv, true, nil -} - -func NewClientCert(ca *tls.Certificate) ([]byte, []byte, error) { - // Generate the SSL's private key - sslKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - - // Create the SSL's certificate - sslTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(2), - Subject: pkix.Name{ - Organization: []string{"GoDoxy"}, - CommonName: CertsDNSName, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years - KeyUsage: x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - } - - // Sign the certificate with the CA - sslCertDER, err := x509.CreateCertificate(rand.Reader, sslTemplate, ca.Leaf, &sslKey.PublicKey, ca.PrivateKey) - if err != nil { - return nil, nil, err - } - - return saveCerts(sslCertDER, sslKey, "", "") -} diff --git a/agent/pkg/certs/zip.go b/agent/pkg/certs/zip.go index 51faa2b..9349499 100644 --- a/agent/pkg/certs/zip.go +++ b/agent/pkg/certs/zip.go @@ -12,7 +12,7 @@ import ( func writeFile(zipWriter *zip.Writer, name string, data []byte) error { w, err := zipWriter.CreateHeader(&zip.FileHeader{ Name: name, - Method: zip.Deflate, + Method: zip.Store, }) if err != nil { return err @@ -31,8 +31,7 @@ func readFile(f *zip.File) ([]byte, error) { } func ZipCert(ca, crt, key []byte) ([]byte, error) { - data := bytes.NewBuffer(nil) - data.Grow(6144) + data := bytes.NewBuffer(make([]byte, 0, 6144)) zipWriter := zip.NewWriter(data) defer zipWriter.Close() diff --git a/agent/pkg/env/env.go b/agent/pkg/env/env.go index 70c1fd8..5342209 100644 --- a/agent/pkg/env/env.go +++ b/agent/pkg/env/env.go @@ -1,11 +1,7 @@ package env import ( - "log" - "net" "os" - "strings" - "sync" "github.com/yusing/go-proxy/internal/common" ) @@ -24,54 +20,6 @@ var ( AgentRegistrationPort = common.GetEnvInt("AGENT_REGISTRATION_PORT", 8891) AgentSkipClientCertCheck = common.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false) - RegistrationAllowedHosts = common.GetCommaSepEnv("REGISTRATION_ALLOWED_HOSTS", "") - RegistrationAllowedCIDRs []*net.IPNet + AgentCACert = common.GetEnvString("AGENT_CA_CERT", "") + AgentSSLCert = common.GetEnvString("AGENT_SSL_CERT", "") ) - -func init() { - cidrs, err := toCIDRs(RegistrationAllowedHosts) - if err != nil { - log.Fatalf("failed to parse allowed hosts: %v", err) - } - RegistrationAllowedCIDRs = cidrs -} - -func toCIDRs(hosts []string) ([]*net.IPNet, error) { - cidrs := make([]*net.IPNet, 0, len(hosts)) - for _, host := range hosts { - if !strings.Contains(host, "/") { - host += "/32" - } - _, cidr, err := net.ParseCIDR(host) - if err != nil { - return nil, err - } - cidrs = append(cidrs, cidr) - } - return cidrs, nil -} - -var warnOnce sync.Once - -func IsAllowedHost(remoteAddr string) bool { - if len(RegistrationAllowedCIDRs) == 0 { - warnOnce.Do(func() { - log.Println("Warning: REGISTRATION_ALLOWED_HOSTS is empty, allowing all hosts") - }) - return true - } - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - ip = remoteAddr - } - netIP := net.ParseIP(ip) - if netIP == nil { - return false - } - for _, cidr := range RegistrationAllowedCIDRs { - if cidr.Contains(netIP) { - return true - } - } - return false -} diff --git a/agent/pkg/handler/handler.go b/agent/pkg/handler/handler.go index cb2af9d..8d661ce 100644 --- a/agent/pkg/handler/handler.go +++ b/agent/pkg/handler/handler.go @@ -1,22 +1,15 @@ package handler import ( - "crypto/tls" - "encoding/pem" "fmt" "io" "net/http" "github.com/yusing/go-proxy/agent/pkg/agent" - "github.com/yusing/go-proxy/agent/pkg/certs" "github.com/yusing/go-proxy/agent/pkg/env" v1 "github.com/yusing/go-proxy/internal/api/v1" - "github.com/yusing/go-proxy/internal/api/v1/utils" - E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/metrics/systeminfo" - "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -54,46 +47,3 @@ func NewAgentHandler() http.Handler { mux.ServeMux.HandleFunc("/", DockerSocketHandler()) return mux } - -// NewRegistrationHandler creates a new registration handler -// It checks if the request is coming from an allowed host -// Generates a new client certificate and zips it -// Sends the zipped certificate to the client -// its run only once on agent first start. -func NewRegistrationHandler(task *task.Task, ca *tls.Certificate) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if !env.IsAllowedHost(r.RemoteAddr) { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - - if r.URL.Path == "/done" { - logging.Info().Msg("registration done") - task.Finish(nil) - w.WriteHeader(http.StatusOK) - return - } - - logging.Info().Msgf("received registration request from %s", r.RemoteAddr) - - caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Certificate[0]}) - - crt, key, err := certs.NewClientCert(ca) - if err != nil { - utils.HandleErr(w, r, E.Wrap(err, "failed to generate client certificate")) - return - } - - zipped, err := certs.ZipCert(caPEM, crt, key) - if err != nil { - utils.HandleErr(w, r, E.Wrap(err, "failed to zip certificate")) - return - } - - w.Header().Set("Content-Type", "application/zip") - if _, err := w.Write(zipped); err != nil { - logging.Error().Err(err).Msg("failed to respond to registration request") - return - } - } -} diff --git a/agent/pkg/server/server.go b/agent/pkg/server/server.go index 3409f67..20a2fc3 100644 --- a/agent/pkg/server/server.go +++ b/agent/pkg/server/server.go @@ -77,32 +77,3 @@ func StartAgentServer(parent task.Parent, opt Options) { } }() } - -func StartRegistrationServer(parent task.Parent, opt Options) { - t := parent.Subtask("registration_server") - - logger := logging.GetLogger() - registrationServer := &http.Server{ - Addr: fmt.Sprintf(":%d", opt.Port), - Handler: handler.NewRegistrationHandler(t, opt.CACert), - ErrorLog: log.New(logger, "", 0), - } - - go func() { - err := registrationServer.ListenAndServe() - server.HandleError(logger, err, "failed to serve registration server") - }() - - logging.Info().Int("port", opt.Port).Msg("registration server started") - - defer t.Finish(nil) - <-t.Context().Done() - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - err := registrationServer.Shutdown(ctx) - server.HandleError(logger, err, "failed to shutdown registration server") - - logging.Info().Int("port", opt.Port).Msg("registration server stopped") -} diff --git a/internal/api/handler.go b/internal/api/handler.go index 1683d4a..6ec1e47 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -77,7 +77,9 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true) mux.HandleFunc("GET", "/v1/favicon", favicon.GetFavIcon, true) mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true) - mux.HandleFunc("GET", "/v1/agents", v1.AgentsWS, true) + mux.HandleFunc("GET", "/v1/agents", v1.ListAgents, true) + mux.HandleFunc("GET", "/v1/agents/new", v1.NewAgent, true) + mux.HandleFunc("POST", "/v1/agents/add", v1.AddAgent, true) mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true) mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true) diff --git a/internal/api/v1/agents.go b/internal/api/v1/agents.go index 4c39d12..d722bee 100644 --- a/internal/api/v1/agents.go +++ b/internal/api/v1/agents.go @@ -11,7 +11,7 @@ import ( "github.com/yusing/go-proxy/internal/net/http/httpheaders" ) -func AgentsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { +func ListAgents(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { if httpheaders.IsWebsocket(r.Header) { U.PeriodicWS(w, r, 10*time.Second, func(conn *websocket.Conn) error { wsjson.Write(r.Context(), conn, cfg.ListAgents()) diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index 751f1e8..2e861c1 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -28,7 +28,6 @@ const ( ListHomepageCategories = "homepage_categories" ListIcons = "icons" ListTasks = "tasks" - ListAgents = "agents" ) func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { @@ -78,8 +77,6 @@ func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { U.RespondJSON(w, r, icons) case ListTasks: U.RespondJSON(w, r, task.DebugTaskList()) - case ListAgents: - U.RespondJSON(w, r, cfg.ListAgents()) default: U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest) } diff --git a/internal/api/v1/new_agent.go b/internal/api/v1/new_agent.go new file mode 100644 index 0000000..c328bdd --- /dev/null +++ b/internal/api/v1/new_agent.go @@ -0,0 +1,135 @@ +package v1 + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strconv" + + _ "embed" + + "github.com/yusing/go-proxy/agent/pkg/agent" + "github.com/yusing/go-proxy/agent/pkg/certs" + U "github.com/yusing/go-proxy/internal/api/v1/utils" + config "github.com/yusing/go-proxy/internal/config/types" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +func NewAgent(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + name := q.Get("name") + if name == "" { + U.RespondError(w, U.ErrMissingKey("name")) + return + } + host := q.Get("host") + if host == "" { + U.RespondError(w, U.ErrMissingKey("host")) + return + } + portStr := q.Get("port") + if portStr == "" { + U.RespondError(w, U.ErrMissingKey("port")) + return + } + port, err := strconv.Atoi(portStr) + if err != nil || port < 1 || port > 65535 { + U.RespondError(w, U.ErrInvalidKey("port")) + return + } + hostport := fmt.Sprintf("%s:%d", host, port) + if _, ok := config.GetInstance().GetAgent(hostport); ok { + U.RespondError(w, U.ErrAlreadyExists("agent", hostport), http.StatusConflict) + return + } + t := q.Get("type") + switch t { + case "docker": + break + case "system": + U.RespondError(w, U.Errorf("system agent is not supported yet"), http.StatusNotImplemented) + return + case "": + U.RespondError(w, U.ErrMissingKey("type")) + return + default: + U.RespondError(w, U.ErrInvalidKey("type")) + return + } + + nightly := strutils.ParseBool(q.Get("nightly")) + var image string + if nightly { + image = agent.DockerImageNightly + } else { + image = agent.DockerImageProduction + } + + ca, srv, client, err := agent.NewAgent() + if err != nil { + U.HandleErr(w, r, err) + return + } + + cfg := agent.AgentComposeConfig{ + Image: image, + Name: name, + Port: port, + CACert: ca.String(), + SSLCert: srv.String(), + } + + template, err := cfg.Generate() + if err != nil { + U.HandleErr(w, r, err) + return + } + + U.RespondJSON(w, r, map[string]any{ + "compose": template, + "ca": ca, + "client": client, + }) +} + +func AddAgent(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + clientPEMData, err := io.ReadAll(r.Body) + if err != nil { + U.HandleErr(w, r, err) + return + } + + var data struct { + Host string `json:"host"` + CA agent.PEMPair `json:"ca"` + Client agent.PEMPair `json:"client"` + } + + if err := json.Unmarshal(clientPEMData, &data); err != nil { + U.RespondError(w, err, http.StatusBadRequest) + return + } + + nRoutesAdded, err := config.GetInstance().AddAgent(data.Host, data.CA, data.Client) + if err != nil { + U.RespondError(w, err) + return + } + + zip, err := certs.ZipCert(data.CA.Cert, data.Client.Cert, data.Client.Key) + if err != nil { + U.HandleErr(w, r, err) + return + } + + if err := os.WriteFile(certs.AgentCertsFilename(data.Host), zip, 0600); err != nil { + U.HandleErr(w, r, err) + return + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf("Added %d routes", nRoutesAdded))) +} diff --git a/internal/api/v1/utils/error.go b/internal/api/v1/utils/error.go index 31c0669..422f1c9 100644 --- a/internal/api/v1/utils/error.go +++ b/internal/api/v1/utils/error.go @@ -2,7 +2,6 @@ package utils import ( "context" - "encoding/json" "errors" "net/http" "syscall" @@ -42,25 +41,26 @@ func RespondError(w http.ResponseWriter, err error, code ...int) { if len(code) == 0 { code = []int{http.StatusBadRequest} } - buf, err := json.Marshal(err) - if err != nil { // just in case - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - http.Error(w, ansi.StripANSI(err.Error()), code[0]) - return - } - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(code[0]) - _, _ = w.Write(buf) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + http.Error(w, ansi.StripANSI(err.Error()), code[0]) +} + +func Errorf(format string, args ...any) error { + return E.Errorf(format, args...) } func ErrMissingKey(k string) error { - return E.New("missing key '" + k + "' in query or request body") + return E.New(k + " is required") } func ErrInvalidKey(k string) error { - return E.New("invalid key '" + k + "' in query or request body") + return E.New(k + " is invalid") +} + +func ErrAlreadyExists(k, v string) error { + return E.Errorf("%s %q already exists", k, v) } func ErrNotFound(k, v string) error { - return E.Errorf("key %q with value %q not found", k, v) + return E.Errorf("%s %q not found", k, v) } diff --git a/internal/config/agent_pool.go b/internal/config/agent_pool.go index 9325803..5088b37 100644 --- a/internal/config/agent_pool.go +++ b/internal/config/agent_pool.go @@ -2,6 +2,8 @@ package config import ( "github.com/yusing/go-proxy/agent/pkg/agent" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -27,6 +29,22 @@ func (cfg *Config) GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, b return GetAgent(agent.GetAgentAddrFromDockerHost(agentAddrOrDockerHost)) } +func (cfg *Config) AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, E.Error) { + var agentCfg agent.AgentConfig + agentCfg.Addr = host + err := agentCfg.StartWithCerts(cfg.Task(), ca.Cert, client.Cert, client.Key) + if err != nil { + return 0, err + } + + provider := provider.NewAgentProvider(&agentCfg) + if err := cfg.errIfExists(provider); err != nil { + return 0, err + } + cfg.storeProvider(provider) + return provider.NumRoutes(), nil +} + func (cfg *Config) ListAgents() []*agent.AgentConfig { agents := make([]*agent.AgentConfig, 0, agentPool.Size()) agentPool.RangeAll(func(key string, value *agent.AgentConfig) { diff --git a/internal/config/types/config.go b/internal/config/types/config.go index e41dd4b..7788d78 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -42,6 +42,7 @@ type ( RouteProviderList() []string Context() context.Context GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool) + AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, E.Error) ListAgents() []*agent.AgentConfig } ) diff --git a/internal/net/http/server/server.go b/internal/net/http/server/server.go index 40c39c9..752cec7 100644 --- a/internal/net/http/server/server.go +++ b/internal/net/http/server/server.go @@ -130,7 +130,7 @@ func (s *Server) stop() { return } - ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second) + ctx, cancel := context.WithTimeout(task.RootContext(), 5*time.Second) defer cancel() if s.http != nil && s.httpStarted {